""" """ import spaces import torch from diffusers.pipelines.flux.pipeline_flux import FluxPipeline from torchao.quantization import quantize_ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig from zerogpu import aoti_compile def _example_tensor(*shape): return torch.randn(*shape, device='cuda', dtype=torch.bfloat16) def optimize_pipeline_(pipeline: FluxPipeline): is_timestep_distilled = not pipeline.transformer.config.guidance_embeds seq_length = 256 if is_timestep_distilled else 512 transformer_kwargs = { 'hidden_states': _example_tensor(1, 4096, 64), 'timestep': torch.tensor([1.], device='cuda', dtype=torch.bfloat16), 'guidance': None if is_timestep_distilled else torch.tensor([1.], device='cuda', dtype=torch.bfloat16), 'pooled_projections': _example_tensor(1, 768), 'encoder_hidden_states': _example_tensor(1, seq_length, 4096), 'txt_ids': _example_tensor(seq_length, 3), 'img_ids': _example_tensor(4096, 3), 'joint_attention_kwargs': {}, 'return_dict': False, } inductor_configs = { 'conv_1x1_as_mm': True, 'epilogue_fusion': False, 'coordinate_descent_tuning': True, 'coordinate_descent_check_all_directions': True, 'max_autotune': True, 'triton.cudagraphs': True, } @spaces.GPU(duration=1500) def compile_transformer(): pipeline.transformer.fuse_qkv_projections() quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig()) exported = torch.export.export(pipeline.transformer, args=(), kwargs=transformer_kwargs) return aoti_compile(exported, inductor_configs) transformer_config = pipeline.transformer.config pipeline.transformer = compile_transformer() pipeline.transformer.config = transformer_config