import sys
import random
import gradio as gr
import matplotlib.pyplot as plt

import os
import argparse
import random
from omegaconf import OmegaConf
import torch
import torchvision
from pytorch_lightning import seed_everything
from huggingface_hub import hf_hub_download
import spaces

sys.path.insert(0, "scripts/evaluation")
from funcs import (
    batch_ddim_sampling,
    batch_ddim_sampling_freetraj,
    load_model_checkpoint,
)
from utils.utils import instantiate_from_config
from utils.utils_freetraj import plan_path

video_length = 16
width = 512 
height = 320
MAX_KEYS = 5

ckpt_dir_512 = "checkpoints/base_512_v2"
ckpt_path_512 = "checkpoints/base_512_v2/model.ckpt"
if not os.path.exists(ckpt_path_512):
    os.makedirs(ckpt_dir_512, exist_ok=True)
    hf_hub_download(repo_id="VideoCrafter/VideoCrafter2", filename="model.ckpt", local_dir=ckpt_dir_512, force_download=True)
    print('Model Loaded.')
    
def check_move(trajectory, video_length=16):
    traj_len = len(trajectory)
    if traj_len < 2:
        return False
    prev_pos = trajectory[0]
    for i in range(1, traj_len):
        cur_pos = trajectory[i]
        if cur_pos[0] > video_length - 1:
            return False
        if (cur_pos[0] - prev_pos[0]) * ((cur_pos[1] - prev_pos[1]) ** 2 + (cur_pos[2] - prev_pos[2]) ** 2) ** 0.5 < 0.02:
            print("Too small movement, please use ori mode.")
            return False
        prev_pos = cur_pos

    return True

def check(radio_mode):
    if radio_mode == 'ori':
        video_path = "output.mp4"
        video_bbox_path = "output.mp4"
    else:
        video_path = "output_freetraj.mp4"
        video_bbox_path = "output_freetraj_bbox.mp4"
    return video_path, video_bbox_path
    

