import json
from pathlib import Path

import gradio as gr
import pandas as pd

TITLE = """<h1 align="center" id="space-title">Large Reasoning Models Leaderboard</h1>"""

DESCRIPTION = f"""
Evaluation of Open R1 models across a diverse range of benchmarks from [LightEval](https://github.com/huggingface/lighteval). All scores are reported as accuracy.
"""

BENCHMARKS_TO_SKIP = ["math", "mini_math", "aimo_math_integer_lvl4-5", "mini_math_v2"]


def get_leaderboard_df():
    filepaths = list(Path("eval_results").rglob("*.json"))

    # Parse filepaths to get unique models
    models = set()
    for filepath in filepaths:
        path_parts = Path(filepath).parts
        model_revision = "_".join(path_parts[1:4])
        models.add(model_revision)

    # Initialize DataFrame
    df = pd.DataFrame(index=list(models))

    # Extract data from each file and populate the DataFrame
    for filepath in filepaths:
        path_parts = Path(filepath).parts
        date = filepath.stem.split("_")[-1][:-3]
        model_revision = "_".join(path_parts[1:4]) + "_" + date
        task = path_parts[4]
        df.loc[model_revision, "Date"] = date

        with open(filepath, "r") as file:
            data = json.load(file)
            # Skip benchmarks that we don't want to include in the leaderboard
            if task.lower() in BENCHMARKS_TO_SKIP:
                continue
            # MixEval doen't have a results key, so we need to get the overall score
            if task.lower() in ["mixeval", "mixeval_hard"]:
                value = data["overall score (final score)"]
                df.loc[model_revision, f"{task}"] = value
            else:
                first_result_key = next(iter(data["results"]))  # gets the first key in 'results'
                # TruthfulQA has two metrics, so we need to pick the `mc2` one that's reported on the leaderboard
                if task.lower() == "truthfulqa":
                    value = data["results"][first_result_key]["truthfulqa_mc2"]
                    df.loc[model_revision, task] = float(value)
                # IFEval has several metrics but we report the average like Llama3 paper
                elif task.lower() == "ifeval":
                    values = 0.0
                    for metric in [
                        "prompt_level_loose",
                        "prompt_level_strict",
                        "inst_level_strict",
                        "inst_level_loose",
                    ]:
                        values += data["results"][first_result_key][f"{metric}_acc"]
                    value = values / 4
                    df.loc[model_revision, f"{task}"] = float(value)
                # MMLU has several metrics but we report just the average one
                elif task.lower() == "mmlu":
                    value = [v["acc"] for k, v in data["results"].items() if "_average" in k.lower()][0]
                    df.loc[model_revision, task] = float(value)
                # HellaSwag and ARC reports acc_norm
                elif task.lower() in ["hellaswag", "arc"]:
                    value = data["results"][first_result_key]["acc_norm"]
                    df.loc[model_revision, task] = float(value)
                # BBH has several metrics but we report just the average one
                elif task.lower() == "bbh":
                    if "all" in data["results"]:
                        value = data["results"]["all"]["acc"]
                    else:
                        value = -100
                    df.loc[model_revision, task] = float(value)
                # AGIEval reports acc_norm
                elif task.lower() == "agieval":
                    value = data["results"]["all"]["acc_norm"]
                    df.loc[model_revision, task] = float(value)
                # AIME24 and 25 report pass@1
                elif task.lower() in ["aime24", "aime25"]:
                    # Check for 32 samples
                    if "math_pass@1:32_samples" in data["results"]["all"]:
                        value = data["results"]["all"]["math_pass@1:32_samples"]
                        df.loc[model_revision, f"{task} (n=32)"] = float(value)

                    # Check for 64 samples
                    if "math_pass@1:64_samples" in data["results"]["all"]:
                        value = data["results"]["all"]["math_pass@1:64_samples"]
                        df.loc[model_revision, f"{task} (n=64)"] = float(value)

                    # For backward compatibility, also store in the original column name if any value exists
                    if "math_pass@1:32_samples" in data["results"]["all"]:
                        df.loc[model_revision, task] = float(data["results"]["all"]["math_pass@1:32_samples"])
                    elif "math_pass@1:64_samples" in data["results"]["all"]:
                        df.loc[model_revision, task] = float(data["results"]["all"]["math_pass@1:64_samples"])
                # GPQA now reports pass@1
                elif task.lower() == "gpqa":
                    # Check for 8 samples
                    if "gpqa_pass@1:8_samples" in data["results"]["all"]:
                        value = data["results"]["all"]["gpqa_pass@1:8_samples"]
                        df.loc[model_revision, f"{task} (n=8)"] = float(value)

                    # For backward compatibility, also store in the original column name if any value exists
                    if "extractive_match" in data["results"]["all"]:
                        df.loc[model_revision, task] = float(data["results"]["all"]["extractive_match"])
                    elif "gpqa_pass@1:8_samples" in data["results"]["all"]:
                        df.loc[model_revision, task] = float(data["results"]["all"]["gpqa_pass@1:8_samples"])
                # MATH-500 now reports pass@1
                elif task.lower() == "math_500":
                    # Check for 8 samples
                    if "math_pass@1:4_samples" in data["results"]["all"]:
                        value = data["results"]["all"]["math_pass@1:4_samples"]
                        df.loc[model_revision, f"{task} (n=4)"] = float(value)

                    # For backward compatibility, also store in the original column name if any value exists
                    if "extractive_match" in data["results"]["all"]:
                        df.loc[model_revision, task] = float(data["results"]["all"]["extractive_match"])
                    elif "math_pass@1:4_samples" in data["results"]["all"]:
                        df.loc[model_revision, task] = float(data["results"]["all"]["math_pass@1:4_samples"])
                # MATH reports qem
                elif task.lower() in ["aimo_kaggle", "math_deepseek_cot", "math_deepseek_rl_cot"]:
                    value = data["results"]["all"]["qem"]
                    df.loc[model_revision, task] = float(value)
                # For mini_math we report 5 metrics, one for each level and store each one as a separate row in the dataframe
                elif task.lower() in ["mini_math_v2"]:
                    for k, v in data["results"].items():
                        if k != "all":
                            level = k.split("|")[1].split(":")[-1]
                            value = v["qem"]
                            df.loc[model_revision, f"{task}_{level}"] = value
                # For PoT we report N metrics, one for each prompt and store each one as a separate row in the dataframe
                elif task.lower() in ["aimo_kaggle_medium_pot", "aimo_kaggle_hard_pot"]:
                    for k, v in data["results"].items():
                        if k != "all" and "_average" not in k:
                            version = k.split("|")[1].split(":")[-1]
                            value = v["qem"] if "qem" in v else v["score"]
                            df.loc[model_revision, f"{task}_{version}"] = value
                # For kaggle_tora we report accuracy as a percentage, so need  to divide by 100
                elif task.lower() in [
                    "aimo_tora_eval_kaggle_medium",
                    "aimo_tora_eval_kaggle_hard",
                    "aimo_kaggle_fast_eval_hard",
                    "aimo_kaggle_tora_medium",
                    "aimo_kaggle_tora_hard",
                    "aimo_kaggle_tora_medium_extended",
                    "aimo_kaggle_tora_hard_extended",
                    "aimo_math_integer_lvl4",
                    "aimo_math_integer_lvl5",
                ]:
                    for k, v in data["results"].items():
                        value = float(v["qem"]) / 100.0
                        df.loc[model_revision, f"{task}"] = value
                # For AlpacaEval we report base winrate and lenght corrected one
                elif task.lower() == "alpaca_eval":
                    value = data["results"][first_result_key]["win_rate"]
                    df.loc[model_revision, "Alpaca_eval"] = value / 100.0
                    value = data["results"][first_result_key]["length_controlled_winrate"]
                    df.loc[model_revision, "Alpaca_eval_lc"] = value / 100.0
                else:
                    first_metric_key = next(
                        iter(data["results"][first_result_key])
                    )  # gets the first key in the first result
                    value = data["results"][first_result_key][first_metric_key]  # gets the value of the first metric
                    df.loc[model_revision, task] = float(value)

    # Drop rows where every entry is NaN
    df = df.dropna(how="all", axis=0, subset=[c for c in df.columns if c != "Date"])

    # Trim minimath column names
    df.columns = [c.replace("_level_", "_l") for c in df.columns]

    # Trim AIMO column names
    df.columns = [c.replace("aimo_", "") for c in df.columns]

    df = df.reset_index().rename(columns={"index": "Model"})
    # Apply rounding only to numeric columns
    numeric_cols = df.select_dtypes(include=["float64", "float32", "int64", "int32"]).columns
    df[numeric_cols] = df[numeric_cols].round(4)
    # Strip off date from model name
    df["Model"] = df["Model"].apply(lambda x: x.rsplit("_", 1)[0])

    return df


