SkyReels / skyreelsinfer /skyreels_video_infer.py
1inkusFace's picture
Update skyreelsinfer/skyreels_video_infer.py
3b62245 verified
import logging
import os
import time
from datetime import timedelta
from typing import Any
from typing import Dict
import torch
from diffusers import HunyuanVideoTransformer3DModel
from PIL import Image
from torchao.quantization import float8_weight_only
from torchao.quantization import quantize_
from transformers import LlamaModel
from . import TaskType # Assuming these are still needed
from .offload import Offload, OffloadConfig
from .pipelines import SkyreelsVideoPipeline
logger = logging.getLogger("SkyreelsVideoInfer")
logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
f"%(asctime)s - %(name)s - %(levelname)s - [%(filename)s:%(lineno)d - %(funcName)s] - %(message)s"
)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
class SkyReelsVideoSingleGpuInfer:
def _load_model(
self,
model_id: str,
base_model_id: str = "hunyuanvideo-community/HunyuanVideo",
quant_model: bool = True,
gpu_device: str = "cuda:0",
) -> SkyreelsVideoPipeline:
logger.info(f"load model model_id:{model_id} quan_model:{quant_model} gpu_device:{gpu_device}")
text_encoder = LlamaModel.from_pretrained(
base_model_id,
subfolder="text_encoder",
torch_dtype=torch.bfloat16,
).to("cpu")
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id,
# subfolder="transformer",
torch_dtype=torch.bfloat16,
device="cpu",
).to("cpu").eval()
if quant_model:
quantize_(text_encoder, float8_weight_only(), device="cpu")
text_encoder.to("cpu")
#torch.cuda.empty_cache()
quantize_(transformer, float8_weight_only(), device="cpu")
transformer.to("cpu")
#torch.cuda.empty_cache()
pipe = SkyreelsVideoPipeline.from_pretrained(
base_model_id,
transformer=transformer,
text_encoder=text_encoder,
torch_dtype=torch.bfloat16,
).to("cpu")
pipe.vae.enable_tiling()
torch.cuda.empty_cache()
return pipe
def __init__(
self,
task_type: TaskType,
model_id: str,
quant_model: bool = True,
is_offload: bool = True,
offload_config: OffloadConfig = OffloadConfig(),
):
self.task_type = task_type
# os.environ["LOCAL_RANK"] = "0" # No longer needed in single-GPU
#torch.cuda.set_device(0) # Still a good idea to be explicit.
torch.backends.cuda.enable_cudnn_sdp(False) #Still a good idea to keep it.
gpu_device = "cuda:0"
self.pipe: SkyreelsVideoPipeline = self._load_model(
model_id=model_id, quant_model=quant_model, gpu_device=gpu_device
)
if is_offload:
Offload.offload(
pipeline=self.pipe,
config=offload_config,
)
else:
self.pipe.to(gpu_device)
if offload_config.compiler_transformer:
torch._dynamo.config.suppress_errors = True
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
os.environ["TORCHINDUCTOR_CACHE_DIR"] = f"{offload_config.compiler_cache}_1" #_1 represents 1 gpu.
self.pipe.transformer = torch.compile(
self.pipe.transformer,
mode="max-autotune-no-cudagraphs",
dynamic=True,
)
self.warm_up()
def warm_up(self):
init_kwargs = {
"prompt": "A woman is dancing in a room",
"height": 512,
"width": 512,
"guidance_scale": 6,
"num_inference_steps": 1,
"negative_prompt": "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion",
"num_frames": 97,
"generator": torch.Generator("cuda").manual_seed(42),
"embedded_guidance_scale": 1.0,
}
if self.task_type == TaskType.I2V:
init_kwargs["image"] = Image.new("RGB", (512, 512), color="black")
self.pipe(**init_kwargs)
def inference(self, kwargs: Dict[str, Any]):
logger.info(f"kwargs: {kwargs}")
if "seed" in kwargs:
kwargs["generator"] = torch.Generator("cuda").manual_seed(kwargs["seed"])
del kwargs["seed"]
start_time = time.time()
assert (self.task_type == TaskType.I2V and "image" in kwargs) or self.task_type == TaskType.T2V
out = self.pipe(**kwargs).frames[0]
logger.info(f"inference time: {time.time() - start_time}")
return out