def infer(*user_args):
    prompt_in = user_args[0]
    target_indices = user_args[1]
    ddim_edit = user_args[2]
    seed = user_args[3]
    ddim_steps = user_args[4]
    unconditional_guidance_scale = user_args[5]
    video_fps = user_args[6]
    save_fps = user_args[7]
    height_ratio = user_args[8]
    width_ratio = user_args[9]
    radio_mode = user_args[10]
    dropdown_diy = user_args[11]
    frame_indices = user_args[-3 * MAX_KEYS: -2 * MAX_KEYS]
    h_positions = user_args[-2 * MAX_KEYS: -MAX_KEYS]
    w_positions = user_args[-MAX_KEYS:]
    print(user_args)

    if radio_mode == 'ori':
        config_512 = "configs/inference_t2v_512_v2.0.yaml"
    else:
        config_512 = "configs/inference_t2v_freetraj_512_v2.0.yaml"

    trajectory = []
    for i in range(dropdown_diy):
        trajectory.append([int(frame_indices[i]), h_positions[i], w_positions[i]])
    trajectory.sort()
    print(trajectory)

    if not check_move(trajectory):
        print("Error trajectory.")

    input_traj = []
    h_remain = 1 - height_ratio
    w_remain = 1 - width_ratio
    for i in trajectory:
        h_relative = i[1] * h_remain
        w_relative = i[2] * w_remain
        input_traj.append([i[0], h_relative, h_relative+height_ratio, w_relative, w_relative+width_ratio])

    if len(target_indices) < 1:
        indices_list = [1, 2]
    else:
        indices_list = target_indices.split(',')
    idx_list = []
    for i in indices_list:
        idx_list.append(int(i))

    config_512 = OmegaConf.load(config_512)
    model_config_512 = config_512.pop("model", OmegaConf.create())

    args = argparse.Namespace(
        mode="base",
        savefps=save_fps,
        n_samples=1,
        ddim_steps=ddim_steps,
        ddim_eta=0.0,
        bs=1,
        fps=video_fps,
        unconditional_guidance_scale=unconditional_guidance_scale,
        unconditional_guidance_scale_temporal=None,
        cond_input=None,
        prompt_in = prompt_in,
        seed = seed,
        ddim_edit = ddim_edit,
        model_config_512 = model_config_512,
        idx_list = idx_list,
        input_traj = input_traj,
    )

    print('GPU starts')
    video = infer_gpu_part(args)
    print('GPU ends')
    
    video = torch.clamp(video.float(), -1.0, 1.0)
    video = video.permute(2, 0, 1, 3, 4)  # t,n,c,h,w

    if radio_mode == 'ori':
        video_path = "output.mp4"
        video_bbox_path = "output.mp4"
        frame_grids = [
            torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples))
            for framesheet in video
        ]  # [3, 1*h, n*w]
        grid = torch.stack(frame_grids, dim=0)  # stack in temporal dim [t, 3, n*h, w]
        grid = (grid + 1.0) / 2.0
        grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)

        torchvision.io.write_video(
            video_path,
            grid,
            fps=args.savefps,
            video_codec="h264",
            options={"crf": "10"},
        )
    else:
        video_path = "output_freetraj.mp4"
        video_bbox_path = "output_freetraj_bbox.mp4"
        frame_grids = [
            torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples))
            for framesheet in video
        ]  # [3, 1*h, n*w]
        grid = torch.stack(frame_grids, dim=0)  # stack in temporal dim [t, 3, n*h, w]
        grid = (grid + 1.0) / 2.0
        grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)

        torchvision.io.write_video(
            video_path,
            grid,
            fps=args.savefps,
            video_codec="h264",
            options={"crf": "10"},
        )

        BOX_SIZE_H = input_traj[0][2] - input_traj[0][1]
        BOX_SIZE_W = input_traj[0][4] - input_traj[0][3]
        PATHS = plan_path(input_traj)
        h_len = grid.shape[1]
        w_len = grid.shape[2]
        sub_h = int(BOX_SIZE_H * h_len) 
        sub_w = int(BOX_SIZE_W * w_len)
        for j in range(grid.shape[0]):
            h_start = int(PATHS[j][0] * h_len)
            h_end = h_start + sub_h
            w_start = int(PATHS[j][2] * w_len)
            w_end = w_start + sub_w

            h_start = max(1, h_start)
            h_end = min(h_len-1, h_end)
            w_start = max(1, w_start)
            w_end = min(w_len-1, w_end)

            grid[j, h_start-1:h_end+1, w_start-1:w_start+2, :] = torch.ones_like(grid[j, h_start-1:h_end+1, w_start-1:w_start+2, :]) * torch.Tensor([127, 255, 127]).view(1, 1, 3)
            grid[j, h_start-1:h_end+1, w_end-2:w_end+1, :] = torch.ones_like(grid[j, h_start-1:h_end+1, w_end-2:w_end+1, :]) * torch.Tensor([127, 255, 127]).view(1, 1, 3)
            grid[j, h_start-1:h_start+2, w_start-1:w_end+1, :] = torch.ones_like(grid[j, h_start-1:h_start+2, w_start-1:w_end+1, :]) * torch.Tensor([127, 255, 127]).view(1, 1, 3)
            grid[j, h_end-2:h_end+1, w_start-1:w_end+1, :] = torch.ones_like(grid[j, h_end-2:h_end+1, w_start-1:w_end+1, :]) * torch.Tensor([127, 255, 127]).view(1, 1, 3)

        torchvision.io.write_video(
            video_bbox_path,
            grid,
            fps=args.savefps,
            video_codec="h264",
            options={"crf": "10"},
        )

    return video_path, video_bbox_path

    
@spaces.GPU(duration=250)
def infer_gpu_part(args):

    model = instantiate_from_config(args.model_config_512)
    model = model.cuda()
    model = load_model_checkpoint(model, ckpt_path_512)
    model.eval()

    if args.seed is None:
        seed = int.from_bytes(os.urandom(2), "big")
    else:
        seed = args.seed
    print(f"Using seed: {seed}")
    seed_everything(seed)

    ## latent noise shape
    h, w = height // 8, width // 8
    frames = video_length
    channels = model.channels

    batch_size = 1
    noise_shape = [batch_size, channels, frames, h, w]
    fps = torch.tensor([args.fps] * batch_size).to(model.device).long()
    prompts = [args.prompt_in]
    text_emb = model.get_learned_conditioning(prompts)

    cond = {"c_crossattn": [text_emb], "fps": fps}

    ## inference
    if radio_mode == 'ori':
        batch_samples = batch_ddim_sampling(
            model,
            cond,
            noise_shape,
            args.n_samples,
            args.ddim_steps,
            args.ddim_eta,
            args.unconditional_guidance_scale,
            args=args,
        )
    else:
        batch_samples = batch_ddim_sampling_freetraj(
            model,
            cond,
            noise_shape,
            args.n_samples,
            args.ddim_steps,
            args.ddim_eta,
            args.unconditional_guidance_scale,
            idx_list = args.idx_list,
            input_traj = args.input_traj,
            args=args,
        )

    vid_tensor = batch_samples[0]
    video = vid_tensor.detach().cpu()

    return video


