fire_stang4 / src /pipeline.py
manbeast3b's picture
Update src/pipeline.py
471b188 verified
raw
history blame
4.83 kB
import os
import gc
import torch
from PIL.Image import Image
from dataclasses import dataclass
from diffusers import DiffusionPipeline, AutoencoderTiny, FluxTransformer2DModel
from transformers import T5EncoderModel
from huggingface_hub.constants import HF_HUB_CACHE
from torchao.quantization import quantize_, int8_weight_only, float8_weight_only
from caching import apply_cache_on_pipe
from pipelines.models import TextToImageRequest
from torch import Generator
# Configuration settings using a dataclass for clarity
@dataclass
class Config:
CKPT_ID: str = "black-forest-labs/FLUX.1-schnell"
CKPT_REVISION: str = "741f7c3ce8b383c54771c7003378a50191e9efe9"
DEVICE: str = "cuda"
DTYPE = torch.bfloat16
PYTORCH_CUDA_ALLOC_CONF: str = "expandable_segments:True"
def _initialize_environment():
"""Set up PyTorch and CUDA environment variables for optimal performance."""
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = Config.PYTORCH_CUDA_ALLOC_CONF
def _clear_gpu_memory():
"""Free up GPU memory to prevent memory-related issues."""
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def _load_text_encoder_model():
"""Load the text encoder model with specified configuration."""
return T5EncoderModel.from_pretrained(
"city96/t5-v1_1-xxl-encoder-bf16",
revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
torch_dtype=Config.DTYPE
).to(memory_format=torch.channels_last)
def _load_vae_model():
"""Load the variational autoencoder (VAE) model with specified configuration."""
return AutoencoderTiny.from_pretrained(
"manbeast3b/FLUX.1-schnell-taef1-float8",
revision="7c538d53ec698509788ed88b1305c6bb019bdb4d",
torch_dtype=Config.DTYPE
)
def _load_transformer_model():
"""Load the transformer model from a specific cached path."""
# transformer_path = os.path.join(
# HF_HUB_CACHE,"models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146",
# )
transformer_path = os.path.join(
HF_HUB_CACHE,
"models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146",
"transformer"
)
return FluxTransformer2DModel.from_pretrained(
transformer_path,
torch_dtype=Config.DTYPE,
use_safetensors=False
).to(memory_format=torch.channels_last)
def _warmup_pipeline(pipeline):
"""Warm up the pipeline by running it with an empty prompt to initialize internal caches."""
for _ in range(3):
pipeline(prompt=" ")
def load_pipeline():
"""
Load and configure the diffusion pipeline for text-to-image generation.
Returns:
DiffusionPipeline: The configured pipeline ready for inference.
"""
_clear_gpu_memory()
# Load individual components
text_encoder = _load_text_encoder_model()
vae = _load_vae_model()
transformer = _load_transformer_model()
# Assemble the diffusion pipeline
pipeline = DiffusionPipeline.from_pretrained(
Config.CKPT_ID,
vae=vae,
revision=Config.CKPT_REVISION,
transformer=transformer,
text_encoder_2=text_encoder,
torch_dtype=Config.DTYPE,
).to(Config.DEVICE)
# Apply optimizations
apply_cache_on_pipe(pipeline)
pipeline.to(memory_format=torch.channels_last)
pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
quantize_(pipeline.vae, int8_weight_only())
quantize_(pipeline.vae, float8_weight_only())
# Warm up the pipeline to ensure readiness
_warmup_pipeline(pipeline)
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: DiffusionPipeline, generator: Generator) -> Image:
"""
Generate an image from a text prompt using the diffusion pipeline.
Args:
request (TextToImageRequest): The request containing the prompt and image parameters.
pipeline (DiffusionPipeline): The pre-loaded diffusion pipeline.
generator (Generator): The random seed generator for reproducibility.
Returns:
Image: The generated image in PIL format.
"""
image = pipeline(
prompt=request.prompt,
generator=generator,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
height=request.height,
width=request.width,
output_type="pil"
).images[0]
return image
# Initialize environment settings when the module is imported
_initialize_environment()
# For compatibility with other scripts, alias load_pipeline as load
load = load_pipeline