import spaces
import os
import json
import gradio as gr
import pycountry
import torch
from datetime import datetime
from typing import Dict, Union
from gliner import GLiNER


_MODEL = {}
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
THRESHOLD = 0.3
LABELS = ["country", "year", "statistical indicator", "geographic region"]
QUERY = "gdp, co2 emissions, and mortality rate of the philippines vs. south asia in 2024"
MODELS = ["urchade/gliner_base", "urchade/gliner_medium-v2.1", "urchade/gliner_multi-v2.1", "urchade/gliner_large-v2.1"]

print(f"Cache directory: {_CACHE_DIR}")


def get_model(model_name: str = None):
    start = datetime.now()

    if model_name is None:
        model_name = "urchade/gliner_base"

    global _MODEL

    if _MODEL.get(model_name) is None:
        _MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)

    if torch.cuda.is_available() and not next(_MODEL[model_name].parameters()).device.type.startswith("cuda"):
        _MODEL[model_name] = _MODEL[model_name].to("cuda")

    print(f"{datetime.now()} :: get_model :: {datetime.now() - start}")

    return _MODEL[model_name]


def get_country(country_name: str):
    try:
        return pycountry.countries.search_fuzzy(country_name)
    except LookupError:
        return None


@spaces.GPU(enable_queue=True, duration=15)
def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False):
    start = datetime.now()
    model = get_model(model_name)

    if isinstance(labels, str):
        labels = [i.strip() for i in labels.split(",")]

    entities = model.predict_entities(query, labels, threshold=threshold, flat_ner=not nested_ner)

    print(f"{datetime.now()} :: predict_entities :: {datetime.now() - start}")

    return entities


def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:

    entities = []
    _entities = predict_entities(model_name=model_name, query=query, labels=labels, threshold=threshold, nested_ner=nested_ner)

    for entity in _entities:
        if entity["label"] == "country":
            country = get_country(entity["text"])
            if country:
                entity["normalized"] = [dict(c) for c in country]
                entities.append(entity)
        else:
            entities.append(entity)
    
    payload = {"query": query, "entities": entities}
    print(f"{datetime.now()} :: parse_query :: {json.dumps(payload)}\n")

    return payload

def annotate_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
    payload = parse_query(query, labels, threshold, nested_ner, model_name)

    return {
        "text": query,
        "entities": [
            {
                "entity": entity["label"],
                "word": entity["text"],
                "start": entity["start"],
                "end": entity["end"],
                "score": entity["score"],
            }
            for entity in payload["entities"]
        ],
    }


# Initialize model here.
print("Initializing models...")
for model_name in MODELS:
    predict_entities(model_name, QUERY, LABELS, threshold=THRESHOLD)


with gr.Blocks(title="GLiNER-query-parser") as demo:
    gr.Markdown(
        """
        # GLiNER-based Query Parser (a zero-shot NER model)

        This space demonstrates the GLiNER model's ability to predict entities in a given text query. Given a set of entities to track, the model can then identify instances of these entities in the query. The parsed entities are then displayed in the output. A special case is the "country" entity, which is normalized to the ISO 3166-1 alpha-2 code using the `pycountry` library. This GLiNER mode is licensed under the Apache 2.0 license.

        ## Links
        * Model: https://huggingface.co/urchade/gliner_medium-v2.1, https://huggingface.co/urchade/gliner_base
        * All GLiNER models: https://huggingface.co/models?library=gliner
        * Paper: https://arxiv.org/abs/2311.08526
        * Repository: https://github.com/urchade/GLiNER
        """
    )

    query = gr.Textbox(
        value=QUERY, label="query", placeholder="Enter your query here"
    )
    with gr.Row() as row:
        model_name = gr.Radio(
            choices=MODELS,
            value="urchade/gliner_base",
            label="Model",
        )
        entities = gr.Textbox(
            value=", ".join(LABELS),
            label="entities",
            placeholder="Enter the entities to detect here (comma separated)",
            scale=2,
        )
        threshold = gr.Slider(
            0,
            1,
            value=THRESHOLD,
            step=0.01,
            label="Threshold",
            info="Lower threshold may extract more false-positive entities from the query.",
            scale=1,
        )
        is_nested = gr.Checkbox(
            value=False,
            label="Nested NER",
            info="Setting to True extracts nested entities",
            scale=0,
        )

    output = gr.HighlightedText(label="Annotated entities")
    submit_btn = gr.Button("Submit")

    json_output = gr.JSON(label="Extracted entities")
    json_button = gr.Button("Get JSON") 

    # Submitting
    query.submit(
        fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
    )
    entities.submit(
        fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
    )
    threshold.release(
        fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
    )
    submit_btn.click(
        fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
    )
    is_nested.change(
        fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
    )
    model_name.change(
        fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
    )
    json_button.click(
        fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=json_output
    )

demo.queue(default_concurrency_limit=5)
demo.launch(debug=True)


"""
from gradio_client import Client

client = Client("avsolatorio/query-parser")
result = client.predict(
		query="gdp, m3, and child mortality of india and southeast asia 2024",
		labels="country, year, statistical indicator, region",
		threshold=0.3,
		nested_ner=False,
		api_name="/parse_query"
)
print(result)
"""