Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
3 |
import argparse
|
@@ -6,17 +7,16 @@ import os
|
|
6 |
import random
|
7 |
import subprocess
|
8 |
from PIL import Image
|
9 |
-
import numpy as np
|
10 |
|
11 |
-
|
12 |
-
subprocess.run(['sh', './sky.sh']) # Keep if needed
|
13 |
sys.path.append("./SkyReels-V1")
|
14 |
|
15 |
-
# Corrected Relative Imports
|
16 |
from skyreelsinfer import TaskType
|
17 |
from skyreelsinfer.offload import OffloadConfig
|
18 |
-
from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer
|
19 |
from diffusers.utils import export_to_video
|
|
|
20 |
import torch
|
21 |
import logging
|
22 |
|
@@ -29,13 +29,12 @@ torch.backends.cudnn.benchmark = False
|
|
29 |
torch.set_float32_matmul_precision("highest")
|
30 |
|
31 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
32 |
-
|
33 |
-
# --- Dummy Classes (Moved to skyreelsinfer/__init__.py) ---
|
34 |
logger = logging.getLogger(__name__)
|
35 |
-
|
36 |
|
37 |
_predictor = None
|
38 |
-
task_type = TaskType.I2V
|
|
|
39 |
@spaces.GPU(duration=90)
|
40 |
def init_predictor():
|
41 |
global _predictor
|
@@ -46,7 +45,7 @@ def init_predictor():
|
|
46 |
if task_type == TaskType.I2V:
|
47 |
model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
|
48 |
elif task_type == TaskType.T2V:
|
49 |
-
model_id = "your_t2v_model_id"
|
50 |
else:
|
51 |
raise ValueError(f"Invalid task_type: {task_type}")
|
52 |
|
@@ -94,16 +93,15 @@ def generate_video(prompt, seed, image=None):
|
|
94 |
elif task_type == TaskType.T2V:
|
95 |
pass
|
96 |
else:
|
97 |
-
raise ValueError(f"Invalid
|
98 |
|
99 |
if _predictor is None:
|
100 |
init_predictor()
|
101 |
|
102 |
output = _predictor.infer(**kwargs)
|
103 |
|
104 |
-
# --- Convert to NumPy, move to CPU, scale, and change dtype ---
|
105 |
output = (output.cpu().numpy() * 255).astype(np.uint8)
|
106 |
-
output = output.transpose(0, 2, 3, 4, 1)
|
107 |
|
108 |
save_dir = f"./result/{task_type.name}"
|
109 |
os.makedirs(save_dir, exist_ok=True)
|
|
|
1 |
+
# app.py
|
2 |
import spaces
|
3 |
import gradio as gr
|
4 |
import argparse
|
|
|
7 |
import random
|
8 |
import subprocess
|
9 |
from PIL import Image
|
10 |
+
import numpy as np
|
11 |
|
12 |
+
subprocess.run(['sh', './sky.sh'])
|
|
|
13 |
sys.path.append("./SkyReels-V1")
|
14 |
|
|
|
15 |
from skyreelsinfer import TaskType
|
16 |
from skyreelsinfer.offload import OffloadConfig
|
17 |
+
from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer
|
18 |
from diffusers.utils import export_to_video
|
19 |
+
|
20 |
import torch
|
21 |
import logging
|
22 |
|
|
|
29 |
torch.set_float32_matmul_precision("highest")
|
30 |
|
31 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
32 |
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
|
35 |
_predictor = None
|
36 |
+
task_type = TaskType.I2V
|
37 |
+
|
38 |
@spaces.GPU(duration=90)
|
39 |
def init_predictor():
|
40 |
global _predictor
|
|
|
45 |
if task_type == TaskType.I2V:
|
46 |
model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
|
47 |
elif task_type == TaskType.T2V:
|
48 |
+
model_id = "your_t2v_model_id"
|
49 |
else:
|
50 |
raise ValueError(f"Invalid task_type: {task_type}")
|
51 |
|
|
|
93 |
elif task_type == TaskType.T2V:
|
94 |
pass
|
95 |
else:
|
96 |
+
raise ValueError(f"Invalid Tasktype")
|
97 |
|
98 |
if _predictor is None:
|
99 |
init_predictor()
|
100 |
|
101 |
output = _predictor.infer(**kwargs)
|
102 |
|
|
|
103 |
output = (output.cpu().numpy() * 255).astype(np.uint8)
|
104 |
+
output = output.transpose(0, 2, 3, 4, 1)
|
105 |
|
106 |
save_dir = f"./result/{task_type.name}"
|
107 |
os.makedirs(save_dir, exist_ok=True)
|