examples = [
    ["A squirrel jumping from one tree to another.",],
    ["A bear climbing down a tree after spotting a threat.",],
    ["A corgi running on the grassland on the grassland.",],
    ["A barrel floating in a river.",],
    ["A horse galloping on a street.",],
    ["A majestic eagle soaring high above the treetops, surveying its territory.",],
]

css = """
#col-container {max-width: 1024px; margin-left: auto; margin-right: auto;}
a {text-decoration-line: underline; font-weight: 600;}
.animate-spin {
  animation: spin 1s linear infinite;
}
#share-btn-container {
  display: flex; 
  padding-left: 0.5rem !important; 
  padding-right: 0.5rem !important; 
  background-color: #000000; 
  justify-content: center; 
  align-items: center; 
  border-radius: 9999px !important; 
  max-width: 15rem;
  height: 36px;
}
div#share-btn-container > div {
    flex-direction: row;
    background: black;
    align-items: center;
}
#share-btn-container:hover {
  background-color: #060606;
}
#share-btn {
  all: initial; 
  color: #ffffff;
  font-weight: 600; 
  cursor:pointer; 
  font-family: 'IBM Plex Sans', sans-serif; 
  margin-left: 0.5rem !important; 
  padding-top: 0.5rem !important; 
  padding-bottom: 0.5rem !important;
  right:0;
}
#share-btn * {
  all: unset;
}
#share-btn-container div:nth-child(-n+2){
  width: auto !important;
  min-height: 0px !important;
}
#share-btn-container .wrap {
  display: none !important;
}
#share-btn-container.hidden {
  display: none!important;
}
img[src*='#center'] { 
    display: inline-block;
    margin: unset;
}
.footer {
        margin-bottom: 45px;
        margin-top: 10px;
        text-align: center;
        border-bottom: 1px solid #e5e5e5;
    }
    .footer>p {
        font-size: .8rem;
        display: inline-block;
        padding: 0 10px;
        transform: translateY(10px);
        background: white;
    }
    .dark .footer {
        border-color: #303030;
    }
    .dark .footer>p {
        background: #0b0f19;
    }
"""

def mode_update(mode):
    if mode == 'demo':
        trajectories_mode = [gr.Row(visible=True), gr.Row(visible=False)]
    elif mode == 'diy':
        trajectories_mode = [gr.Row(visible=False), gr.Row(visible=True)]
    else:
        trajectories_mode = [gr.Row(visible=False), gr.Row(visible=False)]
    return trajectories_mode

def keyframe_update(num):
    keyframes = []
    if type(num) != int:
        num = 0

    for i in range(num):
        keyframes.append(gr.Row(visible=True))
    for i in range(MAX_KEYS - num):
        keyframes.append(gr.Row(visible=False))
    return keyframes

def demo_update(mode):
    if mode == 'topleft->bottomright':
        num = 2
    elif mode == 'bottomleft->topright':
        num = 2
    elif mode == 'topleft->bottomleft->bottomright':
        num = 3
    elif mode == 'bottomright->topright->topleft':
        num = 3
    elif mode == '"V"':
        num = 4
    elif mode == '"^"':
        num = 4
    elif mode == 'left->right->left->right':
        num = 4
    elif mode == 'triangle':
        num = 4
    else:
        num = 0

    return num

def demo_update_frame(mode):
    frame_indices = []
    if mode == 'topleft->bottomright':
        num = 2
        frame_indices.append(gr.Text(value=0))
        frame_indices.append(gr.Text(value=15))
    elif mode == 'bottomleft->topright':
        num = 2
        frame_indices.append(gr.Text(value=0))
        frame_indices.append(gr.Text(value=15))
    elif mode == 'topleft->bottomleft->bottomright':
        num = 3
        frame_indices.append(gr.Text(value=0))
        frame_indices.append(gr.Text(value=9))
        frame_indices.append(gr.Text(value=15))
    elif mode == 'bottomright->topright->topleft':
        num = 3
        frame_indices.append(gr.Text(value=0))
        frame_indices.append(gr.Text(value=6))
        frame_indices.append(gr.Text(value=15))
    elif mode == '"V"':
        num = 4
        frame_indices.append(gr.Text(value=0))
        frame_indices.append(gr.Text(value=7))
        frame_indices.append(gr.Text(value=8))
        frame_indices.append(gr.Text(value=15))
    elif mode == '"^"':
        num = 4
        frame_indices.append(gr.Text(value=0))
        frame_indices.append(gr.Text(value=7))
        frame_indices.append(gr.Text(value=8))
        frame_indices.append(gr.Text(value=15))
    elif mode == 'left->right->left->right':
        num = 4
        frame_indices.append(gr.Text(value=0))
        frame_indices.append(gr.Text(value=5))
        frame_indices.append(gr.Text(value=10))
        frame_indices.append(gr.Text(value=15))
    elif mode == 'triangle':
        num = 4
        frame_indices.append(gr.Text(value=0))
        frame_indices.append(gr.Text(value=5))
        frame_indices.append(gr.Text(value=10))
        frame_indices.append(gr.Text(value=15))
    else:
        num = 0

    for i in range(MAX_KEYS - num):
        frame_indices.append(gr.Text())
    return frame_indices

