1inkusFace commited on
Commit
a05a8a4
·
verified ·
1 Parent(s): a5c228f

Update skyreelsinfer/skyreels_video_infer.py

Browse files
Files changed (1) hide show
  1. skyreelsinfer/skyreels_video_infer.py +34 -19
skyreelsinfer/skyreels_video_infer.py CHANGED
@@ -1,20 +1,22 @@
1
  import logging
2
- import os
3
  import time
4
  from datetime import timedelta
5
  from typing import Any
6
  from typing import Dict
7
 
8
- import torch
9
- from diffusers import HunyuanVideoTransformer3DModel
10
- from diffusers import DiffusionPipeline
11
- from PIL import Image
12
- from transformers import LlamaModel
 
 
 
 
 
 
13
 
14
- from . import TaskType
15
- from .offload import Offload
16
- from .offload import OffloadConfig
17
- from .pipelines import SkyreelsVideoPipeline
18
 
19
  logger = logging.getLogger("SkyReelsVideoInfer")
20
  logger.setLevel(logging.DEBUG)
@@ -29,11 +31,11 @@ logger.addHandler(console_handler)
29
  class SkyReelsVideoInfer:
30
  def __init__(
31
  self,
32
- task_type: TaskType,
33
  model_id: str,
34
  quant_model: bool = True,
35
  is_offload: bool = True,
36
- offload_config: OffloadConfig = None,
37
  use_multiprocessing: bool = False,
38
  ):
39
  self.task_type = task_type
@@ -50,11 +52,19 @@ class SkyReelsVideoInfer:
50
  base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
51
  quant_model: bool = True,
52
  device: str = "cpu",
53
- ) -> SkyreelsVideoPipeline:
54
- logger.info(f"load model model_id:{model_id} quan_model:{quant_model} device:{device}")
55
-
 
 
 
 
56
  from torchao.quantization import float8_weight_only
57
  from torchao.quantization import quantize_
 
 
 
 
58
 
59
  text_encoder = LlamaModel.from_pretrained(
60
  base_model_id,
@@ -81,7 +91,10 @@ class SkyReelsVideoInfer:
81
  return pipe
82
 
83
  def _initialize_pipeline(self):
84
- self.pipe: SkyreelsVideoPipeline = self._load_model(
 
 
 
85
  model_id=self.model_id, quant_model=self.quant_model, device="cpu"
86
  )
87
 
@@ -92,9 +105,11 @@ class SkyReelsVideoInfer:
92
  )
93
 
94
  def inference(self, kwargs):
 
 
95
  if self.task_type == TaskType.I2V:
96
  image = kwargs.pop("image")
97
- output = self.pipe(image=image, **kwargs) # Get full output
98
  else:
99
- output = self.pipe(**kwargs) # Get full output
100
- return output.frames # Return frames directly
 
1
  import logging
2
+ import os # Keep os here
3
  import time
4
  from datetime import timedelta
5
  from typing import Any
6
  from typing import Dict
7
 
8
+ # DELAY ALL THESE IMPORTS:
9
+ # import torch
10
+ # from diffusers import HunyuanVideoTransformer3DModel
11
+ # from diffusers import DiffusionPipeline
12
+ # from PIL import Image
13
+ # from transformers import LlamaModel
14
+
15
+ # from . import TaskType
16
+ # from .offload import Offload
17
+ # from .offload import OffloadConfig
18
+ # from .pipelines import SkyreelsVideoPipeline
19
 
 
 
 
 
20
 
21
  logger = logging.getLogger("SkyReelsVideoInfer")
22
  logger.setLevel(logging.DEBUG)
 
31
  class SkyReelsVideoInfer:
32
  def __init__(
33
  self,
34
+ task_type, # No TaskType.
35
  model_id: str,
36
  quant_model: bool = True,
37
  is_offload: bool = True,
38
+ offload_config = None, # No OffloadConfig
39
  use_multiprocessing: bool = False,
40
  ):
41
  self.task_type = task_type
 
52
  base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
53
  quant_model: bool = True,
54
  device: str = "cpu",
55
+ ):
56
+ # DELAYED IMPORTS:
57
+ import torch
58
+ from diffusers import HunyuanVideoTransformer3DModel
59
+ from diffusers import DiffusionPipeline
60
+ from PIL import Image
61
+ from transformers import LlamaModel
62
  from torchao.quantization import float8_weight_only
63
  from torchao.quantization import quantize_
64
+ from .pipelines import SkyreelsVideoPipeline # Local import
65
+
66
+
67
+ logger.info(f"load model model_id:{model_id} quan_model:{quant_model} device:{device}")
68
 
69
  text_encoder = LlamaModel.from_pretrained(
70
  base_model_id,
 
91
  return pipe
92
 
93
  def _initialize_pipeline(self):
94
+ #More Delayed Imports
95
+ from .offload import Offload
96
+
97
+ self.pipe = self._load_model( #No : SkyreelsVideoPipeline
98
  model_id=self.model_id, quant_model=self.quant_model, device="cpu"
99
  )
100
 
 
105
  )
106
 
107
  def inference(self, kwargs):
108
+ #DELAYED IMPORTS
109
+ from . import TaskType
110
  if self.task_type == TaskType.I2V:
111
  image = kwargs.pop("image")
112
+ output = self.pipe(image=image, **kwargs)
113
  else:
114
+ output = self.pipe(**kwargs)
115
+ return output.frames