wan2-1-fast / optimization.py
cbensimon's picture
cbensimon HF Staff
Cleanup
1017ac0
raw
history blame
4.01 kB
"""
"""
from datetime import datetime
from typing import Any
from typing import Callable
from typing import ParamSpec
import spaces
import torch
from torch.utils._pytree import tree_map_only
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
from optimization_utils import capture_component_call
from optimization_utils import aoti_compile
P = ParamSpec('P')
TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)
TRANSFORMER_DYNAMIC_SHAPES = {
'hidden_states': {
2: TRANSFORMER_NUM_FRAMES_DIM,
},
}
INDUCTOR_CONFIGS = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
t0 = datetime.now()
@spaces.GPU(duration=1500)
def compile_transformer():
nonlocal t0
print('compile_transformer', -(t0 - (t0 := datetime.now())))
with capture_component_call(pipeline, 'transformer') as call:
pipeline(*args, **kwargs)
print('capture_component_call', -(t0 - (t0 := datetime.now())))
dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
print('quantize_', -(t0 - (t0 := datetime.now())))
hidden_states: torch.Tensor = call.kwargs['hidden_states']
hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
if hidden_states.shape[-1] > hidden_states.shape[-2]:
hidden_states_landscape = hidden_states
hidden_states_portrait = hidden_states_transposed
else:
hidden_states_landscape = hidden_states_transposed
hidden_states_portrait = hidden_states
exported_landscape = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs=call.kwargs | {'hidden_states': hidden_states_landscape},
dynamic_shapes=dynamic_shapes,
)
print('exported_landscape', -(t0 - (t0 := datetime.now())))
exported_portrait = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs=call.kwargs | {'hidden_states': hidden_states_portrait},
dynamic_shapes=dynamic_shapes,
)
print('exported_portrait', -(t0 - (t0 := datetime.now())))
compiled_landscape = aoti_compile(exported_landscape, INDUCTOR_CONFIGS)
print('compiled_landscape', -(t0 - (t0 := datetime.now())))
compiled_portrait = aoti_compile(exported_portrait, INDUCTOR_CONFIGS)
print('compiled_portrait', -(t0 - (t0 := datetime.now())))
# Avoid weights duplication when serializing back to main process
compiled_portrait.weights = compiled_landscape.weights
return compiled_landscape, compiled_portrait
compiled_landscape, compiled_portrait = compile_transformer()
print('compiled', -(t0 - (t0 := datetime.now())))
compiled_portrait.weights = compiled_landscape.weights
def combined_transformer(*args, **kwargs):
hidden_states: torch.Tensor = kwargs['hidden_states']
if hidden_states.shape[-1] > hidden_states.shape[-2]:
return compiled_landscape(*args, **kwargs)
else:
return compiled_portrait(*args, **kwargs)
transformer_config = pipeline.transformer.config
transformer_dtype = pipeline.transformer.dtype
pipeline.transformer = combined_transformer
pipeline.transformer.config = transformer_config # pyright: ignore[reportAttributeAccessIssue]
pipeline.transformer.dtype = transformer_dtype # pyright: ignore[reportAttributeAccessIssue]