fire_stang4 / src /pipeline.py
manbeast3b's picture
Update src/pipeline.py
141c336 verified
import os
import gc
import torch
from PIL.Image import Image
from dataclasses import dataclass
from diffusers import DiffusionPipeline, AutoencoderTiny, FluxTransformer2DModel
from transformers import T5EncoderModel
from huggingface_hub.constants import HF_HUB_CACHE
from torchao.quantization import quantize_, int8_weight_only, float8_weight_only
from caching import apply_cache_on_pipe
from pipelines.models import TextToImageRequest
from torch import Generator
# Configuration settings using a dataclass for clarity
@dataclass
class Config:
CKPT_ID: str = "black-forest-labs/FLUX.1-schnell"
CKPT_REVISION: str = "741f7c3ce8b383c54771c7003378a50191e9efe9"
DEVICE: str = "cuda"
DTYPE = torch.bfloat16
PYTORCH_CUDA_ALLOC_CONF: str = "expandable_segments:True"
def _initialize_environment():
"""Set up PyTorch and CUDA environment variables for optimal performance."""
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = Config.PYTORCH_CUDA_ALLOC_CONF
def _clear_gpu_memory():
"""Free up GPU memory to prevent memory-related issues."""
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def _load_text_encoder_model():
"""Load the text encoder model with specified configuration."""
return T5EncoderModel.from_pretrained(
"city96/t5-v1_1-xxl-encoder-bf16",
revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
torch_dtype=Config.DTYPE
).to(memory_format=torch.channels_last)
def _load_vae_model():
"""Load the variational autoencoder (VAE) model with specified configuration."""
return AutoencoderTiny.from_pretrained(
"manbeast3b/FLUX.1-schnell-taef1-float8",
revision="7c538d53ec698509788ed88b1305c6bb019bdb4d",
torch_dtype=Config.DTYPE
)
def _load_transformer_model():
"""Load the transformer model from a specific cached path."""
transformer_path = os.path.join(
HF_HUB_CACHE,
"models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146",
"transformer"
)
return FluxTransformer2DModel.from_pretrained(
transformer_path,
torch_dtype=Config.DTYPE,
use_safetensors=False
).to(memory_format=torch.channels_last)
def _warmup_pipeline(pipeline):
"""Warm up the pipeline by running it with an empty prompt to initialize internal caches."""
for _ in range(3):
pipeline(prompt=" ")
def load_pipeline():
"""
Load and configure the diffusion pipeline for text-to-image generation.
Returns:
DiffusionPipeline: The configured pipeline ready for inference.
"""
_clear_gpu_memory()
# Load individual components
text_encoder = _load_text_encoder_model()
vae = _load_vae_model()
transformer = _load_transformer_model()
# Assemble the diffusion pipeline
pipeline = DiffusionPipeline.from_pretrained(
Config.CKPT_ID,
vae=vae,
revision=Config.CKPT_REVISION,
transformer=transformer,
text_encoder_2=text_encoder,
torch_dtype=Config.DTYPE,
).to(Config.DEVICE)
# Apply optimizations
apply_cache_on_pipe(pipeline)
pipeline.to(memory_format=torch.channels_last)
pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
quantize_(pipeline.vae, int8_weight_only())
quantize_(pipeline.vae, float8_weight_only())
# Warm up the pipeline to ensure readiness
_warmup_pipeline(pipeline)
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: DiffusionPipeline, generator: Generator) -> Image:
"""
Generate an image from a text prompt using the diffusion pipeline.
Args:
request (TextToImageRequest): The request containing the prompt and image parameters.
pipeline (DiffusionPipeline): The pre-loaded diffusion pipeline.
generator (Generator): The random seed generator for reproducibility.
Returns:
Image: The generated image in PIL format.
"""
image = pipeline(
prompt=request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
output_type="pil"
).images[0]
return image
# Initialize environment settings when the module is imported
_initialize_environment()
load = load_pipeline