import os
import spaces
import time
try:
  token =os.environ['HF_TOKEN']
except:
  print("paste your hf token here!")
  token = "hf_xxxxxxxxxxxxxxxxxxx"
  os.environ['HF_TOKEN'] = token
import torch
import gradio as gr
from gradio.themes.utils import colors, fonts, sizes

# from faster_whisper import WhisperModel
from moviepy.editor import VideoFileClip
from transformers import AutoTokenizer, AutoModel

import subprocess
# subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

# ========================================
#             Model Initialization
# ========================================

if gr.NO_RELOAD:
    # if torch.cuda.is_available():
        # speech_model = WhisperModel("large-v3", device="cuda", compute_type="float16")
    # else:
        # speech_model = WhisperModel("large-v3", device="cpu")

    model_path = 'OpenGVLab/InternVideo2_5_Chat_8B'

    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()


model.config.mm_llm_compress = False


# ========================================
#          Define Utils
# ========================================


def extract_audio(name):
    with VideoFileClip(name) as video:
        if video.audio == None:
            return None
        audio = video.audio
        audio_name = name[:-4] + '.wav'
        audio.write_audiofile(audio_name, fps=16000)
    return audio_name

@spaces.GPU
def audio2text(audio):
    segments, _ = speech_model.transcribe(audio)  
    text = ""
    for segment in segments:
        # print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
        text += ("[%.2fs -> %.2fs] %s  " % (segment.start, segment.end, segment.text))
        # print(text)
    return text


# ========================================
#             Gradio Setting
# ========================================
def gradio_reset():
    return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False), gr.update(interactive=False) , gr.update(interactive=False), gr.update(value="Upload & Start Chat", interactive=True), [], ""




def upload_video(gr_video, text_input="Type and press Enter"):
    if gr_video is None:
        return gr.update(interactive=False), gr.update(interactive=False), gr.update(interactive=True), gr.update(interactive=False), gr.update(value="Upload & Start Chat", interactive=True), ""
    
    
    # if check_asr:  #表示需要提取音频
    audio_name = extract_audio(gr_video)
    if audio_name != None:
        asr_msg = audio2text(audio_name)
    else:
        asr_msg = ""
    # else:
    #     asr_msg = ""

    return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True, placeholder=text_input), gr.update(value="Start Chatting", interactive=False), asr_msg

def clear_():
    return [], []

def gradio_ask(user_message, chatbot):
    # if len(user_message) == 0:
    #     return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot
    chatbot = chatbot + [[user_message, None]]
    return user_message, chatbot

@spaces.GPU
def gradio_answer(chatbot, text_input, video_path, max_num_frames, check_asr, asr_msg, chat_state, max_new_tokens, do_sample, num_beams, top_p, temperature):



    if chat_state is None or len(chat_state) == 0:
        if asr_msg is None or len(asr_msg) == 0:
            # text_input = f"Answer the question based on the video content.\n{text_input}"
            pass
        elif check_asr:
            text_input = f"The speech extracted from the video via ASR is as follows: {asr_msg}\n{text_input}"

    print(f"\033[91m== text_input: \033[0m\n{text_input}\n")

    response, chat_state = model.chat(video_path=video_path, tokenizer=tokenizer, user_prompt=text_input, chat_history=chat_state, return_history=True, max_num_frames=max_num_frames, generation_config={
        'max_new_tokens': max_new_tokens, 'do_sample':do_sample,
        'num_beams':num_beams, 'top_p':top_p, 'temperature':temperature
    })

    current_response = ""

    for char in response:
        current_response += char
        chatbot[-1][1] = current_response + "▌"  
        yield chatbot, chat_state
        time.sleep(0.008)  
    chatbot[-1][1] = current_response 
    yield chatbot, chat_state 


class OpenGVLab(gr.themes.base.Base):
    def __init__(
        self,
        *,
        primary_hue=colors.blue,
        secondary_hue=colors.sky,
        neutral_hue=colors.gray,
        spacing_size=sizes.spacing_md,
        radius_size=sizes.radius_sm,
        text_size=sizes.text_md,
        font=(
            fonts.GoogleFont("Noto Sans"),
            "ui-sans-serif",
            "sans-serif",
        ),
        font_mono=(
            fonts.GoogleFont("IBM Plex Mono"),
            "ui-monospace",
            "monospace",
        ),
    ):
        super().__init__(
            primary_hue=primary_hue,
            secondary_hue=secondary_hue,
            neutral_hue=neutral_hue,
            spacing_size=spacing_size,
            radius_size=radius_size,
            text_size=text_size,
            font=font,
            font_mono=font_mono,
        )
        super().set(
            body_background_fill="*neutral_50",
        )


