manbeast3b commited on
Commit
9f3ed26
·
0 Parent(s):

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ RobertML.png filter=lfs diff=lfs merge=lfs -text
37
+ backup.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ # speedmax
pyproject.toml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools >= 75.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "flux-schnell-edge-inference"
7
+ description = "An edge-maxxing model submission by RobertML for the 4090 Flux contest"
8
+ requires-python = ">=3.10,<3.13"
9
+ version = "8"
10
+ dependencies = [
11
+ "diffusers==0.31.0",
12
+ "transformers==4.46.2",
13
+ "accelerate==1.1.0",
14
+ "omegaconf==2.3.0",
15
+ "torch==2.6.0",
16
+ "protobuf==5.28.3",
17
+ "sentencepiece==0.2.0",
18
+ "edge-maxxing-pipelines @ git+https://github.com/womboai/edge-maxxing@7c760ac54f6052803dadb3ade8ebfc9679a94589#subdirectory=pipelines",
19
+ "gitpython>=3.1.43",
20
+ "hf_transfer==0.1.8",
21
+ "torchao==0.6.1",
22
+ "setuptools>=75.3.0",
23
+ "para-attn==0.3.15",
24
+ "git-lfs<=1.6"
25
+ ]
26
+
27
+ [[tool.edge-maxxing.models]]
28
+ repository = "black-forest-labs/FLUX.1-schnell"
29
+ revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
30
+ exclude = ["transformer"]
31
+
32
+ [[tool.edge-maxxing.models]]
33
+ repository = "manbeast3b/flux.1-schnell-full1"
34
+ revision = "cb1b599b0d712b9aab2c4df3ad27b050a27ec146"
35
+
36
+
37
+ [[tool.edge-maxxing.models]]
38
+ repository = "city96/t5-v1_1-xxl-encoder-bf16"
39
+ revision = "1b9c856aadb864af93c1dcdc226c2774fa67bc86"
40
+
41
+ [[tool.edge-maxxing.models]]
42
+ repository = "RobertML/FLUX.1-schnell-vae_e3m2"
43
+ revision = "da0d2cd7815792fb40d084dbd8ed32b63f153d8d"
44
+
45
+
46
+ [project.scripts]
47
+ start_inference = "main:main"
48
+
src/__pycache__/main.cpython-311.pyc ADDED
Binary file (4.42 kB). View file
 
