#!/usr/bin/env python
"""
Скрипт для оценки качества различных стратегий чанкинга.
Сравнивает стратегии на основе релевантности чанков к вопросам.
"""

import argparse
import json
import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from fuzzywuzzy import fuzz
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

# Константы для настройки
DATA_FOLDER = "data/docs"                   # Путь к папке с документами
MODEL_NAME = "intfloat/e5-base"             # Название модели для векторизации
DATASET_PATH = "data/dataset.xlsx"          # Путь к Excel-датасету с вопросами
BATCH_SIZE = 8                              # Размер батча для векторизации
DEVICE = "cuda:1" if torch.cuda.is_available() else "cpu"  # Устройство для вычислений
SIMILARITY_THRESHOLD = 0.7                  # Порог для нечеткого сравнения
OUTPUT_DIR = "data"                         # Директория для сохранения результатов
TOP_CHUNKS_DIR = "data/top_chunks"          # Директория для сохранения топ-чанков
TOP_N_VALUES = [5, 10, 20, 30, 50, 70, 100]  # Значения N для оценки

# Параметры стратегий чанкинга
FIXED_SIZE_CONFIG = {
    "words_per_chunk": 50,                  # Количество слов в чанке
    "overlap_words": 25                     # Количество слов перекрытия
}

sys.path.insert(0, str(Path(__file__).parent.parent))
from ntr_fileparser import UniversalParser

from ntr_text_fragmentation import Destructurer


