Spaces:
Sleeping
Sleeping
Anonymous
commited on
Commit
·
e84616f
1
Parent(s):
e994f84
add spaces
Browse files
app.py
CHANGED
|
@@ -22,6 +22,9 @@ from funcs import (
|
|
| 22 |
from utils.utils import instantiate_from_config
|
| 23 |
from utils.utils_freetraj import plan_path
|
| 24 |
|
|
|
|
|
|
|
|
|
|
| 25 |
MAX_KEYS = 5
|
| 26 |
|
| 27 |
ckpt_dir_512 = "checkpoints/base_512_v2"
|
|
@@ -56,7 +59,7 @@ def check(radio_mode):
|
|
| 56 |
video_bbox_path = "output_freetraj_bbox.mp4"
|
| 57 |
return video_path, video_bbox_path
|
| 58 |
|
| 59 |
-
|
| 60 |
def infer(*user_args):
|
| 61 |
prompt_in = user_args[0]
|
| 62 |
target_indices = user_args[1]
|
|
@@ -75,9 +78,6 @@ def infer(*user_args):
|
|
| 75 |
w_positions = user_args[-MAX_KEYS:]
|
| 76 |
print(user_args)
|
| 77 |
|
| 78 |
-
video_length = 16
|
| 79 |
-
width = 512
|
| 80 |
-
height = 320
|
| 81 |
if radio_mode == 'ori':
|
| 82 |
config_512 = "configs/inference_t2v_512_v2.0.yaml"
|
| 83 |
else:
|
|
@@ -110,15 +110,6 @@ def infer(*user_args):
|
|
| 110 |
|
| 111 |
config_512 = OmegaConf.load(config_512)
|
| 112 |
model_config_512 = config_512.pop("model", OmegaConf.create())
|
| 113 |
-
model = instantiate_from_config(model_config_512)
|
| 114 |
-
model = model.cuda()
|
| 115 |
-
model = load_model_checkpoint(model, ckpt_path_512)
|
| 116 |
-
model.eval()
|
| 117 |
-
|
| 118 |
-
if seed is None:
|
| 119 |
-
seed = int.from_bytes(os.urandom(2), "big")
|
| 120 |
-
print(f"Using seed: {seed}")
|
| 121 |
-
seed_everything(seed)
|
| 122 |
|
| 123 |
args = argparse.Namespace(
|
| 124 |
mode="base",
|
|
@@ -127,57 +118,20 @@ def infer(*user_args):
|
|
| 127 |
ddim_steps=ddim_steps,
|
| 128 |
ddim_eta=0.0,
|
| 129 |
bs=1,
|
| 130 |
-
height=height,
|
| 131 |
-
width=width,
|
| 132 |
-
frames=video_length,
|
| 133 |
fps=video_fps,
|
| 134 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 135 |
unconditional_guidance_scale_temporal=None,
|
| 136 |
cond_input=None,
|
|
|
|
|
|
|
| 137 |
ddim_edit = ddim_edit,
|
|
|
|
|
|
|
|
|
|
| 138 |
)
|
| 139 |
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
frames = model.temporal_length if args.frames < 0 else args.frames
|
| 143 |
-
channels = model.channels
|
| 144 |
-
|
| 145 |
-
batch_size = 1
|
| 146 |
-
noise_shape = [batch_size, channels, frames, h, w]
|
| 147 |
-
fps = torch.tensor([args.fps] * batch_size).to(model.device).long()
|
| 148 |
-
prompts = [prompt_in]
|
| 149 |
-
text_emb = model.get_learned_conditioning(prompts)
|
| 150 |
-
|
| 151 |
-
cond = {"c_crossattn": [text_emb], "fps": fps}
|
| 152 |
-
|
| 153 |
-
## inference
|
| 154 |
-
if radio_mode == 'ori':
|
| 155 |
-
batch_samples = batch_ddim_sampling(
|
| 156 |
-
model,
|
| 157 |
-
cond,
|
| 158 |
-
noise_shape,
|
| 159 |
-
args.n_samples,
|
| 160 |
-
args.ddim_steps,
|
| 161 |
-
args.ddim_eta,
|
| 162 |
-
args.unconditional_guidance_scale,
|
| 163 |
-
args=args,
|
| 164 |
-
)
|
| 165 |
-
else:
|
| 166 |
-
batch_samples = batch_ddim_sampling_freetraj(
|
| 167 |
-
model,
|
| 168 |
-
cond,
|
| 169 |
-
noise_shape,
|
| 170 |
-
args.n_samples,
|
| 171 |
-
args.ddim_steps,
|
| 172 |
-
args.ddim_eta,
|
| 173 |
-
args.unconditional_guidance_scale,
|
| 174 |
-
idx_list = idx_list,
|
| 175 |
-
input_traj = input_traj,
|
| 176 |
-
args=args,
|
| 177 |
-
)
|
| 178 |
-
|
| 179 |
-
vid_tensor = batch_samples[0]
|
| 180 |
-
video = vid_tensor.detach().cpu()
|
| 181 |
video = torch.clamp(video.float(), -1.0, 1.0)
|
| 182 |
video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
|
| 183 |
|
|
@@ -251,6 +205,67 @@ def infer(*user_args):
|
|
| 251 |
|
| 252 |
return video_path, video_bbox_path
|
| 253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 254 |
|
| 255 |
examples = [
|
| 256 |
["A squirrel jumping from one tree to another.",],
|
|
|
|
| 22 |
from utils.utils import instantiate_from_config
|
| 23 |
from utils.utils_freetraj import plan_path
|
| 24 |
|
| 25 |
+
video_length = 16
|
| 26 |
+
width = 512
|
| 27 |
+
height = 320
|
| 28 |
MAX_KEYS = 5
|
| 29 |
|
| 30 |
ckpt_dir_512 = "checkpoints/base_512_v2"
|
|
|
|
| 59 |
video_bbox_path = "output_freetraj_bbox.mp4"
|
| 60 |
return video_path, video_bbox_path
|
| 61 |
|
| 62 |
+
|
| 63 |
def infer(*user_args):
|
| 64 |
prompt_in = user_args[0]
|
| 65 |
target_indices = user_args[1]
|
|
|
|
| 78 |
w_positions = user_args[-MAX_KEYS:]
|
| 79 |
print(user_args)
|
| 80 |
|
|
|
|
|
|
|
|
|
|
| 81 |
if radio_mode == 'ori':
|
| 82 |
config_512 = "configs/inference_t2v_512_v2.0.yaml"
|
| 83 |
else:
|
|
|
|
| 110 |
|
| 111 |
config_512 = OmegaConf.load(config_512)
|
| 112 |
model_config_512 = config_512.pop("model", OmegaConf.create())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
args = argparse.Namespace(
|
| 115 |
mode="base",
|
|
|
|
| 118 |
ddim_steps=ddim_steps,
|
| 119 |
ddim_eta=0.0,
|
| 120 |
bs=1,
|
|
|
|
|
|
|
|
|
|
| 121 |
fps=video_fps,
|
| 122 |
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 123 |
unconditional_guidance_scale_temporal=None,
|
| 124 |
cond_input=None,
|
| 125 |
+
prompt_in = prompt_in,
|
| 126 |
+
seed = seed,
|
| 127 |
ddim_edit = ddim_edit,
|
| 128 |
+
model_config_512 = model_config_512,
|
| 129 |
+
idx_list = idx_list,
|
| 130 |
+
input_traj = input_traj,
|
| 131 |
)
|
| 132 |
|
| 133 |
+
video = infer_gpu_part(args)
|
| 134 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
video = torch.clamp(video.float(), -1.0, 1.0)
|
| 136 |
video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
|
| 137 |
|
|
|
|
| 205 |
|
| 206 |
return video_path, video_bbox_path
|
| 207 |
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@spaces.GPU(duration=270)
|
| 211 |
+
def infer_gpu_part(args):
|
| 212 |
+
|
| 213 |
+
model = instantiate_from_config(args.model_config_512)
|
| 214 |
+
model = model.cuda()
|
| 215 |
+
model = load_model_checkpoint(model, ckpt_path_512)
|
| 216 |
+
model.eval()
|
| 217 |
+
|
| 218 |
+
if args.seed is None:
|
| 219 |
+
seed = int.from_bytes(os.urandom(2), "big")
|
| 220 |
+
else:
|
| 221 |
+
seed = args.seed
|
| 222 |
+
print(f"Using seed: {seed}")
|
| 223 |
+
seed_everything(seed)
|
| 224 |
+
|
| 225 |
+
## latent noise shape
|
| 226 |
+
h, w = height // 8, width // 8
|
| 227 |
+
frames = video_length
|
| 228 |
+
channels = model.channels
|
| 229 |
+
|
| 230 |
+
batch_size = 1
|
| 231 |
+
noise_shape = [batch_size, channels, frames, h, w]
|
| 232 |
+
fps = torch.tensor([args.fps] * batch_size).to(model.device).long()
|
| 233 |
+
prompts = [args.prompt_in]
|
| 234 |
+
text_emb = model.get_learned_conditioning(prompts)
|
| 235 |
+
|
| 236 |
+
cond = {"c_crossattn": [text_emb], "fps": fps}
|
| 237 |
+
|
| 238 |
+
## inference
|
| 239 |
+
if radio_mode == 'ori':
|
| 240 |
+
batch_samples = batch_ddim_sampling(
|
| 241 |
+
model,
|
| 242 |
+
cond,
|
| 243 |
+
noise_shape,
|
| 244 |
+
args.n_samples,
|
| 245 |
+
args.ddim_steps,
|
| 246 |
+
args.ddim_eta,
|
| 247 |
+
args.unconditional_guidance_scale,
|
| 248 |
+
args=args,
|
| 249 |
+
)
|
| 250 |
+
else:
|
| 251 |
+
batch_samples = batch_ddim_sampling_freetraj(
|
| 252 |
+
model,
|
| 253 |
+
cond,
|
| 254 |
+
noise_shape,
|
| 255 |
+
args.n_samples,
|
| 256 |
+
args.ddim_steps,
|
| 257 |
+
args.ddim_eta,
|
| 258 |
+
args.unconditional_guidance_scale,
|
| 259 |
+
idx_list = args.idx_list,
|
| 260 |
+
input_traj = args.input_traj,
|
| 261 |
+
args=args,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
vid_tensor = batch_samples[0]
|
| 265 |
+
video = vid_tensor.detach().cpu()
|
| 266 |
+
|
| 267 |
+
return video
|
| 268 |
+
|
| 269 |
|
| 270 |
examples = [
|
| 271 |
["A squirrel jumping from one tree to another.",],
|