def demo_update_h(mode):
    h_positions = []
    if mode == 'topleft->bottomright':
        num = 2
        h_positions.append(gr.Slider(value=0.1))
        h_positions.append(gr.Slider(value=0.9))
    elif mode == 'bottomleft->topright':
        num = 2
        h_positions.append(gr.Slider(value=0.9))
        h_positions.append(gr.Slider(value=0.1))
    elif mode == 'topleft->bottomleft->bottomright':
        num = 3
        h_positions.append(gr.Slider(value=0.1))
        h_positions.append(gr.Slider(value=0.9))
        h_positions.append(gr.Slider(value=0.9))
    elif mode == 'bottomright->topright->topleft':
        num = 3
        h_positions.append(gr.Slider(value=0.9))
        h_positions.append(gr.Slider(value=0.1))
        h_positions.append(gr.Slider(value=0.1))
    elif mode == '"V"':
        num = 4
        h_positions.append(gr.Slider(value=0.1))
        h_positions.append(gr.Slider(value=0.9))
        h_positions.append(gr.Slider(value=0.9))
        h_positions.append(gr.Slider(value=0.1))
    elif mode == '"^"':
        num = 4
        h_positions.append(gr.Slider(value=0.9))
        h_positions.append(gr.Slider(value=0.1))
        h_positions.append(gr.Slider(value=0.1))
        h_positions.append(gr.Slider(value=0.9))
    elif mode == 'left->right->left->right':
        num = 4
        h_positions.append(gr.Slider(value=0.5))
        h_positions.append(gr.Slider(value=0.5))
        h_positions.append(gr.Slider(value=0.5))
        h_positions.append(gr.Slider(value=0.5))
    elif mode == 'triangle':
        num = 4
        h_positions.append(gr.Slider(value=0.1))
        h_positions.append(gr.Slider(value=0.9))
        h_positions.append(gr.Slider(value=0.9))
        h_positions.append(gr.Slider(value=0.1))
    else:
        num = 0

    for i in range(MAX_KEYS - num):
        h_positions.append(gr.Slider())
    return h_positions

def demo_update_w(mode):
    w_positions = []
    if mode == 'topleft->bottomright':
        num = 2
        w_positions.append(gr.Slider(value=0.1))
        w_positions.append(gr.Slider(value=0.9))
    elif mode == 'bottomleft->topright':
        num = 2
        w_positions.append(gr.Slider(value=0.1))
        w_positions.append(gr.Slider(value=0.9))
    elif mode == 'topleft->bottomleft->bottomright':
        num = 3
        w_positions.append(gr.Slider(value=0.1))
        w_positions.append(gr.Slider(value=0.1))
        w_positions.append(gr.Slider(value=0.9))
    elif mode == 'bottomright->topright->topleft':
        num = 3
        w_positions.append(gr.Slider(value=0.9))
        w_positions.append(gr.Slider(value=0.9))
        w_positions.append(gr.Slider(value=0.1))
    elif mode == '"V"':
        num = 4
        w_positions.append(gr.Slider(value=0.1))
        w_positions.append(gr.Slider(value=0.8/15*7 + 0.1))
        w_positions.append(gr.Slider(value=0.8/15*8 + 0.1))
        w_positions.append(gr.Slider(value=0.9))
    elif mode == '"^"':
        num = 4
        w_positions.append(gr.Slider(value=0.9))
        w_positions.append(gr.Slider(value=0.8/15*8 + 0.1))
        w_positions.append(gr.Slider(value=0.8/15*7 + 0.1))
        w_positions.append(gr.Slider(value=0.1))
    elif mode == 'left->right->left->right':
        num = 4
        w_positions.append(gr.Slider(value=0.1))
        w_positions.append(gr.Slider(value=0.9))
        w_positions.append(gr.Slider(value=0.1))
        w_positions.append(gr.Slider(value=0.9))
    elif mode == 'triangle':
        num = 4
        w_positions.append(gr.Slider(value=0.5))
        w_positions.append(gr.Slider(value=0.9))
        w_positions.append(gr.Slider(value=0.1))
        w_positions.append(gr.Slider(value=0.5))
    else:
        num = 0

    for i in range(MAX_KEYS - num):
        w_positions.append(gr.Slider())
    return w_positions

