FLUX.1-Kontext-Dev / optimization.py
cbensimon's picture
cbensimon HF Staff
Unclutter a bit
288103a
raw
history blame
1.87 kB
"""
"""
import spaces
import torch
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
from zerogpu import aoti_compile
def _example_tensor(*shape):
return torch.randn(*shape, device='cuda', dtype=torch.bfloat16)
def optimize_pipeline_(pipeline: FluxPipeline):
is_timestep_distilled = not pipeline.transformer.config.guidance_embeds
seq_length = 256 if is_timestep_distilled else 512
transformer_kwargs = {
'hidden_states': _example_tensor(1, 4096, 64),
'timestep': torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
'guidance': None if is_timestep_distilled else torch.tensor([1.], device='cuda', dtype=torch.bfloat16),
'pooled_projections': _example_tensor(1, 768),
'encoder_hidden_states': _example_tensor(1, seq_length, 4096),
'txt_ids': _example_tensor(seq_length, 3),
'img_ids': _example_tensor(4096, 3),
'joint_attention_kwargs': {},
'return_dict': False,
}
inductor_configs = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
@spaces.GPU(duration=1500)
def compile_transformer():
pipeline.transformer.fuse_qkv_projections()
quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
exported = torch.export.export(pipeline.transformer, args=(), kwargs=transformer_kwargs)
return aoti_compile(exported, inductor_configs)
transformer_config = pipeline.transformer.config
pipeline.transformer = compile_transformer()
pipeline.transformer.config = transformer_config