Manireddy1508 commited on
Commit
e98b9e9
·
verified ·
1 Parent(s): 3b37434

Upload 2 files

Browse files
uno/flux/modules/conditioner.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from torch import Tensor, nn
17
+ from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
18
+ T5Tokenizer)
19
+
20
+
21
+ class HFEmbedder(nn.Module):
22
+ def __init__(self, version: str, max_length: int, **hf_kwargs):
23
+ super().__init__()
24
+ self.is_clip = "clip" in version.lower()
25
+ self.max_length = max_length
26
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
27
+
28
+ if self.is_clip:
29
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
30
+ self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained(version, **hf_kwargs)
31
+ else:
32
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length)
33
+ self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained(version, **hf_kwargs)
34
+
35
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
36
+
37
+ def forward(self, text: list[str]) -> Tensor:
38
+ batch_encoding = self.tokenizer(
39
+ text,
40
+ truncation=True,
41
+ max_length=self.max_length,
42
+ return_length=False,
43
+ return_overflowing_tokens=False,
44
+ padding="max_length",
45
+ return_tensors="pt",
46
+ )
47
+
48
+ outputs = self.hf_module(
49
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
50
+ attention_mask=None,
51
+ output_hidden_states=False,
52
+ )
53
+ return outputs[self.output_key]
uno/flux/modules/layers.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. All rights reserved.
2
+ # Copyright (c) 2024 Black Forest Labs and The XLabs-AI Team. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from dataclasses import dataclass
18
+
19
+ import torch
20
+ from einops import rearrange
21
+ from torch import Tensor, nn
22
+
23
+ from ..math import attention, rope
24
+ import torch.nn.functional as F
25
+
26
+ class EmbedND(nn.Module):
27
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
28
+ super().__init__()
29
+ self.dim = dim
30
+ self.theta = theta
31
+ self.axes_dim = axes_dim
32
+
33
+ def forward(self, ids: Tensor) -> Tensor:
34
+ n_axes = ids.shape[-1]
35
+ emb = torch.cat(
36
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
37
+ dim=-3,
38
+ )
39
+
40
+ return emb.unsqueeze(1)
41
+
42
+
43
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
44
+ """
45
+ Create sinusoidal timestep embeddings.
46
+ :param t: a 1-D Tensor of N indices, one per batch element.
47
+ These may be fractional.
48
+ :param dim: the dimension of the output.
49
+ :param max_period: controls the minimum frequency of the embeddings.
50
+ :return: an (N, D) Tensor of positional embeddings.
51
+ """
52
+ t = time_factor * t
53
+ half = dim // 2
54
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
55
+ t.device
56
+ )
57
+
58
+ args = t[:, None].float() * freqs[None]
59
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
60
+ if dim % 2:
61
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
62
+ if torch.is_floating_point(t):
63
+ embedding = embedding.to(t)
64
+ return embedding
65
+
66
+
67
+ class MLPEmbedder(nn.Module):
68
+ def __init__(self, in_dim: int, hidden_dim: int):
69
+ super().__init__()
70
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
71
+ self.silu = nn.SiLU()
72
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
73
+
74
+ def forward(self, x: Tensor) -> Tensor:
75
+ return self.out_layer(self.silu(self.in_layer(x)))
76
+
77
+
78
+ class RMSNorm(torch.nn.Module):
79
+ def __init__(self, dim: int):
80
+ super().__init__()
81
+ self.scale = nn.Parameter(torch.ones(dim))
82
+
83
+ def forward(self, x: Tensor):
84
+ x_dtype = x.dtype
85
+ x = x.float()
86
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
87
+ return ((x * rrms) * self.scale.float()).to(dtype=x_dtype)
88
+
89
+
90
+ class QKNorm(torch.nn.Module):
91
+ def __init__(self, dim: int):
92
+ super().__init__()
93
+ self.query_norm = RMSNorm(dim)
94
+ self.key_norm = RMSNorm(dim)
95
+
96
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
97
+ q = self.query_norm(q)
98
+ k = self.key_norm(k)
99
+ return q.to(v), k.to(v)
100
+
101
+ class LoRALinearLayer(nn.Module):
102
+ def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
103
+ super().__init__()
104
+
105
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
106
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
107
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
108
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
109
+ self.network_alpha = network_alpha
110
+ self.rank = rank
111
+
112
+ nn.init.normal_(self.down.weight, std=1 / rank)
113
+ nn.init.zeros_(self.up.weight)
114
+
115
+ def forward(self, hidden_states):
116
+ orig_dtype = hidden_states.dtype
117
+ dtype = self.down.weight.dtype
118
+
119
+ down_hidden_states = self.down(hidden_states.to(dtype))
120
+ up_hidden_states = self.up(down_hidden_states)
121
+
122
+ if self.network_alpha is not None:
123
+ up_hidden_states *= self.network_alpha / self.rank
124
+
125
+ return up_hidden_states.to(orig_dtype)
126
+
127
+ class FLuxSelfAttnProcessor:
128
+ def __call__(self, attn, x, pe, **attention_kwargs):
129
+ qkv = attn.qkv(x)
130
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
131
+ q, k = attn.norm(q, k, v)
132
+ x = attention(q, k, v, pe=pe)
133
+ x = attn.proj(x)
134
+ return x
135
+
136
+ class LoraFluxAttnProcessor(nn.Module):
137
+
138
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
139
+ super().__init__()
140
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
141
+ self.proj_lora = LoRALinearLayer(dim, dim, rank, network_alpha)
142
+ self.lora_weight = lora_weight
143
+
144
+
145
+ def __call__(self, attn, x, pe, **attention_kwargs):
146
+ qkv = attn.qkv(x) + self.qkv_lora(x) * self.lora_weight
147
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
148
+ q, k = attn.norm(q, k, v)
149
+ x = attention(q, k, v, pe=pe)
150
+ x = attn.proj(x) + self.proj_lora(x) * self.lora_weight
151
+ return x
152
+
153
+ class SelfAttention(nn.Module):
154
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
155
+ super().__init__()
156
+ self.num_heads = num_heads
157
+ head_dim = dim // num_heads
158
+
159
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
160
+ self.norm = QKNorm(head_dim)
161
+ self.proj = nn.Linear(dim, dim)
162
+ def forward():
163
+ pass
164
+
165
+
166
+ @dataclass
167
+ class ModulationOut:
168
+ shift: Tensor
169
+ scale: Tensor
170
+ gate: Tensor
171
+
172
+
173
+ class Modulation(nn.Module):
174
+ def __init__(self, dim: int, double: bool):
175
+ super().__init__()
176
+ self.is_double = double
177
+ self.multiplier = 6 if double else 3
178
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
179
+
180
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
181
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
182
+
183
+ return (
184
+ ModulationOut(*out[:3]),
185
+ ModulationOut(*out[3:]) if self.is_double else None,
186
+ )
187
+
188
+ class DoubleStreamBlockLoraProcessor(nn.Module):
189
+ def __init__(self, dim: int, rank=4, network_alpha=None, lora_weight=1):
190
+ super().__init__()
191
+ self.qkv_lora1 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
192
+ self.proj_lora1 = LoRALinearLayer(dim, dim, rank, network_alpha)
193
+ self.qkv_lora2 = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
194
+ self.proj_lora2 = LoRALinearLayer(dim, dim, rank, network_alpha)
195
+ self.lora_weight = lora_weight
196
+
197
+ def forward(self, attn, img, txt, vec, pe, **attention_kwargs):
198
+ img_mod1, img_mod2 = attn.img_mod(vec)
199
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
200
+
201
+ # prepare image for attention
202
+ img_modulated = attn.img_norm1(img)
203
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
204
+ img_qkv = attn.img_attn.qkv(img_modulated) + self.qkv_lora1(img_modulated) * self.lora_weight
205
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
206
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
207
+
208
+ # prepare txt for attention
209
+ txt_modulated = attn.txt_norm1(txt)
210
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
211
+ txt_qkv = attn.txt_attn.qkv(txt_modulated) + self.qkv_lora2(txt_modulated) * self.lora_weight
212
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
213
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
214
+
215
+ # run actual attention
216
+ q = torch.cat((txt_q, img_q), dim=2)
217
+ k = torch.cat((txt_k, img_k), dim=2)
218
+ v = torch.cat((txt_v, img_v), dim=2)
219
+
220
+ attn1 = attention(q, k, v, pe=pe)
221
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
222
+
223
+ # calculate the img bloks
224
+ img = img + img_mod1.gate * (attn.img_attn.proj(img_attn) + self.proj_lora1(img_attn) * self.lora_weight)
225
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
226
+
227
+ # calculate the txt bloks
228
+ txt = txt + txt_mod1.gate * (attn.txt_attn.proj(txt_attn) + self.proj_lora2(txt_attn) * self.lora_weight)
229
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
230
+ return img, txt
231
+
232
+ class DoubleStreamBlockProcessor:
233
+ def __call__(self, attn, img, txt, vec, pe, **attention_kwargs):
234
+ img_mod1, img_mod2 = attn.img_mod(vec)
235
+ txt_mod1, txt_mod2 = attn.txt_mod(vec)
236
+
237
+ # prepare image for attention
238
+ img_modulated = attn.img_norm1(img)
239
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
240
+ img_qkv = attn.img_attn.qkv(img_modulated)
241
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
242
+ img_q, img_k = attn.img_attn.norm(img_q, img_k, img_v)
243
+
244
+ # prepare txt for attention
245
+ txt_modulated = attn.txt_norm1(txt)
246
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
247
+ txt_qkv = attn.txt_attn.qkv(txt_modulated)
248
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads, D=attn.head_dim)
249
+ txt_q, txt_k = attn.txt_attn.norm(txt_q, txt_k, txt_v)
250
+
251
+ # run actual attention
252
+ q = torch.cat((txt_q, img_q), dim=2)
253
+ k = torch.cat((txt_k, img_k), dim=2)
254
+ v = torch.cat((txt_v, img_v), dim=2)
255
+
256
+ attn1 = attention(q, k, v, pe=pe)
257
+ txt_attn, img_attn = attn1[:, : txt.shape[1]], attn1[:, txt.shape[1] :]
258
+
259
+ # calculate the img bloks
260
+ img = img + img_mod1.gate * attn.img_attn.proj(img_attn)
261
+ img = img + img_mod2.gate * attn.img_mlp((1 + img_mod2.scale) * attn.img_norm2(img) + img_mod2.shift)
262
+
263
+ # calculate the txt bloks
264
+ txt = txt + txt_mod1.gate * attn.txt_attn.proj(txt_attn)
265
+ txt = txt + txt_mod2.gate * attn.txt_mlp((1 + txt_mod2.scale) * attn.txt_norm2(txt) + txt_mod2.shift)
266
+ return img, txt
267
+
268
+ class DoubleStreamBlock(nn.Module):
269
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False):
270
+ super().__init__()
271
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
272
+ self.num_heads = num_heads
273
+ self.hidden_size = hidden_size
274
+ self.head_dim = hidden_size // num_heads
275
+
276
+ self.img_mod = Modulation(hidden_size, double=True)
277
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
278
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
279
+
280
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
281
+ self.img_mlp = nn.Sequential(
282
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
283
+ nn.GELU(approximate="tanh"),
284
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
285
+ )
286
+
287
+ self.txt_mod = Modulation(hidden_size, double=True)
288
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
289
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
290
+
291
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
292
+ self.txt_mlp = nn.Sequential(
293
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
294
+ nn.GELU(approximate="tanh"),
295
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
296
+ )
297
+ processor = DoubleStreamBlockProcessor()
298
+ self.set_processor(processor)
299
+
300
+ def set_processor(self, processor) -> None:
301
+ self.processor = processor
302
+
303
+ def get_processor(self):
304
+ return self.processor
305
+
306
+ def forward(
307
+ self,
308
+ img: Tensor,
309
+ txt: Tensor,
310
+ vec: Tensor,
311
+ pe: Tensor,
312
+ image_proj: Tensor = None,
313
+ ip_scale: float =1.0,
314
+ ) -> tuple[Tensor, Tensor]:
315
+ if image_proj is None:
316
+ return self.processor(self, img, txt, vec, pe)
317
+ else:
318
+ return self.processor(self, img, txt, vec, pe, image_proj, ip_scale)
319
+
320
+
321
+ class SingleStreamBlockLoraProcessor(nn.Module):
322
+ def __init__(self, dim: int, rank: int = 4, network_alpha = None, lora_weight: float = 1):
323
+ super().__init__()
324
+ self.qkv_lora = LoRALinearLayer(dim, dim * 3, rank, network_alpha)
325
+ self.proj_lora = LoRALinearLayer(15360, dim, rank, network_alpha)
326
+ self.lora_weight = lora_weight
327
+
328
+ def forward(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
329
+
330
+ mod, _ = attn.modulation(vec)
331
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
332
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
333
+ qkv = qkv + self.qkv_lora(x_mod) * self.lora_weight
334
+
335
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
336
+ q, k = attn.norm(q, k, v)
337
+
338
+ # compute attention
339
+ attn_1 = attention(q, k, v, pe=pe)
340
+
341
+ # compute activation in mlp stream, cat again and run second linear layer
342
+ output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
343
+ output = output + self.proj_lora(torch.cat((attn_1, attn.mlp_act(mlp)), 2)) * self.lora_weight
344
+ output = x + mod.gate * output
345
+ return output
346
+
347
+
348
+ class SingleStreamBlockProcessor:
349
+ def __call__(self, attn: nn.Module, x: Tensor, vec: Tensor, pe: Tensor, **attention_kwargs) -> Tensor:
350
+
351
+ mod, _ = attn.modulation(vec)
352
+ x_mod = (1 + mod.scale) * attn.pre_norm(x) + mod.shift
353
+ qkv, mlp = torch.split(attn.linear1(x_mod), [3 * attn.hidden_size, attn.mlp_hidden_dim], dim=-1)
354
+
355
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=attn.num_heads)
356
+ q, k = attn.norm(q, k, v)
357
+
358
+ # compute attention
359
+ attn_1 = attention(q, k, v, pe=pe)
360
+
361
+ # compute activation in mlp stream, cat again and run second linear layer
362
+ output = attn.linear2(torch.cat((attn_1, attn.mlp_act(mlp)), 2))
363
+ output = x + mod.gate * output
364
+ return output
365
+
366
+ class SingleStreamBlock(nn.Module):
367
+ """
368
+ A DiT block with parallel linear layers as described in
369
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
370
+ """
371
+
372
+ def __init__(
373
+ self,
374
+ hidden_size: int,
375
+ num_heads: int,
376
+ mlp_ratio: float = 4.0,
377
+ qk_scale: float | None = None,
378
+ ):
379
+ super().__init__()
380
+ self.hidden_dim = hidden_size
381
+ self.num_heads = num_heads
382
+ self.head_dim = hidden_size // num_heads
383
+ self.scale = qk_scale or self.head_dim**-0.5
384
+
385
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
386
+ # qkv and mlp_in
387
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
388
+ # proj and mlp_out
389
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
390
+
391
+ self.norm = QKNorm(self.head_dim)
392
+
393
+ self.hidden_size = hidden_size
394
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
395
+
396
+ self.mlp_act = nn.GELU(approximate="tanh")
397
+ self.modulation = Modulation(hidden_size, double=False)
398
+
399
+ processor = SingleStreamBlockProcessor()
400
+ self.set_processor(processor)
401
+
402
+
403
+ def set_processor(self, processor) -> None:
404
+ self.processor = processor
405
+
406
+ def get_processor(self):
407
+ return self.processor
408
+
409
+ def forward(
410
+ self,
411
+ x: Tensor,
412
+ vec: Tensor,
413
+ pe: Tensor,
414
+ image_proj: Tensor | None = None,
415
+ ip_scale: float = 1.0,
416
+ ) -> Tensor:
417
+ if image_proj is None:
418
+ return self.processor(self, x, vec, pe)
419
+ else:
420
+ return self.processor(self, x, vec, pe, image_proj, ip_scale)
421
+
422
+
423
+
424
+ class LastLayer(nn.Module):
425
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
426
+ super().__init__()
427
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
428
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
429
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
430
+
431
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
432
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
433
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
434
+ x = self.linear(x)
435
+ return x