File size: 4,583 Bytes
9f3ed26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471b188
 
9f3ed26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135

import os
import gc
import torch
from PIL.Image import Image
from dataclasses import dataclass
from diffusers import DiffusionPipeline, AutoencoderTiny, FluxTransformer2DModel
from transformers import T5EncoderModel
from huggingface_hub.constants import HF_HUB_CACHE
from torchao.quantization import quantize_, int8_weight_only, float8_weight_only
from caching import apply_cache_on_pipe
from pipelines.models import TextToImageRequest
from torch import Generator

# Configuration settings using a dataclass for clarity
@dataclass
class Config:
    CKPT_ID: str = "black-forest-labs/FLUX.1-schnell"
    CKPT_REVISION: str = "741f7c3ce8b383c54771c7003378a50191e9efe9"
    DEVICE: str = "cuda"
    DTYPE = torch.bfloat16
    PYTORCH_CUDA_ALLOC_CONF: str = "expandable_segments:True"

def _initialize_environment():
    """Set up PyTorch and CUDA environment variables for optimal performance."""
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = Config.PYTORCH_CUDA_ALLOC_CONF

def _clear_gpu_memory():
    """Free up GPU memory to prevent memory-related issues."""
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_max_memory_allocated()
    torch.cuda.reset_peak_memory_stats()

def _load_text_encoder_model():
    """Load the text encoder model with specified configuration."""
    return T5EncoderModel.from_pretrained(
        "city96/t5-v1_1-xxl-encoder-bf16",
        revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
        torch_dtype=Config.DTYPE
    ).to(memory_format=torch.channels_last)

def _load_vae_model():
    """Load the variational autoencoder (VAE) model with specified configuration."""
    return AutoencoderTiny.from_pretrained(
        "manbeast3b/FLUX.1-schnell-taef1-float8",
        revision="7c538d53ec698509788ed88b1305c6bb019bdb4d",
        torch_dtype=Config.DTYPE
    )

def _load_transformer_model():
    """Load the transformer model from a specific cached path."""
    transformer_path = os.path.join(
        HF_HUB_CACHE,
        "models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146", 
        "transformer"
    )
    return FluxTransformer2DModel.from_pretrained(
        transformer_path,
        torch_dtype=Config.DTYPE,
        use_safetensors=False
    ).to(memory_format=torch.channels_last)

def _warmup_pipeline(pipeline):
    """Warm up the pipeline by running it with an empty prompt to initialize internal caches."""
    for _ in range(3):
        pipeline(prompt=" ")

def load_pipeline():
    """
    Load and configure the diffusion pipeline for text-to-image generation.

    Returns:
        DiffusionPipeline: The configured pipeline ready for inference.
    """
    _clear_gpu_memory()

    # Load individual components
    text_encoder = _load_text_encoder_model()
    vae = _load_vae_model()
    transformer = _load_transformer_model()

    # Assemble the diffusion pipeline
    pipeline = DiffusionPipeline.from_pretrained(
        Config.CKPT_ID,
        vae=vae,
        revision=Config.CKPT_REVISION,
        transformer=transformer,
        text_encoder_2=text_encoder,
        torch_dtype=Config.DTYPE,
    ).to(Config.DEVICE)

    # Apply optimizations
    apply_cache_on_pipe(pipeline)
    pipeline.to(memory_format=torch.channels_last)
    pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
    quantize_(pipeline.vae, int8_weight_only())
    quantize_(pipeline.vae, float8_weight_only())

    # Warm up the pipeline to ensure readiness
    _warmup_pipeline(pipeline)

    return pipeline

@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: DiffusionPipeline, generator: Generator) -> Image:
    """
    Generate an image from a text prompt using the diffusion pipeline.

    Args:
        request (TextToImageRequest): The request containing the prompt and image parameters.
        pipeline (DiffusionPipeline): The pre-loaded diffusion pipeline.
        generator (Generator): The random seed generator for reproducibility.

    Returns:
        Image: The generated image in PIL format.
    """
    image = pipeline(
        prompt=request.prompt,
        generator=generator,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        height=request.height,
        width=request.width,
        output_type="pil"
    ).images[0]
    return image

# Initialize environment settings when the module is imported
_initialize_environment()
load = load_pipeline