Spaces:
Running
Running
import os | |
import random | |
from typing import Any, Dict, List, Optional, Tuple | |
import torch | |
from accelerate import init_empty_weights | |
from diffusers import ( | |
AutoencoderKLLTXVideo, | |
FlowMatchEulerDiscreteScheduler, | |
LTXImageToVideoPipeline, | |
LTXPipeline, | |
LTXVideoTransformer3DModel, | |
) | |
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution | |
from PIL.Image import Image | |
from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer | |
from ... import data | |
from ... import functional as FF | |
from ...logging import get_logger | |
from ...parallel import ParallelBackendEnum | |
from ...processors import ProcessorMixin, T5Processor | |
from ...typing import ArtifactType, SchedulerType | |
from ...utils import get_non_null_items | |
from ..modeling_utils import ModelSpecification | |
logger = get_logger() | |
class LTXLatentEncodeProcessor(ProcessorMixin): | |
r""" | |
Processor to encode image/video into latents using the LTX VAE. | |
Args: | |
output_names (`List[str]`): | |
The names of the outputs that the processor returns. The outputs are in the following order: | |
- latents: The latents of the input image/video. | |
- num_frames: The number of frames in the input video. | |
- height: The height of the input image/video. | |
- width: The width of the input image/video. | |
- latents_mean: The latent channel means from the VAE state dict. | |
- latents_std: The latent channel standard deviations from the VAE state dict. | |
""" | |
def __init__(self, output_names: List[str]): | |
super().__init__() | |
self.output_names = output_names | |
assert len(self.output_names) == 6 | |
def forward( | |
self, | |
vae: AutoencoderKLLTXVideo, | |
image: Optional[torch.Tensor] = None, | |
video: Optional[torch.Tensor] = None, | |
generator: Optional[torch.Generator] = None, | |
compute_posterior: bool = True, | |
) -> Dict[str, torch.Tensor]: | |
device = vae.device | |
dtype = vae.dtype | |
if image is not None: | |
video = image.unsqueeze(1) | |
assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor" | |
video = video.to(device=device, dtype=vae.dtype) | |
video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] | |
if compute_posterior: | |
latents = vae.encode(video).latent_dist.sample(generator=generator) | |
latents = latents.to(dtype=dtype) | |
else: | |
if vae.use_slicing and video.shape[0] > 1: | |
encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] | |
moments = torch.cat(encoded_slices) | |
else: | |
moments = vae._encode(video) | |
latents = moments.to(dtype=dtype) | |
_, _, num_frames, height, width = latents.shape | |
return { | |
self.output_names[0]: latents, | |
self.output_names[1]: num_frames, | |
self.output_names[2]: height, | |
self.output_names[3]: width, | |
self.output_names[4]: vae.latents_mean, | |
self.output_names[5]: vae.latents_std, | |
} | |
class LTXVideoModelSpecification(ModelSpecification): | |
def __init__( | |
self, | |
pretrained_model_name_or_path: str = "Lightricks/LTX-Video", | |
tokenizer_id: Optional[str] = None, | |
text_encoder_id: Optional[str] = None, | |
transformer_id: Optional[str] = None, | |
vae_id: Optional[str] = None, | |
text_encoder_dtype: torch.dtype = torch.bfloat16, | |
transformer_dtype: torch.dtype = torch.bfloat16, | |
vae_dtype: torch.dtype = torch.bfloat16, | |
revision: Optional[str] = None, | |
cache_dir: Optional[str] = None, | |
condition_model_processors: List[ProcessorMixin] = None, | |
latent_model_processors: List[ProcessorMixin] = None, | |
**kwargs, | |
) -> None: | |
super().__init__( | |
pretrained_model_name_or_path=pretrained_model_name_or_path, | |
tokenizer_id=tokenizer_id, | |
text_encoder_id=text_encoder_id, | |
transformer_id=transformer_id, | |
vae_id=vae_id, | |
text_encoder_dtype=text_encoder_dtype, | |
transformer_dtype=transformer_dtype, | |
vae_dtype=vae_dtype, | |
revision=revision, | |
cache_dir=cache_dir, | |
) | |
if condition_model_processors is None: | |
condition_model_processors = [T5Processor(["encoder_hidden_states", "encoder_attention_mask"])] | |
if latent_model_processors is None: | |
latent_model_processors = [ | |
LTXLatentEncodeProcessor(["latents", "num_frames", "height", "width", "latents_mean", "latents_std"]) | |
] | |
self.condition_model_processors = condition_model_processors | |
self.latent_model_processors = latent_model_processors | |
def _resolution_dim_keys(self): | |
return {"latents": (2, 3, 4)} | |
def load_condition_models(self) -> Dict[str, torch.nn.Module]: | |
if self.tokenizer_id is not None: | |
tokenizer = AutoTokenizer.from_pretrained( | |
self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir | |
) | |
else: | |
tokenizer = T5Tokenizer.from_pretrained( | |
self.pretrained_model_name_or_path, | |
subfolder="tokenizer", | |
revision=self.revision, | |
cache_dir=self.cache_dir, | |
) | |
if self.text_encoder_id is not None: | |
text_encoder = AutoModel.from_pretrained( | |
self.text_encoder_id, | |
torch_dtype=self.text_encoder_dtype, | |
revision=self.revision, | |
cache_dir=self.cache_dir, | |
) | |
else: | |
text_encoder = T5EncoderModel.from_pretrained( | |
self.pretrained_model_name_or_path, | |
subfolder="text_encoder", | |
torch_dtype=self.text_encoder_dtype, | |
revision=self.revision, | |
cache_dir=self.cache_dir, | |
) | |
return {"tokenizer": tokenizer, "text_encoder": text_encoder} | |
def load_latent_models(self) -> Dict[str, torch.nn.Module]: | |
if self.vae_id is not None: | |
vae = AutoencoderKLLTXVideo.from_pretrained( | |
self.vae_id, | |
torch_dtype=self.vae_dtype, | |
revision=self.revision, | |
cache_dir=self.cache_dir, | |
) | |
else: | |
vae = AutoencoderKLLTXVideo.from_pretrained( | |
self.pretrained_model_name_or_path, | |
subfolder="vae", | |
torch_dtype=self.vae_dtype, | |
revision=self.revision, | |
cache_dir=self.cache_dir, | |
) | |
return {"vae": vae} | |
def load_diffusion_models(self) -> Dict[str, torch.nn.Module]: | |
if self.transformer_id is not None: | |
transformer = LTXVideoTransformer3DModel.from_pretrained( | |
self.transformer_id, | |
torch_dtype=self.transformer_dtype, | |
revision=self.revision, | |
cache_dir=self.cache_dir, | |
) | |
else: | |
transformer = LTXVideoTransformer3DModel.from_pretrained( | |
self.pretrained_model_name_or_path, | |
subfolder="transformer", | |
torch_dtype=self.transformer_dtype, | |
revision=self.revision, | |
cache_dir=self.cache_dir, | |
) | |
scheduler = FlowMatchEulerDiscreteScheduler() | |
return {"transformer": transformer, "scheduler": scheduler} | |
def load_pipeline( | |
self, | |
tokenizer: Optional[T5Tokenizer] = None, | |
text_encoder: Optional[T5EncoderModel] = None, | |
transformer: Optional[LTXVideoTransformer3DModel] = None, | |
vae: Optional[AutoencoderKLLTXVideo] = None, | |
scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, | |
enable_slicing: bool = False, | |
enable_tiling: bool = False, | |
enable_model_cpu_offload: bool = False, | |
training: bool = False, | |
**kwargs, | |
) -> LTXPipeline: | |
components = { | |
"tokenizer": tokenizer, | |
"text_encoder": text_encoder, | |
"transformer": transformer, | |
"vae": vae, | |
"scheduler": scheduler, | |
} | |
components = get_non_null_items(components) | |
pipe = LTXPipeline.from_pretrained( | |
self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir | |
) | |
pipe.text_encoder.to(self.text_encoder_dtype) | |
pipe.vae.to(self.vae_dtype) | |
if not training: | |
pipe.transformer.to(self.transformer_dtype) | |
if enable_slicing: | |
pipe.vae.enable_slicing() | |
if enable_tiling: | |
pipe.vae.enable_tiling() | |
if enable_model_cpu_offload: | |
pipe.enable_model_cpu_offload() | |
return pipe | |
def prepare_conditions( | |
self, | |
tokenizer: T5Tokenizer, | |
text_encoder: T5EncoderModel, | |
caption: str, | |
max_sequence_length: int = 128, | |
**kwargs, | |
) -> Dict[str, Any]: | |
conditions = { | |
"tokenizer": tokenizer, | |
"text_encoder": text_encoder, | |
"caption": caption, | |
"max_sequence_length": max_sequence_length, | |
**kwargs, | |
} | |
input_keys = set(conditions.keys()) | |
conditions = super().prepare_conditions(**conditions) | |
conditions = {k: v for k, v in conditions.items() if k not in input_keys} | |
return conditions | |
def prepare_latents( | |
self, | |
vae: AutoencoderKLLTXVideo, | |
image: Optional[torch.Tensor] = None, | |
video: Optional[torch.Tensor] = None, | |
generator: Optional[torch.Generator] = None, | |
compute_posterior: bool = True, | |
**kwargs, | |
) -> Dict[str, torch.Tensor]: | |
conditions = { | |
"vae": vae, | |
"image": image, | |
"video": video, | |
"generator": generator, | |
"compute_posterior": compute_posterior, | |
**kwargs, | |
} | |
input_keys = set(conditions.keys()) | |
conditions = super().prepare_latents(**conditions) | |
conditions = {k: v for k, v in conditions.items() if k not in input_keys} | |
return conditions | |
def forward( | |
self, | |
transformer: LTXVideoTransformer3DModel, | |
condition_model_conditions: Dict[str, torch.Tensor], | |
latent_model_conditions: Dict[str, torch.Tensor], | |
sigmas: torch.Tensor, | |
generator: Optional[torch.Generator] = None, | |
compute_posterior: bool = True, | |
**kwargs, | |
) -> Tuple[torch.Tensor, ...]: | |
# TODO(aryan): make this configurable? Should it be? | |
first_frame_conditioning_p = 0.1 | |
min_first_frame_sigma = 0.25 | |
if compute_posterior: | |
latents = latent_model_conditions.pop("latents") | |
else: | |
posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents")) | |
latents = posterior.sample(generator=generator) | |
del posterior | |
latents_mean = latent_model_conditions.pop("latents_mean") | |
latents_std = latent_model_conditions.pop("latents_std") | |
latents = self._normalize_latents(latents, latents_mean, latents_std) | |
noise = torch.zeros_like(latents).normal_(generator=generator) | |
if random.random() < first_frame_conditioning_p: | |
# Based on Section 2.4 of the paper, it mentions that the first frame timesteps should be a small random value. | |
# Making as estimated guess, we limit the sigmas to be at least 0.2. | |
# torch.rand_like returns values in [0, 1). We want to make sure that the first frame sigma is <= actual sigmas | |
# for image conditioning. In order to do this, we rescale by multiplying with sigmas so the range is [0, sigmas). | |
first_frame_sigma = torch.rand_like(sigmas) * sigmas | |
first_frame_sigma = torch.min(first_frame_sigma, sigmas.new_full(sigmas.shape, min_first_frame_sigma)) | |
latents_first_frame, latents_rest = latents[:, :, :1], latents[:, :, 1:] | |
noisy_latents_first_frame = FF.flow_match_xt(latents_first_frame, noise[:, :, :1], first_frame_sigma) | |
noisy_latents_remaining = FF.flow_match_xt(latents_rest, noise[:, :, 1:], sigmas) | |
noisy_latents = torch.cat([noisy_latents_first_frame, noisy_latents_remaining], dim=2) | |
else: | |
noisy_latents = FF.flow_match_xt(latents, noise, sigmas) | |
patch_size = self.transformer_config.patch_size | |
patch_size_t = self.transformer_config.patch_size_t | |
latents = self._pack_latents(latents, patch_size, patch_size_t) | |
noise = self._pack_latents(noise, patch_size, patch_size_t) | |
noisy_latents = self._pack_latents(noisy_latents, patch_size, patch_size_t) | |
sigmas = sigmas.view(-1, 1, 1).expand(-1, *noisy_latents.shape[1:-1], -1) | |
timesteps = (sigmas * 1000.0).long() | |
latent_model_conditions["hidden_states"] = noisy_latents.to(latents) | |
# TODO(aryan): make this configurable | |
frame_rate = 25 | |
temporal_compression_ratio = 8 | |
vae_spatial_compression_ratio = 32 | |
latent_frame_rate = frame_rate / temporal_compression_ratio | |
rope_interpolation_scale = [ | |
1 / latent_frame_rate, | |
vae_spatial_compression_ratio, | |
vae_spatial_compression_ratio, | |
] | |
pred = transformer( | |
**latent_model_conditions, | |
**condition_model_conditions, | |
timestep=timesteps, | |
rope_interpolation_scale=rope_interpolation_scale, | |
return_dict=False, | |
)[0] | |
target = FF.flow_match_target(noise, latents) | |
return pred, target, sigmas | |
def validation( | |
self, | |
pipeline: LTXPipeline, | |
prompt: str, | |
image: Optional[Image] = None, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
num_frames: Optional[int] = None, | |
frame_rate: int = 25, | |
num_inference_steps: int = 50, | |
generator: Optional[torch.Generator] = None, | |
**kwargs, | |
) -> List[ArtifactType]: | |
if image is not None: | |
pipeline = LTXImageToVideoPipeline.from_pipe(pipeline) | |
generation_kwargs = { | |
"prompt": prompt, | |
"image": image, | |
"height": height, | |
"width": width, | |
"num_frames": num_frames, | |
"frame_rate": frame_rate, | |
"num_inference_steps": num_inference_steps, | |
"generator": generator, | |
"return_dict": True, | |
"output_type": "pil", | |
} | |
generation_kwargs = get_non_null_items(generation_kwargs) | |
video = pipeline(**generation_kwargs).frames[0] | |
return [data.VideoArtifact(value=video)] | |
def _save_lora_weights( | |
self, | |
directory: str, | |
transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, | |
scheduler: Optional[SchedulerType] = None, | |
*args, | |
**kwargs, | |
) -> None: | |
# TODO(aryan): this needs refactoring | |
if transformer_state_dict is not None: | |
LTXPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True) | |
if scheduler is not None: | |
scheduler.save_pretrained(os.path.join(directory, "scheduler")) | |
def _save_model( | |
self, | |
directory: str, | |
transformer: LTXVideoTransformer3DModel, | |
transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, | |
scheduler: Optional[SchedulerType] = None, | |
) -> None: | |
# TODO(aryan): this needs refactoring | |
if transformer_state_dict is not None: | |
with init_empty_weights(): | |
transformer_copy = LTXVideoTransformer3DModel.from_config(transformer.config) | |
transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) | |
transformer_copy.save_pretrained(os.path.join(directory, "transformer")) | |
if scheduler is not None: | |
scheduler.save_pretrained(os.path.join(directory, "scheduler")) | |
def apply_tensor_parallel( | |
self, | |
backend: ParallelBackendEnum, | |
device_mesh: torch.distributed.DeviceMesh, | |
transformer: LTXVideoTransformer3DModel, | |
**kwargs, | |
) -> None: | |
if backend == ParallelBackendEnum.PTD: | |
_apply_tensor_parallel_ptd(device_mesh, transformer) | |
else: | |
raise NotImplementedError(f"Parallel backend {backend} is not supported for LTXVideoModelSpecification") | |
def _normalize_latents( | |
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0 | |
) -> torch.Tensor: | |
# Normalize latents across the channel dimension [B, C, F, H, W] | |
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device) | |
latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device) | |
latents = ((latents.float() - latents_mean) * scaling_factor / latents_std).to(latents) | |
return latents | |
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor: | |
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p]. | |
# The patch dimensions are then permuted and collapsed into the channel dimension of shape: | |
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor). | |
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features | |
batch_size, num_channels, num_frames, height, width = latents.shape | |
post_patch_num_frames = num_frames // patch_size_t | |
post_patch_height = height // patch_size | |
post_patch_width = width // patch_size | |
latents = latents.reshape( | |
batch_size, | |
-1, | |
post_patch_num_frames, | |
patch_size_t, | |
post_patch_height, | |
patch_size, | |
post_patch_width, | |
patch_size, | |
) | |
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) | |
return latents | |
def _apply_tensor_parallel_ptd( | |
device_mesh: torch.distributed.device_mesh.DeviceMesh, transformer: LTXVideoTransformer3DModel | |
) -> None: | |
from torch.distributed.tensor.parallel import parallelize_module | |
from torch.distributed.tensor.parallel.style import ColwiseParallel, RowwiseParallel | |
transformer_plan = { | |
# ===== Condition embeddings ===== | |
# "time_embed.emb.timestep_embedder.linear_1": ColwiseParallel(), | |
# "time_embed.emb.timestep_embedder.linear_2": RowwiseParallel(output_layouts=Shard(-1)), | |
# "time_embed.linear": ColwiseParallel(input_layouts=Shard(-1), output_layouts=Replicate()), | |
# "time_embed": PrepareModuleOutput(output_layouts=(Replicate(), Shard(-1)), desired_output_layouts=(Replicate(), Replicate())), | |
# "caption_projection.linear_1": ColwiseParallel(), | |
# "caption_projection.linear_2": RowwiseParallel(), | |
# "rope": PrepareModuleOutput(output_layouts=(Replicate(), Replicate()), desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False), | |
# ===== ===== | |
} | |
for block in transformer.transformer_blocks: | |
block_plan = {} | |
# ===== Attention ===== | |
# 8 all-to-all, 3 all-reduce | |
# block_plan["attn1.to_q"] = ColwiseParallel(use_local_output=False) | |
# block_plan["attn1.to_k"] = ColwiseParallel(use_local_output=False) | |
# block_plan["attn1.to_v"] = ColwiseParallel(use_local_output=False) | |
# block_plan["attn1.norm_q"] = SequenceParallel() | |
# block_plan["attn1.norm_k"] = SequenceParallel() | |
# block_plan["attn1.to_out.0"] = RowwiseParallel(input_layouts=Shard(1)) | |
# block_plan["attn2.to_q"] = ColwiseParallel(use_local_output=False) | |
# block_plan["attn2.to_k"] = ColwiseParallel(use_local_output=False) | |
# block_plan["attn2.to_v"] = ColwiseParallel(use_local_output=False) | |
# block_plan["attn2.norm_q"] = SequenceParallel() | |
# block_plan["attn2.norm_k"] = SequenceParallel() | |
# block_plan["attn2.to_out.0"] = RowwiseParallel(input_layouts=Shard(1)) | |
# ===== ===== | |
block_plan["ff.net.0.proj"] = ColwiseParallel() | |
block_plan["ff.net.2"] = RowwiseParallel() | |
parallelize_module(block, device_mesh, block_plan) | |
parallelize_module(transformer, device_mesh, transformer_plan) | |