from transformers import pipeline
import time
import gradio as gr


def get_visual_qa_tab():
    salesforce_model_name = "Salesforce/blip-vqa-base"
    salesforce_pipe = pipeline("visual-question-answering", model=salesforce_model_name)

    dandelin_model_name = "dandelin/vilt-b32-finetuned-vqa"
    dandelin_pipe = pipeline("visual-question-answering", model=dandelin_model_name)

    pipe_map = {
        salesforce_model_name: salesforce_pipe,
        dandelin_model_name: dandelin_pipe
    }

    def gradio_process(model_name, image, text):
        pipe = pipe_map[model_name]
        start = time.time()
        output = pipe(image, text)
        end = time.time()
        time_spent = end - start
        result = output[0]['answer']

        return [result, time_spent]
    
    with gr.TabItem("Visual Q&A") as visual_qa_tab:
        gr.Markdown("# Visual Question & Answering")

        with gr.Row():
            with gr.Column():
                # Input components
                input_image = gr.Image(label="Upload Image", type="pil")
                input_text = gr.Textbox(label="Question")
                model_selector = gr.Dropdown([salesforce_model_name, dandelin_model_name],
                                                label = "Select Model")

                # Process button
                process_btn = gr.Button("Generate answer")

            with gr.Column():
                # Output components
                elapsed_result = gr.Textbox(label="Seconds elapsed", lines=1)
                output_text = gr.Textbox(label="Answer")

        # Connect the input components to the processing function
        process_btn.click(
            fn=gradio_process,
            inputs=[
                model_selector,
                input_image,
                input_text
            ],
            outputs=[output_text, elapsed_result]
        )

    return visual_qa_tab