import argparse
import numpy as np
from pathlib import Path
import tqdm
from pprint import pprint
import torch
from torch.nn.utils.rnn import pad_sequence
from scrl.config import load_config
from scrl.training import setup_and_train
from scrl.model import labels_to_summary
from scrl.eval_metrics import compute_token_f1
import scrl.utils as utils
from nltk import word_tokenize


def evaluate_validation_reward(args, manager, model, tokenizer, reward_generator, dataset):
    device = args.device
    idx_range = list(range(len(dataset)))
    dataset_indices = list(utils.batchify(idx_range, args.batch_size))
    rewards = []
    for i, indices in enumerate(dataset_indices):
        if args.max_val_steps != None and i >= args.max_val_steps:
            break
        batch = dataset[indices]
        input_ids = batch["input_ids"]
        input_ids = pad_sequence(
            [torch.tensor(ids) for ids in input_ids], batch_first=True
        )
        logits = model(input_ids.to(device))
        probs = torch.softmax(logits, dim=2)
        argmax_labels = torch.argmax(logits, dim=2).to(device)
        argmax_summaries = labels_to_summary(input_ids, argmax_labels, tokenizer)
        argmax_rewards, _ = reward_generator(batch["document"], argmax_summaries)
        rewards += argmax_rewards
    avg_reward = np.mean(rewards)
    return avg_reward



def evaluate_validation_dataset(args, manager, model, tokenizer, reward_generator, dataset_path):
    f1_scores = []
    dataset = list(utils.read_jsonl(dataset_path))
    dump_data = []

    for item in tqdm.tqdm(dataset):
        src = item["text"]
        tgts = item["summaries"]

        input_ids = torch.tensor(tokenizer([src])["input_ids"]).to(args.device)
        logits = model.forward(input_ids)
        argmax_labels = torch.argmax(logits, dim=2)
        pred = labels_to_summary(input_ids, argmax_labels, tokenizer)[0]

        pred_tokens = word_tokenize(pred)
        src_tokens = word_tokenize(src)


        item_scores = []
        for tgt in tgts:
            tgt_tokens = word_tokenize(tgt)
            pred_tokens = [t.lower() for t in pred_tokens]
            tgt_tokens = [t.lower() for t in tgt_tokens]
            token_f1 = compute_token_f1(
                tgt_tokens, pred_tokens, use_counts=True
            )
            item_scores.append(token_f1)

        if args.dump:
            probs = torch.softmax(logits, dim=2)[0].detach().tolist()
            dump_item = {
                "probs": probs,
                "source": src,
                "target": tgts[0],
                "f1-score": item_scores[0],
                "pred_summary": pred,
                "pred_labels": argmax_labels[0].tolist(),
            }
            dump_data.append(dump_item)

        item_score = np.mean(item_scores)
        f1_scores.append(item_score)
    score = np.mean(f1_scores)


    if args.dump:
        dataset_name = dataset_path.name.split(".jsonl")[0]
        dump_dir = manager.dir / f"dump-{dataset_name}"
        dump_dir.mkdir(exist_ok=True)
        utils.write_jsonl(
            dump_data,
            dump_dir / f"step-{manager.step}.jsonl",
            "w"
        )
    return score


def evaluate(args, manager, model, tokenizer, reward_generator, holdout_data):
    step = manager.step
    val_reward = evaluate_validation_reward(args, manager, model, tokenizer, reward_generator, holdout_data)

    reward_path = manager.dir / "val_rewards.jsonl"
    if reward_path.exists():
        reward_results = list(utils.read_jsonl(reward_path))
        prev_max = max([x["score"] for x in reward_results])
    else:
        reward_results = []
        prev_max = 0
    if val_reward > prev_max:
        manager.save_model(model, step, "best_val_reward")
    reward_results.append({"step": step, "score": val_reward})
    utils.write_jsonl(reward_results, reward_path, "w")
    if args.verbose:
        print("Validation Rewards:")
        pprint(reward_results)
        print()

    # only used if a validation dataset is specified in config
    for val_data_path in args.validation_datasets:
        val_data_path = Path(val_data_path)
        dataset_name = val_data_path.name.split(".jsonl")[0]
        dataset_score = evaluate_validation_dataset(
            args, manager, model, tokenizer, reward_generator, val_data_path
        )
        result_path = Path(manager.dir / f"val_data_results.{dataset_name}.jsonl")
        if result_path.exists():
            dataset_results = list(utils.read_jsonl(result_path))
            prev_max = max([x["score"] for x in dataset_results])
        else:
            dataset_results = []
            prev_max = 0
        if dataset_score > prev_max:
            manager.save_model(model, step, f"best_on_{dataset_name}")
        dataset_results.append({"step": step, "score": dataset_score})
        utils.write_jsonl(dataset_results, result_path, "w")
        if args.verbose:
            print(f"Validation Dataset Results for {dataset_name}:")
            pprint(dataset_results)
            print()


def main(args):
    utils.set_random_seed(0)
    setup_and_train(args, eval_func=evaluate)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", help="path to JSON config file")
    parser.add_argument("--device", default="cuda")
    parser.add_argument("--dump", action="store_true")
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument(
        "--fresh",
        action="store_true",
        help="delete model directory and start from scratch"
    )
    main(load_config(parser.parse_args()))