nupurkmr9 commited on
Commit
c598075
·
verified ·
1 Parent(s): 2fbc966

Update pipelines/flux_pipeline/transformer.py

Browse files
pipelines/flux_pipeline/transformer.py CHANGED
@@ -41,19 +41,6 @@ from diffusers.utils.torch_utils import maybe_allow_in_graph
41
 
42
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
 
44
- def log_scale_masking(value, min_value=1, max_value=10):
45
- # Convert the value into a positive domain for the logarithmic function
46
- normalized_value = 1*value
47
-
48
- # Apply logarithmic scaling
49
- # log_scaled_value = 1-np.exp(-normalized_value)
50
- log_scaled_value = 2.0* math.log(normalized_value+1, 2) / math.log(2, 2) # np.log1p(x) = log(1 + x)
51
- # print(log_scaled_value)
52
-
53
- # Rescale to original range
54
- scaled_value = log_scaled_value * (max_value - min_value) + min_value
55
-
56
- return min(max_value, int(scaled_value))
57
 
58
  class FluxAttnProcessor2_0:
59
  """Attention processor used typically in processing the SD3-like self-attention projections."""
@@ -137,7 +124,7 @@ class FluxAttnProcessor2_0:
137
  if neg_mode:
138
  res = int(math.sqrt((end_of_hidden_states-(text_seq if encoder_hidden_states is None else 0)) // num))
139
  hw = res*res
140
- mask_ = torch.ones(1, res, num*res, res, num*res).to(query.device)
141
  for i in range(num):
142
  mask_[:, :, i*res:(i+1)*res, :, i*res:(i+1)*res] = 1
143
  mask_ = rearrange(mask_, "b h w h1 w1 -> b (h w) (h1 w1)")
 
41
 
42
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  class FluxAttnProcessor2_0:
46
  """Attention processor used typically in processing the SD3-like self-attention projections."""
 
124
  if neg_mode:
125
  res = int(math.sqrt((end_of_hidden_states-(text_seq if encoder_hidden_states is None else 0)) // num))
126
  hw = res*res
127
+ mask_ = torch.zeros(1, res, num*res, res, num*res).to(query.device)
128
  for i in range(num):
129
  mask_[:, :, i*res:(i+1)*res, :, i*res:(i+1)*res] = 1
130
  mask_ = rearrange(mask_, "b h w h1 w1 -> b (h w) (h1 w1)")