import sys import random import gradio as gr import matplotlib.pyplot as plt import os import argparse import random from omegaconf import OmegaConf import torch import torchvision from pytorch_lightning import seed_everything from huggingface_hub import hf_hub_download import spaces sys.path.insert(0, "scripts/evaluation") from funcs import ( batch_ddim_sampling, batch_ddim_sampling_freetraj, load_model_checkpoint, ) from utils.utils import instantiate_from_config from utils.utils_freetraj import plan_path video_length = 16 width = 512 height = 320 MAX_KEYS = 5 ckpt_dir_512 = "checkpoints/base_512_v2" ckpt_path_512 = "checkpoints/base_512_v2/model.ckpt" if not os.path.exists(ckpt_path_512): os.makedirs(ckpt_dir_512, exist_ok=True) hf_hub_download(repo_id="VideoCrafter/VideoCrafter2", filename="model.ckpt", local_dir=ckpt_dir_512, force_download=True) print('Model Loaded.') def check_move(trajectory, video_length=16): traj_len = len(trajectory) if traj_len < 2: return False prev_pos = trajectory[0] for i in range(1, traj_len): cur_pos = trajectory[i] if cur_pos[0] > video_length - 1: return False if (cur_pos[0] - prev_pos[0]) * ((cur_pos[1] - prev_pos[1]) ** 2 + (cur_pos[2] - prev_pos[2]) ** 2) ** 0.5 < 0.02: print("Too small movement, please use ori mode.") return False prev_pos = cur_pos return True def check(radio_mode): if radio_mode == 'ori': video_path = "output.mp4" video_bbox_path = "output.mp4" else: video_path = "output_freetraj.mp4" video_bbox_path = "output_freetraj_bbox.mp4" return video_path, video_bbox_path def infer(*user_args): prompt_in = user_args[0] target_indices = user_args[1] ddim_edit = user_args[2] seed = user_args[3] ddim_steps = user_args[4] unconditional_guidance_scale = user_args[5] video_fps = user_args[6] save_fps = user_args[7] height_ratio = user_args[8] width_ratio = user_args[9] radio_mode = user_args[10] dropdown_diy = user_args[11] frame_indices = user_args[-3 * MAX_KEYS: -2 * MAX_KEYS] h_positions = user_args[-2 * MAX_KEYS: -MAX_KEYS] w_positions = user_args[-MAX_KEYS:] print(user_args) if radio_mode == 'ori': config_512 = "configs/inference_t2v_512_v2.0.yaml" else: config_512 = "configs/inference_t2v_freetraj_512_v2.0.yaml" trajectory = [] for i in range(dropdown_diy): trajectory.append([int(frame_indices[i]), h_positions[i], w_positions[i]]) trajectory.sort() print(trajectory) if not check_move(trajectory): print("Error trajectory.") input_traj = [] h_remain = 1 - height_ratio w_remain = 1 - width_ratio for i in trajectory: h_relative = i[1] * h_remain w_relative = i[2] * w_remain input_traj.append([i[0], h_relative, h_relative+height_ratio, w_relative, w_relative+width_ratio]) if len(target_indices) < 1: indices_list = [1, 2] else: indices_list = target_indices.split(',') idx_list = [] for i in indices_list: idx_list.append(int(i)) config_512 = OmegaConf.load(config_512) model_config_512 = config_512.pop("model", OmegaConf.create()) args = argparse.Namespace( mode="base", savefps=save_fps, n_samples=1, ddim_steps=ddim_steps, ddim_eta=0.0, bs=1, fps=video_fps, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_guidance_scale_temporal=None, cond_input=None, prompt_in = prompt_in, seed = seed, ddim_edit = ddim_edit, model_config_512 = model_config_512, idx_list = idx_list, input_traj = input_traj, ) print('GPU starts') video = infer_gpu_part(args) print('GPU ends') video = torch.clamp(video.float(), -1.0, 1.0) video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w if radio_mode == 'ori': video_path = "output.mp4" video_bbox_path = "output.mp4" frame_grids = [ torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples)) for framesheet in video ] # [3, 1*h, n*w] grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] grid = (grid + 1.0) / 2.0 grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) torchvision.io.write_video( video_path, grid, fps=args.savefps, video_codec="h264", options={"crf": "10"}, ) else: video_path = "output_freetraj.mp4" video_bbox_path = "output_freetraj_bbox.mp4" frame_grids = [ torchvision.utils.make_grid(framesheet, nrow=int(args.n_samples)) for framesheet in video ] # [3, 1*h, n*w] grid = torch.stack(frame_grids, dim=0) # stack in temporal dim [t, 3, n*h, w] grid = (grid + 1.0) / 2.0 grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) torchvision.io.write_video( video_path, grid, fps=args.savefps, video_codec="h264", options={"crf": "10"}, ) BOX_SIZE_H = input_traj[0][2] - input_traj[0][1] BOX_SIZE_W = input_traj[0][4] - input_traj[0][3] PATHS = plan_path(input_traj) h_len = grid.shape[1] w_len = grid.shape[2] sub_h = int(BOX_SIZE_H * h_len) sub_w = int(BOX_SIZE_W * w_len) for j in range(grid.shape[0]): h_start = int(PATHS[j][0] * h_len) h_end = h_start + sub_h w_start = int(PATHS[j][2] * w_len) w_end = w_start + sub_w h_start = max(1, h_start) h_end = min(h_len-1, h_end) w_start = max(1, w_start) w_end = min(w_len-1, w_end) grid[j, h_start-1:h_end+1, w_start-1:w_start+2, :] = torch.ones_like(grid[j, h_start-1:h_end+1, w_start-1:w_start+2, :]) * torch.Tensor([127, 255, 127]).view(1, 1, 3) grid[j, h_start-1:h_end+1, w_end-2:w_end+1, :] = torch.ones_like(grid[j, h_start-1:h_end+1, w_end-2:w_end+1, :]) * torch.Tensor([127, 255, 127]).view(1, 1, 3) grid[j, h_start-1:h_start+2, w_start-1:w_end+1, :] = torch.ones_like(grid[j, h_start-1:h_start+2, w_start-1:w_end+1, :]) * torch.Tensor([127, 255, 127]).view(1, 1, 3) grid[j, h_end-2:h_end+1, w_start-1:w_end+1, :] = torch.ones_like(grid[j, h_end-2:h_end+1, w_start-1:w_end+1, :]) * torch.Tensor([127, 255, 127]).view(1, 1, 3) torchvision.io.write_video( video_bbox_path, grid, fps=args.savefps, video_codec="h264", options={"crf": "10"}, ) return video_path, video_bbox_path @spaces.GPU(duration=250) def infer_gpu_part(args): model = instantiate_from_config(args.model_config_512) model = model.cuda() model = load_model_checkpoint(model, ckpt_path_512) model.eval() if args.seed is None: seed = int.from_bytes(os.urandom(2), "big") else: seed = args.seed print(f"Using seed: {seed}") seed_everything(seed) ## latent noise shape h, w = height // 8, width // 8 frames = video_length channels = model.channels batch_size = 1 noise_shape = [batch_size, channels, frames, h, w] fps = torch.tensor([args.fps] * batch_size).to(model.device).long() prompts = [args.prompt_in] text_emb = model.get_learned_conditioning(prompts) cond = {"c_crossattn": [text_emb], "fps": fps} ## inference if radio_mode == 'ori': batch_samples = batch_ddim_sampling( model, cond, noise_shape, args.n_samples, args.ddim_steps, args.ddim_eta, args.unconditional_guidance_scale, args=args, ) else: batch_samples = batch_ddim_sampling_freetraj( model, cond, noise_shape, args.n_samples, args.ddim_steps, args.ddim_eta, args.unconditional_guidance_scale, idx_list = args.idx_list, input_traj = args.input_traj, args=args, ) vid_tensor = batch_samples[0] video = vid_tensor.detach().cpu() return video examples = [ ["A squirrel jumping from one tree to another.",], ["A bear climbing down a tree after spotting a threat.",], ["A corgi running on the grassland on the grassland.",], ["A barrel floating in a river.",], ["A horse galloping on a street.",], ["A majestic eagle soaring high above the treetops, surveying its territory.",], ] css = """ #col-container {max-width: 1024px; margin-left: auto; margin-right: auto;} a {text-decoration-line: underline; font-weight: 600;} .animate-spin { animation: spin 1s linear infinite; } #share-btn-container { display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; max-width: 15rem; height: 36px; } div#share-btn-container > div { flex-direction: row; background: black; align-items: center; } #share-btn-container:hover { background-color: #060606; } #share-btn { all: initial; color: #ffffff; font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.5rem !important; padding-bottom: 0.5rem !important; right:0; } #share-btn * { all: unset; } #share-btn-container div:nth-child(-n+2){ width: auto !important; min-height: 0px !important; } #share-btn-container .wrap { display: none !important; } #share-btn-container.hidden { display: none!important; } img[src*='#center'] { display: inline-block; margin: unset; } .footer { margin-bottom: 45px; margin-top: 10px; text-align: center; border-bottom: 1px solid #e5e5e5; } .footer>p { font-size: .8rem; display: inline-block; padding: 0 10px; transform: translateY(10px); background: white; } .dark .footer { border-color: #303030; } .dark .footer>p { background: #0b0f19; } """ def mode_update(mode): if mode == 'demo': trajectories_mode = [gr.Row(visible=True), gr.Row(visible=False)] elif mode == 'diy': trajectories_mode = [gr.Row(visible=False), gr.Row(visible=True)] else: trajectories_mode = [gr.Row(visible=False), gr.Row(visible=False)] return trajectories_mode def keyframe_update(num): keyframes = [] if type(num) != int: num = 0 for i in range(num): keyframes.append(gr.Row(visible=True)) for i in range(MAX_KEYS - num): keyframes.append(gr.Row(visible=False)) return keyframes def demo_update(mode): if mode == 'topleft->bottomright': num = 2 elif mode == 'bottomleft->topright': num = 2 elif mode == 'topleft->bottomleft->bottomright': num = 3 elif mode == 'bottomright->topright->topleft': num = 3 elif mode == '"V"': num = 4 elif mode == '"^"': num = 4 elif mode == 'left->right->left->right': num = 4 elif mode == 'triangle': num = 4 else: num = 0 return num def demo_update_frame(mode): frame_indices = [] if mode == 'topleft->bottomright': num = 2 frame_indices.append(gr.Text(value=0)) frame_indices.append(gr.Text(value=15)) elif mode == 'bottomleft->topright': num = 2 frame_indices.append(gr.Text(value=0)) frame_indices.append(gr.Text(value=15)) elif mode == 'topleft->bottomleft->bottomright': num = 3 frame_indices.append(gr.Text(value=0)) frame_indices.append(gr.Text(value=9)) frame_indices.append(gr.Text(value=15)) elif mode == 'bottomright->topright->topleft': num = 3 frame_indices.append(gr.Text(value=0)) frame_indices.append(gr.Text(value=6)) frame_indices.append(gr.Text(value=15)) elif mode == '"V"': num = 4 frame_indices.append(gr.Text(value=0)) frame_indices.append(gr.Text(value=7)) frame_indices.append(gr.Text(value=8)) frame_indices.append(gr.Text(value=15)) elif mode == '"^"': num = 4 frame_indices.append(gr.Text(value=0)) frame_indices.append(gr.Text(value=7)) frame_indices.append(gr.Text(value=8)) frame_indices.append(gr.Text(value=15)) elif mode == 'left->right->left->right': num = 4 frame_indices.append(gr.Text(value=0)) frame_indices.append(gr.Text(value=5)) frame_indices.append(gr.Text(value=10)) frame_indices.append(gr.Text(value=15)) elif mode == 'triangle': num = 4 frame_indices.append(gr.Text(value=0)) frame_indices.append(gr.Text(value=5)) frame_indices.append(gr.Text(value=10)) frame_indices.append(gr.Text(value=15)) else: num = 0 for i in range(MAX_KEYS - num): frame_indices.append(gr.Text()) return frame_indices def demo_update_h(mode): h_positions = [] if mode == 'topleft->bottomright': num = 2 h_positions.append(gr.Slider(value=0.1)) h_positions.append(gr.Slider(value=0.9)) elif mode == 'bottomleft->topright': num = 2 h_positions.append(gr.Slider(value=0.9)) h_positions.append(gr.Slider(value=0.1)) elif mode == 'topleft->bottomleft->bottomright': num = 3 h_positions.append(gr.Slider(value=0.1)) h_positions.append(gr.Slider(value=0.9)) h_positions.append(gr.Slider(value=0.9)) elif mode == 'bottomright->topright->topleft': num = 3 h_positions.append(gr.Slider(value=0.9)) h_positions.append(gr.Slider(value=0.1)) h_positions.append(gr.Slider(value=0.1)) elif mode == '"V"': num = 4 h_positions.append(gr.Slider(value=0.1)) h_positions.append(gr.Slider(value=0.9)) h_positions.append(gr.Slider(value=0.9)) h_positions.append(gr.Slider(value=0.1)) elif mode == '"^"': num = 4 h_positions.append(gr.Slider(value=0.9)) h_positions.append(gr.Slider(value=0.1)) h_positions.append(gr.Slider(value=0.1)) h_positions.append(gr.Slider(value=0.9)) elif mode == 'left->right->left->right': num = 4 h_positions.append(gr.Slider(value=0.5)) h_positions.append(gr.Slider(value=0.5)) h_positions.append(gr.Slider(value=0.5)) h_positions.append(gr.Slider(value=0.5)) elif mode == 'triangle': num = 4 h_positions.append(gr.Slider(value=0.1)) h_positions.append(gr.Slider(value=0.9)) h_positions.append(gr.Slider(value=0.9)) h_positions.append(gr.Slider(value=0.1)) else: num = 0 for i in range(MAX_KEYS - num): h_positions.append(gr.Slider()) return h_positions def demo_update_w(mode): w_positions = [] if mode == 'topleft->bottomright': num = 2 w_positions.append(gr.Slider(value=0.1)) w_positions.append(gr.Slider(value=0.9)) elif mode == 'bottomleft->topright': num = 2 w_positions.append(gr.Slider(value=0.1)) w_positions.append(gr.Slider(value=0.9)) elif mode == 'topleft->bottomleft->bottomright': num = 3 w_positions.append(gr.Slider(value=0.1)) w_positions.append(gr.Slider(value=0.1)) w_positions.append(gr.Slider(value=0.9)) elif mode == 'bottomright->topright->topleft': num = 3 w_positions.append(gr.Slider(value=0.9)) w_positions.append(gr.Slider(value=0.9)) w_positions.append(gr.Slider(value=0.1)) elif mode == '"V"': num = 4 w_positions.append(gr.Slider(value=0.1)) w_positions.append(gr.Slider(value=0.8/15*7 + 0.1)) w_positions.append(gr.Slider(value=0.8/15*8 + 0.1)) w_positions.append(gr.Slider(value=0.9)) elif mode == '"^"': num = 4 w_positions.append(gr.Slider(value=0.9)) w_positions.append(gr.Slider(value=0.8/15*8 + 0.1)) w_positions.append(gr.Slider(value=0.8/15*7 + 0.1)) w_positions.append(gr.Slider(value=0.1)) elif mode == 'left->right->left->right': num = 4 w_positions.append(gr.Slider(value=0.1)) w_positions.append(gr.Slider(value=0.9)) w_positions.append(gr.Slider(value=0.1)) w_positions.append(gr.Slider(value=0.9)) elif mode == 'triangle': num = 4 w_positions.append(gr.Slider(value=0.5)) w_positions.append(gr.Slider(value=0.9)) w_positions.append(gr.Slider(value=0.1)) w_positions.append(gr.Slider(value=0.5)) else: num = 0 for i in range(MAX_KEYS - num): w_positions.append(gr.Slider()) return w_positions def plot_update(*positions): key_length = positions[-1] frame_indices = positions[:key_length] if type(key_length) != int or len(frame_indices) < 2: traj_plot = gr.Plot( label="Trajectory" ) return traj_plot frame_indices = [int(i) for i in frame_indices] h_positions = positions[MAX_KEYS:MAX_KEYS+key_length] w_positions = positions[2*MAX_KEYS:2*MAX_KEYS+key_length] frame_indices, h_positions, w_positions = zip(*sorted(zip(frame_indices, h_positions, w_positions))) plt.cla() plt.xlim(0, 1) plt.ylim(0, 1) plt.gca().invert_yaxis() plt.gca().xaxis.tick_top() plt.plot(w_positions, h_positions, linestyle='-', marker = 'o', markerfacecolor='r') traj_plot = gr.Plot( label="Trajectory", value = plt ) return traj_plot with gr.Blocks(css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown( """ <h1 style="text-align: center;">FreeTraj</h1> <p style="text-align: center;"> Tuning-Free Trajectory Control in Video Diffusion Models </p> <p style="text-align: center;"> <a href="https://arxiv.org/abs/2406.16863" target="_blank"><b>[arXiv]</b></a> <a href="http://haonanqiu.com/projects/FreeTraj.html" target="_blank"><b>[Project Page]</b></a> <a href="https://github.com/arthur-qiu/FreeTraj" target="_blank"><b>[Code]</b></a> </p> """ ) keyframes = [] frame_indices = [] h_positions = [] w_positions = [] with gr.Row(): video_result = gr.Video(label="Video Output") video_result_bbox = gr.Video(label="Video Output with BBox") with gr.Group(): with gr.Row(): prompt_in = gr.Textbox(label="Prompt", placeholder="A corgi running on the grassland on the grassland.", scale = 5) target_indices = gr.Textbox(label="Target Indices (1 for the first word, necessary!)", placeholder="1,2", scale = 2) with gr.Row(): radio_mode = gr.Radio(label='Trajectory Mode', choices = ['demo', 'diy', 'ori'], scale = 1) height_ratio = gr.Slider(label='Height Ratio of BBox', minimum=0.2, maximum=0.4, step=0.01, value=0.3, scale = 1) width_ratio = gr.Slider(label='Width Ratio of BBox', minimum=0.2, maximum=0.4, step=0.01, value=0.3, scale = 1) with gr.Row(visible=False) as row_demo: dropdown_demo = gr.Dropdown( label="Demo Trajectory", choices= ['topleft->bottomright', 'bottomleft->topright', 'topleft->bottomleft->bottomright', 'bottomright->topright->topleft', '"V"', '"^"', 'left->right->left->right', 'triangle'] ) with gr.Row(visible=False) as row_diy: dropdown_diy = gr.Dropdown( label="Number of Keyframes", choices=range(2, MAX_KEYS+1), ) for i in range(MAX_KEYS): with gr.Row(visible=False) as row: text = gr.Textbox( value=f"Keyframe #{i}", interactive=False, container = False, lines = 3, scale=1 ) frame_ids = gr.Textbox( None, label=f"Frame Indices #{i}", interactive=True, scale=2 ) h_position = gr.Slider(label='Position in Height', minimum=0.0, maximum=1.0, step=0.01, scale=2) w_position = gr.Slider(label='Position in Width', minimum=0.0, maximum=1.0, step=0.01, scale=2) frame_indices.append(frame_ids) h_positions.append(h_position) w_positions.append(w_position) keyframes.append(row) dropdown_demo.change(demo_update, dropdown_demo, dropdown_diy) dropdown_diy.change(keyframe_update, dropdown_diy, keyframes) dropdown_demo.change(demo_update_frame, dropdown_demo, frame_indices) dropdown_demo.change(demo_update_h, dropdown_demo, h_positions) dropdown_demo.change(demo_update_w, dropdown_demo, w_positions) radio_mode.change(mode_update, radio_mode, [row_demo, row_diy]) traj_plot = gr.Plot( label="Trajectory" ) h_positions[0].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot) h_positions[1].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot) h_positions[2].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot) h_positions[3].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot) h_positions[4].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot) w_positions[0].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot) w_positions[1].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot) w_positions[2].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot) w_positions[3].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot) w_positions[4].change(plot_update, frame_indices + h_positions + w_positions + [dropdown_diy], traj_plot) with gr.Row(): with gr.Accordion('Useful FreeTraj Parameters (feel free to adjust these parameters based on your prompt): ', open=True): with gr.Row(): ddim_edit = gr.Slider(label='Editing Steps (larger for better control while losing some quality)', minimum=0, maximum=12, step=1, value=6) seed = gr.Slider(label='Random Seed', minimum=0, maximum=10000, step=1, value=123) with gr.Row(): with gr.Accordion('Useless FreeTraj Parameters (mostly no need to adjust): ', open=False): with gr.Row(): ddim_steps = gr.Slider(label='DDIM Steps', minimum=5, maximum=50, step=1, value=50) unconditional_guidance_scale = gr.Slider(label='Unconditional Guidance Scale', minimum=1.0, maximum=20.0, step=0.1, value=12.0) with gr.Row(): video_fps = gr.Slider(label='Video FPS (larger for quicker motion)', minimum=8, maximum=36, step=4, value=16) save_fps = gr.Slider(label='Save FPS', minimum=1, maximum=30, step=1, value=10) with gr.Row(): submit_btn = gr.Button("Generate", variant='primary') with gr.Row(): check_btn = gr.Button("Check Existing Results (in case of the connection lost)", variant='secondary') with gr.Row(): gr.Examples(label='Sample Prompts', examples=examples, inputs=[prompt_in, target_indices, ddim_edit, seed, ddim_steps, unconditional_guidance_scale, video_fps, save_fps, height_ratio, width_ratio, radio_mode, dropdown_diy, *frame_indices, *h_positions, *w_positions]) demo_list = ['0026_0_0.4_0.4.gif', '0047_1_0.4_0.3.gif', '0051_1_0.4_0.4.gif'] demo_pick = random.randint(0, len(demo_list) - 1) with gr.Row(): for i in range(len(demo_list)): gr.Image(show_label = False, show_download_button = False, value='assets/' + demo_list[i]) with gr.Row(): gr.Markdown( """ <h2 style="text-align: center;">Hints</h2> <p style="text-align: center;"> 1. Choose trajectory mode <b>"ori"</b> to see whether the prompt works on the pre-trained model. </p> <p style="text-align: center;"> 2. Adjust the prompt or random seed to get a qualified video. </p> <p style="text-align: center;"> 3. Choose trajectory mode <b>"demo"</b> to see whether <b>FreeTraj</b> works or not. </p> <p style="text-align: center;"> 4. Choose trajectory mode <b>"diy"</b> to plan new trajectory. It may fail in some extreme cases. </p> """ ) submit_btn.click(fn=infer, inputs=[prompt_in, target_indices, ddim_edit, seed, ddim_steps, unconditional_guidance_scale, video_fps, save_fps, height_ratio, width_ratio, radio_mode, dropdown_diy, *frame_indices, *h_positions, *w_positions], outputs=[video_result, video_result_bbox], api_name="generate") check_btn.click(fn=check, inputs=[radio_mode], outputs=[video_result, video_result_bbox], api_name="check") demo.queue(max_size=8).launch(show_api=True)