import glob
import json
import argparse
import sys
from dataclasses import dataclass
from enum import Enum
import csv

@dataclass(frozen=True)
class Task:
    benchmark: str
    metric: str
    col_name: str
    type: str
    baseline: float = 0.0

# from src.about import Tasks, get_tasks

class Tasks(Enum):
    # task_key in the json file, metric_key in the json file, name to display in the leaderboard 
    # task2 = Task("belebele_pol_Latn", "acc,none", "belebele_pol_Latn", "multiple_choice", 0.279)
    task3 = Task("polemo2_in", "exact_match,score-first", "polemo2-in_g", "generate_until", 0.416)
    task4 = Task("polemo2_in_multiple_choice", "acc,none", "polemo2-in_mc", "multiple_choice", 0.416)
    task5 = Task("polemo2_out", "exact_match,score-first", "polemo2-out_g", "generate_until", 0.368)
    task6 = Task("polemo2_out_multiple_choice", "acc,none", "polemo2-out_mc", "multiple_choice", 0.368)
    task7 = Task("polish_8tags_multiple_choice", "acc,none", "8tags_mc", "multiple_choice", 0.143)
    task8 = Task("polish_8tags_regex", "exact_match,score-first", "8tags_g", "generate_until", 0.143)
    task9a = Task("polish_belebele_mc", "acc,none", "belebele_mc", "multiple_choice", 0.279)
    task9 = Task("polish_belebele_regex", "exact_match,score-first", "belebele_g", "generate_until", 0.279)
    task10 = Task("polish_dyk_multiple_choice", "f1,none", "dyk_mc", "multiple_choice", 0.289)
    task11 = Task("polish_dyk_regex", "f1,score-first", "dyk_g", "generate_until", 0.289)
    task12 = Task("polish_ppc_multiple_choice", "acc,none", "ppc_mc", "multiple_choice", 0.419)
    task13 = Task("polish_ppc_regex", "exact_match,score-first", "ppc_g", "generate_until", 0.419)
    task14 = Task("polish_psc_multiple_choice", "f1,none", "psc_mc", "multiple_choice", 0.466)
    task15 = Task("polish_psc_regex", "f1,score-first", "psc_g", "generate_until", 0.466)
    task16 = Task("polish_cbd_multiple_choice", "f1,none", "cbd_mc", "multiple_choice", 0.149)
    task17 = Task("polish_cbd_regex", "f1,score-first", "cbd_g", "generate_until", 0.149)
    task18 = Task("polish_klej_ner_multiple_choice", "acc,none", "klej_ner_mc", "multiple_choice", 0.343)
    task19 = Task("polish_klej_ner_regex", "exact_match,score-first", "klej_ner_g", "generate_until", 0.343)
    task21 = Task("polish_polqa_reranking_multiple_choice", "acc,none", "polqa_reranking_mc", "multiple_choice", 0.5335588952710677) # multiple_choice
    task22 = Task("polish_polqa_open_book", "levenshtein,none", "polqa_open_book_g", "generate_until", 0.0) # generate_until
    task23 = Task("polish_polqa_closed_book", "levenshtein,none", "polqa_closed_book_g", "generate_until", 0.0) # generate_until
    task24 = Task("polish_poquad_open_book", "levenshtein,none", "poquad_open_book", "generate_until", 0.0)
    task25 = Task("polish_eq_bench_first_turn", "first_eqbench,none", "eq_bench_first_turn", "generate_until", 0.0)
    task26 = Task("polish_eq_bench", "average_eqbench,none", "eq_bench", "other", 0.0)
    task20 = Task("polish_poleval2018_task3_test_10k", "word_perplexity,none", "poleval2018_task3_test_10k", "other")
    task27 = Task("polish_poquad_reranking", "acc,none", "poquad_reranking", "other", 0.0)
    task28 = Task("polish_abstractive_poquad_rag", "levenshtein,none", "abstractive_poquad_rag", "other", 0.0)
    task29 = Task("polish_abstractive_poquad_open_book", "levenshtein,none", "abstractive_poquad_open_book", "other", 0.0)
    task30 = Task("polish_pes", "exact_match,score-first", "pes", "other", 0.2)