leaderboard_df = get_leaderboard_df()


def agg_df(df, agg: str = "max"):
    df = df.copy()
    # Drop date and aggregate results by model name
    df = df.drop("Date", axis=1).groupby("Model").agg(agg).reset_index()

    df.insert(loc=len(df.columns), column="Average", value=df.mean(axis=1, numeric_only=True))

    # Convert all values to percentage
    df[df.select_dtypes(include=["number"]).columns] *= 100.0
    # Apply rounding only to numeric columns
    numeric_cols = df.select_dtypes(include=["float64", "float32", "int64", "int32"]).columns
    df[numeric_cols] = df[numeric_cols].round(4)
    df = df.sort_values(by=["Average"], ascending=False)
    return df


# Function to update the table based on search query
def filter_and_search(cols: list[str], search_query: str, agg: str):
    df = leaderboard_df
    df = agg_df(df, agg)
    if len(search_query) > 0:
        search_terms = search_query.split(";")
        search_terms = [term.strip().lower() for term in search_terms]
        pattern = "|".join(search_terms)
        df = df[df["Model"].str.lower().str.contains(pattern, regex=True)]
        # Drop any columns which are all NaN
        df = df.dropna(how="all", axis=1)
    if len(cols) > 0:
        index_cols = list(leaderboard_df.columns[:1])
        new_cols = index_cols + cols
        df = df.copy()[new_cols]
        # Drop rows with NaN values
        df = df.copy().dropna(how="all", axis=0, subset=[c for c in df.columns if c in cols])
        # Recompute average
        df.insert(loc=len(df.columns), column="Average", value=df.mean(axis=1, numeric_only=True))
        # Apply rounding only to numeric columns
        numeric_cols = df.select_dtypes(include=["float64", "float32", "int64", "int32"]).columns
        df[numeric_cols] = df[numeric_cols].round(4)
    return df


