from typing import List, Tuple, Optional

# this is so that there is no error: exception: access violation reading 0x0000000000000000
# https://github.com/abetlen/llama-cpp-python/issues/1581
from llama_cpp import Llama

import gradio as gr
from langchain_core.vectorstores import VectorStore

from config import (
    LLM_MODEL_REPOS,
    START_LLM_MODEL_FILE,
    EMBED_MODEL_REPOS,
    SUBTITLES_LANGUAGES,
    GENERATE_KWARGS,
    CONTEXT_TEMPLATE,
)

from utils import (
    load_llm_model,
    load_embed_model,
    load_documents_and_create_db,
    user_message_to_chatbot,
    update_user_message_with_context,
    get_llm_response,
    get_gguf_model_names,
    add_new_model_repo,
    clear_llm_folder,
    clear_embed_folder,
    get_memory_usage,
)


# ============ INTERFACE COMPONENT INITIALIZATION FUNCS ============

def get_rag_mode_component(db: Optional[VectorStore]) -> gr.Checkbox:
    value = visible = db is not None
    return gr.Checkbox(value=value, label='RAG Mode', scale=1, visible=visible)
    
    
def get_rag_settings(
    rag_mode: bool, 
    context_template_value: str, 
    render: bool = True,
    ) -> Tuple[gr.component, ...]:

    k = gr.Radio(
        choices=[1, 2, 3, 4, 5, 'all'],
        value=2,
        label='Number of relevant documents for search',
        visible=rag_mode,
        render=render,
        )
    score_threshold = gr.Slider(
        minimum=0,
        maximum=1,
        value=0.5,
        step=0.05,
        label='relevance_scores_threshold',
        visible=rag_mode,
        render=render,
        )
    context_template = gr.Textbox(
        value=context_template_value,
        label='Context Template',
        lines=len(context_template_value.split('\n')),
        visible=rag_mode,
        render=render,
    )
    return k, score_threshold, context_template


def get_user_message_with_context(text: str, rag_mode: bool) -> gr.component:
    num_lines = len(text.split('\n'))
    max_lines = 10
    num_lines = max_lines if num_lines > max_lines else num_lines
    return gr.Textbox(
        text,
        visible=rag_mode,
        interactive=False,
        label='User Message With Context',
        lines=num_lines,
        )


def get_system_prompt_component(interactive: bool) -> gr.Textbox:
    value = '' if interactive else 'System prompt is not supported by this model'
    return gr.Textbox(value=value, label='System prompt', interactive=interactive)


def get_generate_args(do_sample: bool) -> List[gr.component]:
    generate_args = [
        gr.Slider(minimum=0.1, maximum=3, value=GENERATE_KWARGS['temperature'], step=0.1, label='temperature', visible=do_sample),
        gr.Slider(minimum=0, maximum=1, value=GENERATE_KWARGS['top_p'], step=0.01, label='top_p', visible=do_sample),
        gr.Slider(minimum=1, maximum=50, value=GENERATE_KWARGS['top_k'], step=1, label='top_k', visible=do_sample),
        gr.Slider(minimum=1, maximum=5, value=GENERATE_KWARGS['repeat_penalty'], step=0.1, label='repeat_penalty', visible=do_sample),
    ]
    return generate_args
    

# ================ LOADING AND INITIALIZING MODELS ========================

start_llm_model, start_support_system_role, load_log = load_llm_model(
    model_repo=LLM_MODEL_REPOS[0],
    model_file=START_LLM_MODEL_FILE,
)

if start_llm_model['llm_model'] is None:
    raise Exception(f'LLM model not initialized, status message: {load_log}')


start_embed_model, load_log = load_embed_model(
    model_repo=EMBED_MODEL_REPOS[0],
)

if start_embed_model['embed_model'] is None:
    raise Exception(f'Embed model not initialized, status message: {load_log}')


# ================== APPLICATION WEB INTERFACE ============================

css = '''
.gradio-container {
    width: 70% !important;
    margin: 0 auto !important;
}
'''

