rahul7star commited on
Commit
897f552
·
verified ·
1 Parent(s): 9432369

Create generate.py

Browse files
Files changed (1) hide show
  1. generate.py +43 -0
generate.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import subprocess
3
+ import os
4
+ import torch
5
+ from huggingface_hub import snapshot_download
6
+
7
+ # Arguments
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument("--task", type=str, default="t2v-14B")
10
+ parser.add_argument("--size", type=str, default="200*200")
11
+ parser.add_argument("--frame_num", type=int, default=60)
12
+ parser.add_argument("--sample_steps", type=int, default=20)
13
+ parser.add_argument("--ckpt_dir", type=str, default="./Wan2.1-T2V-14B")
14
+ parser.add_argument("--offload_model", type=str, default="True")
15
+ parser.add_argument("--prompt", type=str, required=True)
16
+ args = parser.parse_args()
17
+
18
+ # Ensure the model is downloaded
19
+ if not os.path.exists(args.ckpt_dir):
20
+ print("🔄 Downloading WAN 2.1 - 14B model from Hugging Face...")
21
+ snapshot_download(repo_id="Wan-AI/Wan2.1-T2V-14B", local_dir=args.ckpt_dir)
22
+
23
+ # Free up GPU memory
24
+ if torch.cuda.is_available():
25
+ torch.cuda.empty_cache()
26
+ torch.backends.cudnn.benchmark = False
27
+ torch.backends.cudnn.deterministic = True
28
+
29
+ # Run WAN 2.1 - 14B Model
30
+ command = f"python generate.py --task {args.task} --size {args.size} --frame_num {args.frame_num} --sample_steps {args.sample_steps} --ckpt_dir {args.ckpt_dir} --offload_model {args.offload_model} --prompt \"{args.prompt}\""
31
+
32
+ process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
33
+ stdout, stderr = process.communicate()
34
+
35
+ # Print logs for debugging
36
+ print("🔹 Output:", stdout.decode())
37
+ print("🔺 Error:", stderr.decode())
38
+
39
+ # Verify if video was created
40
+ if os.path.exists("output.mp4"):
41
+ print("✅ Video generated successfully: output.mp4")
42
+ else:
43
+ print("❌ Error: Video file not found!")