File size: 4,583 Bytes
9f3ed26 471b188 9f3ed26 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import os
import gc
import torch
from PIL.Image import Image
from dataclasses import dataclass
from diffusers import DiffusionPipeline, AutoencoderTiny, FluxTransformer2DModel
from transformers import T5EncoderModel
from huggingface_hub.constants import HF_HUB_CACHE
from torchao.quantization import quantize_, int8_weight_only, float8_weight_only
from caching import apply_cache_on_pipe
from pipelines.models import TextToImageRequest
from torch import Generator
# Configuration settings using a dataclass for clarity
@dataclass
class Config:
CKPT_ID: str = "black-forest-labs/FLUX.1-schnell"
CKPT_REVISION: str = "741f7c3ce8b383c54771c7003378a50191e9efe9"
DEVICE: str = "cuda"
DTYPE = torch.bfloat16
PYTORCH_CUDA_ALLOC_CONF: str = "expandable_segments:True"
def _initialize_environment():
"""Set up PyTorch and CUDA environment variables for optimal performance."""
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = Config.PYTORCH_CUDA_ALLOC_CONF
def _clear_gpu_memory():
"""Free up GPU memory to prevent memory-related issues."""
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def _load_text_encoder_model():
"""Load the text encoder model with specified configuration."""
return T5EncoderModel.from_pretrained(
"city96/t5-v1_1-xxl-encoder-bf16",
revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
torch_dtype=Config.DTYPE
).to(memory_format=torch.channels_last)
def _load_vae_model():
"""Load the variational autoencoder (VAE) model with specified configuration."""
return AutoencoderTiny.from_pretrained(
"manbeast3b/FLUX.1-schnell-taef1-float8",
revision="7c538d53ec698509788ed88b1305c6bb019bdb4d",
torch_dtype=Config.DTYPE
)
def _load_transformer_model():
"""Load the transformer model from a specific cached path."""
transformer_path = os.path.join(
HF_HUB_CACHE,
"models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146",
"transformer"
)
return FluxTransformer2DModel.from_pretrained(
transformer_path,
torch_dtype=Config.DTYPE,
use_safetensors=False
).to(memory_format=torch.channels_last)
def _warmup_pipeline(pipeline):
"""Warm up the pipeline by running it with an empty prompt to initialize internal caches."""
for _ in range(3):
pipeline(prompt=" ")
def load_pipeline():
"""
Load and configure the diffusion pipeline for text-to-image generation.
Returns:
DiffusionPipeline: The configured pipeline ready for inference.
"""
_clear_gpu_memory()
# Load individual components
text_encoder = _load_text_encoder_model()
vae = _load_vae_model()
transformer = _load_transformer_model()
# Assemble the diffusion pipeline
pipeline = DiffusionPipeline.from_pretrained(
Config.CKPT_ID,
vae=vae,
revision=Config.CKPT_REVISION,
transformer=transformer,
text_encoder_2=text_encoder,
torch_dtype=Config.DTYPE,
).to(Config.DEVICE)
# Apply optimizations
apply_cache_on_pipe(pipeline)
pipeline.to(memory_format=torch.channels_last)
pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
quantize_(pipeline.vae, int8_weight_only())
quantize_(pipeline.vae, float8_weight_only())
# Warm up the pipeline to ensure readiness
_warmup_pipeline(pipeline)
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: DiffusionPipeline, generator: Generator) -> Image:
"""
Generate an image from a text prompt using the diffusion pipeline.
Args:
request (TextToImageRequest): The request containing the prompt and image parameters.
pipeline (DiffusionPipeline): The pre-loaded diffusion pipeline.
generator (Generator): The random seed generator for reproducibility.
Returns:
Image: The generated image in PIL format.
"""
image = pipeline(
prompt=request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
output_type="pil"
).images[0]
return image
# Initialize environment settings when the module is imported
_initialize_environment()
load = load_pipeline |