with gr.Blocks(css=css) as interface:

    # ==================== GRADIO STATES ===============================

    documents = gr.State([])
    db = gr.State(None)
    user_message_with_context = gr.State('')
    support_system_role = gr.State(start_support_system_role)
    llm_model_repos = gr.State(LLM_MODEL_REPOS)
    embed_model_repos = gr.State(EMBED_MODEL_REPOS)
    llm_model = gr.State(start_llm_model)
    embed_model = gr.State(start_embed_model)



    # ==================== BOT PAGE =================================

    with gr.Tab(label='Chatbot'):
        with gr.Row():
            with gr.Column(scale=3):
                chatbot = gr.Chatbot(
                    type='messages',  # new in gradio 5+
                    show_copy_button=True,
                    height=480,
                )
                user_message = gr.Textbox(label='User')

                with gr.Row():
                    user_message_btn = gr.Button('Send')
                    stop_btn = gr.Button('Stop')
                    clear_btn = gr.Button('Clear')

            # ------------- GENERATION PARAMETERS -------------------

            with gr.Column(scale=1, min_width=80):
                with gr.Group():
                    gr.Markdown('History size')
                    history_len = gr.Slider(
                        minimum=0,
                        maximum=5,
                        value=0,
                        step=1,
                        info='Number of previous messages taken into account in history',
                        label='history_len',
                        show_label=False,
                        )

                    with gr.Group():
                        gr.Markdown('Generation parameters')
                        do_sample = gr.Checkbox(
                            value=False,
                            label='do_sample',
                            info='Activate random sampling',
                            )
                        generate_args = get_generate_args(do_sample.value)
                        do_sample.change(
                            fn=get_generate_args,
                            inputs=do_sample,
                            outputs=generate_args,
                            show_progress=False,
                            )

        rag_mode = get_rag_mode_component(db=db.value)
        k, score_threshold, context_template = get_rag_settings(
            rag_mode=rag_mode.value,
            context_template_value=CONTEXT_TEMPLATE,
            render=False,
            )
        rag_mode.change(
            fn=get_rag_settings,
            inputs=[rag_mode, context_template],
            outputs=[k, score_threshold, context_template],
            )

        with gr.Row():
            k.render()
            score_threshold.render()

        # ---------------- SYSTEM PROMPT AND USER MESSAGE -----------

        with gr.Accordion('Prompt', open=True):
            system_prompt = get_system_prompt_component(interactive=support_system_role.value)
            context_template.render()
            user_message_with_context = get_user_message_with_context(text='', rag_mode=rag_mode.value)

        # ---------------- SEND, CLEAR AND STOP BUTTONS ------------

        generate_event = gr.on(
            triggers=[user_message.submit, user_message_btn.click],
            fn=user_message_to_chatbot,
            inputs=[user_message, chatbot],
            outputs=[user_message, chatbot],
            # queue=False,
        ).then(
            fn=update_user_message_with_context,
            inputs=[chatbot, rag_mode, db, k, score_threshold, context_template],
            outputs=[user_message_with_context],
        ).then(
            fn=get_user_message_with_context,
            inputs=[user_message_with_context, rag_mode],
            outputs=[user_message_with_context],
        ).then(
            fn=get_llm_response,
            inputs=[chatbot, llm_model, user_message_with_context, rag_mode, system_prompt,
                    support_system_role, history_len, do_sample, *generate_args],
            outputs=[chatbot],
        )

        stop_btn.click(
            fn=None,
            inputs=None,
            outputs=None,
            cancels=generate_event,
            queue=False,
        )

        clear_btn.click(
            fn=lambda: (None, ''),
            inputs=None,
            outputs=[chatbot, user_message_with_context],
            queue=False,
            )



    # ================= FILE DOWNLOAD PAGE =========================

    with gr.Tab(label='Load documents'):
        with gr.Row(variant='compact'):
            upload_files = gr.File(file_count='multiple', label='Loading text files')
            web_links = gr.Textbox(lines=6, label='Links to Web sites or YouTube')

        with gr.Row(variant='compact'):
            chunk_size = gr.Slider(50, 2000, value=500, step=50, label='Chunk size')
            chunk_overlap = gr.Slider(0, 200, value=20, step=10, label='Chunk overlap')

            subtitles_lang = gr.Radio(
                SUBTITLES_LANGUAGES,
                value=SUBTITLES_LANGUAGES[0],
                label='YouTube subtitle language',
                )

        load_documents_btn = gr.Button(value='Upload documents and initialize database')
        load_docs_log = gr.Textbox(label='Status of loading and splitting documents', interactive=False)

        load_documents_btn.click(
            fn=load_documents_and_create_db,
            inputs=[upload_files, web_links, subtitles_lang, chunk_size, chunk_overlap, embed_model],
            outputs=[documents, db, load_docs_log],
        ).success(
            fn=get_rag_mode_component,
            inputs=[db],
            outputs=[rag_mode],
        )

        gr.HTML("""<h3 style='text-align: center'>
        <a href="https://github.com/sergey21000/chatbot-rag" target='_blank'>GitHub Repository</a></h3>
        """)



    # ================= VIEW PAGE FOR ALL DOCUMENTS =================

    with gr.Tab(label='View documents'):
        view_documents_btn = gr.Button(value='Show downloaded text chunks')
        view_documents_textbox = gr.Textbox(
            lines=1,
            placeholder='To view chunks, load documents in the Load documents tab',
            label='Uploaded chunks',
            )
        sep = '=' * 20
        view_documents_btn.click(
            lambda documents: f'\n{sep}\n\n'.join([doc.page_content for doc in documents]),
            inputs=[documents],
            outputs=[view_documents_textbox],
        )


    # ============== GGUF MODELS DOWNLOAD PAGE =====================

    with gr.Tab('Load LLM model'):
        new_llm_model_repo = gr.Textbox(
            value='',
            label='Add repository',
            placeholder='Link to repository of HF models in GGUF format',
            )
        new_llm_model_repo_btn = gr.Button('Add repository')
        curr_llm_model_repo = gr.Dropdown(
            choices=LLM_MODEL_REPOS,
            value=None,
            label='HF Model Repository',
            )
        curr_llm_model_path = gr.Dropdown(
            choices=[],
            value=None,
            label='GGUF model file',
            )
        load_llm_model_btn = gr.Button('Loading and initializing model')
        load_llm_model_log = gr.Textbox(
            value=f'Model {LLM_MODEL_REPOS[0]} loaded at application startup',
            label='Model loading status',
            lines=6,
            )

        with gr.Group():
            gr.Markdown('Free up disk space by deleting all models except the currently selected one')
            clear_llm_folder_btn = gr.Button('Clear folder')

        new_llm_model_repo_btn.click(
            fn=add_new_model_repo,
            inputs=[new_llm_model_repo, llm_model_repos],
            outputs=[curr_llm_model_repo, load_llm_model_log],
        ).success(
            fn=lambda: '',
            inputs=None,
            outputs=[new_llm_model_repo],
        )

        curr_llm_model_repo.change(
            fn=get_gguf_model_names,
            inputs=[curr_llm_model_repo],
            outputs=[curr_llm_model_path],
        )

        load_llm_model_btn.click(
            fn=load_llm_model,
            inputs=[curr_llm_model_repo, curr_llm_model_path],
            outputs=[llm_model, support_system_role, load_llm_model_log],
        ).success(
            fn=lambda log: log + get_memory_usage(),
            inputs=[load_llm_model_log],
            outputs=[load_llm_model_log],
        ).then(
            fn=get_system_prompt_component,
            inputs=[support_system_role],
            outputs=[system_prompt],
        )

        clear_llm_folder_btn.click(
            fn=clear_llm_folder,
            inputs=[curr_llm_model_path],
            outputs=None,
        ).success(
            fn=lambda model_path: f'Models other than {model_path} removed',
            inputs=[curr_llm_model_path],
            outputs=None,
        )


    # ============== EMBEDDING MODELS DOWNLOAD PAGE =============

    with gr.Tab('Load embed model'):
        new_embed_model_repo = gr.Textbox(
            value='',
            label='Add repository',
            placeholder='Link to HF model repository',
            )
        new_embed_model_repo_btn = gr.Button('Add repository')
        curr_embed_model_repo = gr.Dropdown(
            choices=EMBED_MODEL_REPOS,
            value=None,
            label='HF model repository',
            )

        load_embed_model_btn = gr.Button('Loading and initializing model')
        load_embed_model_log = gr.Textbox(
            value=f'Model {EMBED_MODEL_REPOS[0]} loaded at application startup',
            label='Model loading status',
            lines=7,
            )
        with gr.Group():
            gr.Markdown('Free up disk space by deleting all models except the currently selected one')
            clear_embed_folder_btn = gr.Button('Clear folder')

        new_embed_model_repo_btn.click(
            fn=add_new_model_repo,
            inputs=[new_embed_model_repo, embed_model_repos],
            outputs=[curr_embed_model_repo, load_embed_model_log],
        ).success(
            fn=lambda: '',
            inputs=None,
            outputs=new_embed_model_repo,
        )

        load_embed_model_btn.click(
            fn=load_embed_model,
            inputs=[curr_embed_model_repo],
            outputs=[embed_model, load_embed_model_log],
        ).success(
            fn=lambda log: log + get_memory_usage(),
            inputs=[load_embed_model_log],
            outputs=[load_embed_model_log],
        )

        clear_embed_folder_btn.click(
            fn=clear_embed_folder,
            inputs=[curr_embed_model_repo],
            outputs=None,
        ).success(
            fn=lambda model_repo: f'Models other than {model_repo} removed',
            inputs=[curr_embed_model_repo],
            outputs=None,
        )


interface.launch(server_name='0.0.0.0', server_port=7860)  # debug=True