demo = gr.Blocks()

with demo:
    gr.HTML(TITLE)
    with gr.Column():
        gr.Markdown(DESCRIPTION, elem_classes="markdown-text")
        with gr.Row():
            search_bar = gr.Textbox(
                placeholder="Search for your model. Use semicolons for multiple terms", show_label=False
            )
            agg = gr.Radio(
                ["min", "max", "mean"],
                value="max",
                label="Aggregation",
                info="How to aggregate results for each model",
            )
        with gr.Row():
            cols_bar = gr.CheckboxGroup(
                choices=sorted([c for c in leaderboard_df.columns[1:] if c not in ["Average", "Date"]]),
                show_label=False,
                info="Select columns to display",
            )
        with gr.Group():
            leaderboard_table = gr.Dataframe(
                value=leaderboard_df,
                wrap=True,
                max_height=1000,
                column_widths=[400, 110] + [(260 + len(c)) for c in leaderboard_df.columns[1:]],
                show_row_numbers=True,
                show_copy_button=True,
            )

    cols_bar.change(filter_and_search, inputs=[cols_bar, search_bar, agg], outputs=[leaderboard_table])
    agg.change(filter_and_search, inputs=[cols_bar, search_bar, agg], outputs=[leaderboard_table])
    search_bar.submit(filter_and_search, inputs=[cols_bar, search_bar, agg], outputs=[leaderboard_table])

demo.launch()