gvlabtheme = OpenGVLab(primary_hue=colors.blue,
        secondary_hue=colors.sky,
        neutral_hue=colors.gray,
        spacing_size=sizes.spacing_md,
        radius_size=sizes.radius_sm,
        text_size=sizes.text_md,
        )

title = """<h1 align="center"><a href="https://github.com/OpenGVLab/VideoChat-Flash"><img src="https://s1.ax1x.com/2023/05/07/p9dBMOU.png" alt="VideoChat-Flash" border="0" style="margin: 0 auto; height: 100px;" /></a> </h1>"""
description ="""
        VideoChat-Flash-7B@448 powered by InternVideo!<br><p><a href='https://github.com/OpenGVLab/VideoChat-Flash'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p>
        """


with gr.Blocks(title="VideoChat-Flash",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    # with gr.Row():
    #     # options_yes_no = ["YES", "NO"]
    #     # with gr.Row():
    #     #     radio_type = gr.Radio(choices=options_1, label="VideoChat-Flash", value=options_1[0])
    #     with gr.Row():

    with gr.Row():
        with gr.Column(scale=0.5, visible=True) as video_upload:
            with gr.Column(elem_id="image", scale=0.5) as img_part:
                up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload")

            upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
            restart = gr.Button("Restart")

            max_num_frames = gr.Slider(
                minimum=4,
                maximum=1024,
                value=512,
                step=4,
                interactive=True,
                label="Max Input Frames",
            )

            max_new_tokens = gr.Slider(
                minimum=1,
                maximum=4096,
                value=1024,
                step=1,
                interactive=True,
                label="Max Output Tokens",
            )

            check_asr = gr.Checkbox(label="Use ASR", info="Whether to extract speech using ASR.")
            check_do_sample = gr.Checkbox(label="Do Sample", info="Whether to do sample during decoding.")

            num_beams = gr.Slider(
                minimum=1,
                maximum=10,
                value=1,
                step=1,
                interactive=True,
                visible=False,
                label="beam search numbers)",
            )

            top_p = gr.Slider(
                minimum=0.0,
                maximum=1.0,
                value=0.1,
                step=0.1,
                visible=False,
                interactive=True, label="Top_P",
            )

            temperature = gr.Slider(
                minimum=0.1,
                maximum=2.0,
                value=0.1,
                step=0.1,
                visible=False,
                interactive=True, label="Temperature",
            )
            
            def toggle_slide(is_checked):
                return gr.update(visible=is_checked), gr.update(visible=is_checked), gr.update(visible=is_checked) 

            check_do_sample.select(fn=toggle_slide, inputs=check_do_sample, outputs=[num_beams, top_p, temperature])

        with gr.Column(visible=True)  as input_raws:
            chat_state = gr.State([])
            asr_msg = gr.State()
            chatbot = gr.Chatbot(
                elem_id="chatbot",
                label='VideoChat',         
                avatar_images=[
                    "human.jpg",  # 用户头像
                    "assistant.png",  # AI头像
                ])
            with gr.Row():
                with gr.Column(scale=0.7):
                    text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False)
                with gr.Column(scale=0.15, min_width=0):
                    run = gr.Button("💭Send", interactive=False)
                with gr.Column(scale=0.15, min_width=0):
                    clear = gr.Button("🔄Clear️", interactive=False)
            with gr.Row():        
                examples = gr.Examples(
                    examples=[
                        ["demo_videos/basketball.mp4", False, "Describe this video in detail."],
                        ["demo_videos/cup1.mp4", False, "Describe this video in detail."],
                        ["demo_videos/dog.mp4", False, "Describe this video in detail."],
                    ],
                    inputs = [up_video, text_input], 
                    outputs = [run, clear, up_video, text_input, upload_button, asr_msg],
                    fn=upload_video,
                    run_on_click=True
                )

    up_video.clear(gradio_reset, None, [chatbot, up_video, text_input, run, clear, upload_button, chat_state, asr_msg], queue=False)

    upload_button.click(upload_video, [up_video], [run, clear, up_video, text_input, upload_button, asr_msg])


    text_input.submit(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then(
        gradio_answer, [chatbot, text_input, up_video, max_num_frames, check_asr, asr_msg, chat_state, max_new_tokens, check_do_sample, num_beams, top_p, temperature], [chatbot, chat_state]
    ).then(lambda: "", None, text_input)
    
    run.click(gradio_ask, [text_input, chatbot], [text_input, chatbot]).then(
        gradio_answer, [chatbot, text_input, up_video, max_num_frames, check_asr, asr_msg, chat_state, max_new_tokens, check_do_sample, num_beams, top_p, temperature], [chatbot, chat_state]
    ).then(lambda: "", None, text_input)

    clear.click(clear_, None, [chatbot, chat_state])
    restart.click(gradio_reset, None, [chatbot, up_video, text_input, run, clear, upload_button, chat_state, asr_msg], queue=False)

demo.launch(server_name='0.0.0.0')