1inkusFace commited on
Commit
06d7801
·
verified ·
1 Parent(s): c459371

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import spaces
2
  import gradio as gr
3
  import argparse
@@ -6,17 +7,16 @@ import os
6
  import random
7
  import subprocess
8
  from PIL import Image
9
- import numpy as np # Import NumPy
10
 
11
-
12
- subprocess.run(['sh', './sky.sh']) # Keep if needed
13
  sys.path.append("./SkyReels-V1")
14
 
15
- # Corrected Relative Imports
16
  from skyreelsinfer import TaskType
17
  from skyreelsinfer.offload import OffloadConfig
18
- from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer # Import the class
19
  from diffusers.utils import export_to_video
 
20
  import torch
21
  import logging
22
 
@@ -29,13 +29,12 @@ torch.backends.cudnn.benchmark = False
29
  torch.set_float32_matmul_precision("highest")
30
 
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
-
33
- # --- Dummy Classes (Moved to skyreelsinfer/__init__.py) ---
34
  logger = logging.getLogger(__name__)
35
- # --- Global Variables and Argument Parsing ---
36
 
37
  _predictor = None
38
- task_type = TaskType.I2V # Default task type.
 
39
  @spaces.GPU(duration=90)
40
  def init_predictor():
41
  global _predictor
@@ -46,7 +45,7 @@ def init_predictor():
46
  if task_type == TaskType.I2V:
47
  model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
48
  elif task_type == TaskType.T2V:
49
- model_id = "your_t2v_model_id" # Replace
50
  else:
51
  raise ValueError(f"Invalid task_type: {task_type}")
52
 
@@ -94,16 +93,15 @@ def generate_video(prompt, seed, image=None):
94
  elif task_type == TaskType.T2V:
95
  pass
96
  else:
97
- raise ValueError(f"Invalid task_type: {task_type}")
98
 
99
  if _predictor is None:
100
  init_predictor()
101
 
102
  output = _predictor.infer(**kwargs)
103
 
104
- # --- Convert to NumPy, move to CPU, scale, and change dtype ---
105
  output = (output.cpu().numpy() * 255).astype(np.uint8)
106
- output = output.transpose(0, 2, 3, 4, 1) #Correct transpose.
107
 
108
  save_dir = f"./result/{task_type.name}"
109
  os.makedirs(save_dir, exist_ok=True)
 
1
+ # app.py
2
  import spaces
3
  import gradio as gr
4
  import argparse
 
7
  import random
8
  import subprocess
9
  from PIL import Image
10
+ import numpy as np
11
 
12
+ subprocess.run(['sh', './sky.sh'])
 
13
  sys.path.append("./SkyReels-V1")
14
 
 
15
  from skyreelsinfer import TaskType
16
  from skyreelsinfer.offload import OffloadConfig
17
+ from skyreelsinfer.skyreels_video_infer import SkyReelsVideoSingleGpuInfer
18
  from diffusers.utils import export_to_video
19
+
20
  import torch
21
  import logging
22
 
 
29
  torch.set_float32_matmul_precision("highest")
30
 
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
32
  logger = logging.getLogger(__name__)
33
+
34
 
35
  _predictor = None
36
+ task_type = TaskType.I2V
37
+
38
  @spaces.GPU(duration=90)
39
  def init_predictor():
40
  global _predictor
 
45
  if task_type == TaskType.I2V:
46
  model_id = "Skywork/SkyReels-V1-Hunyuan-I2V"
47
  elif task_type == TaskType.T2V:
48
+ model_id = "your_t2v_model_id"
49
  else:
50
  raise ValueError(f"Invalid task_type: {task_type}")
51
 
 
93
  elif task_type == TaskType.T2V:
94
  pass
95
  else:
96
+ raise ValueError(f"Invalid Tasktype")
97
 
98
  if _predictor is None:
99
  init_predictor()
100
 
101
  output = _predictor.infer(**kwargs)
102
 
 
103
  output = (output.cpu().numpy() * 255).astype(np.uint8)
104
+ output = output.transpose(0, 2, 3, 4, 1)
105
 
106
  save_dir = f"./result/{task_type.name}"
107
  os.makedirs(save_dir, exist_ok=True)