Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import os | |
import time | |
from datetime import timedelta | |
from typing import Any | |
from typing import Dict | |
import torch | |
from diffusers import HunyuanVideoTransformer3DModel | |
from PIL import Image | |
from torchao.quantization import float8_weight_only | |
from torchao.quantization import quantize_ | |
from transformers import LlamaModel | |
from . import TaskType # Assuming these are still needed | |
from .offload import Offload, OffloadConfig | |
from .pipelines import SkyreelsVideoPipeline | |
logger = logging.getLogger("SkyreelsVideoInfer") | |
logger.setLevel(logging.DEBUG) | |
console_handler = logging.StreamHandler() | |
console_handler.setLevel(logging.DEBUG) | |
formatter = logging.Formatter( | |
f"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d - %(funcName)s] - %(message)s" | |
) | |
console_handler.setFormatter(formatter) | |
logger.addHandler(console_handler) | |
class SkyReelsVideoSingleGpuInfer: | |
def _load_model( | |
self, | |
model_id: str, | |
base_model_id: str = "hunyuanvideo-community/HunyuanVideo", | |
quant_model: bool = True, | |
gpu_device: str = "cuda:0", | |
) -> SkyreelsVideoPipeline: | |
logger.info(f"load model model_id:{model_id} quan_model:{quant_model} gpu_device:{gpu_device}") | |
text_encoder = LlamaModel.from_pretrained( | |
base_model_id, | |
subfolder="text_encoder", | |
torch_dtype=torch.bfloat16, | |
).to("cpu") | |
transformer = HunyuanVideoTransformer3DModel.from_pretrained( | |
model_id, | |
# subfolder="transformer", | |
torch_dtype=torch.bfloat16, | |
device="cpu", | |
).to("cpu").eval() | |
if quant_model: | |
quantize_(text_encoder, float8_weight_only(), device="cpu") | |
text_encoder.to("cpu") | |
#torch.cuda.empty_cache() | |
quantize_(transformer, float8_weight_only(), device="cpu") | |
transformer.to("cpu") | |
#torch.cuda.empty_cache() | |
pipe = SkyreelsVideoPipeline.from_pretrained( | |
base_model_id, | |
transformer=transformer, | |
text_encoder=text_encoder, | |
torch_dtype=torch.bfloat16, | |
).to("cpu") | |
pipe.vae.enable_tiling() | |
torch.cuda.empty_cache() | |
return pipe | |
def __init__( | |
self, | |
task_type: TaskType, | |
model_id: str, | |
quant_model: bool = True, | |
is_offload: bool = True, | |
offload_config: OffloadConfig = OffloadConfig(), | |
): | |
self.task_type = task_type | |
# os.environ["LOCAL_RANK"] = "0" # No longer needed in single-GPU | |
#torch.cuda.set_device(0) # Still a good idea to be explicit. | |
torch.backends.cuda.enable_cudnn_sdp(False) #Still a good idea to keep it. | |
gpu_device = "cuda:0" | |
self.pipe: SkyreelsVideoPipeline = self._load_model( | |
model_id=model_id, quant_model=quant_model, gpu_device=gpu_device | |
) | |
if is_offload: | |
Offload.offload( | |
pipeline=self.pipe, | |
config=offload_config, | |
) | |
else: | |
self.pipe.to(gpu_device) | |
if offload_config.compiler_transformer: | |
torch._dynamo.config.suppress_errors = True | |
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" | |
os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{offload_config.compiler_cache}_1" #_1 represents 1 gpu. | |
self.pipe.transformer = torch.compile( | |
self.pipe.transformer, | |
mode="max-autotune-no-cudagraphs", | |
dynamic=True, | |
) | |
self.warm_up() | |
def warm_up(self): | |
init_kwargs = { | |
"prompt": "A woman is dancing in a room", | |
"height": 512, | |
"width": 512, | |
"guidance_scale": 6, | |
"num_inference_steps": 1, | |
"negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", | |
"num_frames": 97, | |
"generator": torch.Generator("cuda").manual_seed(42), | |
"embedded_guidance_scale": 1.0, | |
} | |
if self.task_type == TaskType.I2V: | |
init_kwargs["image"] = Image.new("RGB", (512, 512), color="black") | |
self.pipe(**init_kwargs) | |
def inference(self, kwargs: Dict[str, Any]): | |
logger.info(f"kwargs: {kwargs}") | |
if "seed" in kwargs: | |
kwargs["generator"] = torch.Generator("cuda").manual_seed(kwargs["seed"]) | |
del kwargs["seed"] | |
start_time = time.time() | |
assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V | |
out = self.pipe(**kwargs).frames[0] | |
logger.info(f"inference time: {time.time() - start_time}") | |
return out |