import os

# Questions for Gradio
# - Chat share button is enabled by default but thrown an error when clicked.
# - How to add local images in HTML? (https://github.com/gradio-app/gradio/issues/884)
# - How to allow Chatbot to fill the vertical space? (https://github.com/gradio-app/gradio/issues/4001)
# TODO:
# - Add the 1MB models, keras/gemma_1.1_instruct_7b_en
# - Add retry button, for each model individually
# - Add ability to route a message to a single model only.
# - log_applied_layout_map: make it work for Llama3CausalLM and LlamaCausalLM (vicuna)
# - display context length

os.environ["KERAS_BACKEND"] = "jax"

import gradio as gr
from gradio import ChatMessage
import keras_hub

from chatstate import ChatState
from enum import Enum
from models import (
    model_presets,
    load_model,
    model_labels,
    preset_to_website_url,
    get_appropriate_chat_template,
)


class TextRoute(Enum):
    LEFT = 0
    RIGHT = 1
    BOTH = 2


model_labels_list = list(model_labels)

# load and warm up (compile) all the models
models = []
for preset in model_presets:
    model = load_model(preset)
    chat_template = get_appropriate_chat_template(preset)
    chat_state = ChatState(model, "", chat_template)
    prompt, response = chat_state.send_message("Hello")
    print("model " + preset + " loaded and initialized.")
    print("The model responded: " + response)
    models.append(model)

# For local debugging
# model = keras_hub.models.Llama3CausalLM.from_preset(
#     # "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16"
#     "../misc-code/ari_tiny_llama3"
# )
# models = [model, model, model, model, model]


def chat_turn_assistant(
    message,
    sel,
    history,
    system_message,
    # max_tokens,
    # temperature,
    # top_p,
):
    model = models[sel]
    preset = model_presets[sel]
    chat_template = get_appropriate_chat_template(preset)
    chat_state = ChatState(model, system_message, chat_template)

    for msg in history:
        msg = ChatMessage(**msg)
        if msg.role == "user":
            chat_state.add_to_history_as_user(msg.content)
        elif msg.role == "assistant":
            chat_state.add_to_history_as_model(msg.content)

    prompt, response = chat_state.send_message(message)
    history.append(ChatMessage(role="assistant", content=response))
    return history


def chat_turn_both_assistant(
    message, sel1, sel2, history1, history2, system_message
):
    return (
        chat_turn_assistant(message, sel1, history1, system_message),
        chat_turn_assistant(message, sel2, history2, system_message),
    )


def chat_turn_user(message, history):
    history.append(ChatMessage(role="user", content=message))
    return history


def chat_turn_both_user(message, history1, history2):
    return (
        chat_turn_user(message, history1),
        chat_turn_user(message, history2),
    )


def bot_icon_select(model_name):
    if "gemma" in model_name:
        return "img/gemma.png"
    elif "llama" in model_name:
        return "img/meta.png"
    elif "vicuna" in model_name:
        return "img/vicuna.png"
    elif "mistral" in model_name:
        return "img/mistral.png"
    # default
    return "img/bot.png"


def instantiate_select_box(sel, model_labels):
    return gr.Dropdown(
        choices=[(name, i) for i, name in enumerate(model_labels)],
        show_label=False,
        value=sel,
        info="<span style='color:black'>Selected model:</span> <a href='"
        + preset_to_website_url(model_presets[sel])
        + "'>"
        + preset_to_website_url(model_presets[sel])
        + "</a>",
    )


def instantiate_chatbot(sel, key):
    model_name = model_presets[sel]
    return gr.Chatbot(
        type="messages",
        key=key,
        show_label=False,
        show_share_button=False,
        show_copy_all_button=True,
        avatar_images=("img/usr.png", bot_icon_select(model_name)),
    )


def instantiate_arrow_button(route, text_route):
    icons = {
        TextRoute.LEFT: "img/arrowL.png",
        TextRoute.RIGHT: "img/arrowR.png",
        TextRoute.BOTH: "img/arrowRL.png",
    }
    button = gr.Button(
        "",
        size="sm",
        scale=0,
        min_width=40,
        icon=icons[route],
    )
    button.click(lambda: route, outputs=[text_route])
    return button


def instantiate_retry_button(route):
    return gr.Button(
        "",
        size="sm",
        scale=0,
        min_width=40,
        icon="img/retry.png",
    )


def instantiate_trash_button():
    return gr.Button(
        "",
        size="sm",
        scale=0,
        min_width=40,
        icon="img/trash.png",
    )


def instantiate_text_box():
    return gr.Textbox(label="Your message:", submit_btn=True, key="msg")


def instantiate_additional_settings():
    with gr.Accordion("Additional settings", open=False):
        system_message = gr.Textbox(
            label="Sytem prompt",
            key="system_prompt",
            value="You are a helpful assistant and your name is Eliza.",
        )
    return system_message


def retry_fn(history):
    if len(history) >= 2:
        msg = history.pop(-1)  # assistant message
        msg = history.pop(-1)  # user message
        return msg["content"], history
    else:
        return gr.skip(), gr.skip()


def retry_fn_both(history1, history2):
    msg1, history1 = retry_fn(history1)
    msg2, history2 = retry_fn(history2)
    if isinstance(msg1, str) and isinstance(msg2, str):
        if msg1 == msg2:
            msg = msg1
        else:
            msg = msg1 + " / " + msg2
    elif isinstance(msg1, str):
        msg = msg1
    elif isinstance(msg2, str):
        msg = msg2
    else:
        msg = msg1
    return msg, history1, history2