def get_tasks():
  g_tasks = [task.value.benchmark for task in Tasks if task.value.type == "generate_until"]
  mc_tasks = [task.value.benchmark for task in Tasks if task.value.type == "multiple_choice"]
  rag_tasks = ['polish_polqa_reranking_multiple_choice', 'polish_polqa_open_book', 'polish_poquad_open_book']
  all_tasks = g_tasks + mc_tasks
  return g_tasks, mc_tasks, rag_tasks, all_tasks

g_tasks, mc_tasks, rag_tasks, all_tasks = get_tasks()

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Calculate average scores from JSON with scores')
    parser.add_argument('json', type=str, help='Path to JSON file with scores')
    parser.add_argument('--header', action='store_true', help='Print header')
    parser.add_argument('-d', '--delimiter', type=str, default=',', help='Delimiter for CSV output')
    args = parser.parse_args()

    if args.json.endswith('.json'):
        paths=[args.json]
    else:

        paths=glob.glob(args.json + '/**/results*.json', recursive=True)

    print(paths)
    # paths=[args.json]

    results = {}
    for path in paths:
        print(path)
        data = json.load(open(path))


        for task in Tasks:
            try:
                # print(task.value.benchmark, task.value.baseline)
                # print(data['results'][task.value.benchmark], data['results'][task.value.benchmark][task.value.metric])
                results[task.value.benchmark] = data['results'][task.value.benchmark][task.value.metric]
                if 'perplexity' not in task.value.metric and 'eqbench' not in task.value.metric:
                    results[task.value.benchmark] *= 100

                # if 'perplexity' in task.metric or 'eqbench' in task.metric:
                #     mean_acc = np.mean(accs)
                # else:
                #     mean_acc = np.mean(accs) * 100.0

            except KeyError:
                print(f'No data for {task.value.benchmark}', file=sys.stderr)
    # results=data['results']
    print(results)
    all_tasks_wo_polqa = [task for task in all_tasks if 'polqa' not in task]

    baselines = {task.value.benchmark: task.value.baseline * 100 for task in Tasks}
    print(baselines)
    average_old = sum([v for task, v in results.items() if v is not None and task in all_tasks_wo_polqa]) / len(
        all_tasks_wo_polqa)

    average = sum(
        [(results.get(task, 0) - baselines.get(task, 0)) / (100 - baselines.get(task, 0)) * 100 for task in
         all_tasks]) / len(all_tasks)

    for task in all_tasks:
        print (task, results.get(task, 0), baselines.get(task, 0))

    average_g = sum(
        [(results.get(task, 0) - baselines.get(task, 0)) / (100 - baselines.get(task, 0)) * 100 for task in
         g_tasks]) / len(g_tasks)
    average_mc = sum(
        [(results.get(task, 0) - baselines.get(task, 0)) / (100 - baselines.get(task, 0)) * 100 for task in
         mc_tasks]) / len(mc_tasks)
    average_rag = sum(
        [(results.get(task, 0) - baselines.get(task, 0)) / (100 - baselines.get(task, 0)) * 100 for task in
         rag_tasks]) / len(rag_tasks)

    

    # for task in Tasks:
    #     print(task.value.benchmark, task.value.baseline)
    #     print(data['results'][task.value.benchmark])
    # print(f'Average: {average:.2f}')
    # print(f'Average generate: {average_g:.2f}')
    # print(f'Average multiple choice: {average_mc:.2f}')
    # print(f'Average old: {average_old:.2f}')

    row = [args.json, None, average, average_old, average_g, average_mc, average_rag]
    for task in Tasks:
        row.append(results.get(task.value.benchmark, None))

    # printe headers
    if args.header:
        csv.writer(sys.stdout, delimiter=args.delimiter).writerow(['file', 'name', 'average', 'average_old', 'average_g', 'average_mc'] + [task.value.benchmark for task in Tasks])
    # print(row)
    csv.writer(sys.stdout, delimiter=args.delimiter).writerow(row)