src/__pycache__/pipeline.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
src/caching.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # caching.py
2
+
3
+ import functools
4
+ import unittest
5
+ import contextlib
6
+ import dataclasses
7
+ from collections import defaultdict
8
+ from typing import DefaultDict, Dict
9
+ import torch
10
+ from diffusers import DiffusionPipeline, FluxTransformer2DModel
11
+
12
+
13
+ @dataclasses.dataclass
14
+ class CacheContext:
15
+ buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict)
16
+ incremental_name_counters: DefaultDict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int))
17
+
18
+ def get_incremental_name(self, name=None):
19
+ if name is None:
20
+ name = "default"
21
+ idx = self.incremental_name_counters[name]
22
+ self.incremental_name_counters[name] += 1
23
+ return f"{name}_{idx}"
24
+
25
+ def reset_incremental_names(self):
26
+ self.incremental_name_counters.clear()
27
+
28
+ @torch.compiler.disable
29
+ def get_buffer(self, name):
30
+ return self.buffers.get(name)
31
+
32
+ @torch.compiler.disable
33
+ def set_buffer(self, name, buffer):
34
+ self.buffers[name] = buffer
35
+
36
+ def clear_buffers(self):
37
+ self.buffers.clear()
38
+
39
+
40
+ @torch.compiler.disable
41
+ def get_buffer(name):
42
+ cache_context = get_current_cache_context()
43
+ assert cache_context is not None, "cache_context must be set before"
44
+ return cache_context.get_buffer(name)
45
+
46
+
47
+ @torch.compiler.disable
48
+ def set_buffer(name, buffer):
49
+ cache_context = get_current_cache_context()
50
+ assert cache_context is not None, "cache_context must be set before"
51
+ cache_context.set_buffer(name, buffer)
52
+
53
+
54
+ _current_cache_context = None
55
+
56
+
57
+ def create_cache_context():
58
+ return CacheContext()
59
+
60
+
61
+ def get_current_cache_context():
62
+ return _current_cache_context
63
+
64
+
65
+ def set_current_cache_context(cache_context=None):
66
+ global _current_cache_context
67
+ _current_cache_context = cache_context
68
+
69
+
70
+ @contextlib.contextmanager
71
+ def cache_context(cache_context):
72
+ global _current_cache_context
73
+ old_cache_context = _current_cache_context
74
+ _current_cache_context = cache_context
75
+ try:
76
+ yield
77
+ finally:
78
+ _current_cache_context = old_cache_context
79
+
80
+
81
+
82
+ @torch.compiler.disable
83
+ def apply_prev_hidden_states_residual(hidden_states, encoder_hidden_states):
84
+ hidden_states_residual = get_buffer("hidden_states_residual")
85
+ assert hidden_states_residual is not None, "hidden_states_residual must be set before"
86
+ hidden_states = hidden_states_residual + hidden_states
87
+
88
+ encoder_hidden_states_residual = get_buffer("encoder_hidden_states_residual")
89
+ assert encoder_hidden_states_residual is not None, "encoder_hidden_states_residual must be set before"
90
+ encoder_hidden_states = encoder_hidden_states_residual + encoder_hidden_states
91
+
92
+ hidden_states = hidden_states.contiguous()
93
+ encoder_hidden_states = encoder_hidden_states.contiguous()
94
+
95
+ return hidden_states, encoder_hidden_states
96
+
97
+
98
+ def are_two_tensors_similar(t1, t2, *, threshold=0.85):
99
+ mean_diff = (t1 - t2).abs().mean()
100
+ mean_t1 = t1.abs().mean()
101
+ diff = mean_diff / mean_t1
102
+ return diff.item() < threshold
103
+
104
+ @torch.compiler.disable
105
+ def get_can_use_cache(first_hidden_states_residual, threshold, parallelized=False):
106
+ prev_first_hidden_states_residual = get_buffer("first_hidden_states_residual")
107
+ can_use_cache = prev_first_hidden_states_residual is not None and are_two_tensors_similar(
108
+ prev_first_hidden_states_residual,
109
+ first_hidden_states_residual,
110
+ )
111
+ return can_use_cache
112
+
113
+
114
+ class CachedTransformerBlocks(torch.nn.Module):
115
+ def __init__(
116
+ self,
117
+ transformer_blocks,
118
+ single_transformer_blocks=None,
119
+ *,
120
+ transformer=None,
121
+ residual_diff_threshold,
122
+ return_hidden_states_first=True,
123
+ ):
124
+ super().__init__()
125
+ self.transformer = transformer
126
+ self.transformer_blocks = transformer_blocks
127
+ self.single_transformer_blocks = single_transformer_blocks
128
+ self.residual_diff_threshold = residual_diff_threshold
129
+ self.return_hidden_states_first = return_hidden_states_first
130
+
131
+ def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs):
132
+ if self.residual_diff_threshold <= 0.0:
133
+ for block in self.transformer_blocks:
134
+ hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
135
+ if not self.return_hidden_states_first:
136
+ hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
137
+ if self.single_transformer_blocks is not None:
138
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
139
+ for block in self.single_transformer_blocks:
140
+ hidden_states = block(hidden_states, *args, **kwargs)
141
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :]
142
+ return (
143
+ (hidden_states, encoder_hidden_states)
144
+ if self.return_hidden_states_first
145
+ else (encoder_hidden_states, hidden_states)
146
+ )
147
+
148
+ original_hidden_states = hidden_states
149
+ first_transformer_block = self.transformer_blocks[0]
150
+ hidden_states, encoder_hidden_states = first_transformer_block(
151
+ hidden_states, encoder_hidden_states, *args, **kwargs
152
+ )
153
+ if not self.return_hidden_states_first:
154
+ hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
155
+ first_hidden_states_residual = hidden_states - original_hidden_states
156
+ del original_hidden_states
157
+
158
+ can_use_cache = get_can_use_cache(
159
+ first_hidden_states_residual,
160
+ threshold=self.residual_diff_threshold,
161
+ parallelized=self.transformer is not None and getattr(self.transformer, "_is_parallelized", False),
162
+ )
163
+
164
+ torch._dynamo.graph_break()
165
+ if can_use_cache:
166
+ del first_hidden_states_residual
167
+ hidden_states, encoder_hidden_states = apply_prev_hidden_states_residual(
168
+ hidden_states, encoder_hidden_states
169
+ )
170
+ else:
171
+ set_buffer("first_hidden_states_residual", first_hidden_states_residual)
172
+ del first_hidden_states_residual
173
+ (
174
+ hidden_states,
175
+ encoder_hidden_states,
176
+ hidden_states_residual,
177
+ encoder_hidden_states_residual,
178
+ ) = self.call_remaining_transformer_blocks(hidden_states, encoder_hidden_states, *args, **kwargs)
179
+ set_buffer("hidden_states_residual", hidden_states_residual)
180
+ set_buffer("encoder_hidden_states_residual", encoder_hidden_states_residual)
181
+ torch._dynamo.graph_break()
182
+
183
+ return (
184
+ (hidden_states, encoder_hidden_states)
185
+ if self.return_hidden_states_first
186
+ else (encoder_hidden_states, hidden_states)
187
+ )
188
+
189
+ def call_remaining_transformer_blocks(self, hidden_states, encoder_hidden_states, *args, **kwargs):
190
+ original_hidden_states = hidden_states
191
+ original_encoder_hidden_states = encoder_hidden_states
192
+ for block in self.transformer_blocks[1:]:
193
+ hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs)
194
+ if not self.return_hidden_states_first:
195
+ hidden_states, encoder_hidden_states = encoder_hidden_states, hidden_states
196
+ if self.single_transformer_blocks is not None:
197
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
198
+ for block in self.single_transformer_blocks:
199
+ hidden_states = block(hidden_states, *args, **kwargs)
200
+ encoder_hidden_states, hidden_states = hidden_states.split(
201
+ [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
202
+ )
203
+
204
+ # hidden_states_shape = hidden_states.shape
205
+ # encoder_hidden_states_shape = encoder_hidden_states.shape
206
+ hidden_states = hidden_states.reshape(-1).contiguous().reshape(original_hidden_states.shape)
207
+ encoder_hidden_states = (
208
+ encoder_hidden_states.reshape(-1).contiguous().reshape(original_encoder_hidden_states.shape)
209
+ )
210
+
211
+ # hidden_states = hidden_states.contiguous()
212
+ # encoder_hidden_states = encoder_hidden_states.contiguous()
213
+
214
+ hidden_states_residual = hidden_states - original_hidden_states
215
+ encoder_hidden_states_residual = encoder_hidden_states - original_encoder_hidden_states
216
+
217
+ hidden_states_residual = hidden_states_residual.reshape(-1).contiguous().reshape(original_hidden_states.shape)
218
+ encoder_hidden_states_residual = (
219
+ encoder_hidden_states_residual.reshape(-1).contiguous().reshape(original_encoder_hidden_states.shape)
220
+ )
221
+
222
+ return hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual
223
+
224
+
225
+ def apply_cache_on_transformer(
226
+ transformer: FluxTransformer2DModel,
227
+ *,
228
+ residual_diff_threshold=0.1,
229
+ ):
230
+ cached_transformer_blocks = torch.nn.ModuleList(
231
+ [
232
+ CachedTransformerBlocks(
233
+ transformer.transformer_blocks,
234
+ transformer.single_transformer_blocks,
235
+ transformer=transformer,
236
+ residual_diff_threshold=residual_diff_threshold,
237
+ return_hidden_states_first=False,
238
+ )
239
+ ]
240
+ )
241
+ dummy_single_transformer_blocks = torch.nn.ModuleList()
242
+
243
+ original_forward = transformer.forward
244
+
245
+ @functools.wraps(original_forward)
246
+ def new_forward(
247
+ self,
248
+ *args,
249
+ **kwargs,
250
+ ):
251
+ with unittest.mock.patch.object(
252
+ self,
253
+ "transformer_blocks",
254
+ cached_transformer_blocks,
255
+ ), unittest.mock.patch.object(
256
+ self,
257
+ "single_transformer_blocks",
258
+ dummy_single_transformer_blocks,
259
+ ):
260
+ return original_forward(
261
+ *args,
262
+ **kwargs,
263
+ )
264
+
265
+ transformer.forward = new_forward.__get__(transformer)
266
+
267
+ return transformer
268
+
269
+
270
+ def apply_cache_on_pipe(
271
+ pipe: DiffusionPipeline,
272
+ *,
273
+ shallow_patch: bool = False,
274
+ **kwargs,
275
+ ):
276
+ original_call = pipe.__class__.__call__
277
+
278
+ if not getattr(original_call, "_is_cached", False):
279
+
280
+ @functools.wraps(original_call)
281
+ def new_call(self, *args, **kwargs):
282
+ with cache_context(create_cache_context()):
283
+ return original_call(self, *args, **kwargs)
284
+
285
+ pipe.__class__.__call__ = new_call
286
+
287
+ new_call._is_cached = True
288
+
289
+ if not shallow_patch:
290
+ apply_cache_on_transformer(pipe.transformer, **kwargs)
291
+
292
+ pipe._is_cached = True
293
+
294
+ return pipe
src/flux_schnell_edge_inference.egg-info/PKG-INFO ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.2
2
+ Name: flux-schnell-edge-inference
3
+ Version: 8
4
+ Summary: An edge-maxxing model submission by RobertML for the 4090 Flux contest
5
+ Requires-Python: <3.13,>=3.10
6
+ Requires-Dist: diffusers==0.31.0
7
+ Requires-Dist: transformers==4.46.2
8
+ Requires-Dist: accelerate==1.1.0
9
+ Requires-Dist: omegaconf==2.3.0
10
+ Requires-Dist: torch==2.6.0
11
+ Requires-Dist: protobuf==5.28.3
12
+ Requires-Dist: sentencepiece==0.2.0
13
+ Requires-Dist: edge-maxxing-pipelines@ git+https://github.com/womboai/edge-maxxing@7c760ac54f6052803dadb3ade8ebfc9679a94589#subdirectory=pipelines
14
+ Requires-Dist: gitpython>=3.1.43
15
+ Requires-Dist: hf_transfer==0.1.8
16
+ Requires-Dist: torchao==0.6.1
src/flux_schnell_edge_inference.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ src/main.py
4
+ src/pipeline.py
5
+ src/first_block_cache/__init__.py
6
+ src/first_block_cache/utils.py
7
+ src/first_block_cache/diffusers_adapters/__init__.py
8
+ src/first_block_cache/diffusers_adapters/cogvideox.py
9
+ src/first_block_cache/diffusers_adapters/flux.py
10
+ src/first_block_cache/diffusers_adapters/hunyuan_video.py
11
+ src/first_block_cache/diffusers_adapters/mochi.py
12
+ src/flux_schnell_edge_inference.egg-info/PKG-INFO
13
+ src/flux_schnell_edge_inference.egg-info/SOURCES.txt
14
+ src/flux_schnell_edge_inference.egg-info/dependency_links.txt
15
+ src/flux_schnell_edge_inference.egg-info/entry_points.txt
16
+ src/flux_schnell_edge_inference.egg-info/requires.txt
17
+ src/flux_schnell_edge_inference.egg-info/top_level.txt
src/flux_schnell_edge_inference.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
src/flux_schnell_edge_inference.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ start_inference = main:main
src/flux_schnell_edge_inference.egg-info/requires.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.31.0
2
+ transformers==4.46.2
3
+ accelerate==1.1.0
4
+ omegaconf==2.3.0
5
+ torch==2.6.0
6
+ protobuf==5.28.3
7
+ sentencepiece==0.2.0
8
+ edge-maxxing-pipelines@ git+https://github.com/womboai/edge-maxxing@7c760ac54f6052803dadb3ade8ebfc9679a94589#subdirectory=pipelines
9
+ gitpython>=3.1.43
10
+ hf_transfer==0.1.8
11
+ torchao==0.6.1
src/flux_schnell_edge_inference.egg-info/top_level.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ first_block_cache
2
+ main
3
+ pipeline
src/main.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ from io import BytesIO
3
+ from multiprocessing.connection import Listener
4
+ from os import chmod, remove
5
+ from os.path import abspath, exists
6
+ from pathlib import Path
7
+ from git import Repo
8
+ import torch
9
+
10
+ from PIL.JpegImagePlugin import JpegImageFile
11
+ from pipelines.models import TextToImageRequest
12
+ from pipeline import load_pipeline, infer
13
+ SOCKET = abspath(Path(__file__).parent.parent / "inferences.sock")
14
+
15
+
16
+ def at_exit():
17
+ torch.cuda.empty_cache()
18
+
19
+
20
+ def main():
21
+ atexit.register(at_exit)
22
+
23
+ print(f"Loading pipeline")
24
+ pipeline = load_pipeline()
25
+
26
+ print(f"Pipeline loaded, creating socket at '{SOCKET}'")
27
+
28
+ if exists(SOCKET):
29
+ remove(SOCKET)
30
+
31
+ with Listener(SOCKET) as listener:
32
+ chmod(SOCKET, 0o777)
33
+
34
+ print(f"Awaiting connections")
35
+ with listener.accept() as connection:
36
+ print(f"Connected")
37
+ generator = torch.Generator("cuda")
38
+ while True:
39
+ try:
40
+ request = TextToImageRequest.model_validate_json(connection.recv_bytes().decode("utf-8"))
41
+ except EOFError:
42
+ print(f"Inference socket exiting")
43
+
44
+ return
45
+ image = infer(request, pipeline, generator.manual_seed(request.seed))
46
+ data = BytesIO()
47
+ image.save(data, format=JpegImageFile.format)
48
+
49
+ packet = data.getvalue()
50
+
51
+ connection.send_bytes(packet )
52
+
53
+
54
+ if __name__ == '__main__':
55
+ main()
src/pipeline.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import gc
4
+ import torch
5
+ from PIL.Image import Image
6
+ from dataclasses import dataclass
7
+ from diffusers import DiffusionPipeline, AutoencoderTiny, FluxTransformer2DModel
8
+ from transformers import T5EncoderModel
9
+ from huggingface_hub.constants import HF_HUB_CACHE
10
+ from torchao.quantization import quantize_, int8_weight_only, float8_weight_only
11
+ from caching import apply_cache_on_pipe
12
+ from pipelines.models import TextToImageRequest
13
+ from torch import Generator
14
+
15
+ # Configuration settings using a dataclass for clarity
16
+ @dataclass
17
+ class Config:
18
+ CKPT_ID: str = "black-forest-labs/FLUX.1-schnell"
19
+ CKPT_REVISION: str = "741f7c3ce8b383c54771c7003378a50191e9efe9"
20
+ DEVICE: str = "cuda"
21
+ DTYPE = torch.bfloat16
22
+ PYTORCH_CUDA_ALLOC_CONF: str = "expandable_segments:True"
23
+
24
+ def _initialize_environment():
25
+ """Set up PyTorch and CUDA environment variables for optimal performance."""
26
+ torch.backends.cuda.matmul.allow_tf32 = True
27
+ torch.backends.cudnn.enabled = True
28
+ torch.backends.cudnn.benchmark = True
29
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = Config.PYTORCH_CUDA_ALLOC_CONF
30
+
31
+ def _clear_gpu_memory():
32
+ """Free up GPU memory to prevent memory-related issues."""
33
+ gc.collect()
34
+ torch.cuda.empty_cache()
35
+ torch.cuda.reset_max_memory_allocated()
36
+ torch.cuda.reset_peak_memory_stats()
37
+
38
+ def _load_text_encoder_model():
39
+ """Load the text encoder model with specified configuration."""
40
+ return T5EncoderModel.from_pretrained(
41
+ "city96/t5-v1_1-xxl-encoder-bf16",
42
+ revision="1b9c856aadb864af93c1dcdc226c2774fa67bc86",
43
+ torch_dtype=Config.DTYPE
44
+ ).to(memory_format=torch.channels_last)
45
+
46
+ def _load_vae_model():
47
+ """Load the variational autoencoder (VAE) model with specified configuration."""
48
+ return AutoencoderTiny.from_pretrained(
49
+ "RobertML/FLUX.1-schnell-vae_e3m2",
50
+ revision="da0d2cd7815792fb40d084dbd8ed32b63f153d8d",
51
+ torch_dtype=Config.DTYPE
52
+ )
53
+
54
+ def _load_transformer_model():
55
+ """Load the transformer model from a specific cached path."""
56
+ # transformer_path = os.path.join(
57
+ # HF_HUB_CACHE,"models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146",
58
+
59
+ # )
60
+ transformer_path = os.path.join(
61
+ HF_HUB_CACHE,
62
+ "models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146",
63
+ "transformer"
64
+ )
65
+ return FluxTransformer2DModel.from_pretrained(
66
+ transformer_path,
67
+ torch_dtype=Config.DTYPE,
68
+ use_safetensors=False
69
+ ).to(memory_format=torch.channels_last)
70
+
71
+ def _warmup_pipeline(pipeline):
72
+ """Warm up the pipeline by running it with an empty prompt to initialize internal caches."""
73
+ for _ in range(3):
74
+ pipeline(prompt=" ")
75
+
76
+ def load_pipeline():
77
+ """
78
+ Load and configure the diffusion pipeline for text-to-image generation.
79
+
80
+ Returns:
81
+ DiffusionPipeline: The configured pipeline ready for inference.
82
+ """
83
+ _clear_gpu_memory()
84
+
85
+ # Load individual components
86
+ text_encoder = _load_text_encoder_model()
87
+ vae = _load_vae_model()
88
+ transformer = _load_transformer_model()
89
+
90
+ # Assemble the diffusion pipeline
91
+ pipeline = DiffusionPipeline.from_pretrained(
92
+ Config.CKPT_ID,
93
+ vae=vae,
94
+ revision=Config.CKPT_REVISION,
95
+ transformer=transformer,
96
+ text_encoder_2=text_encoder,
97
+ torch_dtype=Config.DTYPE,
98
+ ).to(Config.DEVICE)
99
+
100
+ # Apply optimizations
101
+ apply_cache_on_pipe(pipeline)
102
+ pipeline.to(memory_format=torch.channels_last)
103
+ pipeline.vae = torch.compile(pipeline.vae, mode="max-autotune")
104
+ quantize_(pipeline.vae, int8_weight_only())
105
+ quantize_(pipeline.vae, float8_weight_only())
106
+
107
+ # Warm up the pipeline to ensure readiness
108
+ _warmup_pipeline(pipeline)
109
+
110
+ return pipeline
111
+
112
+ @torch.no_grad()
113
+ def infer(request: TextToImageRequest, pipeline: DiffusionPipeline, generator: Generator) -> Image:
114
+ """
115
+ Generate an image from a text prompt using the diffusion pipeline.
116
+
117
+ Args:
118
+ request (TextToImageRequest): The request containing the prompt and image parameters.
119
+ pipeline (DiffusionPipeline): The pre-loaded diffusion pipeline.
120
+ generator (Generator): The random seed generator for reproducibility.
121
+
122
+ Returns:
123
+ Image: The generated image in PIL format.
124
+ """
125
+ image = pipeline(
126
+ prompt=request.prompt,
127
+ generator=generator,
128
+ guidance_scale=0.0,
129
+ num_inference_steps=4,
130
+ max_sequence_length=256,
131
+ height=request.height,
132
+ width=request.width,
133
+ output_type="pil"
134
+ ).images[0]
135
+ return image
136
+
137
+ # Initialize environment settings when the module is imported
138
+ _initialize_environment()
139
+
140
+ # For compatibility with other scripts, alias load_pipeline as load
141
+ load = load_pipeline
uv.lock ADDED
The diff for this file is too large to render. See raw diff