import logging import os import time from datetime import timedelta from typing import Any from typing import Dict import torch from diffusers import HunyuanVideoTransformer3DModel from PIL import Image from torchao.quantization import float8_weight_only from torchao.quantization import quantize_ from transformers import LlamaModel from . import TaskType # Assuming these are still needed from .offload import Offload, OffloadConfig from .pipelines import SkyreelsVideoPipeline logger = logging.getLogger("SkyreelsVideoInfer") logger.setLevel(logging.DEBUG) console_handler = logging.StreamHandler() console_handler.setLevel(logging.DEBUG) formatter = logging.Formatter( f"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d - %(funcName)s] - %(message)s" ) console_handler.setFormatter(formatter) logger.addHandler(console_handler) class SkyReelsVideoSingleGpuInfer: def _load_model( self, model_id: str, base_model_id: str = "hunyuanvideo-community/HunyuanVideo", quant_model: bool = True, gpu_device: str = "cuda:0", ) -> SkyreelsVideoPipeline: logger.info(f"load model model_id:{model_id} quan_model:{quant_model} gpu_device:{gpu_device}") text_encoder = LlamaModel.from_pretrained( base_model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, ).to("cpu") transformer = HunyuanVideoTransformer3DModel.from_pretrained( model_id, # subfolder="transformer", torch_dtype=torch.bfloat16, device="cpu", ).to("cpu").eval() if quant_model: quantize_(text_encoder, float8_weight_only(), device="cpu") text_encoder.to("cpu") #torch.cuda.empty_cache() quantize_(transformer, float8_weight_only(), device="cpu") transformer.to("cpu") #torch.cuda.empty_cache() pipe = SkyreelsVideoPipeline.from_pretrained( base_model_id, transformer=transformer, text_encoder=text_encoder, torch_dtype=torch.bfloat16, ).to("cpu") pipe.vae.enable_tiling() torch.cuda.empty_cache() return pipe def __init__( self, task_type: TaskType, model_id: str, quant_model: bool = True, is_offload: bool = True, offload_config: OffloadConfig = OffloadConfig(), ): self.task_type = task_type # os.environ["LOCAL_RANK"] = "0" # No longer needed in single-GPU #torch.cuda.set_device(0) # Still a good idea to be explicit. torch.backends.cuda.enable_cudnn_sdp(False) #Still a good idea to keep it. gpu_device = "cuda:0" self.pipe: SkyreelsVideoPipeline = self._load_model( model_id=model_id, quant_model=quant_model, gpu_device=gpu_device ) if is_offload: Offload.offload( pipeline=self.pipe, config=offload_config, ) else: self.pipe.to(gpu_device) if offload_config.compiler_transformer: torch._dynamo.config.suppress_errors = True os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{offload_config.compiler_cache}_1" #_1 represents 1 gpu. self.pipe.transformer = torch.compile( self.pipe.transformer, mode="max-autotune-no-cudagraphs", dynamic=True, ) self.warm_up() def warm_up(self): init_kwargs = { "prompt": "A woman is dancing in a room", "height": 512, "width": 512, "guidance_scale": 6, "num_inference_steps": 1, "negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion", "num_frames": 97, "generator": torch.Generator("cuda").manual_seed(42), "embedded_guidance_scale": 1.0, } if self.task_type == TaskType.I2V: init_kwargs["image"] = Image.new("RGB", (512, 512), color="black") self.pipe(**init_kwargs) def inference(self, kwargs: Dict[str, Any]): logger.info(f"kwargs: {kwargs}") if "seed" in kwargs: kwargs["generator"] = torch.Generator("cuda").manual_seed(kwargs["seed"]) del kwargs["seed"] start_time = time.time() assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V out = self.pipe(**kwargs).frames[0] logger.info(f"inference time: {time.time() - start_time}") return out