import pprint
import re

from huggingface_hub import snapshot_download, delete_inference_endpoint

from src.backend.inference_endpoint import create_endpoint
from src.backend.manage_requests import check_completed_evals, \
    get_eval_requests, set_eval_request, PENDING_STATUS, FINISHED_STATUS, \
    FAILED_STATUS, RUNNING_STATUS
from src.backend.run_toxicity_eval import compute_results
from src.backend.sort_queue import sort_models_by_priority
from src.envs import (REQUESTS_REPO, EVAL_REQUESTS_PATH_BACKEND, RESULTS_REPO,
                      EVAL_RESULTS_PATH_BACKEND, API, TOKEN)
from src.logging import setup_logger

logger = setup_logger(__name__)

pp = pprint.PrettyPrinter(width=80)


snapshot_download(repo_id=RESULTS_REPO, revision="main",
                  local_dir=EVAL_RESULTS_PATH_BACKEND, repo_type="dataset",
                  max_workers=60, token=TOKEN)
snapshot_download(repo_id=REQUESTS_REPO, revision="main",
                  local_dir=EVAL_REQUESTS_PATH_BACKEND, repo_type="dataset",
                  max_workers=60, token=TOKEN)


def run_auto_eval():
    # pull the eval dataset from the hub and parse any eval requests
    # check completed evals and set them to finished
    check_completed_evals(
        api=API,
        completed_status=FINISHED_STATUS,
        failed_status=FAILED_STATUS,
        hf_repo=REQUESTS_REPO,
        local_dir=EVAL_REQUESTS_PATH_BACKEND,
        hf_repo_results=RESULTS_REPO,
        local_dir_results=EVAL_RESULTS_PATH_BACKEND
    )

    # Get all eval requests that are PENDING
    eval_requests = get_eval_requests(hf_repo=REQUESTS_REPO,
                                      local_dir=EVAL_REQUESTS_PATH_BACKEND)
    # Sort the evals by priority (first submitted, first run)
    eval_requests = sort_models_by_priority(api=API, models=eval_requests)

    logger.info(
        f"Found {len(eval_requests)} {PENDING_STATUS} eval requests")

    if len(eval_requests) == 0:
        return

    eval_request = eval_requests[0]
    logger.info(pp.pformat(eval_request))

    set_eval_request(
        api=API,
        eval_request=eval_request,
        set_to_status=RUNNING_STATUS,
        hf_repo=REQUESTS_REPO,
        local_dir=EVAL_REQUESTS_PATH_BACKEND,
    )

    logger.info(
        f'Starting Evaluation of {eval_request.json_filepath} on Inference endpoints')
    endpoint_name = _make_endpoint_name(eval_request)
    endpoint_url = create_endpoint(endpoint_name, eval_request.model)
    logger.info("Created an endpoint url at %s" % endpoint_url)
    results = compute_results(endpoint_url, eval_request)
    logger.info("FINISHED!")
    logger.info(results)
    logger.info(f'Completed Evaluation of {eval_request.json_filepath}')
    set_eval_request(api=API,
                     eval_request=eval_request,
                     set_to_status=FINISHED_STATUS,
                     hf_repo=REQUESTS_REPO,
                     local_dir=EVAL_REQUESTS_PATH_BACKEND,
                     )
    # Delete endpoint when we're done.
    delete_inference_endpoint(endpoint_name)


def _make_endpoint_name(eval_request):
    model_repository = eval_request.model
    # Naming convention for endpoints
    endpoint_name_tmp = re.sub("[/.]", "-",
                               model_repository.lower()) + "-toxicity-eval"
    # Endpoints apparently can't have more than 32 characters.
    endpoint_name = endpoint_name_tmp[:32]
    return endpoint_name


if __name__ == "__main__":
    run_auto_eval()