def plot_update(*positions):
    key_length = positions[-1]
    frame_indices = positions[:key_length]
    if type(key_length) != int or len(frame_indices) < 2:
        traj_plot = gr.Plot(
            label="Trajectory"
        )
        return traj_plot
    frame_indices = [int(i) for i in frame_indices]
    h_positions = positions[MAX_KEYS:MAX_KEYS+key_length]
    w_positions = positions[2*MAX_KEYS:2*MAX_KEYS+key_length]
    frame_indices, h_positions, w_positions = zip(*sorted(zip(frame_indices, h_positions, w_positions)))
    plt.cla()
    plt.xlim(0, 1)
    plt.ylim(0, 1)
    plt.gca().invert_yaxis()
    plt.gca().xaxis.tick_top()
    plt.plot(w_positions, h_positions, linestyle='-', marker = 'o', markerfacecolor='r')
    traj_plot = gr.Plot(
        label="Trajectory",
        value = plt
    )
    return traj_plot


with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(
            """
            <h1 style="text-align: center;">FreeTraj</h1>
            <p style="text-align: center;">
            Tuning-Free Trajectory Control in Video Diffusion Models
            </p>
            <p style="text-align: center;">
            <a href="https://arxiv.org/abs/2406.16863" target="_blank"><b>[arXiv]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
            <a href="http://haonanqiu.com/projects/FreeTraj.html" target="_blank"><b>[Project Page]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
            <a href="https://github.com/arthur-qiu/FreeTraj" target="_blank"><b>[Code]</b></a>
            </p>         
            """
        )

        keyframes = []
        frame_indices = []
        h_positions = []
        w_positions = []

        with gr.Row():
            video_result = gr.Video(label="Video Output")
            video_result_bbox = gr.Video(label="Video Output with BBox")

        with gr.Group():
            with gr.Row():
                prompt_in = gr.Textbox(label="Prompt", placeholder="A corgi running on the grassland on the grassland.", scale = 5)
                target_indices = gr.Textbox(label="Target Indices (1 for the first word, necessary!)", placeholder="1,2", scale = 2)

            with gr.Row():
                radio_mode = gr.Radio(label='Trajectory Mode', choices = ['demo', 'diy', 'ori'], scale = 1)
                height_ratio = gr.Slider(label='Height Ratio of BBox',
                                minimum=0.2,
                                maximum=0.4,
                                step=0.01,
                                value=0.3,
                                scale = 1)
                width_ratio = gr.Slider(label='Width Ratio of BBox',
                                minimum=0.2,
                                maximum=0.4,
                                step=0.01,
                                value=0.3, 
                                scale = 1)
            
            with gr.Row(visible=False) as row_demo:
                dropdown_demo = gr.Dropdown(
                    label="Demo Trajectory",
                    choices= ['topleft->bottomright', 'bottomleft->topright', 'topleft->bottomleft->bottomright', 'bottomright->topright->topleft', '"V"', '"^"', 'left->right->left->right', 'triangle']
                )
                
            with gr.Row(visible=False) as row_diy:
                dropdown_diy = gr.Dropdown(
                    label="Number of Keyframes",
                    choices=range(2, MAX_KEYS+1),
                )
                
            for i in range(MAX_KEYS):
                with gr.Row(visible=False) as row:
                    text = gr.Textbox(
                        value=f"Keyframe #{i}",
                        interactive=False,
                        container = False,
                        lines = 3,
                        scale=1
                    )
                    frame_ids = gr.Textbox(
                        None,
                        label=f"Frame Indices #{i}",
                        interactive=True,
                        scale=2
                    )
                    h_position = gr.Slider(label='Position in Height',
                        minimum=0.0,
                        maximum=1.0,
                        step=0.01,
                        scale=2)
                    w_position = gr.Slider(label='Position in Width',
                        minimum=0.0,
                        maximum=1.0,
                        step=0.01,
                        scale=2)
                    
                frame_indices.append(frame_ids)
                h_positions.append(h_position)
                w_positions.append(w_position)
                keyframes.append(row)


            dropdown_demo.change(demo_update, dropdown_demo, dropdown_diy)
            dropdown_diy.change(keyframe_update, dropdown_diy, keyframes)
            dropdown_demo.change(demo_update_frame, dropdown_demo, frame_indices)
            dropdown_demo.change(demo_update_h, dropdown_demo, h_positions)
            dropdown_demo.change(demo_update_w, dropdown_demo, w_positions)
            radio_mode.change(mode_update, radio_mode, [row_demo, row_diy])

            traj_plot = gr.Plot(
                label="Trajectory"
            )

            h_positions[0].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
            h_positions[1].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
            h_positions[2].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
            h_positions[3].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
            h_positions[4].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
            w_positions[0].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
            w_positions[1].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
            w_positions[2].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
            w_positions[3].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)
            w_positions[4].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot)

            with gr.Row():
                with gr.Accordion('Useful FreeTraj Parameters (feel free to adjust these parameters based on your prompt): ', open=True):
                    with gr.Row():
                        ddim_edit = gr.Slider(label='Editing Steps (larger for better control while losing some quality)',
                                minimum=0,
                                maximum=12,
                                step=1,
                                value=6)
                        seed = gr.Slider(label='Random Seed',
                                minimum=0,
                                maximum=10000,
                                step=1,
                                value=123)
                        
            with gr.Row():
                with gr.Accordion('Useless FreeTraj Parameters (mostly no need to adjust): ', open=False):
                    with gr.Row():
                        ddim_steps = gr.Slider(label='DDIM Steps',
                                minimum=5,
                                maximum=50,
                                step=1,
                                value=50)
                        unconditional_guidance_scale = gr.Slider(label='Unconditional Guidance Scale',
                                minimum=1.0,
                                maximum=20.0,
                                step=0.1,
                                value=12.0)
                    with gr.Row():
                        video_fps = gr.Slider(label='Video FPS (larger for quicker motion)',
                                minimum=8,
                                maximum=36,
                                step=4,
                                value=16)
                        save_fps = gr.Slider(label='Save FPS',
                                minimum=1,
                                maximum=30,
                                step=1,
                                value=10)
                
        with gr.Row():
            submit_btn = gr.Button("Generate", variant='primary')

        with gr.Row():
            check_btn = gr.Button("Check Existing Results (in case of the connection lost)", variant='secondary')

        with gr.Row():
            gr.Examples(label='Sample Prompts', examples=examples, inputs=[prompt_in, target_indices, ddim_edit, seed, ddim_steps, unconditional_guidance_scale, video_fps, save_fps, height_ratio, width_ratio, radio_mode, dropdown_diy, *frame_indices, *h_positions, *w_positions])

        demo_list = ['0026_0_0.4_0.4.gif', '0047_1_0.4_0.3.gif', '0051_1_0.4_0.4.gif']
        demo_pick = random.randint(0, len(demo_list) - 1)
        with gr.Row():
            for i in range(len(demo_list)):
                gr.Image(show_label = False, show_download_button = False, value='assets/' + demo_list[i])

        with gr.Row():
            gr.Markdown(
                """
                <h2 style="text-align: center;">Hints</h2>
                <p style="text-align: center;">
                1. Choose trajectory mode <b>"ori"</b> to see whether the prompt works on the pre-trained model. 
                </p>    
                <p style="text-align: center;">
                2. Adjust the prompt or random seed to get a qualified video.
                </p>  
                <p style="text-align: center;">
                3. Choose trajectory mode <b>"demo"</b> to see whether <b>FreeTraj</b> works or not.
                </p>  
                <p style="text-align: center;">
                4. Choose trajectory mode <b>"diy"</b> to plan new trajectory. It may fail in some extreme cases.
                </p>       
                """
            )
            

    submit_btn.click(fn=infer,
            inputs=[prompt_in, target_indices, ddim_edit, seed, ddim_steps, unconditional_guidance_scale, video_fps, save_fps, height_ratio, width_ratio, radio_mode, dropdown_diy, *frame_indices, *h_positions, *w_positions],
            outputs=[video_result, video_result_bbox],
            api_name="generate")
    
    check_btn.click(fn=check,
            inputs=[radio_mode],
            outputs=[video_result, video_result_bbox],
            api_name="check")

demo.queue(max_size=8).launch(show_api=True)