jbilcke's picture
upgrade Finetrainers
57737a0
raw
history blame
20.8 kB
import os
import random
from typing import Any, Dict, List, Optional, Tuple
import torch
from accelerate import init_empty_weights
from diffusers import (
AutoencoderKLLTXVideo,
FlowMatchEulerDiscreteScheduler,
LTXImageToVideoPipeline,
LTXPipeline,
LTXVideoTransformer3DModel,
)
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
from PIL.Image import Image
from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer
from ... import data
from ... import functional as FF
from ...logging import get_logger
from ...parallel import ParallelBackendEnum
from ...processors import ProcessorMixin, T5Processor
from ...typing import ArtifactType, SchedulerType
from ...utils import get_non_null_items
from ..modeling_utils import ModelSpecification
logger = get_logger()
class LTXLatentEncodeProcessor(ProcessorMixin):
r"""
Processor to encode image/video into latents using the LTX VAE.
Args:
output_names (`List[str]`):
The names of the outputs that the processor returns. The outputs are in the following order:
- latents: The latents of the input image/video.
- num_frames: The number of frames in the input video.
- height: The height of the input image/video.
- width: The width of the input image/video.
- latents_mean: The latent channel means from the VAE state dict.
- latents_std: The latent channel standard deviations from the VAE state dict.
"""
def __init__(self, output_names: List[str]):
super().__init__()
self.output_names = output_names
assert len(self.output_names) == 6
def forward(
self,
vae: AutoencoderKLLTXVideo,
image: Optional[torch.Tensor] = None,
video: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
compute_posterior: bool = True,
) -> Dict[str, torch.Tensor]:
device = vae.device
dtype = vae.dtype
if image is not None:
video = image.unsqueeze(1)
assert video.ndim == 5, f"Expected 5D tensor, got {video.ndim}D tensor"
video = video.to(device=device, dtype=vae.dtype)
video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W]
if compute_posterior:
latents = vae.encode(video).latent_dist.sample(generator=generator)
latents = latents.to(dtype=dtype)
else:
if vae.use_slicing and video.shape[0] > 1:
encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)]
moments = torch.cat(encoded_slices)
else:
moments = vae._encode(video)
latents = moments.to(dtype=dtype)
_, _, num_frames, height, width = latents.shape
return {
self.output_names[0]: latents,
self.output_names[1]: num_frames,
self.output_names[2]: height,
self.output_names[3]: width,
self.output_names[4]: vae.latents_mean,
self.output_names[5]: vae.latents_std,
}
class LTXVideoModelSpecification(ModelSpecification):
def __init__(
self,
pretrained_model_name_or_path: str = "Lightricks/LTX-Video",
tokenizer_id: Optional[str] = None,
text_encoder_id: Optional[str] = None,
transformer_id: Optional[str] = None,
vae_id: Optional[str] = None,
text_encoder_dtype: torch.dtype = torch.bfloat16,
transformer_dtype: torch.dtype = torch.bfloat16,
vae_dtype: torch.dtype = torch.bfloat16,
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
condition_model_processors: List[ProcessorMixin] = None,
latent_model_processors: List[ProcessorMixin] = None,
**kwargs,
) -> None:
super().__init__(
pretrained_model_name_or_path=pretrained_model_name_or_path,
tokenizer_id=tokenizer_id,
text_encoder_id=text_encoder_id,
transformer_id=transformer_id,
vae_id=vae_id,
text_encoder_dtype=text_encoder_dtype,
transformer_dtype=transformer_dtype,
vae_dtype=vae_dtype,
revision=revision,
cache_dir=cache_dir,
)
if condition_model_processors is None:
condition_model_processors = [T5Processor(["encoder_hidden_states", "encoder_attention_mask"])]
if latent_model_processors is None:
latent_model_processors = [
LTXLatentEncodeProcessor(["latents", "num_frames", "height", "width", "latents_mean", "latents_std"])
]
self.condition_model_processors = condition_model_processors
self.latent_model_processors = latent_model_processors
@property
def _resolution_dim_keys(self):
return {"latents": (2, 3, 4)}
def load_condition_models(self) -> Dict[str, torch.nn.Module]:
if self.tokenizer_id is not None:
tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
)
else:
tokenizer = T5Tokenizer.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=self.revision,
cache_dir=self.cache_dir,
)
if self.text_encoder_id is not None:
text_encoder = AutoModel.from_pretrained(
self.text_encoder_id,
torch_dtype=self.text_encoder_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
)
else:
text_encoder = T5EncoderModel.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="text_encoder",
torch_dtype=self.text_encoder_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
)
return {"tokenizer": tokenizer, "text_encoder": text_encoder}
def load_latent_models(self) -> Dict[str, torch.nn.Module]:
if self.vae_id is not None:
vae = AutoencoderKLLTXVideo.from_pretrained(
self.vae_id,
torch_dtype=self.vae_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
)
else:
vae = AutoencoderKLLTXVideo.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="vae",
torch_dtype=self.vae_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
)
return {"vae": vae}
def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
if self.transformer_id is not None:
transformer = LTXVideoTransformer3DModel.from_pretrained(
self.transformer_id,
torch_dtype=self.transformer_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
)
else:
transformer = LTXVideoTransformer3DModel.from_pretrained(
self.pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=self.transformer_dtype,
revision=self.revision,
cache_dir=self.cache_dir,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {"transformer": transformer, "scheduler": scheduler}
def load_pipeline(
self,
tokenizer: Optional[T5Tokenizer] = None,
text_encoder: Optional[T5EncoderModel] = None,
transformer: Optional[LTXVideoTransformer3DModel] = None,
vae: Optional[AutoencoderKLLTXVideo] = None,
scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
enable_slicing: bool = False,
enable_tiling: bool = False,
enable_model_cpu_offload: bool = False,
training: bool = False,
**kwargs,
) -> LTXPipeline:
components = {
"tokenizer": tokenizer,
"text_encoder": text_encoder,
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
}
components = get_non_null_items(components)
pipe = LTXPipeline.from_pretrained(
self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
)
pipe.text_encoder.to(self.text_encoder_dtype)
pipe.vae.to(self.vae_dtype)
if not training:
pipe.transformer.to(self.transformer_dtype)
if enable_slicing:
pipe.vae.enable_slicing()
if enable_tiling:
pipe.vae.enable_tiling()
if enable_model_cpu_offload:
pipe.enable_model_cpu_offload()
return pipe
@torch.no_grad()
def prepare_conditions(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
caption: str,
max_sequence_length: int = 128,
**kwargs,
) -> Dict[str, Any]:
conditions = {
"tokenizer": tokenizer,
"text_encoder": text_encoder,
"caption": caption,
"max_sequence_length": max_sequence_length,
**kwargs,
}
input_keys = set(conditions.keys())
conditions = super().prepare_conditions(**conditions)
conditions = {k: v for k, v in conditions.items() if k not in input_keys}
return conditions
@torch.no_grad()
def prepare_latents(
self,
vae: AutoencoderKLLTXVideo,
image: Optional[torch.Tensor] = None,
video: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
compute_posterior: bool = True,
**kwargs,
) -> Dict[str, torch.Tensor]:
conditions = {
"vae": vae,
"image": image,
"video": video,
"generator": generator,
"compute_posterior": compute_posterior,
**kwargs,
}
input_keys = set(conditions.keys())
conditions = super().prepare_latents(**conditions)
conditions = {k: v for k, v in conditions.items() if k not in input_keys}
return conditions
def forward(
self,
transformer: LTXVideoTransformer3DModel,
condition_model_conditions: Dict[str, torch.Tensor],
latent_model_conditions: Dict[str, torch.Tensor],
sigmas: torch.Tensor,
generator: Optional[torch.Generator] = None,
compute_posterior: bool = True,
**kwargs,
) -> Tuple[torch.Tensor, ...]:
# TODO(aryan): make this configurable? Should it be?
first_frame_conditioning_p = 0.1
min_first_frame_sigma = 0.25
if compute_posterior:
latents = latent_model_conditions.pop("latents")
else:
posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
latents = posterior.sample(generator=generator)
del posterior
latents_mean = latent_model_conditions.pop("latents_mean")
latents_std = latent_model_conditions.pop("latents_std")
latents = self._normalize_latents(latents, latents_mean, latents_std)
noise = torch.zeros_like(latents).normal_(generator=generator)
if random.random() < first_frame_conditioning_p:
# Based on Section 2.4 of the paper, it mentions that the first frame timesteps should be a small random value.
# Making as estimated guess, we limit the sigmas to be at least 0.2.
# torch.rand_like returns values in [0, 1). We want to make sure that the first frame sigma is <= actual sigmas
# for image conditioning. In order to do this, we rescale by multiplying with sigmas so the range is [0, sigmas).
first_frame_sigma = torch.rand_like(sigmas) * sigmas
first_frame_sigma = torch.min(first_frame_sigma, sigmas.new_full(sigmas.shape, min_first_frame_sigma))
latents_first_frame, latents_rest = latents[:, :, :1], latents[:, :, 1:]
noisy_latents_first_frame = FF.flow_match_xt(latents_first_frame, noise[:, :, :1], first_frame_sigma)
noisy_latents_remaining = FF.flow_match_xt(latents_rest, noise[:, :, 1:], sigmas)
noisy_latents = torch.cat([noisy_latents_first_frame, noisy_latents_remaining], dim=2)
else:
noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
patch_size = self.transformer_config.patch_size
patch_size_t = self.transformer_config.patch_size_t
latents = self._pack_latents(latents, patch_size, patch_size_t)
noise = self._pack_latents(noise, patch_size, patch_size_t)
noisy_latents = self._pack_latents(noisy_latents, patch_size, patch_size_t)
sigmas = sigmas.view(-1, 1, 1).expand(-1, *noisy_latents.shape[1:-1], -1)
timesteps = (sigmas * 1000.0).long()
latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
# TODO(aryan): make this configurable
frame_rate = 25
temporal_compression_ratio = 8
vae_spatial_compression_ratio = 32
latent_frame_rate = frame_rate / temporal_compression_ratio
rope_interpolation_scale = [
1 / latent_frame_rate,
vae_spatial_compression_ratio,
vae_spatial_compression_ratio,
]
pred = transformer(
**latent_model_conditions,
**condition_model_conditions,
timestep=timesteps,
rope_interpolation_scale=rope_interpolation_scale,
return_dict=False,
)[0]
target = FF.flow_match_target(noise, latents)
return pred, target, sigmas
def validation(
self,
pipeline: LTXPipeline,
prompt: str,
image: Optional[Image] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: Optional[int] = None,
frame_rate: int = 25,
num_inference_steps: int = 50,
generator: Optional[torch.Generator] = None,
**kwargs,
) -> List[ArtifactType]:
if image is not None:
pipeline = LTXImageToVideoPipeline.from_pipe(pipeline)
generation_kwargs = {
"prompt": prompt,
"image": image,
"height": height,
"width": width,
"num_frames": num_frames,
"frame_rate": frame_rate,
"num_inference_steps": num_inference_steps,
"generator": generator,
"return_dict": True,
"output_type": "pil",
}
generation_kwargs = get_non_null_items(generation_kwargs)
video = pipeline(**generation_kwargs).frames[0]
return [data.VideoArtifact(value=video)]
def _save_lora_weights(
self,
directory: str,
transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
scheduler: Optional[SchedulerType] = None,
*args,
**kwargs,
) -> None:
# TODO(aryan): this needs refactoring
if transformer_state_dict is not None:
LTXPipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True)
if scheduler is not None:
scheduler.save_pretrained(os.path.join(directory, "scheduler"))
def _save_model(
self,
directory: str,
transformer: LTXVideoTransformer3DModel,
transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
scheduler: Optional[SchedulerType] = None,
) -> None:
# TODO(aryan): this needs refactoring
if transformer_state_dict is not None:
with init_empty_weights():
transformer_copy = LTXVideoTransformer3DModel.from_config(transformer.config)
transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
if scheduler is not None:
scheduler.save_pretrained(os.path.join(directory, "scheduler"))
def apply_tensor_parallel(
self,
backend: ParallelBackendEnum,
device_mesh: torch.distributed.DeviceMesh,
transformer: LTXVideoTransformer3DModel,
**kwargs,
) -> None:
if backend == ParallelBackendEnum.PTD:
_apply_tensor_parallel_ptd(device_mesh, transformer)
else:
raise NotImplementedError(f"Parallel backend {backend} is not supported for LTXVideoModelSpecification")
@staticmethod
def _normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Normalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(device=latents.device)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(device=latents.device)
latents = ((latents.float() - latents_mean) * scaling_factor / latents_std).to(latents)
return latents
@staticmethod
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
# The patch dimensions are then permuted and collapsed into the channel dimension of shape:
# [B, F // p_t * H // p * W // p, C * p_t * p * p] (an ndim=3 tensor).
# dim=0 is the batch size, dim=1 is the effective video sequence length, dim=2 is the effective number of input features
batch_size, num_channels, num_frames, height, width = latents.shape
post_patch_num_frames = num_frames // patch_size_t
post_patch_height = height // patch_size
post_patch_width = width // patch_size
latents = latents.reshape(
batch_size,
-1,
post_patch_num_frames,
patch_size_t,
post_patch_height,
patch_size,
post_patch_width,
patch_size,
)
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
return latents
def _apply_tensor_parallel_ptd(
device_mesh: torch.distributed.device_mesh.DeviceMesh, transformer: LTXVideoTransformer3DModel
) -> None:
from torch.distributed.tensor.parallel import parallelize_module
from torch.distributed.tensor.parallel.style import ColwiseParallel, RowwiseParallel
transformer_plan = {
# ===== Condition embeddings =====
# "time_embed.emb.timestep_embedder.linear_1": ColwiseParallel(),
# "time_embed.emb.timestep_embedder.linear_2": RowwiseParallel(output_layouts=Shard(-1)),
# "time_embed.linear": ColwiseParallel(input_layouts=Shard(-1), output_layouts=Replicate()),
# "time_embed": PrepareModuleOutput(output_layouts=(Replicate(), Shard(-1)), desired_output_layouts=(Replicate(), Replicate())),
# "caption_projection.linear_1": ColwiseParallel(),
# "caption_projection.linear_2": RowwiseParallel(),
# "rope": PrepareModuleOutput(output_layouts=(Replicate(), Replicate()), desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False),
# ===== =====
}
for block in transformer.transformer_blocks:
block_plan = {}
# ===== Attention =====
# 8 all-to-all, 3 all-reduce
# block_plan["attn1.to_q"] = ColwiseParallel(use_local_output=False)
# block_plan["attn1.to_k"] = ColwiseParallel(use_local_output=False)
# block_plan["attn1.to_v"] = ColwiseParallel(use_local_output=False)
# block_plan["attn1.norm_q"] = SequenceParallel()
# block_plan["attn1.norm_k"] = SequenceParallel()
# block_plan["attn1.to_out.0"] = RowwiseParallel(input_layouts=Shard(1))
# block_plan["attn2.to_q"] = ColwiseParallel(use_local_output=False)
# block_plan["attn2.to_k"] = ColwiseParallel(use_local_output=False)
# block_plan["attn2.to_v"] = ColwiseParallel(use_local_output=False)
# block_plan["attn2.norm_q"] = SequenceParallel()
# block_plan["attn2.norm_k"] = SequenceParallel()
# block_plan["attn2.to_out.0"] = RowwiseParallel(input_layouts=Shard(1))
# ===== =====
block_plan["ff.net.0.proj"] = ColwiseParallel()
block_plan["ff.net.2"] = RowwiseParallel()
parallelize_module(block, device_mesh, block_plan)
parallelize_module(transformer, device_mesh, transformer_plan)