def _average_pool(
        last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Расчёт усредненного эмбеддинга по всем токенам

        Args:
            last_hidden_states: Матрица эмбеддингов отдельных токенов размерности (batch_size, seq_len, embedding_size) - последний скрытый слой
            attention_mask: Маска, чтобы не учитывать при усреднении пустые токены

        Returns:
            torch.Tensor - Усредненный эмбеддинг размерности (batch_size, embedding_size)
        """
        last_hidden = last_hidden_states.masked_fill(
            ~attention_mask[..., None].bool(), 0.0
        )
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


def parse_args():
    """
    Парсит аргументы командной строки.
    
    Returns:
        Аргументы командной строки
    """
    parser = argparse.ArgumentParser(description="Скрипт для оценки качества чанкинга")
    
    parser.add_argument("--data-folder", type=str, default=DATA_FOLDER,
                        help=f"Путь к папке с документами (по умолчанию: {DATA_FOLDER})")
    parser.add_argument("--model-name", type=str, default=MODEL_NAME,
                        help=f"Название модели для векторизации (по умолчанию: {MODEL_NAME})")
    parser.add_argument("--dataset-path", type=str, default=DATASET_PATH,
                        help=f"Путь к Excel-датасету с вопросами (по умолчанию: {DATASET_PATH})")
    parser.add_argument("--batch-size", type=int, default=BATCH_SIZE,
                        help=f"Размер батча для векторизации (по умолчанию: {BATCH_SIZE})")
    parser.add_argument("--similarity-threshold", type=float, default=SIMILARITY_THRESHOLD,
                        help=f"Порог для нечеткого сравнения (по умолчанию: {SIMILARITY_THRESHOLD})")
    parser.add_argument("--output-dir", type=str, default=OUTPUT_DIR,
                        help=f"Директория для сохранения результатов (по умолчанию: {OUTPUT_DIR})")
    parser.add_argument("--force-recompute", action="store_true",
                        help="Принудительно пересчитать эмбеддинги, игнорируя сохраненные")
    parser.add_argument("--use-sentence-transformers", action="store_true",
                        help="Использовать библиотеку sentence_transformers для извлечения эмбеддингов (для FRIDA и других моделей)")
    parser.add_argument("--device", type=str, default=DEVICE,
                        help=f"Устройство для вычислений (по умолчанию: {DEVICE})")
    
    # Параметры для fixed_size стратегии
    parser.add_argument("--words-per-chunk", type=int, default=FIXED_SIZE_CONFIG["words_per_chunk"],
                        help=f"Количество слов в чанке для fixed_size стратегии (по умолчанию: {FIXED_SIZE_CONFIG['words_per_chunk']})")
    parser.add_argument("--overlap-words", type=int, default=FIXED_SIZE_CONFIG["overlap_words"],
                        help=f"Количество слов перекрытия для fixed_size стратегии (по умолчанию: {FIXED_SIZE_CONFIG['overlap_words']})")
    
    return parser.parse_args()


def read_documents(folder_path: str) -> dict:
    """
    Читает все документы из указанной папки.
    
    Args:
        folder_path: Путь к папке с документами
        
    Returns:
        Словарь {имя_файла: parsed_document}
    """
    print(f"Чтение документов из {folder_path}...")
    parser = UniversalParser()
    documents = {}
    
    for file_path in tqdm(list(Path(folder_path).glob("*.docx")), desc="Чтение документов"):
        try:
            doc_name = file_path.stem
            documents[doc_name] = parser.parse_by_path(str(file_path))
        except Exception as e:
            print(f"Ошибка при чтении файла {file_path}: {e}")
    
    return documents


def process_documents(documents: dict, fixed_size_config: dict) -> pd.DataFrame:
    """
    Обрабатывает документы со стратегией fixed_size для чанкинга.
    
    Args:
        documents: Словарь с распарсенными документами
        fixed_size_config: Конфигурация для fixed_size стратегии
        
    Returns:
        DataFrame с чанками
    """
    print("Обработка документов стратегией fixed_size...")
    
    all_data = []
    
    for doc_name, document in tqdm(documents.items(), desc="Применение стратегии fixed_size"):
        # Стратегия fixed_size для чанкинга
        destructurer = Destructurer(document)
        destructurer.configure('fixed_size', 
                                 words_per_chunk=fixed_size_config["words_per_chunk"], 
                                 overlap_words=fixed_size_config["overlap_words"])
        fixed_size_entities, _ = destructurer.destructure()
        
        # Обрабатываем только сущности для поиска
        for entity in fixed_size_entities:
            if hasattr(entity, 'use_in_search') and entity.use_in_search:
                entity_data = {
                    'id': str(entity.id),
                    'doc_name': doc_name,
                    'name': entity.name,
                    'text': entity.text,
                    'type': entity.type,
                    'strategy': 'fixed_size',
                    'metadata': json.dumps(entity.metadata, ensure_ascii=False)
                }
                all_data.append(entity_data)
    
    # Создаем DataFrame
    df = pd.DataFrame(all_data)
    
    # Фильтруем по типу, исключая Document
    df = df[df['type'] != 'Document']
    
    return df


def load_questions_dataset(file_path: str) -> pd.DataFrame:
    """
    Загружает датасет с вопросами из Excel-файла.
    
    Args:
        file_path: Путь к Excel-файлу
        
    Returns:
        DataFrame с вопросами и пунктами
    """
    print(f"Загрузка датасета из {file_path}...")
    
    df = pd.read_excel(file_path)
    print(f"Загружен датасет со столбцами: {df.columns.tolist()}")
    
    # Преобразуем NaN в пустые строки для текстовых полей
    text_columns = ['question', 'text', 'item_type']
    for col in text_columns:
        if col in df.columns:
            df[col] = df[col].fillna('')
    
    return df


def setup_model_and_tokenizer(model_name: str, use_sentence_transformers: bool = False, device: str = DEVICE):
    """
    Инициализирует модель и токенизатор.
    
    Args:
        model_name: Название предобученной модели
        use_sentence_transformers: Использовать ли библиотеку sentence_transformers
        device: Устройство для вычислений
        
    Returns:
        Кортеж (модель, токенизатор) или объект SentenceTransformer
    """
    print(f"Загрузка модели {model_name} на устройство {device}...")
    
    if use_sentence_transformers:
        try:
            from sentence_transformers import SentenceTransformer
            model = SentenceTransformer(model_name, device=device)
            return model, None
        except ImportError:
            print("Библиотека sentence_transformers не установлена. Установите её с помощью pip install sentence-transformers")
            raise
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModel.from_pretrained(model_name).to(device)
        model.eval()
        
        return model, tokenizer


def get_embeddings(texts: list[str], model, tokenizer=None, batch_size: int = BATCH_SIZE, use_sentence_transformers: bool = False, device: str = DEVICE) -> np.ndarray:
    """
    Получает эмбеддинги для списка текстов с использованием average pooling или sentence_transformers.
    
    Args:
        texts: Список текстов
        model: Модель для векторизации или SentenceTransformer
        tokenizer: Токенизатор (None для sentence_transformers)
        batch_size: Размер батча
        use_sentence_transformers: Использовать ли библиотеку sentence_transformers
        device: Устройство для вычислений
        
    Returns:
        Массив эмбеддингов
    """
    if use_sentence_transformers:
        # Используем sentence_transformers для получения эмбеддингов
        all_embeddings = []
        
        for i in tqdm(range(0, len(texts), batch_size), desc="Векторизация текстов (sentence_transformers)"):
            batch_texts = texts[i:i+batch_size]
            
            # Получаем эмбеддинги с помощью sentence_transformers
            embeddings = model.encode(batch_texts, batch_size=batch_size, show_progress_bar=False)
            all_embeddings.append(embeddings)
        
        return np.vstack(all_embeddings)
    else:
        # Используем стандартный подход с average pooling
        all_embeddings = []
        
        for i in tqdm(range(0, len(texts), batch_size), desc="Векторизация текстов"):
            batch_texts = texts[i:i+batch_size]
            
            # Токенизация с обрезкой и padding
            encoding = tokenizer(
                batch_texts, 
                padding=True, 
                truncation=True, 
                max_length=512, 
                return_tensors="pt"
            ).to(device)
            
            # Получаем эмбеддинги с average pooling
            with torch.no_grad():
                outputs = model(**encoding)
                embeddings = _average_pool(outputs.last_hidden_state, encoding["attention_mask"])
                all_embeddings.append(embeddings.cpu().numpy())
        
        return np.vstack(all_embeddings)


def calculate_chunk_overlap(chunk_text: str, punct_text: str) -> float:
    """
    Рассчитывает степень перекрытия между чанком и пунктом с использованием partial_ratio.
    
    Args:
        chunk_text: Текст чанка
        punct_text: Текст пункта
        
    Returns:
        Коэффициент перекрытия от 0 до 1
    """
    # Если чанк входит в пункт, возвращаем 1.0 (полное вхождение)
    if chunk_text in punct_text:
        return 1.0
    
    # Если пункт входит в чанк, возвращаем соотношение длин
    if punct_text in chunk_text:
        return len(punct_text) / len(chunk_text)
    
    # Используем partial_ratio из fuzzywuzzy, который лучше обрабатывает
    # случаи, когда один текст является подстрокой другого, даже с небольшими различиями
    partial_ratio_score = fuzz.partial_ratio(chunk_text, punct_text) / 100.0
    
    return partial_ratio_score


def save_embeddings_and_data(embeddings: np.ndarray, data: pd.DataFrame, filename: str, output_dir: str):
    """
    Сохраняет эмбеддинги и соответствующие данные в файлы.
    
    Args:
        embeddings: Массив эмбеддингов
        data: DataFrame с данными
        filename: Базовое имя файла
        output_dir: Директория для сохранения
    """
    embeddings_path = os.path.join(output_dir, f"{filename}_embeddings.npy")
    data_path = os.path.join(output_dir, f"{filename}_data.csv")
    
    # Сохраняем эмбеддинги
    np.save(embeddings_path, embeddings)
    print(f"Эмбеддинги сохранены в {embeddings_path}")
    
    # Сохраняем данные
    data.to_csv(data_path, index=False)
    print(f"Данные сохранены в {data_path}")


def load_embeddings_and_data(filename: str, output_dir: str) -> tuple[np.ndarray | None, pd.DataFrame | None]:
    """
    Загружает эмбеддинги и соответствующие данные из файлов.
    
    Args:
        filename: Базовое имя файла
        output_dir: Директория, где хранятся файлы
        
    Returns:
        Кортеж (эмбеддинги, данные) или (None, None), если файлы не найдены
    """
    embeddings_path = os.path.join(output_dir, f"{filename}_embeddings.npy")
    data_path = os.path.join(output_dir, f"{filename}_data.csv")
    
    if os.path.exists(embeddings_path) and os.path.exists(data_path):
        print(f"Загрузка данных из {embeddings_path} и {data_path}...")
        embeddings = np.load(embeddings_path)
        data = pd.read_csv(data_path)
        return embeddings, data
    
    return None, None


def save_top_chunks_for_question(
    question_id: int,
    question_text: str,
    question_puncts: list[str],
    top_chunks: pd.DataFrame,
    similarities: dict,
    overlap_data: list,
    output_dir: str
):
    """
    Сохраняет топ-чанки для конкретного вопроса в JSON-файл.
    
    Args:
        question_id: ID вопроса
        question_text: Текст вопроса
        question_puncts: Список пунктов, относящихся к вопросу
        top_chunks: DataFrame с топ-чанками
        similarities: Словарь с косинусными схожестями для чанков
        overlap_data: Данные о перекрытии чанков с пунктами
        output_dir: Директория для сохранения
    """
    # Подготавливаем результаты для сохранения
    chunks_data = []
    
    for i, (idx, chunk) in enumerate(top_chunks.iterrows()):
        # Получаем данные о перекрытии для текущего чанка
        chunk_overlaps = overlap_data[i] if i < len(overlap_data) else []
        
        # Преобразуем numpy типы в стандартные типы Python
        similarity = float(similarities.get(idx, 0.0))
        
        # Формируем данные чанка
        chunk_data = {
            'chunk_id': chunk['id'],
            'doc_name': chunk['doc_name'],
            'text': chunk['text'],
            'similarity': similarity,
            'overlaps': chunk_overlaps
        }
        chunks_data.append(chunk_data)
    
    # Преобразуем numpy.int64 в int для question_id
    question_id = int(question_id)
    
    # Формируем общий результат
    result = {
        'question_id': question_id,
        'question_text': question_text,
        'puncts': question_puncts,
        'chunks': chunks_data
    }
    
    # Создаем имя файла
    filename = f"question_{question_id}_top_chunks.json"
    filepath = os.path.join(output_dir, filename)
    
    # Класс для сериализации numpy типов
    class NumpyEncoder(json.JSONEncoder):
        def default(self, obj):
            if isinstance(obj, np.integer):
                return int(obj)
            if isinstance(obj, np.floating):
                return float(obj)
            if isinstance(obj, np.ndarray):
                return obj.tolist()
            return super().default(obj)
    
    # Сохраняем в JSON с кастомным энкодером
    with open(filepath, 'w', encoding='utf-8') as f:
        json.dump(result, f, ensure_ascii=False, indent=2, cls=NumpyEncoder)
    
    print(f"Топ-чанки для вопроса {question_id} сохранены в {filepath}")


def evaluate_for_top_n_with_mapping(
    questions_df: pd.DataFrame,
    chunks_df: pd.DataFrame,
    question_embeddings: np.ndarray,
    chunk_embeddings: np.ndarray,
    question_id_to_idx: dict,
    top_n: int,
    similarity_threshold: float,
    top_chunks_dir: str = None
) -> tuple[dict[str, float], pd.DataFrame]:
    """
    Оценивает качество чанкинга для заданного значения top_n с использованием маппинга id -> индекс.
    
    Args:
        questions_df: DataFrame с вопросами и релевантными пунктами (исходный датасет)
        chunks_df: DataFrame с чанками
        question_embeddings: Эмбеддинги вопросов
        chunk_embeddings: Эмбеддинги чанков
        question_id_to_idx: Словарь соответствия id вопроса и его индекса в массиве эмбеддингов
        top_n: Количество чанков в топе для каждого вопроса
        similarity_threshold: Порог для нечеткого сравнения
        top_chunks_dir: Директория для сохранения топ-чанков (если None, то не сохраняем)
        
    Returns:
        Кортеж (словарь с усредненными метриками, DataFrame с метриками по отдельным вопросам)
    """
    print(f"Оценка для top-{top_n}...")
    
    # Вычисляем косинусную близость между вопросами и чанками
    similarity_matrix = cosine_similarity(question_embeddings, chunk_embeddings)
    
    # Счетчики для метрик на основе текста
    total_puncts = 0
    found_puncts = 0
    total_chunks = 0
    relevant_chunks = 0
    
    # Счетчики для метрик на основе документов
    total_docs_required = 0
    found_relevant_docs = 0
    total_docs_found = 0
    
    # Для сохранения метрик по отдельным вопросам
    question_metrics = []
    
    # Выводим информацию о столбцах для отладки
    print(f"Столбцы в исходном датасете: {questions_df.columns.tolist()}")
    
    # Группируем вопросы по id (у нас 20 уникальных вопросов)
    for question_id in tqdm(questions_df['id'].unique(), desc=f"Оценка top-{top_n}"):
        # Получаем строки для текущего вопроса из исходного датасета
        question_rows = questions_df[questions_df['id'] == question_id]
        
        # Проверяем, есть ли вопрос с таким id в нашем маппинге
        if question_id not in question_id_to_idx:
            print(f"Предупреждение: вопрос с id {question_id} отсутствует в маппинге")
            continue
            
        # Если нет строк с таким id, пропускаем
        if len(question_rows) == 0:
            continue
        
        # Получаем индекс вопроса в массиве эмбеддингов
        question_idx = question_id_to_idx[question_id]
        
        # Получаем текст вопроса
        question_text = question_rows['question'].iloc[0]
        
        # Получаем все пункты для этого вопроса
        puncts = question_rows['text'].tolist()
        question_total_puncts = len(puncts)
        total_puncts += question_total_puncts
        
        # Получаем связанные документы
        relevant_docs = []
        if 'filename' in question_rows.columns:
            relevant_docs = [f for f in question_rows['filename'].unique() if f and not pd.isna(f)]
            question_total_docs_required = len(relevant_docs)
            total_docs_required += question_total_docs_required
            print(f"Найдено {question_total_docs_required} документов для вопроса {question_id}")
        else:
            print(f"Столбец 'filename' отсутствует. Используем все документы.")
            relevant_docs = chunks_df['doc_name'].unique().tolist()
            question_total_docs_required = len(relevant_docs)
            total_docs_required += question_total_docs_required
        
        # Если для вопроса нет релевантных документов, пропускаем
        if not relevant_docs:
            print(f"Для вопроса {question_id} нет связанных документов")
            continue
        
        # Флаги для отслеживания найденных пунктов
        punct_found = [False] * question_total_puncts
        
        # Для отслеживания найденных документов
        docs_found_for_question = set()
        
        # Для хранения всех чанков вопроса для ограничения top_n
        all_question_chunks = []
        all_question_similarities = []
        
        # Собираем чанки для всех документов по этому вопросу
        for filename in relevant_docs:
            if not filename or pd.isna(filename):
                continue
                
            # Фильтруем чанки по имени файла
            doc_chunks = chunks_df[chunks_df['doc_name'] == filename]
            
            if doc_chunks.empty:
                print(f"Предупреждение: документ {filename} не содержит чанков")
                continue
            
            # Индексы чанков для текущего файла
            doc_chunk_indices = doc_chunks.index.tolist()
            
            # Получаем значения близости для чанков текущего файла
            doc_similarities = [
                similarity_matrix[question_idx, chunks_df.index.get_loc(idx)] 
                for idx in doc_chunk_indices
            ]
            
            # Добавляем чанки и их схожести к общему списку для вопроса
            for i, idx in enumerate(doc_chunk_indices):
                all_question_chunks.append((idx, doc_chunks.iloc[doc_chunks.index.get_indexer([idx])[0]]))
                all_question_similarities.append(doc_similarities[i])
        
        # Сортируем все чанки по убыванию схожести и берем top_n
        sorted_indices = np.argsort(all_question_similarities)[-min(top_n, len(all_question_similarities)):][::-1]
        top_chunks_indices = [all_question_chunks[i][0] for i in sorted_indices]
        top_chunks = [all_question_chunks[i][1] for i in sorted_indices]
        
        # Увеличиваем счетчик общего числа чанков
        question_total_chunks = len(top_chunks)
        total_chunks += question_total_chunks
        
        # Для сохранения данных топ-чанков
        all_top_chunks = pd.DataFrame([chunk for chunk in top_chunks])
        all_chunk_similarities = {idx: all_question_similarities[i] for i, idx in enumerate([all_question_chunks[j][0] for j in sorted_indices])}
        all_chunk_overlaps = []
        
        # Для каждого чанка проверяем его релевантность к пунктам
        question_relevant_chunks = 0
        
        for i, chunk in enumerate(top_chunks):
            is_relevant = False
            chunk_overlaps = []
            
            # Добавляем документ в найденные
            docs_found_for_question.add(chunk['doc_name'])
            
            # Проверяем перекрытие с каждым пунктом
            for j, punct in enumerate(puncts):
                overlap = calculate_chunk_overlap(chunk['text'], punct)
                
                # Если нужно сохранить топ-чанки и top_n == 20
                if top_chunks_dir and top_n == 20:
                    chunk_overlaps.append({
                        'punct_index': j,
                        'punct_text': punct[:100] + '...' if len(punct) > 100 else punct,
                        'overlap': overlap
                    })
                
                # Если перекрытие больше порога, чанк релевантен
                if overlap >= similarity_threshold:
                    is_relevant = True
                    punct_found[j] = True
            
            if is_relevant:
                question_relevant_chunks += 1
            
            # Если нужно сохранить топ-чанки и top_n == 20
            if top_chunks_dir and top_n == 20:
                all_chunk_overlaps.append(chunk_overlaps)
        
        # Если нужно сохранить топ-чанки и top_n == 20
        if top_chunks_dir and top_n == 20 and not all_top_chunks.empty:
            save_top_chunks_for_question(
                question_id,
                question_text,
                puncts,
                all_top_chunks,
                all_chunk_similarities,
                all_chunk_overlaps,
                top_chunks_dir
            )
        
        # Подсчитываем метрики для текущего вопроса
        question_found_puncts = sum(punct_found)
        found_puncts += question_found_puncts
        
        relevant_chunks += question_relevant_chunks
        
        # Обновляем метрики для документов
        question_found_relevant_docs = sum(1 for doc in docs_found_for_question if doc in relevant_docs)
        found_relevant_docs += question_found_relevant_docs
        question_total_docs_found = len(docs_found_for_question)
        total_docs_found += question_total_docs_found
        
        # Вычисляем метрики для текущего вопроса
        question_text_precision = question_relevant_chunks / question_total_chunks if question_total_chunks > 0 else 0
        question_text_recall = question_found_puncts / question_total_puncts if question_total_puncts > 0 else 0
        question_text_f1 = 2 * question_text_precision * question_text_recall / (question_text_precision + question_text_recall) if question_text_precision + question_text_recall > 0 else 0
        
        question_doc_precision = question_found_relevant_docs / question_total_docs_found if question_total_docs_found > 0 else 0
        question_doc_recall = question_found_relevant_docs / question_total_docs_required if question_total_docs_required > 0 else 0
        question_doc_f1 = 2 * question_doc_precision * question_doc_recall / (question_doc_precision + question_doc_recall) if question_doc_precision + question_doc_recall > 0 else 0
        
        # Сохраняем метрики вопроса
        question_metrics.append({
            'question_id': question_id,
            'question_text': question_text,
            'top_n': top_n,
            'text_precision': question_text_precision,
            'text_recall': question_text_recall,
            'text_f1': question_text_f1,
            'doc_precision': question_doc_precision,
            'doc_recall': question_doc_recall,
            'doc_f1': question_doc_f1,
            'found_puncts': question_found_puncts,
            'total_puncts': question_total_puncts,
            'relevant_chunks': question_relevant_chunks,
            'total_chunks': question_total_chunks,
            'found_relevant_docs': question_found_relevant_docs,
            'total_docs_required': question_total_docs_required,
            'total_docs_found': question_total_docs_found
        })
    
    # Вычисляем метрики для текста
    text_precision = relevant_chunks / total_chunks if total_chunks > 0 else 0
    text_recall = found_puncts / total_puncts if total_puncts > 0 else 0
    text_f1 = 2 * text_precision * text_recall / (text_precision + text_recall) if text_precision + text_recall > 0 else 0
    
    # Вычисляем метрики для документов
    doc_precision = found_relevant_docs / total_docs_found if total_docs_found > 0 else 0
    doc_recall = found_relevant_docs / total_docs_required if total_docs_required > 0 else 0
    doc_f1 = 2 * doc_precision * doc_recall / (doc_precision + doc_recall) if doc_precision + doc_recall > 0 else 0
    
    aggregated_metrics = {
        'top_n': top_n,
        'text_precision': text_precision,
        'text_recall': text_recall,
        'text_f1': text_f1,
        'doc_precision': doc_precision,
        'doc_recall': doc_recall,
        'doc_f1': doc_f1,
        'found_puncts': found_puncts,
        'total_puncts': total_puncts,
        'relevant_chunks': relevant_chunks,
        'total_chunks': total_chunks,
        'found_relevant_docs': found_relevant_docs,
        'total_docs_required': total_docs_required,
        'total_docs_found': total_docs_found
    }
    
    return aggregated_metrics, pd.DataFrame(question_metrics)


def main():
    """
    Основная функция скрипта.
    """
    args = parse_args()
    
    # Устанавливаем устройство из аргументов
    device = args.device
    
    # Создаем выходной каталог, если его нет
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Создаем директорию для топ-чанков
    top_chunks_dir = os.path.join(args.output_dir, "top_chunks")
    os.makedirs(top_chunks_dir, exist_ok=True)
    
    # Загружаем датасет с вопросами
    questions_df = load_questions_dataset(args.dataset_path)
    
    # Формируем уникальное имя для сохраняемых файлов на основе параметров стратегии и модели
    strategy_config_str = f"fixed_size_w{args.words_per_chunk}_o{args.overlap_words}"
    chunks_filename = f"chunks_{strategy_config_str}_{args.model_name.replace('/', '_')}"
    questions_filename = f"questions_{args.model_name.replace('/', '_')}"
    
    # Пытаемся загрузить сохраненные эмбеддинги и данные
    chunk_embeddings, chunks_df = None, None
    question_embeddings, questions_df_with_embeddings = None, None
    
    if not args.force_recompute:
        chunk_embeddings, chunks_df = load_embeddings_and_data(chunks_filename, args.output_dir)
        question_embeddings, questions_df_with_embeddings = load_embeddings_and_data(questions_filename, args.output_dir)
    
    # Если не удалось загрузить данные или включен режим принудительного пересчета
    if chunk_embeddings is None or chunks_df is None:
        # Читаем и обрабатываем документы
        documents = read_documents(args.data_folder)
        
        # Формируем конфигурацию для стратегии fixed_size
        fixed_size_config = {
            "words_per_chunk": args.words_per_chunk,
            "overlap_words": args.overlap_words
        }
        
        # Получаем DataFrame с чанками
        chunks_df = process_documents(documents, fixed_size_config)
        
        # Настраиваем модель и токенизатор
        model, tokenizer = setup_model_and_tokenizer(args.model_name, args.use_sentence_transformers, device)
        
        # Получаем эмбеддинги для чанков
        chunk_embeddings = get_embeddings(chunks_df['text'].tolist(), model, tokenizer, args.batch_size, args.use_sentence_transformers, device)
        
        # Сохраняем эмбеддинги и данные
        save_embeddings_and_data(chunk_embeddings, chunks_df, chunks_filename, args.output_dir)
    
    # Если не удалось загрузить эмбеддинги вопросов или включен режим принудительного пересчета
    if question_embeddings is None or questions_df_with_embeddings is None:
        # Получаем уникальные вопросы (по id)
        unique_questions = questions_df.drop_duplicates(subset=['id'])[['id', 'question']]
        
        # Настраиваем модель и токенизатор (если еще не настроены)
        if 'model' not in locals() or 'tokenizer' not in locals():
            model, tokenizer = setup_model_and_tokenizer(args.model_name, args.use_sentence_transformers, device)
        
        # Получаем эмбеддинги для вопросов
        question_embeddings = get_embeddings(unique_questions['question'].tolist(), model, tokenizer, args.batch_size, args.use_sentence_transformers, device)
        
        # Сохраняем эмбеддинги и данные
        save_embeddings_and_data(question_embeddings, unique_questions, questions_filename, args.output_dir)
        
        # Устанавливаем questions_df_with_embeddings для дальнейшего использования
        questions_df_with_embeddings = unique_questions
    
    # Создаем словарь соответствия id вопроса и его индекса в эмбеддингах
    question_id_to_idx = {
        row['id']: i 
        for i, (_, row) in enumerate(questions_df_with_embeddings.iterrows())
    }
    
    # Оцениваем стратегию чанкинга для разных значений top_n
    aggregated_results = []
    all_question_metrics = []
    
    for top_n in TOP_N_VALUES:
        metrics, question_metrics = evaluate_for_top_n_with_mapping(
            questions_df,           # Исходный датасет с связью между вопросами и документами
            chunks_df,              # Датасет с чанками
            question_embeddings,    # Эмбеддинги вопросов
            chunk_embeddings,       # Эмбеддинги чанков
            question_id_to_idx,     # Маппинг id вопроса к индексу в эмбеддингах
            top_n,                  # Количество чанков в топе
            args.similarity_threshold, # Порог для определения перекрытия
            top_chunks_dir if top_n == 20 else None  # Сохраняем топ-чанки только для top_n=20
        )
        aggregated_results.append(metrics)
        all_question_metrics.append(question_metrics)
    
    # Объединяем все метрики по вопросам
    all_question_metrics_df = pd.concat(all_question_metrics)
    
    # Создаем DataFrame с агрегированными результатами
    aggregated_results_df = pd.DataFrame(aggregated_results)
    
    # Сохраняем результаты
    results_filename = f"results_{strategy_config_str}_{args.model_name.replace('/', '_')}.csv"
    results_path = os.path.join(args.output_dir, results_filename)
    aggregated_results_df.to_csv(results_path, index=False)
    
    # Сохраняем метрики по вопросам
    question_metrics_filename = f"question_metrics_{strategy_config_str}_{args.model_name.replace('/', '_')}.xlsx"
    question_metrics_path = os.path.join(args.output_dir, question_metrics_filename)
    all_question_metrics_df.to_excel(question_metrics_path, index=False)
    
    print(f"\nРезультаты сохранены в {results_path}")
    print(f"Метрики по вопросам сохранены в {question_metrics_path}")
    print(f"Топ-20 чанков для каждого вопроса сохранены в {top_chunks_dir}")
    print("\nМетрики для различных значений top_n:")
    print(aggregated_results_df[['top_n', 'text_precision', 'text_recall', 'text_f1', 'doc_precision', 'doc_recall', 'doc_f1']])


if __name__ == "__main__":
    main()