from functools import partial
import os
import json
import re
import tempfile
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict
from datatrove.io import get_datafolder
from datatrove.utils.stats import MetricStatsDict
import gradio as gr
import tenacity

from src.logic.graph_settings import Grouping

def find_folders(base_folder: str, path: str) -> List[str]:
    base_folder_df = get_datafolder(base_folder)
    if not base_folder_df.exists(path):
        return []
    return sorted(
        [
            folder
            for folder,info in base_folder_df.find(path, maxdepth=1, withdirs=True, detail=True).items()
            if info["type"] == "directory" and not (folder.rstrip("/") == path.rstrip("/"))
        ]
    )

def fetch_datasets(base_folder: str):
    datasets = sorted(find_folders(base_folder, ""))
    if len(datasets) == 0:
        raise ValueError("No datasets found")
    return datasets

def fetch_groups(base_folder: str, datasets: List[str], old_groups: str, type: str = "intersection"):
    if not datasets:
        return gr.update(choices=[], value=None)

    with ThreadPoolExecutor() as executor:
        GROUPS = list(executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, run)], datasets))
    if len(GROUPS) == 0:
        return gr.update(choices=[], value=None)

    if type == "intersection":
        new_choices = set.intersection(*(set(g) for g in GROUPS))
    else:
        new_choices = set.union(*(set(g) for g in GROUPS))
    value = None
    if old_groups:
        value = list(set.intersection(new_choices, {old_groups}))
        value = value[0] if value else None

    if not value and len(new_choices) == 1:
        value = list(new_choices)[0]

    return gr.Dropdown(choices=sorted(list(new_choices)), value=value)

def fetch_metrics(base_folder: str, datasets: List[str], group: str, old_metrics: str, type: str = "intersection"):
    if not group:
        return gr.update(choices=[], value=None)

    with ThreadPoolExecutor() as executor:
        metrics = list(
            executor.map(lambda run: [Path(x).name for x in find_folders(base_folder, f"{run}/{group}")], datasets))
    if len(metrics) == 0:
        return gr.update(choices=[], value=None)

    if type == "intersection":
        new_possibles_choices = set.intersection(*(set(s) for s in metrics))
    else:
        new_possibles_choices = set.union(*(set(s) for s in metrics))
    value = None
    if old_metrics:
        value = list(set.intersection(new_possibles_choices, {old_metrics}))
        value = value[0] if value else None

    if not value and len(new_possibles_choices) == 1:
        value = list(new_possibles_choices)[0]

    return gr.Dropdown(choices=sorted(list(new_possibles_choices)), value=value)

def reverse_search(base_folder: str, possible_datasets: List[str], grouping: str, metric_name: str) -> str:
    with ThreadPoolExecutor() as executor:
        found_datasets = list(executor.map(
            lambda dataset: dataset if metric_exists(base_folder, dataset, metric_name, grouping) else None,
            possible_datasets))
    found_datasets = [dataset for dataset in found_datasets if dataset is not None]
    return "\n".join(found_datasets)

def reverse_search_add(datasets: List[str], reverse_search_results: str) -> List[str]:
    datasets = datasets or []
    return list(set(datasets + reverse_search_results.strip().split("\n")))

def metric_exists(base_folder: str, path: str, metric_name: str, group_by: str) -> bool:
    base_folder = get_datafolder(base_folder)
    return base_folder.exists(f"{path}/{group_by}/{metric_name}/metric.json")

@tenacity.retry(stop=tenacity.stop_after_attempt(5))
def load_metrics(base_folder: str, path: str, metric_name: str, group_by: str) -> MetricStatsDict:
    base_folder = get_datafolder(base_folder)
    with base_folder.open(f"{path}/{group_by}/{metric_name}/metric.json") as f:
        json_metric = json.load(f)
        return MetricStatsDict.from_dict(json_metric)

def load_data(dataset_path: str, base_folder: str, grouping: str, metric_name: str) -> MetricStatsDict:
    return load_metrics(base_folder, dataset_path, metric_name, grouping)


def fetch_graph_data(
        base_folder: str,
        datasets: List[str],
        metric_name: str,
        grouping: Grouping,
        progress=gr.Progress(),
):
    if len(datasets) <= 0 or not metric_name or not grouping:
        return None

    with ThreadPoolExecutor() as pool:
        data = list(
            progress.tqdm(
                pool.map(
                    partial(load_data, base_folder=base_folder, metric_name=metric_name, grouping=grouping),
                    datasets,
                ),
                total=len(datasets),
                desc="Loading data...",
            )
        )

    data = {path: result for path, result in zip(datasets, data)}
    return data, None

def update_datasets_with_regex(regex: str, selected_runs: List[str], all_runs: List[str]):
    if not regex:
        return []
    new_dsts = {run for run in all_runs if re.search(regex, run)}
    if not new_dsts:
        return selected_runs
    dst_union = new_dsts.union(selected_runs or [])
    return sorted(list(dst_union))