sel1 = instantiate_select_box(0, model_labels_list)
sel2 = instantiate_select_box(1, model_labels_list)
chatbot1 = instantiate_chatbot(sel1.value, "chat1")
chatbot2 = instantiate_chatbot(sel2.value, "chat2")

# to correctly align the left/right arrows
CSS = ".stick-to-the-right {align-items: end; justify-content: end}"

with gr.Blocks(fill_width=True, title="Keras demo", css=CSS) as demo:

    # Where do messages go
    text_route = gr.State(TextRoute.BOTH)

    with gr.Row():
        gr.Image(
            "img/keras_logo_k.png",
            width=80,
            height=80,
            min_width=80,
            show_label=False,
            show_download_button=False,
            show_fullscreen_button=False,
            show_share_button=False,
            interactive=False,
            scale=0,
            container=False,
        )
        gr.HTML(
            "<H2>Keras chatbot arena  - running with JAX on TPU</H2>"
            + "All the models are loaded into the TPU memory. "
            + "You can call any of them and compare their answers. "
            + "The entire chat<br/>history is fed to the models at every submission. "
            + "This demo is runnig on a Google TPU v5e 2x4 (8 cores) in bfloat16 precision."
        )
    with gr.Row():
        sel1.render(),
        sel2.render(),

    with gr.Row():
        chatbot1.render()
        chatbot2.render()

    @gr.render(inputs=text_route)
    def render_text_area(route):

        if route == TextRoute.BOTH:
            with gr.Row():
                msg = instantiate_text_box()
                with gr.Column(scale=0, min_width=100):
                    with gr.Row():
                        instantiate_arrow_button(TextRoute.LEFT, text_route)
                        retry = instantiate_retry_button(route)
                    with gr.Row():
                        instantiate_arrow_button(TextRoute.RIGHT, text_route)
                        trash = instantiate_trash_button()
            retry.click(
                retry_fn_both,
                inputs=[chatbot1, chatbot2],
                outputs=[msg, chatbot1, chatbot2],
            )
            trash.click(lambda: ("", [], []), outputs=[msg, chatbot1, chatbot2])

        elif route == TextRoute.LEFT:
            with gr.Row():
                with gr.Column(scale=1):
                    msg = instantiate_text_box()
                with gr.Column(scale=1):
                    with gr.Row():
                        instantiate_arrow_button(TextRoute.RIGHT, text_route)
                        retry = instantiate_retry_button(route)
                    with gr.Row():
                        instantiate_arrow_button(TextRoute.BOTH, text_route)
                        trash = instantiate_trash_button()
            retry.click(retry_fn, inputs=[chatbot1], outputs=[msg, chatbot1])
            trash.click(lambda: ("", []), outputs=[msg, chatbot1])

        elif route == TextRoute.RIGHT:
            with gr.Row():
                with gr.Column(scale=1, elem_classes="stick-to-the-right"):
                    with gr.Row(elem_classes="stick-to-the-right"):
                        retry = instantiate_retry_button(route)
                        instantiate_arrow_button(TextRoute.LEFT, text_route)
                    with gr.Row(elem_classes="stick-to-the-right"):
                        trash = instantiate_trash_button()
                        instantiate_arrow_button(TextRoute.BOTH, text_route)
                with gr.Column(scale=1):
                    msg = instantiate_text_box()
            retry.click(retry_fn, inputs=[chatbot2], outputs=[msg, chatbot2])
            trash.click(lambda: ("", []), outputs=[msg, chatbot2])

        system_message = instantiate_additional_settings()

        # Route the submitted message to the left, right or both chatbots
        if route == TextRoute.LEFT:
            submission = msg.submit(
                chat_turn_user, inputs=[msg, chatbot1], outputs=[chatbot1]
            ).then(
                chat_turn_assistant,
                [msg, sel1, chatbot1, system_message],
                outputs=[chatbot1],
            )
        elif route == TextRoute.RIGHT:
            submission = msg.submit(
                chat_turn_user, inputs=[msg, chatbot2], outputs=[chatbot2]
            ).then(
                chat_turn_assistant,
                [msg, sel2, chatbot2, system_message],
                outputs=[chatbot2],
            )
        elif route == TextRoute.BOTH:
            submission = msg.submit(
                chat_turn_both_user,
                inputs=[msg, chatbot1, chatbot2],
                outputs=[chatbot1, chatbot2],
            ).then(
                chat_turn_both_assistant,
                [msg, sel1, sel2, chatbot1, chatbot2, system_message],
                outputs=[chatbot1, chatbot2],
            )
        # In all cases reset text box after submission
        submission.then(lambda: "", outputs=msg)

    sel1.select(
        lambda sel: instantiate_chatbot(sel, "chat1"),
        inputs=[sel1],
        outputs=[chatbot1],
    ).then(
        lambda sel: instantiate_select_box(sel, model_labels_list),
        inputs=[sel1],
        outputs=[sel1],
    )

    sel2.select(
        lambda sel: instantiate_chatbot(sel, "chat2"),
        inputs=[sel2],
        outputs=[chatbot2],
    ).then(
        lambda sel: instantiate_select_box(sel, model_labels_list),
        inputs=[sel2],
        outputs=[sel2],
    )


if __name__ == "__main__":
    demo.launch()