import logging

import datasets
import huggingface_hub
import requests
import os

from app_env import HF_WRITE_TOKEN

logger = logging.getLogger(__name__)
AUTH_CHECK_URL = "https://huggingface.co/api/whoami-v2"

class HuggingFaceInferenceAPIResponse:
    def __init__(self, message):
        self.message = message


def get_labels_and_features_from_dataset(ds):
    try:
        dataset_features = ds.features
        label_keys = [i for i in dataset_features.keys() if i.startswith('label')]
        if len(label_keys) == 0: # no labels found
            # return everything for post processing
            return list(dataset_features.keys()), list(dataset_features.keys())
        if not isinstance(dataset_features[label_keys[0]], datasets.ClassLabel):
            if hasattr(dataset_features[label_keys[0]], 'feature'):
                label_feat = dataset_features[label_keys[0]].feature
                labels = label_feat.names
        else:
            labels = dataset_features[label_keys[0]].names
        features = [f for f in dataset_features.keys() if not f.startswith("label")]
        return labels, features
    except Exception as e:
        logging.warning(
            f"Get Labels/Features Failed for dataset: {e}"
        )
        return None, None

def check_model_task(model_id):
    # check if model is valid on huggingface
    try:
        task = huggingface_hub.model_info(model_id).pipeline_tag
        if task is None:
            return None
        return task
    except Exception:
        return None

def get_model_labels(model_id, example_input):
    hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
    payload = {"inputs": example_input, "options": {"use_cache": True}}
    response = hf_inference_api(model_id, hf_token, payload)
    if "error" in response:
        return None
    return extract_from_response(response, "label")

def extract_from_response(data, key):
    results = []

    if isinstance(data, dict):
        res = data.get(key)
        if res is not None:
            results.append(res)

        for value in data.values():
            results.extend(extract_from_response(value, key))

    elif isinstance(data, list):
        for element in data:
            results.extend(extract_from_response(element, key))

    return results

def hf_inference_api(model_id, hf_token, payload):
    hf_inference_api_endpoint = os.environ.get(
        "HF_INFERENCE_ENDPOINT", default="https://api-inference.huggingface.co"
    )
    url = f"{hf_inference_api_endpoint}/models/{model_id}"
    headers = {"Authorization": f"Bearer {hf_token}"}
    response = requests.post(url, headers=headers, json=payload)
    if not hasattr(response, "status_code") or response.status_code != 200:
        logger.warning(f"Request to inference API returns {response}")
    try:
        return response.json()
    except Exception:
        return {"error": response.content}
    
def preload_hf_inference_api(model_id):
    payload = {"inputs": "This is a test", "options": {"use_cache": True, }}
    hf_token = os.environ.get(HF_WRITE_TOKEN, default="")
    hf_inference_api(model_id, hf_token, payload)

def check_dataset_features_validity(d_id, config, split):
    # We assume dataset is ok here
    ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True)
    try:
        dataset_features = ds.features
    except AttributeError:
        # Dataset does not have features, need to provide everything
        return None, None
        # Load dataset as DataFrame
    df = ds.to_pandas()

    return df, dataset_features

def select_the_first_string_column(ds):
    for feature in ds.features.keys():
        if isinstance(ds[0][feature], str):
            return feature
    return None


def get_example_prediction(model_id, dataset_id, dataset_config, dataset_split, hf_token):
    # get a sample prediction from the model on the dataset
    prediction_input = None
    prediction_result = None
    try:
        # Use the first item to test prediction
        ds = datasets.load_dataset(dataset_id, dataset_config, split=dataset_split, trust_remote_code=True)
        if "text" not in ds.features.keys():
            # Dataset does not have text column
            prediction_input = ds[0][select_the_first_string_column(ds)]
        else:
            prediction_input = ds[0]["text"]

        payload = {"inputs": prediction_input, "options": {"use_cache": True}}
        results = hf_inference_api(model_id, hf_token, payload)

        if isinstance(results, dict) and "error" in results.keys():
            if "estimated_time" in results.keys():
                return prediction_input, HuggingFaceInferenceAPIResponse(
                    f"Estimated time: {int(results['estimated_time'])}s. Please try again later.")
            return prediction_input, HuggingFaceInferenceAPIResponse(
                f"Inference Error: {results['error']}.")
        
        while isinstance(results, list):
            if isinstance(results[0], dict):
                break
            results = results[0]
        prediction_result = {
            f'{result["label"]}': result["score"] for result in results
        }
    except Exception as e:
        # inference api prediction failed, show the error message
        logger.error(f"Get example prediction failed {e}")
        return prediction_input, None

    return prediction_input, prediction_result


def get_sample_prediction(ppl, df, column_mapping, id2label_mapping):
    # get a sample prediction from the model on the dataset
    prediction_input = None
    prediction_result = None
    try:
        # Use the first item to test prediction
        prediction_input = df.head(1).at[0, column_mapping["text"]]
        results = ppl({"text": prediction_input}, top_k=None)
        prediction_result = {
            f'{result["label"]}': result["score"] for result in results
        }
    except Exception:
        # Pipeline prediction failed, need to provide labels
        return prediction_input, None

    # Display results in original label and mapped label
    prediction_result = {
        f'{result["label"]}(original) - {id2label_mapping[result["label"]]}(mapped)': result[
            "score"
        ]
        for result in results
    }
    return prediction_input, prediction_result


def strip_model_id_from_url(model_id):
    if model_id.startswith("https://huggingface.co/"):
        return "/".join(model_id.split("/")[-2])
    return model_id

def check_hf_token_validity(hf_token):
    if hf_token == "":
        return False
    if not isinstance(hf_token, str):
        return False
    # use huggingface api to check the token
    headers = {"Authorization": f"Bearer {hf_token}"}
    response = requests.get(AUTH_CHECK_URL, headers=headers)
    if response.status_code != 200:
        return False
    return True