Spaces:
Running
on
Zero
Running
on
Zero
Update pipelines/flux_pipeline/transformer.py
Browse files
pipelines/flux_pipeline/transformer.py
CHANGED
@@ -41,19 +41,6 @@ from diffusers.utils.torch_utils import maybe_allow_in_graph
|
|
41 |
|
42 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
43 |
|
44 |
-
def log_scale_masking(value, min_value=1, max_value=10):
|
45 |
-
# Convert the value into a positive domain for the logarithmic function
|
46 |
-
normalized_value = 1*value
|
47 |
-
|
48 |
-
# Apply logarithmic scaling
|
49 |
-
# log_scaled_value = 1-np.exp(-normalized_value)
|
50 |
-
log_scaled_value = 2.0* math.log(normalized_value+1, 2) / math.log(2, 2) # np.log1p(x) = log(1 + x)
|
51 |
-
# print(log_scaled_value)
|
52 |
-
|
53 |
-
# Rescale to original range
|
54 |
-
scaled_value = log_scaled_value * (max_value - min_value) + min_value
|
55 |
-
|
56 |
-
return min(max_value, int(scaled_value))
|
57 |
|
58 |
class FluxAttnProcessor2_0:
|
59 |
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
@@ -137,7 +124,7 @@ class FluxAttnProcessor2_0:
|
|
137 |
if neg_mode:
|
138 |
res = int(math.sqrt((end_of_hidden_states-(text_seq if encoder_hidden_states is None else 0)) // num))
|
139 |
hw = res*res
|
140 |
-
mask_ = torch.
|
141 |
for i in range(num):
|
142 |
mask_[:, :, i*res:(i+1)*res, :, i*res:(i+1)*res] = 1
|
143 |
mask_ = rearrange(mask_, "b h w h1 w1 -> b (h w) (h1 w1)")
|
|
|
41 |
|
42 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
class FluxAttnProcessor2_0:
|
46 |
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
|
|
124 |
if neg_mode:
|
125 |
res = int(math.sqrt((end_of_hidden_states-(text_seq if encoder_hidden_states is None else 0)) // num))
|
126 |
hw = res*res
|
127 |
+
mask_ = torch.zeros(1, res, num*res, res, num*res).to(query.device)
|
128 |
for i in range(num):
|
129 |
mask_[:, :, i*res:(i+1)*res, :, i*res:(i+1)*res] = 1
|
130 |
mask_ = rearrange(mask_, "b h w h1 w1 -> b (h w) (h1 w1)")
|