import gradio as gr

from transformers import pipeline
from typing import Dict, Union
from gliner import GLiNER

model = GLiNER.from_pretrained("urchade/gliner_multi-v2.1")  # numind/NuNER_Zero

classifier = pipeline("zero-shot-classification", model="MoritzLaurer/deberta-v3-base-zeroshot-v1")

css = """
h1 {
    text-align: center;
    display:block;
}
"""

#define a function to process your input and output
def zero_shot(doc, candidates):
    given_labels = candidates.split(", ")
    dictionary = classifier(doc, given_labels)
    labels = dictionary['labels']
    scores = dictionary['scores']
    return dict(zip(labels, scores))

examples_text = [
    [
        "Pasar saham ngalaman panurunan nu signifikan akibat kateupastian global.",
        "ékonomi, pulitik, bisnis, kauangan, téknologi"
    ],
    [
        "I am very happy today but suddenly sad because of the recent news.",
        "positive, negative, neutral"
    ],
    [
        "I just received the best news ever! I got the job I always wanted!",
        "joy, sadness, anger, surprise, fear, disgust"
    ],
]
examples_ner = [
    [
        "Pada tahun 1945, Indonesia memproklamasikan kemerdekaannya dari penjajahan Belanda. Proklamasi tersebut dibacakan oleh Soekarno dan Mohammad Hatta di Jakarta.",
        "tahun, negara, tokoh, lokasi",
        0.3
    ],
    [
        "Mount Everest is the highest mountain above sea level, located in the Himalayas. It stands at 8,848 meters (29,029 ft) and attracts many climbers.",
        "location, measurement, person",
        0.3
    ],
     [
        "Perusahaan teknologi raksasa, Google, mbukak kantor cabang anyar ing Jakarta ing wulan Januari 2020 kanggo nggedhekake operasine ing Asia Tenggara",
        "perusahaan, lokasi, wulan, taun",
        0.3
    ],
]

def merge_entities(entities):
    if not entities:
        return []
    merged = []
    current = entities[0]
    for next_entity in entities[1:]:
        if next_entity['entity'] == current['entity'] and (next_entity['start'] == current['end'] + 1 or next_entity['start'] == current['end']):
            current['word'] += ' ' + next_entity['word']
            current['end'] = next_entity['end']
        else:
            merged.append(current)
            current = next_entity
    merged.append(current)
    return merged

def ner(
    text, labels: str, threshold: float, nested_ner: bool
) -> Dict[str, Union[str, int, float]]:
    labels = labels.split(",")
    r = {
        "text": text,
        "entities": [
            {
                "entity": entity["label"],
                "word": entity["text"],
                "start": entity["start"],
                "end": entity["end"],
                "score": 0,
            }
            for entity in model.predict_entities(
                text, labels, flat_ner=not nested_ner, threshold=threshold
            )
        ],
    }
    r["entities"] =  merge_entities(r["entities"])
    return r


with gr.Blocks(title="Zero-Shot Demo", css=css) as demo: #, theme=gr.themes.Soft()

    gr.Markdown(
            """
            # Zero-Shot Model Demo
            """
        )

    #create input and output objects
    with gr.Tab("Zero-Shot Text Classification"):

        gr.Markdown(
            """
            ## Zero-Shot Text Classification
            """
        )

        input1 = gr.Textbox(label="Text", value=examples_text[0][0])
        input2 = gr.Textbox(label="Labels",value=examples_text[0][1])
        output = gr.Label(label="Output")

        gui = gr.Interface(
            # title="Zero-Shot Text Classification",
            fn=zero_shot,
            inputs=[input1, input2],
            outputs=[output]
        )

        examples = gr.Examples(
            examples_text,
            fn=zero_shot,
            inputs=[input1, input2],
            outputs=output,
            cache_examples=True,
        )


    with gr.Tab("Zero-Shot NER"):
        gr.Markdown(
            """
            ## Zero-Shot Named Entity Recognition (NER)
            """
        )

        input_text = gr.Textbox(
            value=examples_ner[0][0], label="Text input", placeholder="Enter your text here", lines=3
        )
        with gr.Row() as row:
            labels = gr.Textbox(
                value=examples_ner[0][1],
                label="Labels",
                placeholder="Enter your labels here (comma separated)",
                scale=2,
            )
            threshold = gr.Slider(
                0,
                1,
                value=examples_ner[0][2],
                step=0.01,
                label="Threshold",
                info="Lower the threshold to increase how many entities get predicted.",
                scale=1,
            )
        
        output = gr.HighlightedText(label="Predicted Entities")

        submit_btn = gr.Button("Submit")

        examples = gr.Examples(
            examples_ner,
            fn=ner,
            inputs=[input_text, labels, threshold],
            outputs=output,
            cache_examples=True,
        )

        submit_btn.click(
            fn=ner, inputs=[input_text, labels, threshold], outputs=output
        )

demo.queue()
demo.launch(debug=True)