Wan2GP / hyvideo /modules /models.py
zxymimi23451's picture
Upload 258 files
78360e7 verified
raw
history blame
47.1 kB
from typing import Any, List, Tuple, Optional, Union, Dict
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from .activation_layers import get_activation_layer
from .norm_layers import get_norm_layer
from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
from .attenion import attention, parallel_attention, get_cu_seqlens
from .posemb_layers import apply_rotary_emb
from .mlp_layers import MLP, MLPEmbedder, FinalLayer
from .modulate_layers import ModulateDiT, modulate, modulate_ , apply_gate, apply_gate_and_accumulate_
from .token_refiner import SingleTokenRefiner
import numpy as np
from mmgp import offload
from wan.modules.attention import pay_attention
from .audio_adapters import AudioProjNet2, PerceiverAttentionCA
def get_linear_split_map():
hidden_size = 3072
split_linear_modules_map = {
"img_attn_qkv" : {"mapped_modules" : ["img_attn_q", "img_attn_k", "img_attn_v"] , "split_sizes": [hidden_size, hidden_size, hidden_size]},
"linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes": [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]}
}
return split_linear_modules_map
try:
from xformers.ops.fmha.attn_bias import BlockDiagonalPaddedKeysMask
except ImportError:
BlockDiagonalPaddedKeysMask = None
class MMDoubleStreamBlock(nn.Module):
"""
A multimodal dit block with seperate modulation for
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qkv_bias: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
attention_mode: str = "sdpa",
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.attention_mode = attention_mode
self.deterministic = False
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.img_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.img_norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.img_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.img_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.img_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.img_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.img_norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.img_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
self.txt_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.txt_norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.txt_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
)
self.txt_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.txt_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.txt_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.txt_norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.txt_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
self.hybrid_seq_parallel_attn = None
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
attn_mask = None,
seqlens_q: Optional[torch.Tensor] = None,
seqlens_kv: Optional[torch.Tensor] = None,
freqs_cis: tuple = None,
condition_type: str = None,
token_replace_vec: torch.Tensor = None,
frist_frame_token_num: int = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if condition_type == "token_replace":
img_mod1, token_replace_img_mod1 = self.img_mod(vec, condition_type=condition_type, \
token_replace_vec=token_replace_vec)
(img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate) = img_mod1.chunk(6, dim=-1)
(tr_img_mod1_shift,
tr_img_mod1_scale,
tr_img_mod1_gate,
tr_img_mod2_shift,
tr_img_mod2_scale,
tr_img_mod2_gate) = token_replace_img_mod1.chunk(6, dim=-1)
else:
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = self.img_mod(vec).chunk(6, dim=-1)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.txt_mod(vec).chunk(6, dim=-1)
##### Enjoy this spagheti VRAM optimizations done by DeepBeepMeep !
# I am sure you are a nice person and as you copy this code, you will give me officially proper credits:
# Please link to https://github.com/deepbeepmeep/HunyuanVideoGP and @deepbeepmeep on twitter
# Prepare image for attention.
img_modulated = self.img_norm1(img)
img_modulated = img_modulated.to(torch.bfloat16)
if condition_type == "token_replace":
modulate_(img_modulated[:, :frist_frame_token_num], shift=tr_img_mod1_shift, scale=tr_img_mod1_scale)
modulate_(img_modulated[:, frist_frame_token_num:], shift=img_mod1_shift, scale=img_mod1_scale)
else:
modulate_( img_modulated, shift=img_mod1_shift, scale=img_mod1_scale )
shape = (*img_modulated.shape[:2], self.heads_num, int(img_modulated.shape[-1] / self.heads_num) )
img_q = self.img_attn_q(img_modulated).view(*shape)
img_k = self.img_attn_k(img_modulated).view(*shape)
img_v = self.img_attn_v(img_modulated).view(*shape)
del img_modulated
# Apply QK-Norm if needed
self.img_attn_q_norm.apply_(img_q).to(img_v)
img_q_len = img_q.shape[1]
self.img_attn_k_norm.apply_(img_k).to(img_v)
img_kv_len= img_k.shape[1]
batch_size = img_k.shape[0]
# Apply RoPE if needed.
qklist = [img_q, img_k]
del img_q, img_k
img_q, img_k = apply_rotary_emb(qklist, freqs_cis, head_first=False)
# Prepare txt for attention.
txt_modulated = self.txt_norm1(txt)
modulate_(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale )
txt_qkv = self.txt_attn_qkv(txt_modulated)
del txt_modulated
txt_q, txt_k, txt_v = rearrange(
txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
)
del txt_qkv
# Apply QK-Norm if needed.
self.txt_attn_q_norm.apply_(txt_q).to(txt_v)
self.txt_attn_k_norm.apply_(txt_k).to(txt_v)
# Run actual attention.
q = torch.cat((img_q, txt_q), dim=1)
del img_q, txt_q
k = torch.cat((img_k, txt_k), dim=1)
del img_k, txt_k
v = torch.cat((img_v, txt_v), dim=1)
del img_v, txt_v
# attention computation start
qkv_list = [q,k,v]
del q, k, v
attn = pay_attention(
qkv_list,
attention_mask=attn_mask,
q_lens=seqlens_q,
k_lens=seqlens_kv,
)
b, s, a, d = attn.shape
attn = attn.reshape(b, s, -1)
del qkv_list
# attention computation end
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
del attn
# Calculate the img bloks.
if condition_type == "token_replace":
img_attn = self.img_attn_proj(img_attn)
apply_gate_and_accumulate_(img[:, :frist_frame_token_num], img_attn[:, :frist_frame_token_num], gate=tr_img_mod1_gate)
apply_gate_and_accumulate_(img[:, frist_frame_token_num:], img_attn[:, frist_frame_token_num:], gate=img_mod1_gate)
del img_attn
img_modulated = self.img_norm2(img)
img_modulated = img_modulated.to(torch.bfloat16)
modulate_( img_modulated[:, :frist_frame_token_num], shift=tr_img_mod2_shift, scale=tr_img_mod2_scale)
modulate_( img_modulated[:, frist_frame_token_num:], shift=img_mod2_shift, scale=img_mod2_scale)
self.img_mlp.apply_(img_modulated)
apply_gate_and_accumulate_(img[:, :frist_frame_token_num], img_modulated[:, :frist_frame_token_num], gate=tr_img_mod2_gate)
apply_gate_and_accumulate_(img[:, frist_frame_token_num:], img_modulated[:, frist_frame_token_num:], gate=img_mod2_gate)
del img_modulated
else:
img_attn = self.img_attn_proj(img_attn)
apply_gate_and_accumulate_(img, img_attn, gate=img_mod1_gate)
del img_attn
img_modulated = self.img_norm2(img)
img_modulated = img_modulated.to(torch.bfloat16)
modulate_( img_modulated , shift=img_mod2_shift, scale=img_mod2_scale)
self.img_mlp.apply_(img_modulated)
apply_gate_and_accumulate_(img, img_modulated, gate=img_mod2_gate)
del img_modulated
# Calculate the txt bloks.
txt_attn = self.txt_attn_proj(txt_attn)
apply_gate_and_accumulate_(txt, txt_attn, gate=txt_mod1_gate)
del txt_attn
txt_modulated = self.txt_norm2(txt)
txt_modulated = txt_modulated.to(torch.bfloat16)
modulate_(txt_modulated, shift=txt_mod2_shift, scale=txt_mod2_scale)
txt_mlp = self.txt_mlp(txt_modulated)
del txt_modulated
apply_gate_and_accumulate_(txt, txt_mlp, gate=txt_mod2_gate)
return img, txt
class MMSingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
Also refer to (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qk_scale: float = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
attention_mode: str = "sdpa",
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.attention_mode = attention_mode
self.deterministic = False
self.hidden_size = hidden_size
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.mlp_hidden_dim = mlp_hidden_dim
self.scale = qk_scale or head_dim ** -0.5
# qkv and mlp_in
self.linear1 = nn.Linear(
hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs
)
# proj and mlp_out
self.linear2 = nn.Linear(
hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.pre_norm = nn.LayerNorm(
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
)
self.mlp_act = get_activation_layer(mlp_act_type)()
self.modulation = ModulateDiT(
hidden_size,
factor=3,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.hybrid_seq_parallel_attn = None
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
# x: torch.Tensor,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
txt_len: int,
attn_mask= None,
seqlens_q: Optional[torch.Tensor] = None,
seqlens_kv: Optional[torch.Tensor] = None,
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
condition_type: str = None,
token_replace_vec: torch.Tensor = None,
frist_frame_token_num: int = None,
) -> torch.Tensor:
##### More spagheti VRAM optimizations done by DeepBeepMeep !
# I am sure you are a nice person and as you copy this code, you will give me proper credits:
# Please link to https://github.com/deepbeepmeep/HunyuanVideoGP and @deepbeepmeep on twitter
if condition_type == "token_replace":
mod, tr_mod = self.modulation(vec,
condition_type=condition_type,
token_replace_vec=token_replace_vec)
(mod_shift,
mod_scale,
mod_gate) = mod.chunk(3, dim=-1)
(tr_mod_shift,
tr_mod_scale,
tr_mod_gate) = tr_mod.chunk(3, dim=-1)
else:
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
img_mod = self.pre_norm(img)
img_mod = img_mod.to(torch.bfloat16)
if condition_type == "token_replace":
modulate_(img_mod[:, :frist_frame_token_num], shift=tr_mod_shift, scale=tr_mod_scale)
modulate_(img_mod[:, frist_frame_token_num:], shift=mod_shift, scale=mod_scale)
else:
modulate_(img_mod, shift=mod_shift, scale=mod_scale)
txt_mod = self.pre_norm(txt)
txt_mod = txt_mod.to(torch.bfloat16)
modulate_(txt_mod, shift=mod_shift, scale=mod_scale)
shape = (*img_mod.shape[:2], self.heads_num, int(img_mod.shape[-1] / self.heads_num) )
img_q = self.linear1_attn_q(img_mod).view(*shape)
img_k = self.linear1_attn_k(img_mod).view(*shape)
img_v = self.linear1_attn_v(img_mod).view(*shape)
shape = (*txt_mod.shape[:2], self.heads_num, int(txt_mod.shape[-1] / self.heads_num) )
txt_q = self.linear1_attn_q(txt_mod).view(*shape)
txt_k = self.linear1_attn_k(txt_mod).view(*shape)
txt_v = self.linear1_attn_v(txt_mod).view(*shape)
batch_size = img_mod.shape[0]
# Apply QK-Norm if needed.
# q = self.q_norm(q).to(v)
self.q_norm.apply_(img_q)
self.k_norm.apply_(img_k)
self.q_norm.apply_(txt_q)
self.k_norm.apply_(txt_k)
qklist = [img_q, img_k]
del img_q, img_k
img_q, img_k = apply_rotary_emb(qklist, freqs_cis, head_first=False)
img_q_len=img_q.shape[1]
q = torch.cat((img_q, txt_q), dim=1)
del img_q, txt_q
k = torch.cat((img_k, txt_k), dim=1)
img_kv_len=img_k.shape[1]
del img_k, txt_k
v = torch.cat((img_v, txt_v), dim=1)
del img_v, txt_v
# attention computation start
qkv_list = [q,k,v]
del q, k, v
attn = pay_attention(
qkv_list,
attention_mask=attn_mask,
q_lens = seqlens_q,
k_lens = seqlens_kv,
)
b, s, a, d = attn.shape
attn = attn.reshape(b, s, -1)
del qkv_list
# attention computation end
x_mod = torch.cat((img_mod, txt_mod), 1)
del img_mod, txt_mod
x_mod_shape = x_mod.shape
x_mod = x_mod.view(-1, x_mod.shape[-1])
chunk_size = int(x_mod_shape[1]/6)
x_chunks = torch.split(x_mod, chunk_size)
attn = attn.view(-1, attn.shape[-1])
attn_chunks =torch.split(attn, chunk_size)
for x_chunk, attn_chunk in zip(x_chunks, attn_chunks):
mlp_chunk = self.linear1_mlp(x_chunk)
mlp_chunk = self.mlp_act(mlp_chunk)
attn_mlp_chunk = torch.cat((attn_chunk, mlp_chunk), -1)
del attn_chunk, mlp_chunk
x_chunk[...] = self.linear2(attn_mlp_chunk)
del attn_mlp_chunk
x_mod = x_mod.view(x_mod_shape)
if condition_type == "token_replace":
apply_gate_and_accumulate_(img[:, :frist_frame_token_num, :], x_mod[:, :frist_frame_token_num, :], gate=tr_mod_gate)
apply_gate_and_accumulate_(img[:, frist_frame_token_num:, :], x_mod[:, frist_frame_token_num:-txt_len, :], gate=mod_gate)
else:
apply_gate_and_accumulate_(img, x_mod[:, :-txt_len, :], gate=mod_gate)
apply_gate_and_accumulate_(txt, x_mod[:, -txt_len:, :], gate=mod_gate)
return img, txt
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
def preprocess_loras(self, model_type, sd):
if model_type != "i2v" :
return sd
new_sd = {}
for k,v in sd.items():
repl_list = ["double_blocks", "single_blocks", "final_layer", "img_mlp", "img_attn_qkv", "img_attn_proj","img_mod", "txt_mlp", "txt_attn_qkv","txt_attn_proj", "txt_mod", "linear1",
"linear2", "modulation", "mlp_fc1"]
src_list = [k +"_" for k in repl_list] + ["_" + k for k in repl_list]
tgt_list = [k +"." for k in repl_list] + ["." + k for k in repl_list]
if k.startswith("Hunyuan_video_I2V_lora_"):
# crappy conversion script for non reversible lora naming
k = k.replace("Hunyuan_video_I2V_lora_","diffusion_model.")
k = k.replace("lora_up","lora_B")
k = k.replace("lora_down","lora_A")
if "txt_in_individual" in k:
pass
for s,t in zip(src_list, tgt_list):
k = k.replace(s,t)
if "individual_token_refiner" in k:
k = k.replace("txt_in_individual_token_refiner_blocks_", "txt_in.individual_token_refiner.blocks.")
k = k.replace("_mlp_fc", ".mlp.fc",)
k = k.replace(".mlp_fc", ".mlp.fc",)
new_sd[k] = v
return new_sd
"""
HunyuanVideo Transformer backbone
Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
Reference:
[1] Flux.1: https://github.com/black-forest-labs/flux
[2] MMDiT: http://arxiv.org/abs/2403.03206
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
patch_size: list
The size of the patch.
in_channels: int
The number of input channels.
out_channels: int
The number of output channels.
hidden_size: int
The hidden size of the transformer backbone.
heads_num: int
The number of attention heads.
mlp_width_ratio: float
The ratio of the hidden size of the MLP in the transformer block.
mlp_act_type: str
The activation function of the MLP in the transformer block.
depth_double_blocks: int
The number of transformer blocks in the double blocks.
depth_single_blocks: int
The number of transformer blocks in the single blocks.
rope_dim_list: list
The dimension of the rotary embedding for t, h, w.
qkv_bias: bool
Whether to use bias in the qkv linear layer.
qk_norm: bool
Whether to use qk norm.
qk_norm_type: str
The type of qk norm.
guidance_embed: bool
Whether to use guidance embedding for distillation.
text_projection: str
The type of the text projection, default is single_refiner.
use_attention_mask: bool
Whether to use attention mask for text encoder.
dtype: torch.dtype
The dtype of the model.
device: torch.device
The device of the model.
"""
@register_to_config
def __init__(
self,
i2v_condition_type,
patch_size: list = [1, 2, 2],
in_channels: int = 4, # Should be VAE.config.latent_channels.
out_channels: int = None,
hidden_size: int = 3072,
heads_num: int = 24,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = "gelu_tanh",
mm_double_blocks_depth: int = 20,
mm_single_blocks_depth: int = 40,
rope_dim_list: List[int] = [16, 56, 56],
qkv_bias: bool = True,
qk_norm: bool = True,
qk_norm_type: str = "rms",
guidance_embed: bool = False, # For modulation.
text_projection: str = "single_refiner",
use_attention_mask: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
attention_mode: Optional[str] = "sdpa",
video_condition: bool = False,
audio_condition: bool = False,
avatar = False,
custom = False,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
# mm_double_blocks_depth , mm_single_blocks_depth = 5, 5
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.unpatchify_channels = self.out_channels
self.guidance_embed = guidance_embed
self.rope_dim_list = rope_dim_list
self.i2v_condition_type = i2v_condition_type
self.attention_mode = attention_mode
self.video_condition = video_condition
self.audio_condition = audio_condition
self.avatar = avatar
self.custom = custom
# Text projection. Default to linear projection.
# Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
self.use_attention_mask = use_attention_mask
self.text_projection = text_projection
self.text_states_dim = 4096
self.text_states_dim_2 = 768
if hidden_size % heads_num != 0:
raise ValueError(
f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}"
)
pe_dim = hidden_size // heads_num
if sum(rope_dim_list) != pe_dim:
raise ValueError(
f"Got {rope_dim_list} but expected positional dim {pe_dim}"
)
self.hidden_size = hidden_size
self.heads_num = heads_num
# image projection
self.img_in = PatchEmbed(
self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
)
# text projection
if self.text_projection == "linear":
self.txt_in = TextProjection(
self.text_states_dim,
self.hidden_size,
get_activation_layer("silu"),
**factory_kwargs,
)
elif self.text_projection == "single_refiner":
self.txt_in = SingleTokenRefiner(
self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs
)
else:
raise NotImplementedError(
f"Unsupported text_projection: {self.text_projection}"
)
# time modulation
self.time_in = TimestepEmbedder(
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
)
# text modulation
self.vector_in = MLPEmbedder(
self.text_states_dim_2, self.hidden_size, **factory_kwargs
)
# guidance modulation
self.guidance_in = (
TimestepEmbedder(
self.hidden_size, get_activation_layer("silu"), **factory_kwargs
)
if guidance_embed
else None
)
# double blocks
self.double_blocks = nn.ModuleList(
[
MMDoubleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
attention_mode = attention_mode,
**factory_kwargs,
)
for _ in range(mm_double_blocks_depth)
]
)
# single blocks
self.single_blocks = nn.ModuleList(
[
MMSingleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
attention_mode = attention_mode,
**factory_kwargs,
)
for _ in range(mm_single_blocks_depth)
]
)
self.final_layer = FinalLayer(
self.hidden_size,
self.patch_size,
self.out_channels,
get_activation_layer("silu"),
**factory_kwargs,
)
if self.video_condition:
self.bg_in = PatchEmbed(
self.patch_size, self.in_channels * 2, self.hidden_size, **factory_kwargs
)
self.bg_proj = nn.Linear(self.hidden_size, self.hidden_size)
if audio_condition:
if avatar:
self.ref_in = PatchEmbed(
self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
)
# -------------------- audio_proj_model --------------------
self.audio_proj = AudioProjNet2(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=3072, context_tokens=4)
# -------------------- motion-embeder --------------------
self.motion_exp = TimestepEmbedder(
self.hidden_size // 4,
get_activation_layer("silu"),
**factory_kwargs
)
self.motion_pose = TimestepEmbedder(
self.hidden_size // 4,
get_activation_layer("silu"),
**factory_kwargs
)
self.fps_proj = TimestepEmbedder(
self.hidden_size,
get_activation_layer("silu"),
**factory_kwargs
)
self.before_proj = nn.Linear(self.hidden_size, self.hidden_size)
# -------------------- audio_insert_model --------------------
self.double_stream_list = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
audio_block_name = "audio_adapter_blocks"
elif custom:
self.audio_proj = AudioProjNet2(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=3072, context_tokens=4)
self.double_stream_list = [1, 3, 5, 7, 9, 11]
audio_block_name = "audio_models"
self.double_stream_map = {str(i): j for j, i in enumerate(self.double_stream_list)}
self.single_stream_list = []
self.single_stream_map = {str(i): j+len(self.double_stream_list) for j, i in enumerate(self.single_stream_list)}
setattr(self, audio_block_name, nn.ModuleList([
PerceiverAttentionCA(dim=3072, dim_head=1024, heads=33) for _ in range(len(self.double_stream_list) + len(self.single_stream_list))
]))
def lock_layers_dtypes(self, dtype = torch.float32):
layer_list = [self.final_layer, self.final_layer.linear, self.final_layer.adaLN_modulation[1]]
target_dype= dtype
for current_layer_list, current_dtype in zip([layer_list], [target_dype]):
for layer in current_layer_list:
layer._lock_dtype = dtype
if hasattr(layer, "weight") and layer.weight.dtype != current_dtype :
layer.weight.data = layer.weight.data.to(current_dtype)
if hasattr(layer, "bias"):
layer.bias.data = layer.bias.data.to(current_dtype)
self._lock_dtype = dtype
def enable_deterministic(self):
for block in self.double_blocks:
block.enable_deterministic()
for block in self.single_blocks:
block.enable_deterministic()
def disable_deterministic(self):
for block in self.double_blocks:
block.disable_deterministic()
for block in self.single_blocks:
block.disable_deterministic()
def forward(
self,
x: torch.Tensor,
t: torch.Tensor, # Should be in range(0, 1000).
ref_latents: torch.Tensor=None,
text_states: torch.Tensor = None,
text_mask: torch.Tensor = None, # Now we don't use it.
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
freqs_cos: Optional[torch.Tensor] = None,
freqs_sin: Optional[torch.Tensor] = None,
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
pipeline=None,
x_id = 0,
step_no = 0,
callback = None,
audio_prompts = None,
motion_exp = None,
motion_pose = None,
fps = None,
face_mask = None,
audio_strength = None,
bg_latents = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
img = x
bsz, _, ot, oh, ow = x.shape
del x
txt = text_states
tt, th, tw = (
ot // self.patch_size[0],
oh // self.patch_size[1],
ow // self.patch_size[2],
)
# Prepare modulation vectors.
vec = self.time_in(t)
if motion_exp != None:
vec += self.motion_exp(motion_exp.view(-1)).view(bsz, -1) # (b, 3072)
if motion_pose != None:
vec += self.motion_pose(motion_pose.view(-1)).view(bsz, -1) # (b, 3072)
if fps != None:
vec += self.fps_proj(fps) # (b, 3072)
if audio_prompts != None:
audio_feature_all = self.audio_proj(audio_prompts)
audio_feature_pad = audio_feature_all[:,:1].repeat(1,3,1,1)
audio_feature_all_insert = torch.cat([audio_feature_pad, audio_feature_all], dim=1).view(bsz, ot, 16, 3072)
audio_feature_all = None
if self.i2v_condition_type == "token_replace":
token_replace_t = torch.zeros_like(t)
token_replace_vec = self.time_in(token_replace_t)
frist_frame_token_num = th * tw
else:
token_replace_vec = None
frist_frame_token_num = None
# token_replace_mask_img = None
# token_replace_mask_txt = None
# text modulation
vec_2 = self.vector_in(text_states_2)
del text_states_2
vec += vec_2
if self.i2v_condition_type == "token_replace":
token_replace_vec += vec_2
del vec_2
# guidance modulation
if self.guidance_embed:
if guidance is None:
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
vec += self.guidance_in(guidance)
# Embed image and text.
img, shape_mask = self.img_in(img)
if self.avatar:
ref_latents_first = ref_latents[:, :, :1].clone()
ref_latents,_ = self.ref_in(ref_latents)
ref_latents_first,_ = self.img_in(ref_latents_first)
elif self.custom:
if ref_latents != None:
ref_latents, _ = self.img_in(ref_latents)
if bg_latents is not None and self.video_condition:
bg_latents, _ = self.bg_in(bg_latents)
img += self.bg_proj(bg_latents)
if self.text_projection == "linear":
txt = self.txt_in(txt)
elif self.text_projection == "single_refiner":
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
else:
raise NotImplementedError(
f"Unsupported text_projection: {self.text_projection}"
)
if self.avatar:
img += self.before_proj(ref_latents)
ref_length = ref_latents_first.shape[-2] # [b s c]
img = torch.cat([ref_latents_first, img], dim=-2) # t c
img_len = img.shape[1]
mask_len = img_len - ref_length
if face_mask.shape[2] == 1:
face_mask = face_mask.repeat(1,1,ot,1,1) # repeat if number of mask frame is 1
face_mask = torch.nn.functional.interpolate(face_mask, size=[ot, shape_mask[-2], shape_mask[-1]], mode="nearest")
# face_mask = face_mask.view(-1,mask_len,1).repeat(1,1,img.shape[-1]).type_as(img)
face_mask = face_mask.view(-1,mask_len,1).type_as(img)
elif ref_latents == None:
ref_length = None
else:
ref_length = ref_latents.shape[-2]
img = torch.cat([ref_latents, img], dim=-2) # t c
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
text_len = text_mask.sum(1)
total_len = text_len + img_seq_len
seqlens_q = seqlens_kv = total_len
attn_mask = None
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
if self.enable_cache:
if x_id == 0:
self.should_calc = True
inp = img[0:1]
vec_ = vec[0:1]
( img_mod1_shift, img_mod1_scale, _ , _ , _ , _ , ) = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1)
normed_inp = self.double_blocks[0].img_norm1(inp)
normed_inp = normed_inp.to(torch.bfloat16)
modulated_inp = modulate( normed_inp, shift=img_mod1_shift, scale=img_mod1_scale )
del normed_inp, img_mod1_shift, img_mod1_scale
if step_no <= self.cache_start_step or step_no == self.num_steps-1:
self.accumulated_rel_l1_distance = 0
else:
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
self.should_calc = False
self.teacache_skipped_steps += 1
else:
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
else:
self.should_calc = True
if not self.should_calc:
img += self.previous_residual[x_id]
else:
if self.enable_cache:
self.previous_residual[x_id] = None
ori_img = img[0:1].clone()
# --------------------- Pass through DiT blocks ------------------------
for layer_num, block in enumerate(self.double_blocks):
for i in range(len(img)):
if callback != None:
callback(-1, None, False, True)
if pipeline._interrupt:
return None
double_block_args = [
img[i:i+1],
txt[i:i+1],
vec[i:i+1],
attn_mask,
seqlens_q[i:i+1],
seqlens_kv[i:i+1],
freqs_cis,
self.i2v_condition_type,
token_replace_vec,
frist_frame_token_num,
]
img[i], txt[i] = block(*double_block_args)
double_block_args = None
# insert audio feature to img
if audio_prompts != None:
audio_adapter = getattr(self.double_blocks[layer_num], "audio_adapter", None)
if audio_adapter != None:
real_img = img[i:i+1,ref_length:].view(1, ot, -1, 3072)
real_img = audio_adapter(audio_feature_all_insert[i:i+1], real_img).view(1, -1, 3072)
if face_mask != None:
real_img *= face_mask[i:i+1]
if audio_strength != None and audio_strength != 1:
real_img *= audio_strength
img[i:i+1, ref_length:] += real_img
real_img = None
for _, block in enumerate(self.single_blocks):
for i in range(len(img)):
if callback != None:
callback(-1, None, False, True)
if pipeline._interrupt:
return None
single_block_args = [
# x,
img[i:i+1],
txt[i:i+1],
vec[i:i+1],
txt_seq_len,
attn_mask,
seqlens_q[i:i+1],
seqlens_kv[i:i+1],
(freqs_cos, freqs_sin),
self.i2v_condition_type,
token_replace_vec,
frist_frame_token_num,
]
img[i], txt[i] = block(*single_block_args)
single_block_args = None
# img = x[:, :img_seq_len, ...]
if self.enable_cache:
if len(img) > 1:
self.previous_residual[0] = torch.empty_like(img)
for i, (x, residual) in enumerate(zip(img, self.previous_residual[0])):
if i < len(img) - 1:
residual[...] = torch.sub(x, ori_img)
else:
residual[...] = ori_img
torch.sub(x, ori_img, out=residual)
x = None
else:
self.previous_residual[x_id] = ori_img
torch.sub(img, ori_img, out=self.previous_residual[x_id])
if ref_length != None:
img = img[:, ref_length:]
# ---------------------------- Final layer ------------------------------
out_dtype = self.final_layer.linear.weight.dtype
vec = vec.to(out_dtype)
img_list = []
for img_chunk, vec_chunk in zip(img,vec):
img_list.append( self.final_layer(img_chunk.to(out_dtype).unsqueeze(0), vec_chunk.unsqueeze(0))) # (N, T, patch_size ** 2 * out_channels)
img = torch.cat(img_list)
img_list = None
# img = self.unpatchify(img, tt, th, tw)
img = self.unpatchify(img, tt, th, tw)
return img
def unpatchify(self, x, t, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.unpatchify_channels
pt, ph, pw = self.patch_size
assert t * h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
x = torch.einsum("nthwcopq->nctohpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def params_count(self):
counts = {
"double": sum(
[
sum(p.numel() for p in block.img_attn_qkv.parameters())
+ sum(p.numel() for p in block.img_attn_proj.parameters())
+ sum(p.numel() for p in block.img_mlp.parameters())
+ sum(p.numel() for p in block.txt_attn_qkv.parameters())
+ sum(p.numel() for p in block.txt_attn_proj.parameters())
+ sum(p.numel() for p in block.txt_mlp.parameters())
for block in self.double_blocks
]
),
"single": sum(
[
sum(p.numel() for p in block.linear1.parameters())
+ sum(p.numel() for p in block.linear2.parameters())
for block in self.single_blocks
]
),
"total": sum(p.numel() for p in self.parameters()),
}
counts["attn+mlp"] = counts["double"] + counts["single"]
return counts
#################################################################################
# HunyuanVideo Configs #
#################################################################################
HUNYUAN_VIDEO_CONFIG = {
"HYVideo-T/2": {
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
},
"HYVideo-T/2-cfgdistill": {
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
"guidance_embed": True,
},
"HYVideo-S/2": {
"mm_double_blocks_depth": 6,
"mm_single_blocks_depth": 12,
"rope_dim_list": [12, 42, 42],
"hidden_size": 480,
"heads_num": 5,
"mlp_width_ratio": 4,
},
'HYVideo-T/2-custom': { # 9.0B / 12.5B
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
'custom' : True
},
'HYVideo-T/2-custom-audio': { # 9.0B / 12.5B
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
'custom' : True,
'audio_condition' : True,
},
'HYVideo-T/2-custom-edit': { # 9.0B / 12.5B
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
'custom' : True,
'video_condition' : True,
},
'HYVideo-T/2-avatar': { # 9.0B / 12.5B
'mm_double_blocks_depth': 20,
'mm_single_blocks_depth': 40,
'rope_dim_list': [16, 56, 56],
'hidden_size': 3072,
'heads_num': 24,
'mlp_width_ratio': 4,
'avatar': True,
'audio_condition' : True,
},
}