surokpro2 commited on
Commit
215c4b7
·
verified ·
1 Parent(s): f3855f2

Upload 47 files

Browse files
Files changed (48) hide show
  1. .gitattributes +1 -0
  2. LICENSE +21 -0
  3. README.MD +82 -0
  4. SAE/__init__.py +1 -0
  5. SAE/config.json +23 -0
  6. SAE/dataset_iterator.py +53 -0
  7. SAE/sae.py +216 -0
  8. SAE/sae_utils.py +48 -0
  9. SDLens/__init__.py +2 -0
  10. SDLens/cache_and_edit/__init__.py +1 -0
  11. SDLens/cache_and_edit/activation_cache.py +147 -0
  12. SDLens/cache_and_edit/cached_pipeline.py +342 -0
  13. SDLens/cache_and_edit/edits.py +223 -0
  14. SDLens/cache_and_edit/flux_pipeline.py +998 -0
  15. SDLens/cache_and_edit/hooks.py +108 -0
  16. SDLens/cache_and_edit/inversion.py +568 -0
  17. SDLens/cache_and_edit/metrics.py +116 -0
  18. SDLens/cache_and_edit/qkv_cache.py +557 -0
  19. SDLens/cache_and_edit/scheduler_inversion.py +98 -0
  20. SDLens/hooked_scheduler.py +40 -0
  21. SDLens/hooked_sd_pipeline.py +319 -0
  22. app.ipynb +0 -0
  23. app.py +768 -0
  24. checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
  25. checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
  26. checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
  27. checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
  28. checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
  29. checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
  30. checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
  31. checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
  32. checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
  33. checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
  34. checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
  35. checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
  36. checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
  37. checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
  38. checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
  39. checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
  40. colab_requirements.txt +8 -0
  41. example.ipynb +0 -0
  42. requirements.txt +12 -0
  43. resourses/image.png +3 -0
  44. retrieval.py +71 -0
  45. scripts/collect_latents_dataset.py +96 -0
  46. scripts/train_sae.py +308 -0
  47. utils/__init__.py +1 -0
  48. utils/hooks.py +145 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ resourses/image.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Viacheslav Surkov
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.MD ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Unpacking SDXL Turbo: Interpreting Text-to-Image Models with Sparse Autoencoders
2
+
3
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-red)](https://arxiv.org/abs/2410.22366)
4
+ [![Hugging Face Spaces Demo](https://img.shields.io/badge/Hugging%20Face-Demo-blue)](https://huggingface.co/spaces/surokpro2/Unboxing_SDXL_with_SAEs)
5
+ [![Colab](https://img.shields.io/badge/Colab-Notebook-yellow)](https://colab.research.google.com/drive/1lWZ2yCRwCf4iuykvb-91QYUNkuzIwI3k?usp=sharing)
6
+
7
+
8
+ ![modification demostration](resourses/image.png)
9
+
10
+ This repository contains code to reproduce results from our paper on using sparse autoencoders (SAEs) to analyze and interpret the internal representations of text-to-image diffusion models, specifically SDXL Turbo.
11
+
12
+
13
+ ## Repository Structure
14
+
15
+ ```
16
+ |-- SAE/ # Core sparse autoencoder implementation
17
+ |-- SDLens/ # Tools for analyzing diffusion models
18
+ | `-- hooked_sd_pipeline.py # Modified stable diffusion pipeline
19
+ |-- scripts/
20
+ | |-- collect_latents_dataset.py # Generate training data
21
+ | `-- train_sae.py # Train SAE models
22
+ |-- utils/
23
+ | `-- hooks.py # Hook utility functions
24
+ |-- checkpoints/ # Pretrained SAE model checkpoints
25
+ |-- app.py # Demo application
26
+ |-- app.ipynb # Interactive notebook demo
27
+ |-- example.ipynb # Usage examples
28
+ `-- requirements.txt # Python dependencies
29
+ ```
30
+
31
+ ## Installation
32
+
33
+ ```bash
34
+ pip install -r requirements.txt
35
+ ```
36
+
37
+ ## Demo Application
38
+
39
+ You can try our gradio demo application (`app.ipynb`) to browse and experiment with 20K+ features of our trained SAEs out-of-the-box. You can find the same notebook on [Google Colab](https://colab.research.google.com/drive/1lWZ2yCRwCf4iuykvb-91QYUNkuzIwI3k?usp=sharing).
40
+
41
+ ## Usage
42
+
43
+ 1. Collect latent data from SDXL Turbo:
44
+ ```bash
45
+ python scripts/collect_latents_dataset.py --save_path={your_save_path}
46
+ ```
47
+
48
+ 2. Train sparse autoencoders:
49
+
50
+ 2.1. Insert the path of stored latents and directory to store checkpoints in `SAE/config.json`
51
+
52
+ 2.2. Run the training script:
53
+
54
+ ```bash
55
+ python scripts/train_sae.py
56
+ ```
57
+
58
+ ## Pretrained Models
59
+
60
+ We provide pretrained SAE checkpoints for 4 key transformer blocks in SDXL Turbo's U-Net in the `checkpoints` folder. See `example.ipynb` for analysis examples and visualization of learned features. More pretrained SAEs with different parameters are accessible through [HuggingFace repo](https://huggingface.co/surokpro2/sdxl-saes/tree/main).
61
+
62
+
63
+ ## Citation
64
+
65
+ If you find this code useful in your research, please cite our paper:
66
+
67
+ ```bibtex
68
+ @misc{surkov2024unpackingsdxlturbointerpreting,
69
+ title={Unpacking SDXL Turbo: Interpreting Text-to-Image Models with Sparse Autoencoders},
70
+ author={Viacheslav Surkov and Chris Wendler and Mikhail Terekhov and Justin Deschenaux and Robert West and Caglar Gulcehre},
71
+ year={2024},
72
+ eprint={2410.22366},
73
+ archivePrefix={arXiv},
74
+ primaryClass={cs.LG},
75
+ url={https://arxiv.org/abs/2410.22366},
76
+ }
77
+ ```
78
+
79
+ ## Acknowledgements
80
+
81
+ The SAE component was implemented based on [`openai/sparse_autoencoder`](https://github.com/openai/sparse_autoencoder) repository.
82
+
SAE/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .sae import SparseAutoencoder
SAE/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sae_configs": [
3
+ {
4
+ "d_model": 1280,
5
+ "n_dirs": 5120,
6
+ "k": 20
7
+ },
8
+ {
9
+ "d_model": 1280,
10
+ "n_dirs": 640,
11
+ "k": 20
12
+ }
13
+ ],
14
+ "bs": 4096,
15
+ "log_interval": 500,
16
+ "save_interval": 5000,
17
+
18
+ "paths_to_latents": [
19
+ "PASS YOUR PATHS HERE. Example /home/username/latents/<timestamp>. It should contain tar archives with latents."
20
+ ],
21
+ "save_path_base": "<Your SAE save path>",
22
+ "block_name": "unet.down_blocks.2.attentions.1"
23
+ }
SAE/dataset_iterator.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import webdataset as wds
2
+ import os
3
+ import torch
4
+
5
+ class ActivationsDataloader:
6
+ def __init__(self, paths_to_datasets, block_name, batch_size, output_or_diff='diff', num_in_buffer=50):
7
+ assert output_or_diff in ['diff', 'output'], "Provide 'output' or 'diff'"
8
+
9
+ self.dataset = wds.WebDataset(
10
+ [os.path.join(path_to_dataset, f"{block_name}.tar")
11
+ for path_to_dataset in paths_to_datasets]
12
+ ).decode("torch")
13
+ self.iter = iter(self.dataset)
14
+ self.buffer = None
15
+ self.pointer = 0
16
+ self.num_in_buffer = num_in_buffer
17
+ self.output_or_diff = output_or_diff
18
+ self.batch_size = batch_size
19
+ self.one_size = None
20
+
21
+ def renew_buffer(self, to_retrieve):
22
+ to_merge = []
23
+ if self.buffer is not None and self.buffer.shape[0] > self.pointer:
24
+ to_merge = [self.buffer[self.pointer:].clone()]
25
+ del self.buffer
26
+ for _ in range(to_retrieve):
27
+ sample = next(self.iter)
28
+ latents = sample['output.pth'] if self.output_or_diff == 'output' else sample['diff.pth']
29
+ latents = latents.permute((0, 1, 3, 4, 2))
30
+ latents = latents.reshape((-1, latents.shape[-1]))
31
+ to_merge.append(latents.to('cuda'))
32
+ self.one_size = latents.shape[0]
33
+ self.buffer = torch.cat(to_merge, dim=0)
34
+ shuffled_indices = torch.randperm(self.buffer.shape[0])
35
+ self.buffer = self.buffer[shuffled_indices]
36
+ self.pointer = 0
37
+
38
+ def iterate(self):
39
+ while True:
40
+ if self.buffer == None or self.buffer.shape[0] - self.pointer < self.num_in_buffer * self.one_size * 4 // 5:
41
+ try:
42
+ to_retrieve = self.num_in_buffer if self.buffer is None else self.num_in_buffer // 5
43
+ self.renew_buffer(to_retrieve)
44
+ except StopIteration:
45
+ break
46
+
47
+ batch = self.buffer[self.pointer: self.pointer + self.batch_size]
48
+ self.pointer += self.batch_size
49
+
50
+ assert batch.shape[0] == self.batch_size
51
+ yield batch
52
+
53
+
SAE/sae.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from
3
+ https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/model.py
4
+ '''
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import os
9
+ import json
10
+
11
+ class SparseAutoencoder(nn.Module):
12
+ """
13
+ Top-K Autoencoder with sparse kernels. Implements:
14
+
15
+ latents = relu(topk(encoder(x - pre_bias) + latent_bias))
16
+ recons = decoder(latents) + pre_bias
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ n_dirs_local: int,
22
+ d_model: int,
23
+ k: int,
24
+ auxk: int | None,
25
+ dead_steps_threshold: int,
26
+ ):
27
+ super().__init__()
28
+ self.n_dirs_local = n_dirs_local
29
+ self.d_model = d_model
30
+ self.k = k
31
+ self.auxk = auxk
32
+ self.dead_steps_threshold = dead_steps_threshold
33
+
34
+ self.encoder = nn.Linear(d_model, n_dirs_local, bias=False)
35
+ self.decoder = nn.Linear(n_dirs_local, d_model, bias=False)
36
+
37
+ self.pre_bias = nn.Parameter(torch.zeros(d_model))
38
+ self.latent_bias = nn.Parameter(torch.zeros(n_dirs_local))
39
+
40
+ self.stats_last_nonzero: torch.Tensor
41
+ self.register_buffer("stats_last_nonzero", torch.zeros(n_dirs_local, dtype=torch.long))
42
+
43
+ def auxk_mask_fn(x):
44
+ dead_mask = self.stats_last_nonzero > dead_steps_threshold
45
+ x.data *= dead_mask # inplace to save memory
46
+ return x
47
+
48
+ self.auxk_mask_fn = auxk_mask_fn
49
+
50
+ ## initialization
51
+
52
+ # "tied" init
53
+ self.decoder.weight.data = self.encoder.weight.data.T.clone()
54
+
55
+ # store decoder in column major layout for kernel
56
+ self.decoder.weight.data = self.decoder.weight.data.T.contiguous().T
57
+
58
+ unit_norm_decoder_(self)
59
+
60
+ def save_to_disk(self, path: str):
61
+ PATH_TO_CFG = 'config.json'
62
+ PATH_TO_WEIGHTS = 'state_dict.pth'
63
+
64
+ cfg = {
65
+ "n_dirs_local": self.n_dirs_local,
66
+ "d_model": self.d_model,
67
+ "k": self.k,
68
+ "auxk": self.auxk,
69
+ "dead_steps_threshold": self.dead_steps_threshold,
70
+ }
71
+
72
+ os.makedirs(path, exist_ok=True)
73
+
74
+ with open(os.path.join(path, PATH_TO_CFG), 'w') as f:
75
+ json.dump(cfg, f)
76
+
77
+
78
+ torch.save({
79
+ "state_dict": self.state_dict(),
80
+ }, os.path.join(path, PATH_TO_WEIGHTS))
81
+
82
+
83
+ @classmethod
84
+ def load_from_disk(cls, path: str):
85
+ PATH_TO_CFG = 'config.json'
86
+ PATH_TO_WEIGHTS = 'state_dict.pth'
87
+
88
+ with open(os.path.join(path, PATH_TO_CFG), 'r') as f:
89
+ cfg = json.load(f)
90
+
91
+ ae = cls(
92
+ n_dirs_local=cfg["n_dirs_local"],
93
+ d_model=cfg["d_model"],
94
+ k=cfg["k"],
95
+ auxk=cfg["auxk"],
96
+ dead_steps_threshold=cfg["dead_steps_threshold"],
97
+ )
98
+
99
+ state_dict = torch.load(os.path.join(path, PATH_TO_WEIGHTS))["state_dict"]
100
+ ae.load_state_dict(state_dict)
101
+
102
+ return ae
103
+
104
+ @property
105
+ def n_dirs(self):
106
+ return self.n_dirs_local
107
+
108
+ def encode(self, x):
109
+ x = x - self.pre_bias
110
+ latents_pre_act = self.encoder(x) + self.latent_bias
111
+
112
+ vals, inds = torch.topk(
113
+ latents_pre_act,
114
+ k=self.k,
115
+ dim=-1
116
+ )
117
+
118
+ latents = torch.zeros_like(latents_pre_act)
119
+ latents.scatter_(-1, inds, torch.relu(vals))
120
+
121
+ return latents
122
+
123
+ def forward(self, x):
124
+ x = x - self.pre_bias
125
+ latents_pre_act = self.encoder(x) + self.latent_bias
126
+ vals, inds = torch.topk(
127
+ latents_pre_act,
128
+ k=self.k,
129
+ dim=-1
130
+ )
131
+
132
+ ## set num nonzero stat ##
133
+ tmp = torch.zeros_like(self.stats_last_nonzero)
134
+ tmp.scatter_add_(
135
+ 0,
136
+ inds.reshape(-1),
137
+ (vals > 1e-3).to(tmp.dtype).reshape(-1),
138
+ )
139
+ self.stats_last_nonzero *= 1 - tmp.clamp(max=1)
140
+ self.stats_last_nonzero += 1
141
+ ## end stats ##
142
+
143
+ ## auxk
144
+ if self.auxk is not None: # for auxk
145
+ # IMPORTANT: has to go after stats update!
146
+ # WARN: auxk_mask_fn can mutate latents_pre_act!
147
+ auxk_vals, auxk_inds = torch.topk(
148
+ self.auxk_mask_fn(latents_pre_act),
149
+ k=self.auxk,
150
+ dim=-1
151
+ )
152
+ else:
153
+ auxk_inds = None
154
+ auxk_vals = None
155
+
156
+ ## end auxk
157
+
158
+ vals = torch.relu(vals)
159
+ if auxk_vals is not None:
160
+ auxk_vals = torch.relu(auxk_vals)
161
+
162
+
163
+ rows, cols = latents_pre_act.size()
164
+ row_indices = torch.arange(rows).unsqueeze(1).expand(-1, self.k).reshape(-1)
165
+ vals = vals.reshape(-1)
166
+ inds = inds.reshape(-1)
167
+
168
+ indices = torch.stack([row_indices.to(inds.device), inds])
169
+
170
+ sparse_tensor = torch.sparse_coo_tensor(indices, vals, torch.Size([rows, cols]))
171
+
172
+ recons = torch.sparse.mm(sparse_tensor, self.decoder.weight.T) + self.pre_bias
173
+
174
+
175
+ return recons, {
176
+ "inds": inds,
177
+ "vals": vals,
178
+ "auxk_inds": auxk_inds,
179
+ "auxk_vals": auxk_vals,
180
+ }
181
+
182
+
183
+ def decode_sparse(self, inds, vals):
184
+ rows, cols = inds.shape[0], self.n_dirs
185
+
186
+ row_indices = torch.arange(rows).unsqueeze(1).expand(-1, inds.shape[1]).reshape(-1)
187
+ vals = vals.reshape(-1)
188
+ inds = inds.reshape(-1)
189
+
190
+ indices = torch.stack([row_indices.to(inds.device), inds])
191
+
192
+ sparse_tensor = torch.sparse_coo_tensor(indices, vals, torch.Size([rows, cols]))
193
+
194
+ recons = torch.sparse.mm(sparse_tensor, self.decoder.weight.T) + self.pre_bias
195
+ return recons
196
+
197
+ @property
198
+ def device(self):
199
+ return next(self.parameters()).device
200
+
201
+
202
+ def unit_norm_decoder_(autoencoder: SparseAutoencoder) -> None:
203
+ """
204
+ Unit normalize the decoder weights of an autoencoder.
205
+ """
206
+ autoencoder.decoder.weight.data /= autoencoder.decoder.weight.data.norm(dim=0)
207
+
208
+
209
+ def unit_norm_decoder_grad_adjustment_(autoencoder) -> None:
210
+ """project out gradient information parallel to the dictionary vectors - assumes that the decoder is already unit normed"""
211
+
212
+ assert autoencoder.decoder.weight.grad is not None
213
+
214
+ autoencoder.decoder.weight.grad +=\
215
+ torch.einsum("bn,bn->n", autoencoder.decoder.weight.data, autoencoder.decoder.weight.grad) *\
216
+ autoencoder.decoder.weight.data * -1
SAE/sae_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dataclasses import dataclass, field
3
+ import os
4
+
5
+ @dataclass
6
+ class SAETrainingConfig:
7
+ d_model: int
8
+ n_dirs: int
9
+ k: int
10
+ block_name: str
11
+ bs: int
12
+ save_path_base: str
13
+ auxk: int = 256
14
+ lr: float = 1e-4
15
+ eps: float = 6.25e-10
16
+ dead_toks_threshold: int = 10_000_000
17
+ auxk_coef: float = 1/32
18
+
19
+ @property
20
+ def sae_name(self):
21
+ return f'{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}'
22
+
23
+ @property
24
+ def save_path(self):
25
+ return os.path.join(save_path_base, f'{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}')
26
+
27
+
28
+ @dataclass
29
+ class Config:
30
+ saes: list[SAETrainingConfig]
31
+ paths_to_latents: list[str]
32
+ log_interval: int
33
+ save_interval: int
34
+ bs: int
35
+ block_name: str
36
+ wandb_project: str = 'sdxl_sae_train'
37
+ wandb_name: str = 'multiple_sae'
38
+
39
+ def __init__(self, cfg_json):
40
+ self.saes = [SAETrainingConfig(**sae_cfg, block_name=cfg_json['block_name'], bs=cfg_json['bs'], save_path_base=cfg_json['save_path_base'])
41
+ for sae_cfg in cfg_json['sae_configs']]
42
+
43
+ self.save_path_base = cfg_json['save_path_base']
44
+ self.paths_to_latents = cfg_json['paths_to_latents']
45
+ self.log_interval = cfg_json['log_interval']
46
+ self.save_interval = cfg_json['save_interval']
47
+ self.bs = cfg_json['bs']
48
+ self.block_name = cfg_json['block_name']
SDLens/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .hooked_sd_pipeline import HookedIFPipeline, HookedStableDiffusionXLPipeline
2
+ from .cache_and_edit import CachedPipeline
SDLens/cache_and_edit/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .cached_pipeline import CachedPipeline
SDLens/cache_and_edit/activation_cache.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from collections import defaultdict
3
+ from typing import List
4
+ from diffusers.models.transformers.transformer_flux import FluxTransformerBlock, FluxSingleTransformerBlock
5
+ from SDLens.cache_and_edit.hooks import fix_inf_values_hook, register_general_hook
6
+ import torch
7
+
8
+ class ModelActivationCache(ABC):
9
+ """
10
+ Cache for inference pass of a Diffusion Transformer.
11
+ Used to cache residual-streams and activations.
12
+ """
13
+ def __init__(self):
14
+
15
+ # Initialize caches for "double transformer" blocks using the subclass-defined NUM_TRANSFORMER_BLOCKS
16
+ if hasattr(self, 'NUM_TRANSFORMER_BLOCKS'):
17
+ self.image_residual = []
18
+ self.image_activation = []
19
+ self.text_residual = []
20
+ self.text_activation = []
21
+
22
+ # Initialize caches for "single transformer" blocks if defined (using NUM_SINGLE_TRANSFORMER_BLOCKS)
23
+ if hasattr(self, 'NUM_SINGLE_TRANSFORMER_BLOCKS'):
24
+ self.text_image_residual = []
25
+ self.text_image_activation = []
26
+
27
+ def __str__(self):
28
+ lines = [f"{self.__class__.__name__}:"]
29
+ for attr_name, value in self.__dict__.items():
30
+ if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value):
31
+ shapes = value[0].shape
32
+ lines.append(f" {attr_name}: len={len(value)}, shapes={shapes}")
33
+ else:
34
+ lines.append(f" {attr_name}: {type(value)}")
35
+ return "\n".join(lines)
36
+
37
+ def _repr_pretty_(self, p, cycle):
38
+ p.text(str(self))
39
+
40
+ @abstractmethod
41
+ def get_cache_info(self):
42
+ """
43
+ Return details about the cache configuration.
44
+ Subclasses must implement this to provide info on their transformer block counts.
45
+ """
46
+ pass
47
+
48
+
49
+ class FluxActivationCache(ModelActivationCache):
50
+ # Define number of blocks for double and single transformer caches
51
+ NUM_TRANSFORMER_BLOCKS = 19
52
+ NUM_SINGLE_TRANSFORMER_BLOCKS = 38
53
+
54
+ def __init__(self):
55
+ super().__init__()
56
+
57
+ def get_cache_info(self):
58
+ return {
59
+ "transformer_blocks": self.NUM_TRANSFORMER_BLOCKS,
60
+ "single_transformer_blocks": self.NUM_SINGLE_TRANSFORMER_BLOCKS,
61
+ }
62
+
63
+ def __getitem__(self, key):
64
+ return getattr(self, key)
65
+
66
+
67
+ class PixartActivationCache(ModelActivationCache):
68
+ # Define number of blocks for the double transformer cache only
69
+ NUM_TRANSFORMER_BLOCKS = 28
70
+
71
+ def __init__(self):
72
+ super().__init__()
73
+
74
+ def get_cache_info(self):
75
+ return {
76
+ "double_transformer_blocks": self.NUM_TRANSFORMER_BLOCKS,
77
+ }
78
+
79
+
80
+ class ActivationCacheHandler:
81
+ """ Used to manage ModelActivationCache of a Diffusion Transformer.
82
+ """
83
+
84
+ def __init__(self, cache: ModelActivationCache, positions_to_cache: List[str] = None):
85
+ """Constructor.
86
+
87
+ Args:
88
+ cache (ModelActivationCache): cache to be used to store tensors.
89
+ positions_to_cache (List[str], optional): name of modules to cached.
90
+ If None, all modules as specified in `cache.get_cache_info()` will be cached. Defaults to None.
91
+
92
+ Raises:
93
+ NotImplementedError: _description_
94
+
95
+ Returns:
96
+ _type_: _description_
97
+ """
98
+ self.cache = cache
99
+ self.positions_to_cache = positions_to_cache
100
+
101
+ @torch.no_grad()
102
+ def cache_residual_and_activation_hook(self, *args):
103
+ """
104
+ To be used as a forward hook on a Transformer Block.
105
+ It caches both residual_stream and activation (defined as output - residual_stream).
106
+ """
107
+
108
+ if len(args) == 3:
109
+ module, input, output = args
110
+ elif len(args) == 4:
111
+ module, input, kwinput, output = args
112
+
113
+ if isinstance(module, FluxTransformerBlock):
114
+ encoder_hidden_states = output[0]
115
+ hidden_states = output[1]
116
+
117
+ self.cache.image_activation.append(hidden_states - kwinput["hidden_states"])
118
+ self.cache.text_activation.append(encoder_hidden_states - kwinput["encoder_hidden_states"])
119
+ self.cache.image_residual.append(kwinput["hidden_states"])
120
+ self.cache.text_residual.append(kwinput["encoder_hidden_states"])
121
+
122
+ elif isinstance(module, FluxSingleTransformerBlock):
123
+ self.cache.text_image_activation.append(output - kwinput["hidden_states"])
124
+ self.cache.text_image_residual.append(kwinput["hidden_states"])
125
+ else:
126
+ raise NotImplementedError(f"Caching not implemented for {type(module)}")
127
+
128
+
129
+ @property
130
+ def forward_hooks_dict(self):
131
+
132
+ # insert cache storing in dict
133
+ hooks = defaultdict(list)
134
+
135
+ if self.positions_to_cache is None:
136
+ for block_type, num_layers in self.cache.get_cache_info().items():
137
+ for i in range(num_layers):
138
+ module_name: str = f"transformer.{block_type}.{i}"
139
+ hooks[module_name].append(fix_inf_values_hook)
140
+ hooks[module_name].append(self.cache_residual_and_activation_hook)
141
+ else:
142
+ for module_name in self.positions_to_cache:
143
+ hooks[module_name].append(fix_inf_values_hook)
144
+ hooks[module_name].append(self.cache_residual_and_activation_hook)
145
+
146
+ return hooks
147
+
SDLens/cache_and_edit/cached_pipeline.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from functools import partial
3
+ import gc
4
+ from typing import Callable, Dict, List, Literal, Union, Optional, Type, Union
5
+ import torch
6
+ from SDLens.cache_and_edit.activation_cache import FluxActivationCache, ModelActivationCache, PixartActivationCache, ActivationCacheHandler
7
+ from diffusers.models.transformers.transformer_flux import FluxTransformerBlock, FluxSingleTransformerBlock
8
+ from SDLens.cache_and_edit.hooks import locate_block, register_general_hook, fix_inf_values_hook, edit_streams_hook
9
+ from SDLens.cache_and_edit.qkv_cache import QKVCacheFluxHandler, QKVCache, CachedFluxAttnProcessor3_0
10
+ from SDLens.cache_and_edit.scheduler_inversion import FlowMatchEulerDiscreteSchedulerForInversion
11
+ from SDLens.cache_and_edit.flux_pipeline import EditedFluxPipeline
12
+
13
+ from diffusers.pipelines import FluxPipeline
14
+
15
+
16
+
17
+ class CachedPipeline:
18
+
19
+ def __init__(self, pipe: EditedFluxPipeline, text_seq_length: int = 512):
20
+
21
+ assert isinstance(pipe, EditedFluxPipeline) or isinstance(pipe, FluxPipeline), "Use EditedFluxPipeline class in `cache_and_edit/flux_pipeline.py`"
22
+ self.pipe = pipe
23
+ self.text_seq_length = text_seq_length
24
+
25
+ # Cache handlers
26
+ self.activation_cache_handler = None
27
+ self.qkv_cache_handler = None
28
+ # keeps references to all registered hooks
29
+ self.registered_hooks = []
30
+
31
+
32
+ def setup_cache(self, use_activation_cache = True,
33
+ use_qkv_cache = False,
34
+ positions_to_cache: List[str] = None,
35
+ positions_to_cache_foreground: List[str] = None,
36
+ qkv_to_inject: QKVCache = None,
37
+ inject_kv_mode: Literal["image", "text", "both"] = None,
38
+ q_mask=None,
39
+ processor_class: Optional[Type] = CachedFluxAttnProcessor3_0
40
+ ) -> None:
41
+ """
42
+ Sets up activation_cache and/or qkv_cache, setting the required hooks.
43
+ If positions_to_cache is None, then all modules will be cached.
44
+ If inject_kv_mode is None, then qkv cache will be stored, otherwise qkv_to_inject will be injected.
45
+ """
46
+
47
+ if use_activation_cache:
48
+ if isinstance(self.pipe, EditedFluxPipeline) or isinstance(self.pipe, FluxPipeline):
49
+ activation_cache = FluxActivationCache()
50
+ else:
51
+ raise AssertionError(f"activation cache not implemented for {type(self.pipe)}")
52
+
53
+ self.activation_cache_handler = ActivationCacheHandler(activation_cache, positions_to_cache)
54
+ # register hooks crated by activation_cache
55
+ self._set_hooks(position_hook_dict=self.activation_cache_handler.forward_hooks_dict,
56
+ with_kwargs=True)
57
+
58
+ if use_qkv_cache:
59
+ if isinstance(self.pipe, EditedFluxPipeline) or isinstance(self.pipe, FluxPipeline):
60
+ self.qkv_cache_handler = QKVCacheFluxHandler(self.pipe,
61
+ positions_to_cache,
62
+ positions_to_cache_foreground,
63
+ inject_kv=inject_kv_mode,
64
+ text_seq_length=self.text_seq_length,
65
+ q_mask=q_mask,
66
+ processor_class=processor_class,
67
+ )
68
+ else:
69
+ raise AssertionError(f"QKV cache not implemented for {type(self.pipe)}")
70
+
71
+ # qkv_cache does not use hooks
72
+
73
+
74
+ @property
75
+ def activation_cache(self) -> ModelActivationCache:
76
+ return self.activation_cache_handler.cache if hasattr(self, "activation_cache_handler") and self.activation_cache_handler else None
77
+
78
+
79
+ @property
80
+ def qkv_cache(self) -> QKVCache:
81
+ return self.qkv_cache_handler.cache if hasattr(self, "qkv_cache_handler") and self.qkv_cache_handler else None
82
+
83
+
84
+ @torch.no_grad
85
+ def run(self,
86
+ prompt: Union[str, List[str]],
87
+ num_inference_steps: int = 1,
88
+ seed: int = 42,
89
+ width=1024,
90
+ height=1024,
91
+ cache_activations: bool = False,
92
+ cache_qkv: bool = False,
93
+ guidance_scale: float = 0.0,
94
+ positions_to_cache: List[str] = None,
95
+ empty_clip_embeddings: bool = True,
96
+ inverse: bool = False,
97
+ **kwargs):
98
+ """run the pipeline, possibly cachine activations or QKV.
99
+
100
+ Args:
101
+ prompt (str): Prompt to run the pipeline (NOTE: for Flux, parameters passed are prompt='' and prompt2=prompt)
102
+ num_inference_steps (int, optional): Num steps for inference. Defaults to 1.
103
+ seed (int, optional): seed for generators. Defaults to 42.
104
+ cache_activations (bool, optional): Whether to cache activations. Defaults to True.
105
+ cache_qkv (bool, optional): Whether to cache queries, keys, values. Defaults to False.
106
+ positions_to_cache (List[str], optional): list of blocks to cache.
107
+ If None, all transformer blocks will be cached. Defaults to None.
108
+
109
+ Returns:
110
+ _type_: same output as wrapped pipeline.
111
+ """
112
+
113
+ # First, clear all registered hooks
114
+ self.clear_all_hooks()
115
+
116
+ # Delete cache already present
117
+ if self.activation_cache or self.qkv_cache:
118
+
119
+ if self.activation_cache:
120
+ del(self.activation_cache_handler.cache)
121
+ del(self.activation_cache_handler)
122
+
123
+ if self.qkv_cache:
124
+ # Necessary to delete the old cache.
125
+ self.qkv_cache_handler.clear_cache()
126
+ del(self.qkv_cache_handler)
127
+
128
+ gc.collect() # force Python to clean up unreachable objects
129
+ torch.cuda.empty_cache() # tell PyTorch to release unused GPU memory from its cache
130
+
131
+ # Setup cache again for the current inference pass
132
+ self.setup_cache(cache_activations, cache_qkv, positions_to_cache, inject_kv_mode=None)
133
+
134
+ assert isinstance(seed, int)
135
+
136
+ if isinstance(prompt, str):
137
+ empty_prompt = [""]
138
+ prompt = [prompt]
139
+ else:
140
+ empty_prompt = [""] * len(prompt)
141
+
142
+ gen = [torch.Generator(device="cpu").manual_seed(seed) for _ in range(len(prompt))]
143
+
144
+ if inverse:
145
+ # maybe create scheduler for inversion
146
+ if not hasattr(self, "inversion_scheduler"):
147
+ self.inversion_scheduler = FlowMatchEulerDiscreteSchedulerForInversion.from_config(
148
+ self.pipe.scheduler.config,
149
+ inverse=True
150
+ )
151
+ self.og_scheduler = self.pipe.scheduler
152
+
153
+ self.pipe.scheduler = self.inversion_scheduler
154
+
155
+ output = self.pipe(
156
+ prompt=empty_prompt if empty_clip_embeddings else prompt,
157
+ prompt_2=prompt,
158
+ num_inference_steps=num_inference_steps,
159
+ guidance_scale=guidance_scale,
160
+ generator=gen,
161
+ width=width,
162
+ height=height,
163
+ **kwargs
164
+ )
165
+
166
+ # Restore original scheduler
167
+ if inverse:
168
+ self.pipe.scheduler = self.og_scheduler
169
+
170
+ return output
171
+
172
+ @torch.no_grad
173
+ def run_inject_qkv(self,
174
+ prompt: Union[str, List[str]],
175
+ positions_to_inject: List[str] = None,
176
+ positions_to_inject_foreground: List[str] = None,
177
+ inject_kv_mode: Literal["image", "text", "both"] = "image",
178
+ num_inference_steps: int = 1,
179
+ guidance_scale: float = 0.0,
180
+ seed: int = 42,
181
+ empty_clip_embeddings: bool = True,
182
+ q_mask=None,
183
+ width: int = 1024,
184
+ height: int = 1024,
185
+ processor_class: Optional[Type] = CachedFluxAttnProcessor3_0,
186
+ **kwargs):
187
+ """run the pipeline, possibly cachine activations or QKV.
188
+
189
+ Args:
190
+ prompt (str): Prompt to run the pipeline (NOTE: for Flux, parameters passed are prompt='' and prompt2=prompt)
191
+ num_inference_steps (int, optional): Num steps for inference. Defaults to 1.
192
+ seed (int, optional): seed for generators. Defaults to 42.
193
+ cache_activations (bool, optional): Whether to cache activations. Defaults to True.
194
+ cache_qkv (bool, optional): Whether to cache queries, keys, values. Defaults to False.
195
+ positions_to_cache (List[str], optional): list of blocks to cache.
196
+ If None, all transformer blocks will be cached. Defaults to None.
197
+
198
+ Returns:
199
+ _type_: same output as wrapped pipeline.
200
+ """
201
+
202
+ # First, clear all registered hooks
203
+ self.clear_all_hooks()
204
+
205
+ # Delete previous QKVCache
206
+ if hasattr(self, "qkv_cache_handler") and self.qkv_cache_handler is not None:
207
+ self.qkv_cache_handler.clear_cache()
208
+ del(self.qkv_cache_handler)
209
+ gc.collect() # force Python to clean up unreachable objects
210
+ torch.cuda.empty_cache() # tell PyTorch to release unused GPU memory from its cache
211
+
212
+ # Will setup existing QKV cache to be injected
213
+ self.setup_cache(use_activation_cache=False,
214
+ use_qkv_cache=True,
215
+ positions_to_cache=positions_to_inject,
216
+ positions_to_cache_foreground=positions_to_inject_foreground,
217
+ inject_kv_mode=inject_kv_mode,
218
+ q_mask=q_mask,
219
+ processor_class=processor_class,
220
+ )
221
+
222
+ self.qkv_cache_handler
223
+
224
+ assert isinstance(seed, int)
225
+
226
+ if isinstance(prompt, str):
227
+ empty_prompt = [""]
228
+ prompt = [prompt]
229
+ else:
230
+ empty_prompt = [""] * len(prompt)
231
+
232
+ gen = [torch.Generator(device="cpu").manual_seed(seed) for _ in range(len(prompt))]
233
+
234
+ output = self.pipe(
235
+ prompt=empty_prompt if empty_clip_embeddings else prompt,
236
+ prompt_2=prompt,
237
+ num_inference_steps=num_inference_steps,
238
+ guidance_scale=guidance_scale,
239
+ generator=gen,
240
+ width=width,
241
+ height=height,
242
+ **kwargs
243
+ )
244
+
245
+
246
+
247
+ return output
248
+
249
+
250
+ def clear_all_hooks(self):
251
+
252
+ # 1. Clear all registered hooks
253
+ for hook in self.registered_hooks:
254
+ hook.remove()
255
+ self.registered_hooks = []
256
+
257
+ # 2. Eventually clear other hooks registered in the pipeline but not present here
258
+ # TODO: make it general for other models
259
+ for i in range(len(locate_block(self.pipe, "transformer.transformer_blocks"))):
260
+ locate_block(self.pipe, f"transformer.transformer_blocks.{i}")._forward_hooks.clear()
261
+
262
+ for i in range(len(locate_block(self.pipe, "transformer.single_transformer_blocks"))):
263
+ locate_block(self.pipe, f"transformer.single_transformer_blocks.{i}")._forward_hooks.clear()
264
+
265
+
266
+ def _set_hooks(self,
267
+ position_hook_dict: Dict[str, List[Callable]] = {},
268
+ position_pre_hook_dict: Dict[str, List[Callable]] = {},
269
+ with_kwargs=False
270
+ ):
271
+ '''
272
+ Set hooks at specified positions and register them.
273
+ Args:
274
+ position_hook_dict: A dictionary mapping positions to hooks.
275
+ The keys are positions in the pipeline where the hooks should be registered.
276
+ The values are either a single hook or a list of hooks to be registered at the specified position.
277
+ Each hook should be a callable that takes three arguments: (module, input, output).
278
+ **kwargs: Keyword arguments to pass to the pipeline.
279
+ '''
280
+
281
+ # Register hooks
282
+ for is_pre_hook, hook_dict in [(True, position_pre_hook_dict), (False, position_hook_dict)]:
283
+ for position, hook in hook_dict.items():
284
+ assert isinstance(hook, list)
285
+ for h in hook:
286
+ self.registered_hooks.append(register_general_hook(self.pipe, position, h, with_kwargs, is_pre_hook))
287
+
288
+
289
+ def run_with_edit(self,
290
+ prompt: str,
291
+ edit_fn: callable,
292
+ layers_for_edit_fn: List[int],
293
+ stream: Literal['text', 'image', 'both'],
294
+ guidance_scale: float = 0.0,
295
+ seed=42,
296
+ num_inference_steps=1,
297
+ empty_clip_embeddings: bool = True,
298
+ width: int = 1024,
299
+ height: int = 1024,
300
+ **kwargs,
301
+ ):
302
+
303
+ assert isinstance(seed, int)
304
+
305
+ self.clear_all_hooks()
306
+
307
+
308
+ # Setup hooks for edit_fn at the specified layers
309
+ # NOTE: edit_fn_hooks has to be Dict[str, List[Callable]]
310
+ edit_fn_hooks = {f"transformer.transformer_blocks.{layer}": [lambda *args: edit_streams_hook(*args, recompute_fn=edit_fn, stream=stream)]
311
+ for layer in layers_for_edit_fn if layer < 19}
312
+ edit_fn_hooks.update({f"transformer.single_transformer_blocks.{layer - 19}": [lambda *args: edit_streams_hook(*args, recompute_fn=edit_fn, stream=stream)]
313
+ for layer in layers_for_edit_fn if layer >= 19})
314
+
315
+
316
+ # register hooks in the pipe
317
+ self._set_hooks(position_hook_dict=edit_fn_hooks, with_kwargs=True)
318
+
319
+ # Create generators
320
+
321
+ if isinstance(prompt, str):
322
+ empty_prompt = [""]
323
+ prompt = [prompt]
324
+ else:
325
+ empty_prompt = [""] * len(prompt)
326
+
327
+ gen = [torch.Generator(device="cpu").manual_seed(seed) for _ in range(len(prompt))]
328
+
329
+ with torch.no_grad():
330
+ output = self.pipe(
331
+ prompt=empty_prompt if empty_clip_embeddings else prompt,
332
+ prompt_2=prompt,
333
+ num_inference_steps=num_inference_steps,
334
+ guidance_scale=guidance_scale,
335
+ generator=gen,
336
+ width=width,
337
+ height=height,
338
+ **kwargs
339
+ )
340
+
341
+ return output
342
+
SDLens/cache_and_edit/edits.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class Edit:
3
+
4
+ def __init__(self, ablator, vanilla_pre_forward_dict: Callable[[str, int], dict],
5
+ vanilla_forward_dict: Callable[[str, int], dict],
6
+ ablated_pre_forward_dict: Callable[[str, int], dict],
7
+ ablated_forward_dict: Callable[[str, int], dict],):
8
+ self.ablator=ablator
9
+ self.vanilla_seed = 42
10
+ self.vanilla_pre_forward_dict = vanilla_pre_forward_dict
11
+ self.vanilla_forward_dict = vanilla_forward_dict
12
+
13
+ self.ablated_seed = 42
14
+ self.ablated_pre_forward_dict = ablated_pre_forward_dict
15
+ self.ablated_forward_dict = ablated_forward_dict
16
+
17
+
18
+ def get_edit(name: str):
19
+
20
+ if name == "edit_streams":
21
+ ablator = TransformerActivationCache()
22
+ stream: str = kwargs["stream"]
23
+ layers = kwargs["layers"]
24
+ edit_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = kwargs["edit_fn"]
25
+
26
+ interventions = {f"transformer.transformer_blocks.{layer}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer < 19}
27
+ interventions.update({f"transformer.single_transformer_blocks.{layer - 19}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer >= 19})
28
+
29
+ return Ablation(ablator,
30
+ vanilla_pre_forward_dict=lambda block_type, layer_num: {},
31
+ vanilla_forward_dict=lambda block_type, layer_num: {},
32
+ ablated_pre_forward_dict=lambda block_type, layer_num: {},
33
+ ablated_forward_dict=lambda block_type, layer_num: interventions,
34
+ )
35
+
36
+
37
+ """
38
+ def get_ablation(name: str, **kwargs):
39
+
40
+ if name == "intermediate_text_stream_to_input":
41
+
42
+ ablator = TransformerActivationCache()
43
+ return Ablation(ablator,
44
+ vanilla_pre_forward_dict=lambda block_type, layer_num: {},
45
+ vanilla_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.cache_attention_activation(*args, full_output=True)},
46
+ ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.transformer_blocks.0": lambda *args: ablator.replace_stream_input(*args, stream="text")},
47
+ ablated_forward_dict=lambda block_type, layer_num: {})
48
+ elif name == "input_to_intermediate_text_stream":
49
+ ablator = TransformerActivationCache()
50
+ return Ablation(ablator,
51
+ vanilla_pre_forward_dict=lambda block_type, layer_num: {},
52
+ vanilla_forward_dict=lambda block_type, layer_num: {f"transformer.transformer_blocks.0": lambda *args: ablator.cache_attention_activation(*args, full_output=True)},
53
+ ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.replace_stream_input(*args, stream="text")},
54
+ ablated_forward_dict=lambda block_type, layer_num: {})
55
+
56
+ elif name == "set_input_text":
57
+
58
+ tensor: torch.Tensor = kwargs["tensor"]
59
+
60
+ ablator = TransformerActivationCache()
61
+ return Ablation(ablator,
62
+ vanilla_pre_forward_dict=lambda block_type, layer_num: {},
63
+ vanilla_forward_dict=lambda block_type, layer_num: {},
64
+ ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.replace_stream_input(*args, use_tensor=tensor, stream="text")},
65
+ ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.clamp_output(*args)})
66
+
67
+ elif name == "replace_text_stream_activation":
68
+ ablator = AttentionAblationCacheHook()
69
+ weight = kwargs["weight"] if "weight" in kwargs else 1.0
70
+
71
+
72
+ return Ablation(ablator,
73
+ vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_text_stream},
74
+ vanilla_forward_dict=lambda block_type, layer_num: {},
75
+ ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_and_inject_pre_forward},
76
+ ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.set_ablated_attention(*args, weight=weight)})
77
+
78
+ elif name == "replace_text_stream":
79
+ ablator = TransformerActivationCache()
80
+ weight = kwargs["weight"] if "weight" in kwargs else 1.0
81
+
82
+ return Ablation(ablator,
83
+ vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_text_stream},
84
+ vanilla_forward_dict=lambda block_type, layer_num: {},
85
+ ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_and_inject_pre_forward},
86
+ ablated_forward_dict=lambda block_type, layer_num: {})
87
+
88
+
89
+ elif name == "input=output":
90
+ return Ablation(None,
91
+ vanilla_pre_forward_dict=lambda block_type, layer_num: {},
92
+ vanilla_forward_dict=lambda block_type, layer_num: {},
93
+ ablated_pre_forward_dict=lambda block_type, layer_num: {},
94
+ ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablate_block(*args)})
95
+
96
+ elif name == "reweight_text_stream":
97
+ ablator = TransformerActivationCache()
98
+
99
+ residual_w=kwargs["residual_w"]
100
+ activation_w=kwargs["activation_w"]
101
+
102
+ return Ablation(ablator,
103
+ vanilla_pre_forward_dict=lambda block_type, layer_num: {},
104
+ vanilla_forward_dict=lambda block_type, layer_num: {},
105
+ ablated_pre_forward_dict=lambda block_type, layer_num: {},
106
+ ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.reweight_text_stream(*args, residual_w=residual_w, activation_w=activation_w)})
107
+
108
+ elif name == "add_input_text":
109
+
110
+ tensor: torch.Tensor = kwargs["tensor"]
111
+
112
+ ablator = TransformerActivationCache()
113
+ return Ablation(ablator,
114
+ vanilla_pre_forward_dict=lambda block_type, layer_num: {},
115
+ vanilla_forward_dict=lambda block_type, layer_num: {},
116
+ ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.add_text_stream_input(*args, use_tensor=tensor)},
117
+ ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.clamp_output(*args)})
118
+
119
+ elif name == "nothing":
120
+ ablator = TransformerActivationCache()
121
+ return Ablation(ablator,
122
+ vanilla_pre_forward_dict=lambda block_type, layer_num: {},
123
+ vanilla_forward_dict=lambda block_type, layer_num: {},
124
+ ablated_pre_forward_dict=lambda block_type, layer_num: {},
125
+ ablated_forward_dict=lambda block_type, layer_num: {})
126
+
127
+ elif name == "reweight_image_stream":
128
+ ablator = TransformerActivationCache()
129
+ residual_w=kwargs["residual_w"]
130
+ activation_w=kwargs["activation_w"]
131
+
132
+ return Ablation(ablator,
133
+ vanilla_pre_forward_dict=lambda block_type, layer_num: {},
134
+ vanilla_forward_dict=lambda block_type, layer_num: {},
135
+ ablated_pre_forward_dict=lambda block_type, layer_num: {},
136
+ ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.reweight_image_stream(*args, residual_w=residual_w, activation_w=activation_w)})
137
+
138
+ if name == "intermediate_image_stream_to_input":
139
+
140
+ ablator = TransformerActivationCache()
141
+ return Ablation(ablator,
142
+ vanilla_pre_forward_dict=lambda block_type, layer_num: {},
143
+ vanilla_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.cache_attention_activation(*args, full_output=True)},
144
+ ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.transformer_blocks.0": lambda *args: ablator.replace_stream_input(*args, stream='image')},
145
+ ablated_forward_dict=lambda block_type, layer_num: {})
146
+
147
+
148
+ elif name == "replace_text_stream_one_layer":
149
+ ablator = AttentionAblationCacheHook()
150
+ weight = kwargs["weight"] if "weight" in kwargs else 1.0
151
+
152
+
153
+ return Ablation(ablator,
154
+ vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_text_stream},
155
+ vanilla_forward_dict=lambda block_type, layer_num: {},
156
+ ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_and_inject_pre_forward},
157
+ ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.restore_text_stream})
158
+
159
+ elif name == "replace_intermediate_representation":
160
+ ablator = TransformerActivationCache()
161
+ tensor: torch.Tensor = kwargs["tensor"]
162
+
163
+ return Ablation(ablator,
164
+ vanilla_pre_forward_dict=lambda block_type, layer_num: {},
165
+ vanilla_forward_dict=lambda block_type, layer_num: {},
166
+ ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.0": lambda *args: ablator.replace_stream_input(*args, use_tensor=tensor, stream='text_image')},
167
+ ablated_forward_dict=lambda block_type, layer_num: {})
168
+
169
+ elif name == "destroy_registers":
170
+ ablator = TransformerActivationCache()
171
+ layers: List[int] = kwargs['layers']
172
+ k: float = kwargs["k"]
173
+ stream: str = kwargs['stream']
174
+ random: bool = kwargs["random"] if "random" in kwargs else False
175
+ lowest_norm: bool = kwargs["lowest_norm"] if "lowest_norm" in kwargs else False
176
+
177
+ return Ablation(ablator,
178
+ vanilla_pre_forward_dict=lambda block_type, layer_num: {},
179
+ vanilla_forward_dict=lambda block_type, layer_num: {},
180
+ ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.{i}": lambda *args: ablator.destroy_registers(*args, k=k, stream=stream, random_ablation=random, lowest_norm=lowest_norm) for i in layers},
181
+ ablated_forward_dict=lambda block_type, layer_num: {})
182
+
183
+ elif name == "patch_registers":
184
+ ablator = TransformerActivationCache()
185
+ layers: List[int] = kwargs['layers']
186
+ k: float = kwargs["k"]
187
+ stream: str = kwargs['stream']
188
+ random: bool = kwargs["random"] if "random" in kwargs else False
189
+ lowest_norm: bool = kwargs["lowest_norm"] if "lowest_norm" in kwargs else False
190
+
191
+ return Ablation(ablator,
192
+ vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.{i}": lambda *args: ablator.destroy_registers(*args, k=k, stream=stream, random_ablation=random, lowest_norm=lowest_norm) for i in layers},
193
+ vanilla_forward_dict=lambda block_type, layer_num: {},
194
+ ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.{i}": lambda *args: ablator.set_cached_registers(*args, k=k, stream=stream, random_ablation=random, lowest_norm=lowest_norm) for i in layers},
195
+ ablated_forward_dict=lambda block_type, layer_num: {})
196
+
197
+ elif name == "add_registers":
198
+ ablator = TransformerActivationCache()
199
+ num_registers: int = kwargs["num_registers"]
200
+
201
+ return Ablation(ablator,
202
+ vanilla_pre_forward_dict=lambda block_type, layer_num: {},
203
+ vanilla_forward_dict=lambda block_type, layer_num: {},
204
+ ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer": lambda *args: insert_extra_registers(*args, num_registers=num_registers)},
205
+ ablated_forward_dict=lambda block_type, layer_num: {f"transformer": lambda *args: discard_extra_registers(*args, num_registers=num_registers)},)
206
+
207
+
208
+ elif name == "edit_streams":
209
+ ablator = TransformerActivationCache()
210
+ stream: str = kwargs["stream"]
211
+ layers = kwargs["layers"]
212
+ edit_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = kwargs["edit_fn"]
213
+
214
+ interventions = {f"transformer.transformer_blocks.{layer}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer < 19}
215
+ interventions.update({f"transformer.single_transformer_blocks.{layer - 19}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer >= 19})
216
+
217
+ return Ablation(ablator,
218
+ vanilla_pre_forward_dict=lambda block_type, layer_num: {},
219
+ vanilla_forward_dict=lambda block_type, layer_num: {},
220
+ ablated_pre_forward_dict=lambda block_type, layer_num: {},
221
+ ablated_forward_dict=lambda block_type, layer_num: interventions,
222
+ )
223
+ """
SDLens/cache_and_edit/flux_pipeline.py ADDED
@@ -0,0 +1,998 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import (
21
+ CLIPImageProcessor,
22
+ CLIPTextModel,
23
+ CLIPTokenizer,
24
+ CLIPVisionModelWithProjection,
25
+ T5EncoderModel,
26
+ T5TokenizerFast,
27
+ )
28
+
29
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
30
+ from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
31
+ from diffusers.models.autoencoders import AutoencoderKL
32
+ from diffusers.models.transformers import FluxTransformer2DModel
33
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
34
+ from diffusers.utils import (
35
+ USE_PEFT_BACKEND,
36
+ is_torch_xla_available,
37
+ logging,
38
+ replace_example_docstring,
39
+ scale_lora_layers,
40
+ unscale_lora_layers,
41
+ )
42
+ from diffusers.utils.torch_utils import randn_tensor
43
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
44
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
45
+
46
+
47
+ if is_torch_xla_available():
48
+ import torch_xla.core.xla_model as xm
49
+
50
+ XLA_AVAILABLE = True
51
+ else:
52
+ XLA_AVAILABLE = False
53
+
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+ EXAMPLE_DOC_STRING = """
58
+ Examples:
59
+ ```py
60
+ >>> import torch
61
+ >>> from diffusers import FluxPipeline
62
+
63
+ >>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
64
+ >>> pipe.to("cuda")
65
+ >>> prompt = "A cat holding a sign that says hello world"
66
+ >>> # Depending on the variant being used, the pipeline call will slightly vary.
67
+ >>> # Refer to the pipeline documentation for more details.
68
+ >>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
69
+ >>> image.save("flux.png")
70
+ ```
71
+ """
72
+
73
+
74
+ def calculate_shift(
75
+ image_seq_len,
76
+ base_seq_len: int = 256,
77
+ max_seq_len: int = 4096,
78
+ base_shift: float = 0.5,
79
+ max_shift: float = 1.16,
80
+ ):
81
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
82
+ b = base_shift - m * base_seq_len
83
+ mu = image_seq_len * m + b
84
+ return mu
85
+
86
+
87
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
88
+ def retrieve_timesteps(
89
+ scheduler,
90
+ num_inference_steps: Optional[int] = None,
91
+ device: Optional[Union[str, torch.device]] = None,
92
+ timesteps: Optional[List[int]] = None,
93
+ sigmas: Optional[List[float]] = None,
94
+ **kwargs,
95
+ ):
96
+ r"""
97
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
98
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
99
+
100
+ Args:
101
+ scheduler (`SchedulerMixin`):
102
+ The scheduler to get timesteps from.
103
+ num_inference_steps (`int`):
104
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
105
+ must be `None`.
106
+ device (`str` or `torch.device`, *optional*):
107
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
108
+ timesteps (`List[int]`, *optional*):
109
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
110
+ `num_inference_steps` and `sigmas` must be `None`.
111
+ sigmas (`List[float]`, *optional*):
112
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
113
+ `num_inference_steps` and `timesteps` must be `None`.
114
+
115
+ Returns:
116
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
117
+ second element is the number of inference steps.
118
+ """
119
+ if timesteps is not None and sigmas is not None:
120
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
121
+ if timesteps is not None:
122
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
123
+ if not accepts_timesteps:
124
+ raise ValueError(
125
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
126
+ f" timestep schedules. Please check whether you are using the correct scheduler."
127
+ )
128
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
129
+ timesteps = scheduler.timesteps
130
+ num_inference_steps = len(timesteps)
131
+ elif sigmas is not None:
132
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
133
+ if not accept_sigmas:
134
+ raise ValueError(
135
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
136
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
137
+ )
138
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
139
+ timesteps = scheduler.timesteps
140
+ num_inference_steps = len(timesteps)
141
+ else:
142
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
143
+ timesteps = scheduler.timesteps
144
+ return timesteps, num_inference_steps
145
+
146
+
147
+ class EditedFluxPipeline(
148
+ DiffusionPipeline,
149
+ FluxLoraLoaderMixin,
150
+ FromSingleFileMixin,
151
+ TextualInversionLoaderMixin,
152
+ FluxIPAdapterMixin,
153
+ ):
154
+ r"""
155
+ The Flux pipeline for text-to-image generation.
156
+
157
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
158
+
159
+ Args:
160
+ transformer ([`FluxTransformer2DModel`]):
161
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
162
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
163
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
164
+ vae ([`AutoencoderKL`]):
165
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
166
+ text_encoder ([`CLIPTextModel`]):
167
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
168
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
169
+ text_encoder_2 ([`T5EncoderModel`]):
170
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
171
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
172
+ tokenizer (`CLIPTokenizer`):
173
+ Tokenizer of class
174
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
175
+ tokenizer_2 (`T5TokenizerFast`):
176
+ Second Tokenizer of class
177
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
178
+ """
179
+
180
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
181
+ _optional_components = ["image_encoder", "feature_extractor"]
182
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
183
+
184
+ def __init__(
185
+ self,
186
+ scheduler: FlowMatchEulerDiscreteScheduler,
187
+ vae: AutoencoderKL,
188
+ text_encoder: CLIPTextModel,
189
+ tokenizer: CLIPTokenizer,
190
+ text_encoder_2: T5EncoderModel,
191
+ tokenizer_2: T5TokenizerFast,
192
+ transformer: FluxTransformer2DModel,
193
+ image_encoder: CLIPVisionModelWithProjection = None,
194
+ feature_extractor: CLIPImageProcessor = None,
195
+ ):
196
+ super().__init__()
197
+
198
+ self.register_modules(
199
+ vae=vae,
200
+ text_encoder=text_encoder,
201
+ text_encoder_2=text_encoder_2,
202
+ tokenizer=tokenizer,
203
+ tokenizer_2=tokenizer_2,
204
+ transformer=transformer,
205
+ scheduler=scheduler,
206
+ image_encoder=image_encoder,
207
+ feature_extractor=feature_extractor,
208
+ )
209
+ self.vae_scale_factor = (
210
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
211
+ )
212
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
213
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
214
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
215
+ self.tokenizer_max_length = (
216
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
217
+ )
218
+ self.default_sample_size = 128
219
+
220
+ def _get_t5_prompt_embeds(
221
+ self,
222
+ prompt: Union[str, List[str]] = None,
223
+ num_images_per_prompt: int = 1,
224
+ max_sequence_length: int = 512,
225
+ device: Optional[torch.device] = None,
226
+ dtype: Optional[torch.dtype] = None,
227
+ ):
228
+ device = device or self._execution_device
229
+ dtype = dtype or self.text_encoder.dtype
230
+
231
+ prompt = [prompt] if isinstance(prompt, str) else prompt
232
+ batch_size = len(prompt)
233
+
234
+ if isinstance(self, TextualInversionLoaderMixin):
235
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
236
+
237
+ text_inputs = self.tokenizer_2(
238
+ prompt,
239
+ padding="max_length",
240
+ max_length=max_sequence_length,
241
+ truncation=True,
242
+ return_length=False,
243
+ return_overflowing_tokens=False,
244
+ return_tensors="pt",
245
+ )
246
+ text_input_ids = text_inputs.input_ids
247
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
248
+
249
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
250
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
251
+ logger.warning(
252
+ "The following part of your input was truncated because `max_sequence_length` is set to "
253
+ f" {max_sequence_length} tokens: {removed_text}"
254
+ )
255
+
256
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
257
+
258
+ dtype = self.text_encoder_2.dtype
259
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
260
+
261
+ _, seq_len, _ = prompt_embeds.shape
262
+
263
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
264
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
265
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
266
+
267
+ return prompt_embeds
268
+
269
+ def _get_clip_prompt_embeds(
270
+ self,
271
+ prompt: Union[str, List[str]],
272
+ num_images_per_prompt: int = 1,
273
+ device: Optional[torch.device] = None,
274
+ ):
275
+ device = device or self._execution_device
276
+
277
+ prompt = [prompt] if isinstance(prompt, str) else prompt
278
+ batch_size = len(prompt)
279
+
280
+ if isinstance(self, TextualInversionLoaderMixin):
281
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
282
+
283
+ text_inputs = self.tokenizer(
284
+ prompt,
285
+ padding="max_length",
286
+ max_length=self.tokenizer_max_length,
287
+ truncation=True,
288
+ return_overflowing_tokens=False,
289
+ return_length=False,
290
+ return_tensors="pt",
291
+ )
292
+
293
+ text_input_ids = text_inputs.input_ids
294
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
295
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
296
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
297
+ logger.warning(
298
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
299
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
300
+ )
301
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
302
+
303
+ # Use pooled output of CLIPTextModel
304
+ prompt_embeds = prompt_embeds.pooler_output
305
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
306
+
307
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
308
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
309
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
310
+
311
+ return prompt_embeds
312
+
313
+ def encode_prompt(
314
+ self,
315
+ prompt: Union[str, List[str]],
316
+ prompt_2: Union[str, List[str]],
317
+ device: Optional[torch.device] = None,
318
+ num_images_per_prompt: int = 1,
319
+ prompt_embeds: Optional[torch.FloatTensor] = None,
320
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
321
+ max_sequence_length: int = 512,
322
+ lora_scale: Optional[float] = None,
323
+ ):
324
+ r"""
325
+
326
+ Args:
327
+ prompt (`str` or `List[str]`, *optional*):
328
+ prompt to be encoded
329
+ prompt_2 (`str` or `List[str]`, *optional*):
330
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
331
+ used in all text-encoders
332
+ device: (`torch.device`):
333
+ torch device
334
+ num_images_per_prompt (`int`):
335
+ number of images that should be generated per prompt
336
+ prompt_embeds (`torch.FloatTensor`, *optional*):
337
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
338
+ provided, text embeddings will be generated from `prompt` input argument.
339
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
340
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
341
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
342
+ lora_scale (`float`, *optional*):
343
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
344
+ """
345
+ device = device or self._execution_device
346
+
347
+ # set lora scale so that monkey patched LoRA
348
+ # function of text encoder can correctly access it
349
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
350
+ self._lora_scale = lora_scale
351
+
352
+ # dynamically adjust the LoRA scale
353
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
354
+ scale_lora_layers(self.text_encoder, lora_scale)
355
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
356
+ scale_lora_layers(self.text_encoder_2, lora_scale)
357
+
358
+ prompt = [prompt] if isinstance(prompt, str) else prompt
359
+
360
+ if prompt_embeds is None:
361
+ prompt_2 = prompt_2 or prompt
362
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
363
+
364
+ # We only use the pooled prompt output from the CLIPTextModel
365
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
366
+ prompt=prompt,
367
+ device=device,
368
+ num_images_per_prompt=num_images_per_prompt,
369
+ )
370
+ prompt_embeds = self._get_t5_prompt_embeds(
371
+ prompt=prompt_2,
372
+ num_images_per_prompt=num_images_per_prompt,
373
+ max_sequence_length=max_sequence_length,
374
+ device=device,
375
+ )
376
+
377
+ if self.text_encoder is not None:
378
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
379
+ # Retrieve the original scale by scaling back the LoRA layers
380
+ unscale_lora_layers(self.text_encoder, lora_scale)
381
+
382
+ if self.text_encoder_2 is not None:
383
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
384
+ # Retrieve the original scale by scaling back the LoRA layers
385
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
386
+
387
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
388
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
389
+
390
+ return prompt_embeds, pooled_prompt_embeds, text_ids
391
+
392
+ def encode_image(self, image, device, num_images_per_prompt):
393
+ dtype = next(self.image_encoder.parameters()).dtype
394
+
395
+ if not isinstance(image, torch.Tensor):
396
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
397
+
398
+ image = image.to(device=device, dtype=dtype)
399
+ image_embeds = self.image_encoder(image).image_embeds
400
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
401
+ return image_embeds
402
+
403
+ def prepare_ip_adapter_image_embeds(
404
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
405
+ ):
406
+ image_embeds = []
407
+ if ip_adapter_image_embeds is None:
408
+ if not isinstance(ip_adapter_image, list):
409
+ ip_adapter_image = [ip_adapter_image]
410
+
411
+ if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
412
+ raise ValueError(
413
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
414
+ )
415
+
416
+ for single_ip_adapter_image, image_proj_layer in zip(
417
+ ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
418
+ ):
419
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
420
+
421
+ image_embeds.append(single_image_embeds[None, :])
422
+ else:
423
+ for single_image_embeds in ip_adapter_image_embeds:
424
+ image_embeds.append(single_image_embeds)
425
+
426
+ ip_adapter_image_embeds = []
427
+ for i, single_image_embeds in enumerate(image_embeds):
428
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
429
+ single_image_embeds = single_image_embeds.to(device=device)
430
+ ip_adapter_image_embeds.append(single_image_embeds)
431
+
432
+ return ip_adapter_image_embeds
433
+
434
+ def check_inputs(
435
+ self,
436
+ prompt,
437
+ prompt_2,
438
+ height,
439
+ width,
440
+ negative_prompt=None,
441
+ negative_prompt_2=None,
442
+ prompt_embeds=None,
443
+ negative_prompt_embeds=None,
444
+ pooled_prompt_embeds=None,
445
+ negative_pooled_prompt_embeds=None,
446
+ callback_on_step_end_tensor_inputs=None,
447
+ max_sequence_length=None,
448
+ ):
449
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
450
+ logger.warning(
451
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
452
+ )
453
+
454
+ if callback_on_step_end_tensor_inputs is not None and not all(
455
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
456
+ ):
457
+ raise ValueError(
458
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
459
+ )
460
+
461
+ if prompt is not None and prompt_embeds is not None:
462
+ raise ValueError(
463
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
464
+ " only forward one of the two."
465
+ )
466
+ elif prompt_2 is not None and prompt_embeds is not None:
467
+ raise ValueError(
468
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
469
+ " only forward one of the two."
470
+ )
471
+ elif prompt is None and prompt_embeds is None:
472
+ raise ValueError(
473
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
474
+ )
475
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
476
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
477
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
478
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
479
+
480
+ if negative_prompt is not None and negative_prompt_embeds is not None:
481
+ raise ValueError(
482
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
483
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
484
+ )
485
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
486
+ raise ValueError(
487
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
488
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
489
+ )
490
+
491
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
492
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
493
+ raise ValueError(
494
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
495
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
496
+ f" {negative_prompt_embeds.shape}."
497
+ )
498
+
499
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
500
+ raise ValueError(
501
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
502
+ )
503
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
504
+ raise ValueError(
505
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
506
+ )
507
+
508
+ if max_sequence_length is not None and max_sequence_length > 512:
509
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
510
+
511
+ @staticmethod
512
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
513
+ latent_image_ids = torch.zeros(height, width, 3)
514
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
515
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
516
+
517
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
518
+
519
+ latent_image_ids = latent_image_ids.reshape(
520
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
521
+ )
522
+
523
+ return latent_image_ids.to(device=device, dtype=dtype)
524
+
525
+ @staticmethod
526
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
527
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
528
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
529
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
530
+
531
+ return latents
532
+
533
+ @staticmethod
534
+ def _unpack_latents(latents, height, width, vae_scale_factor):
535
+ batch_size, num_patches, channels = latents.shape
536
+
537
+ # VAE applies 8x compression on images but we must also account for packing which requires
538
+ # latent height and width to be divisible by 2.
539
+ height = 2 * (int(height) // (vae_scale_factor * 2))
540
+ width = 2 * (int(width) // (vae_scale_factor * 2))
541
+
542
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
543
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
544
+
545
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
546
+
547
+ return latents
548
+
549
+ def enable_vae_slicing(self):
550
+ r"""
551
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
552
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
553
+ """
554
+ self.vae.enable_slicing()
555
+
556
+ def disable_vae_slicing(self):
557
+ r"""
558
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
559
+ computing decoding in one step.
560
+ """
561
+ self.vae.disable_slicing()
562
+
563
+ def enable_vae_tiling(self):
564
+ r"""
565
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
566
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
567
+ processing larger images.
568
+ """
569
+ self.vae.enable_tiling()
570
+
571
+ def disable_vae_tiling(self):
572
+ r"""
573
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
574
+ computing decoding in one step.
575
+ """
576
+ self.vae.disable_tiling()
577
+
578
+
579
+ def prepare_latents(
580
+ self,
581
+ batch_size,
582
+ num_channels_latents,
583
+ height,
584
+ width,
585
+ dtype,
586
+ device,
587
+ generator,
588
+ latents=None,
589
+ ):
590
+ # VAE applies 8x compression on images but we must also account for packing which requires
591
+ # latent height and width to be divisible by 2.
592
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
593
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
594
+
595
+ shape = (batch_size, num_channels_latents, height, width)
596
+
597
+ if latents is not None:
598
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
599
+ return latents.to(device=device, dtype=dtype), latent_image_ids
600
+
601
+ if isinstance(generator, list) and len(generator) != batch_size:
602
+ raise ValueError(
603
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
604
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
605
+ )
606
+
607
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
608
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
609
+
610
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
611
+
612
+ return latents, latent_image_ids
613
+
614
+ @property
615
+ def guidance_scale(self):
616
+ return self._guidance_scale
617
+
618
+ @property
619
+ def joint_attention_kwargs(self):
620
+ return self._joint_attention_kwargs
621
+
622
+ @property
623
+ def num_timesteps(self):
624
+ return self._num_timesteps
625
+
626
+ @property
627
+ def interrupt(self):
628
+ return self._interrupt
629
+
630
+ @torch.no_grad()
631
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
632
+ def __call__(
633
+ self,
634
+ prompt: Union[str, List[str]] = None,
635
+ prompt_2: Optional[Union[str, List[str]]] = None,
636
+ negative_prompt: Union[str, List[str]] = None,
637
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
638
+ true_cfg_scale: float = 1.0,
639
+ height: Optional[int] = None,
640
+ width: Optional[int] = None,
641
+ num_inference_steps: int = 28,
642
+ sigmas: Optional[List[float]] = None,
643
+ guidance_scale: float = 3.5,
644
+ num_images_per_prompt: Optional[int] = 1,
645
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
646
+ latents: Optional[torch.FloatTensor] = None,
647
+ prompt_embeds: Optional[torch.FloatTensor] = None,
648
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
649
+ ip_adapter_image: Optional[PipelineImageInput] = None,
650
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
651
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
652
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
653
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
654
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
655
+ output_type: Optional[str] = "pil",
656
+ return_dict: bool = True,
657
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
658
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
659
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
660
+ max_sequence_length: int = 512,
661
+ is_inverted_generation: bool = False,
662
+ inverted_latents_list: List[torch.Tensor] = None,
663
+ tau_b: Optional[float] = None,
664
+ bg_consistency_mask: Optional[torch.Tensor] = None,
665
+ ):
666
+ r"""
667
+ Function invoked when calling the pipeline for generation.
668
+
669
+ Args:
670
+ prompt (`str` or `List[str]`, *optional*):
671
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
672
+ instead.
673
+ prompt_2 (`str` or `List[str]`, *optional*):
674
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
675
+ will be used instead
676
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
677
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
678
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
679
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
680
+ num_inference_steps (`int`, *optional*, defaults to 50):
681
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
682
+ expense of slower inference.
683
+ sigmas (`List[float]`, *optional*):
684
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
685
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
686
+ will be used.
687
+ guidance_scale (`float`, *optional*, defaults to 7.0):
688
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
689
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
690
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
691
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
692
+ usually at the expense of lower image quality.
693
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
694
+ The number of images to generate per prompt.
695
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
696
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
697
+ to make generation deterministic.
698
+ latents (`torch.FloatTensor`, *optional*):
699
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
700
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
701
+ tensor will ge generated by sampling using the supplied random `generator`.
702
+ prompt_embeds (`torch.FloatTensor`, *optional*):
703
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
704
+ provided, text embeddings will be generated from `prompt` input argument.
705
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
706
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
707
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
708
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
709
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
710
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
711
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
712
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
713
+ negative_ip_adapter_image:
714
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
715
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
716
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
717
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
718
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
719
+ output_type (`str`, *optional*, defaults to `"pil"`):
720
+ The output format of the generate image. Choose between
721
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
722
+ return_dict (`bool`, *optional*, defaults to `True`):
723
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
724
+ joint_attention_kwargs (`dict`, *optional*):
725
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
726
+ `self.processor` in
727
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
728
+ callback_on_step_end (`Callable`, *optional*):
729
+ A function that calls at the end of each denoising steps during the inference. The function is called
730
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
731
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
732
+ `callback_on_step_end_tensor_inputs`.
733
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
734
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
735
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
736
+ `._callback_tensor_inputs` attribute of your pipeline class.
737
+ max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
738
+ tau_b (`float`, *optional*): Proportion of steps during which the background consistency is applied.
739
+ bg_consistency_mask (`torch.Tensor`, *optional*): Mask to use when applying background consistency. The mask
740
+ background consistency will be applied to the areas outside of the mask.
741
+
742
+ Examples:
743
+
744
+ Returns:
745
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
746
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
747
+ images.
748
+ """
749
+
750
+ height = height or self.default_sample_size * self.vae_scale_factor
751
+ width = width or self.default_sample_size * self.vae_scale_factor
752
+
753
+ # 1. Check inputs. Raise error if not correct
754
+ self.check_inputs(
755
+ prompt,
756
+ prompt_2,
757
+ height,
758
+ width,
759
+ negative_prompt=negative_prompt,
760
+ negative_prompt_2=negative_prompt_2,
761
+ prompt_embeds=prompt_embeds,
762
+ negative_prompt_embeds=negative_prompt_embeds,
763
+ pooled_prompt_embeds=pooled_prompt_embeds,
764
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
765
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
766
+ max_sequence_length=max_sequence_length,
767
+ )
768
+
769
+ self._guidance_scale = guidance_scale
770
+ self._joint_attention_kwargs = joint_attention_kwargs
771
+ self._interrupt = False
772
+
773
+ # 2. Define call parameters
774
+ if prompt is not None and isinstance(prompt, str):
775
+ batch_size = 1
776
+ elif prompt is not None and isinstance(prompt, list):
777
+ batch_size = len(prompt)
778
+ else:
779
+ batch_size = prompt_embeds.shape[0]
780
+
781
+ device = self._execution_device
782
+
783
+ lora_scale = (
784
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
785
+ )
786
+ do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
787
+ (
788
+ prompt_embeds,
789
+ pooled_prompt_embeds,
790
+ text_ids,
791
+ ) = self.encode_prompt(
792
+ prompt=prompt,
793
+ prompt_2=prompt_2,
794
+ prompt_embeds=prompt_embeds,
795
+ pooled_prompt_embeds=pooled_prompt_embeds,
796
+ device=device,
797
+ num_images_per_prompt=num_images_per_prompt,
798
+ max_sequence_length=max_sequence_length,
799
+ lora_scale=lora_scale,
800
+ )
801
+ if do_true_cfg:
802
+ (
803
+ negative_prompt_embeds,
804
+ negative_pooled_prompt_embeds,
805
+ _,
806
+ ) = self.encode_prompt(
807
+ prompt=negative_prompt,
808
+ prompt_2=negative_prompt_2,
809
+ prompt_embeds=negative_prompt_embeds,
810
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
811
+ device=device,
812
+ num_images_per_prompt=num_images_per_prompt,
813
+ max_sequence_length=max_sequence_length,
814
+ lora_scale=lora_scale,
815
+ )
816
+
817
+ # 4. Prepare latent variables
818
+ num_channels_latents = self.transformer.config.in_channels // 4
819
+ latents, latent_image_ids = self.prepare_latents(
820
+ batch_size * num_images_per_prompt,
821
+ num_channels_latents,
822
+ height,
823
+ width,
824
+ prompt_embeds.dtype,
825
+ device,
826
+ generator,
827
+ latents,
828
+ )
829
+
830
+ # 5. Prepare timesteps
831
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
832
+ image_seq_len = latents.shape[1]
833
+ mu = calculate_shift(
834
+ image_seq_len,
835
+ self.scheduler.config.base_image_seq_len,
836
+ self.scheduler.config.max_image_seq_len,
837
+ self.scheduler.config.base_shift,
838
+ self.scheduler.config.max_shift,
839
+ )
840
+ timesteps, num_inference_steps = retrieve_timesteps(
841
+ self.scheduler,
842
+ num_inference_steps,
843
+ device,
844
+ sigmas=sigmas,
845
+ mu=mu,
846
+ )
847
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
848
+ self._num_timesteps = len(timesteps)
849
+
850
+ if is_inverted_generation:
851
+ timesteps = reversed(timesteps)
852
+
853
+
854
+ # handle guidance
855
+ if self.transformer.config.guidance_embeds:
856
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
857
+ guidance = guidance.expand(latents.shape[0])
858
+ else:
859
+ guidance = None
860
+
861
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
862
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
863
+ ):
864
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
865
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
866
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
867
+ ):
868
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
869
+
870
+ if self.joint_attention_kwargs is None:
871
+ self._joint_attention_kwargs = {}
872
+
873
+ image_embeds = None
874
+ negative_image_embeds = None
875
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
876
+ image_embeds = self.prepare_ip_adapter_image_embeds(
877
+ ip_adapter_image,
878
+ ip_adapter_image_embeds,
879
+ device,
880
+ batch_size * num_images_per_prompt,
881
+ )
882
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
883
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
884
+ negative_ip_adapter_image,
885
+ negative_ip_adapter_image_embeds,
886
+ device,
887
+ batch_size * num_images_per_prompt,
888
+ )
889
+
890
+ # 6. Denoising loop
891
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
892
+ for i, t in enumerate(timesteps):
893
+
894
+ if self.interrupt:
895
+ continue
896
+
897
+ if image_embeds is not None:
898
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
899
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
900
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
901
+
902
+ noise_pred = self.transformer(
903
+ hidden_states=latents,
904
+ timestep=timestep / 1000,
905
+ guidance=guidance,
906
+ pooled_projections=pooled_prompt_embeds,
907
+ encoder_hidden_states=prompt_embeds,
908
+ txt_ids=text_ids,
909
+ img_ids=latent_image_ids,
910
+ joint_attention_kwargs=self.joint_attention_kwargs,
911
+ return_dict=False,
912
+ )[0]
913
+
914
+ if do_true_cfg:
915
+ if negative_image_embeds is not None:
916
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
917
+ neg_noise_pred = self.transformer(
918
+ hidden_states=latents,
919
+ timestep=timestep / 1000,
920
+ guidance=guidance,
921
+ pooled_projections=negative_pooled_prompt_embeds,
922
+ encoder_hidden_states=negative_prompt_embeds,
923
+ txt_ids=text_ids,
924
+ img_ids=latent_image_ids,
925
+ joint_attention_kwargs=self.joint_attention_kwargs,
926
+ return_dict=False,
927
+ )[0]
928
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
929
+
930
+ # compute the previous noisy sample x_t -> x_t-1
931
+ latents_dtype = latents.dtype
932
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
933
+
934
+ if tau_b:
935
+ if bg_consistency_mask is None:
936
+ raise ValueError("if tau_b is set, bg_consistency_mask must be provided for background consistency to work.")
937
+
938
+ assert latents.shape[0] >= 3, "Three processes are required for background consistency injection (being background, foreground and composed process)."
939
+ assert latents.shape[1] == bg_consistency_mask.shape[0], f"Latents and segmentation mask must have the same number of timesteps. Got {latents.shape[1]} and {bg_consistency_mask.shape[0]}."
940
+
941
+ bg_consistency_mask = bg_consistency_mask.to(device=latents.device, dtype=torch.int32)
942
+
943
+ # TF-ICON background consistency: if we're in the first tau_b part of the de-noising process,
944
+ # overwrite the latents of the composed image with those of the background process (only outside the segmentation mask)
945
+ if i <= tau_b * num_inference_steps:
946
+ latents[2, :, :] = latents[0, :, :] * (1 - bg_consistency_mask) + latents[2, :, :] * bg_consistency_mask
947
+
948
+ # NOTE: this was the added part for inversion
949
+ if is_inverted_generation:
950
+ inverted_latents_list.append(latents)
951
+ else:
952
+ if inverted_latents_list is not None:
953
+ if isinstance(inverted_latents_list[0], torch.Tensor):
954
+ latents[0] = inverted_latents_list[-i][0]
955
+ else:
956
+ assert isinstance(inverted_latents_list[0], tuple)
957
+ for j, tensor_tuple in enumerate(inverted_latents_list[-i]):
958
+ latents[j] = tensor_tuple
959
+
960
+
961
+
962
+ if latents.dtype != latents_dtype:
963
+ if torch.backends.mps.is_available():
964
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
965
+ latents = latents.to(latents_dtype)
966
+
967
+ if callback_on_step_end is not None:
968
+ callback_kwargs = {}
969
+ for k in callback_on_step_end_tensor_inputs:
970
+ callback_kwargs[k] = locals()[k]
971
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
972
+
973
+ latents = callback_outputs.pop("latents", latents)
974
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
975
+
976
+ # call the callback, if provided
977
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
978
+ progress_bar.update()
979
+
980
+ if XLA_AVAILABLE:
981
+ xm.mark_step()
982
+
983
+ if output_type == "latent":
984
+ image = latents
985
+
986
+ else:
987
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
988
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
989
+ image = self.vae.decode(latents, return_dict=False)[0]
990
+ image = self.image_processor.postprocess(image, output_type=output_type)
991
+
992
+ # Offload all models
993
+ self.maybe_free_model_hooks()
994
+
995
+ if not return_dict:
996
+ return (image,)
997
+
998
+ return FluxPipelineOutput(images=image)
SDLens/cache_and_edit/hooks.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Literal
2
+ import torch
3
+ import torch.nn as nn
4
+ from diffusers.models.transformers.transformer_flux import FluxTransformerBlock, FluxSingleTransformerBlock
5
+
6
+
7
+ def register_general_hook(pipe, position, hook, with_kwargs=False, is_pre_hook=False):
8
+ """Registers a forward hook in a module of the pipeline specified with 'position'
9
+
10
+ Args:
11
+ pipe (_type_): _description_
12
+ position (_type_): _description_
13
+ hook (_type_): _description_
14
+ with_kwargs (bool, optional): _description_. Defaults to False.
15
+ is_pre_hook (bool, optional): _description_. Defaults to False.
16
+
17
+ Returns:
18
+ _type_: _description_
19
+ """
20
+
21
+ block: nn.Module = locate_block(pipe, position)
22
+
23
+ if is_pre_hook:
24
+ return block.register_forward_pre_hook(hook, with_kwargs=with_kwargs)
25
+ else:
26
+ return block.register_forward_hook(hook, with_kwargs=with_kwargs)
27
+
28
+
29
+ def locate_block(pipe, position: str) -> nn.Module:
30
+ '''
31
+ Locate the block at the specified position in the pipeline.
32
+ '''
33
+ block = pipe
34
+ for step in position.split('.'):
35
+ if step.isdigit():
36
+ step = int(step)
37
+ block = block[step]
38
+ else:
39
+ block = getattr(block, step)
40
+ return block
41
+
42
+
43
+ def _safe_clip(x: torch.Tensor):
44
+ if x.dtype == torch.float16:
45
+ x[torch.isposinf(x)] = 65504
46
+ x[torch.isneginf(x)] = -65504
47
+ return x
48
+
49
+
50
+ @torch.no_grad()
51
+ def fix_inf_values_hook(*args):
52
+
53
+ # Case 1: no kwards are passed to the module
54
+ if len(args) == 3:
55
+ module, input, output = args
56
+ # Case 2: when kwargs are passed to the model as input
57
+ elif len(args) == 4:
58
+ module, input, kwinput, output = args
59
+
60
+ if isinstance(module, FluxTransformerBlock):
61
+ return _safe_clip(output[0]), _safe_clip(output[1])
62
+
63
+ elif isinstance(module, FluxSingleTransformerBlock):
64
+ return _safe_clip(output)
65
+
66
+
67
+ @torch.no_grad()
68
+ def edit_streams_hook(*args,
69
+ recompute_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
70
+ stream: Literal["text", "image", "both"]):
71
+ """
72
+ recompute_fn will get as input the input tensor and the output tensor for such stream
73
+ and returns what should be the new modified output
74
+ """
75
+
76
+ # Case 1: no kwards are passed to the module
77
+ if len(args) == 3:
78
+ module, input, output = args
79
+ # Case 2: when kwargs are passed to the model as input
80
+ elif len(args) == 4:
81
+ module, input, kwinput, output = args
82
+ else:
83
+ raise AssertionError(f'Weird len(args):{len(args)}')
84
+
85
+ if isinstance(module, FluxTransformerBlock):
86
+
87
+ if stream == 'text':
88
+ output_text = recompute_fn(kwinput["encoder_hidden_states"], output[0])
89
+ output_image = output[1]
90
+ elif stream == 'image':
91
+ output_image = recompute_fn(kwinput["hidden_states"], output[1])
92
+ output_text = output[0]
93
+ else:
94
+ raise AssertionError("Branch not supported for this layer.")
95
+
96
+ return _safe_clip(output_text), _safe_clip(output_image)
97
+
98
+ elif isinstance(module, FluxSingleTransformerBlock):
99
+
100
+ if stream == 'text':
101
+ output[:, :512] = recompute_fn(kwinput["hidden_states"][:, :512], output[:, :512])
102
+ elif stream == 'image':
103
+ output[:, 512:] = recompute_fn(kwinput["hidden_states"][:, 512:], output[:, 512:])
104
+ else:
105
+ output = recompute_fn(kwinput["hidden_states"], output)
106
+
107
+ return _safe_clip(output)
108
+
SDLens/cache_and_edit/inversion.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ import torch
3
+ import torchvision.transforms.functional as TF
4
+ from PIL import Image
5
+ from cache_and_edit import CachedPipeline
6
+ import numpy as np
7
+ from IPython.display import display
8
+
9
+ from cache_and_edit.flux_pipeline import EditedFluxPipeline
10
+
11
+ def image2latent(pipe, image, latent_nudging_scalar = 1.15):
12
+ image = pipe.image_processor.preprocess(image).type(pipe.vae.dtype).to("cuda")
13
+ latents = pipe.vae.encode(image)["latent_dist"].mean
14
+ latents = (latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
15
+ latents = latents * latent_nudging_scalar
16
+
17
+ latents = pipe._pack_latents(
18
+ latents=latents,
19
+ batch_size=1,
20
+ num_channels_latents=16,
21
+ height=image.size(2) // 8,
22
+ width= image.size(3) // 8
23
+ )
24
+
25
+ return latents
26
+
27
+
28
+ def get_inverted_input_noise(pipe: CachedPipeline,
29
+ image,
30
+ prompt: str = "",
31
+ num_steps: int = 28,
32
+ latent_nudging_scalar: int = 1.15):
33
+ """_summary_
34
+
35
+ Args:
36
+ pipe (CachedPipeline): _description_
37
+ image (_type_): _description_
38
+ num_steps (int, optional): _description_. Defaults to 28.
39
+
40
+ Returns:
41
+ _type_: _description_
42
+ """
43
+
44
+ width, height = image.size
45
+ inverted_latents_list = []
46
+
47
+ if isinstance(pipe.pipe, EditedFluxPipeline):
48
+
49
+ _ = pipe.run(
50
+ prompt,
51
+ num_inference_steps=num_steps,
52
+ seed=42,
53
+ guidance_scale=1,
54
+ output_type="latent",
55
+ latents=image2latent(pipe.pipe, image, latent_nudging_scalar=latent_nudging_scalar),
56
+ empty_clip_embeddings=False,
57
+ inverse=True,
58
+ width=width,
59
+ height=height,
60
+ is_inverted_generation=True,
61
+ inverted_latents_list=inverted_latents_list
62
+ ).images[0]
63
+
64
+ return inverted_latents_list
65
+
66
+
67
+ else:
68
+ noise = pipe.run(
69
+ prompt,
70
+ num_inference_steps=num_steps,
71
+ seed=42,
72
+ guidance_scale=1,
73
+ output_type="latent",
74
+ latents=image2latent(pipe.pipe, image, latent_nudging_scalar=latent_nudging_scalar),
75
+ empty_clip_embeddings=False,
76
+ inverse=True,
77
+ width=width,
78
+ height=height
79
+ ).images[0]
80
+
81
+ return noise
82
+
83
+
84
+
85
+
86
+ def resize_bounding_box(
87
+ bb_mask: torch.Tensor,
88
+ target_size: Tuple[int, int] = (64, 64),
89
+ ) -> torch.Tensor:
90
+ """
91
+ Given a bounding box mask, patches it into a mask with the target size.
92
+ The mask is a 2D tensor of shape (H, W) where each element is either 0 or 1.
93
+ Any patch that contains at least one 1 in the original mask will be set to 1 in the output mask.
94
+
95
+ Args:
96
+ bb_mask (torch.Tensor): The bounding box mask as a boolean tensor of shape (H, W).
97
+ target_size (Tuple[int, int]): The size of the target mask as a tuple (H, W).
98
+
99
+ Returns:
100
+ torch.Tensor: The resized bounding box mask as a boolean tensor of shape (H, W).
101
+ """
102
+
103
+ w_mask, h_mask = bb_mask.shape[-2:]
104
+ w_target, h_target = target_size
105
+
106
+ # Make sure the sizes are compatible
107
+ if w_mask % w_target != 0 or h_mask % h_target != 0:
108
+ raise ValueError(
109
+ f"Mask size {bb_mask.shape[-2:]} is not compatible with target size {target_size}"
110
+ )
111
+
112
+ # Compute the size of a patch
113
+ patch_size = (w_mask // w_target, h_mask // h_target)
114
+
115
+ # Iterate over the mask, one patch at a time, and save a 0 patch if the patch is empty or a 1 patch if the patch is not empty
116
+ out_mask = torch.zeros((w_target, h_target), dtype=bb_mask.dtype, device=bb_mask.device)
117
+ for i in range(w_target):
118
+ for j in range(h_target):
119
+ patch = bb_mask[
120
+ i * patch_size[0] : (i + 1) * patch_size[0],
121
+ j * patch_size[1] : (j + 1) * patch_size[1],
122
+ ]
123
+ if torch.sum(patch) > 0:
124
+ out_mask[i, j] = 1
125
+ else:
126
+ out_mask[i, j] = 0
127
+
128
+ return out_mask
129
+
130
+
131
+ def place_image_in_bounding_box(
132
+ image_tensor_whc: torch.Tensor,
133
+ mask_tensor_wh: torch.Tensor
134
+ ) -> tuple[torch.Tensor, torch.Tensor]:
135
+ """
136
+ Resizes an input image to fit within a bounding box (from a mask)
137
+ preserving aspect ratio, and places it centered on a new canvas.
138
+
139
+ Args:
140
+ image_tensor_whc: Input image tensor, shape [width, height, channels].
141
+ mask_tensor_wh: Bounding box mask, shape [width, height]. Defines canvas size
142
+ and contains a rectangle of 1s for the BB.
143
+
144
+ Returns:
145
+ A tuple:
146
+ - output_image_whc (torch.Tensor): Canvas with the resized image placed.
147
+ Shape [canvas_width, canvas_height, channels].
148
+ - new_mask_wh (torch.Tensor): Mask showing the actual placement of the image.
149
+ Shape [canvas_width, canvas_height].
150
+ """
151
+
152
+ # Validate input image dimensions
153
+ if not (image_tensor_whc.ndim == 3 and image_tensor_whc.shape[0] > 0 and image_tensor_whc.shape[1] > 0):
154
+ raise ValueError(
155
+ "Input image_tensor_whc must be a 3D tensor [width, height, channels] "
156
+ "with width > 0 and height > 0."
157
+ )
158
+ img_orig_w, img_orig_h, num_channels = image_tensor_whc.shape
159
+
160
+ # Validate mask tensor dimensions
161
+ if not (mask_tensor_wh.ndim == 2):
162
+ raise ValueError("Input mask_tensor_wh must be a 2D tensor [width, height].")
163
+ canvas_w, canvas_h = mask_tensor_wh.shape
164
+
165
+ # Prepare default empty outputs for early exit scenarios
166
+ empty_output_image = torch.zeros(
167
+ canvas_w, canvas_h, num_channels,
168
+ dtype=image_tensor_whc.dtype, device=image_tensor_whc.device
169
+ )
170
+ empty_new_mask = torch.zeros(
171
+ canvas_w, canvas_h,
172
+ dtype=mask_tensor_wh.dtype, device=mask_tensor_wh.device
173
+ )
174
+
175
+ # 1. Find Bounding Box (BB) coordinates from the input mask_tensor_wh
176
+ # fg_coords shape: [N, 2], where N is num_nonzero. Each row: [x_coord, y_coord].
177
+ fg_coords = torch.nonzero(mask_tensor_wh, as_tuple=False)
178
+
179
+ if fg_coords.numel() == 0: # No bounding box found in mask
180
+ return empty_output_image, empty_new_mask
181
+
182
+ # Determine min/max extents of the bounding box
183
+ x_min_bb, y_min_bb = fg_coords[:, 0].min(), fg_coords[:, 1].min()
184
+ x_max_bb, y_max_bb = fg_coords[:, 0].max(), fg_coords[:, 1].max()
185
+
186
+ bb_target_w = x_max_bb - x_min_bb + 1
187
+ bb_target_h = y_max_bb - y_min_bb + 1
188
+
189
+ if bb_target_w <= 0 or bb_target_h <= 0: # Should not happen if fg_coords not empty
190
+ return empty_output_image, empty_new_mask
191
+
192
+ # 2. Prepare image for resizing: TF.resize expects [C, H, W]
193
+ # Input image_tensor_whc is [W, H, C]. Permute to [C, H_orig, W_orig].
194
+ image_tensor_chw = image_tensor_whc.permute(2, 1, 0)
195
+
196
+ # 3. Calculate new dimensions for the image to fit in BB, preserving aspect ratio
197
+ scale_factor_w = bb_target_w / img_orig_w
198
+ scale_factor_h = bb_target_h / img_orig_h
199
+ scale = min(scale_factor_w, scale_factor_h) # Fit entirely within BB
200
+
201
+ resized_img_w = int(img_orig_w * scale)
202
+ resized_img_h = int(img_orig_h * scale)
203
+
204
+ if resized_img_w == 0 or resized_img_h == 0: # Image scaled to nothing
205
+ return empty_output_image, empty_new_mask
206
+
207
+ # 4. Resize the image. TF.resize expects size as [H, W].
208
+ try:
209
+ # antialias=True for better quality (requires torchvision >= 0.8.0 approx)
210
+ resized_image_chw = TF.resize(image_tensor_chw, [resized_img_h, resized_img_w], antialias=True)
211
+ except TypeError: # Fallback for older torchvision versions
212
+ resized_image_chw = TF.resize(image_tensor_chw, [resized_img_h, resized_img_w])
213
+
214
+ # Permute resized image back to [W, H, C] format
215
+ resized_image_whc = resized_image_chw.permute(2, 1, 0)
216
+
217
+ # 5. Create the output canvas image (initialized to zeros)
218
+ output_image_whc = torch.zeros(
219
+ canvas_w, canvas_h, num_channels,
220
+ dtype=image_tensor_whc.dtype, device=image_tensor_whc.device
221
+ )
222
+
223
+ # 6. Calculate pasting coordinates to center the resized image within the original BB
224
+ offset_x = (bb_target_w - resized_img_w) // 2
225
+ offset_y = (bb_target_h - resized_img_h) // 2
226
+
227
+ paste_x_start = x_min_bb + offset_x
228
+ paste_y_start = y_min_bb + offset_y
229
+
230
+ paste_x_end = paste_x_start + resized_img_w
231
+ paste_y_end = paste_y_start + resized_img_h
232
+
233
+ # Place the resized image onto the canvas
234
+ output_image_whc[paste_x_start:paste_x_end, paste_y_start:paste_y_end, :] = resized_image_whc
235
+
236
+ # 7. Create the new mask representing where the image was actually placed
237
+ new_mask_wh = torch.zeros(
238
+ canvas_w, canvas_h,
239
+ dtype=mask_tensor_wh.dtype, device=mask_tensor_wh.device
240
+ )
241
+ new_mask_wh[paste_x_start:paste_x_end, paste_y_start:paste_y_end] = 1
242
+
243
+ return output_image_whc, new_mask_wh
244
+
245
+
246
+
247
+ ### Function to cut image and put it in bounding box (either cut or not cut)
248
+ def compose_noise_masks(cached_pipe,
249
+ foreground_image: Image,
250
+ background_image: Image,
251
+ target_mask: torch.Tensor,
252
+ foreground_mask: torch.Tensor,
253
+ option: str = "bg", # bg, bg_fg, segmentation1, tf_icon
254
+ photoshop_fg_noise: bool = False,
255
+ num_inversion_steps: int = 100,
256
+ ):
257
+
258
+ """
259
+ Composes noise masks for image generation using different strategies.
260
+ This function composes noise masks for stable diffusion inversion, with several composition strategies:
261
+ - "bg": Uses only background noise
262
+ - "bg_fg": Combines background and foreground noise using a target mask
263
+ - "segmentation1": Uses segmentation mask to compose foreground and background noise
264
+ - "segmentation2": Implements advanced composition with additional boundary noise
265
+ Parameters:
266
+ ----------
267
+ cached_pipe : object
268
+ The cached stable diffusion pipeline used for noise inversion
269
+ foreground_image : PIL.Image
270
+ The foreground image to be placed in the background
271
+ background_image : PIL.Image
272
+ The background image
273
+ target_mask : torch.Tensor
274
+ Target mask indicating the position where the foreground should be placed
275
+ foreground_mask : torch.Tensor
276
+ Segmentation mask of the foreground object
277
+ option : str, default="bg"
278
+ Composition strategy: "bg", "bg_fg", "segmentation1", or "segmentation2"
279
+ photoshop_fg_noise : bool, default=False
280
+ Whether to generate noise from a photoshopped composition of foreground and background
281
+ num_inversion_steps : int, default=100
282
+ Number of steps for the inversion process
283
+ Returns:
284
+ -------
285
+ dict
286
+ A dictionary containing:
287
+ - "noise": Dictionary of generated noises (composed_noise, foreground_noise, background_noise)
288
+ - "latent_masks": Dictionary of latent masks used for composition
289
+ """
290
+
291
+ # assert options
292
+ assert option in ["bg", "bg_fg", "segmentation1", "segmentation2"], f"Invalid option: {option}"
293
+
294
+ # calculate size of latent noise for mask resizing
295
+ PATCH_SIZE = 16
296
+ latent_size = background_image.size[0] // PATCH_SIZE
297
+ latents = (latent_size, latent_size)
298
+
299
+ # process the options
300
+ if option == "bg":
301
+ # only background noise
302
+ bg_noise = get_inverted_input_noise(cached_pipe, background_image, num_steps=num_inversion_steps)
303
+ composed_noise = bg_noise
304
+
305
+ all_noise = {
306
+ "composed_noise": composed_noise,
307
+ "background_noise": bg_noise,
308
+ }
309
+ all_latent_masks = {}
310
+
311
+
312
+ elif option == "bg_fg":
313
+
314
+ # resize and scale the image to the bounding box
315
+ reframed_fg_img, resized_mask = place_image_in_bounding_box(
316
+ torch.from_numpy(np.array(foreground_image)),
317
+ (torch.from_numpy(np.array(target_mask)) / 255.0).to(dtype=bool)
318
+ )
319
+
320
+ #print("Placed Foreground Image")
321
+ reframed_fg_img = Image.fromarray(reframed_fg_img.numpy())
322
+ #display(reframed_fg_img)
323
+
324
+ #print("Placed Mask")
325
+ resized_mask_img = Image.fromarray((resized_mask.numpy() * 255).astype(np.uint8))
326
+ #display(resized_mask_img)
327
+
328
+ # invert resized & padded image
329
+ if photoshop_fg_noise:
330
+ #print("Photoshopping FG IMAGE")
331
+ photoshop_img = Image.fromarray(
332
+ (torch.tensor(np.array(background_image)) * ~resized_mask.cpu().unsqueeze(-1) + torch.tensor(np.array(reframed_fg_img)) * resized_mask.cpu().unsqueeze(-1)).numpy()
333
+ )
334
+ #display(photoshop_img)
335
+ fg_noise = get_inverted_input_noise(cached_pipe, photoshop_img, num_steps=num_inversion_steps)
336
+ else:
337
+ fg_noise = get_inverted_input_noise(cached_pipe, reframed_fg_img, num_steps=num_inversion_steps)
338
+ bg_noise = get_inverted_input_noise(cached_pipe, background_image, num_steps=num_inversion_steps)
339
+
340
+ # overwrite get masked in latent space
341
+ latent_mask = resize_bounding_box(
342
+ resized_mask,
343
+ target_size=latents,
344
+ ).flatten().unsqueeze(-1).to("cuda")
345
+
346
+ # compose the noise
347
+ composed_noise = bg_noise * (~latent_mask) + fg_noise * latent_mask
348
+ all_latent_masks = {
349
+ "latent_mask": latent_mask,
350
+ }
351
+ all_noise = {
352
+ "composed_noise": composed_noise,
353
+ "foreground_noise": fg_noise,
354
+ "background_noise": bg_noise,
355
+ }
356
+
357
+ elif option == "segmentation1":
358
+ # cut out the object and compose it with the background noise
359
+
360
+ # segmented foreground image
361
+ segmented_fg_image = torch.tensor(
362
+ np.array(
363
+ foreground_mask.resize(foreground_image.size)
364
+ )).to(torch.bool).unsqueeze(-1) * torch.tensor(
365
+ np.array(foreground_image)
366
+ )
367
+
368
+ # resize and scale the image to the bounding box
369
+ reframed_fg_img, resized_mask = place_image_in_bounding_box(
370
+ segmented_fg_image,
371
+ (torch.from_numpy(np.array(target_mask)) / 255.0).to(dtype=bool)
372
+ )
373
+
374
+ reframed_fg_img = Image.fromarray(reframed_fg_img.numpy())
375
+ #display(reframed_fg_img)
376
+
377
+ resized_mask_img = Image.fromarray((resized_mask.numpy() * 255).astype(np.uint8))
378
+
379
+ # resize and scale the mask itself
380
+ foreground_mask = foreground_mask.convert("RGB") # to avoid extraction of contours and make work with function
381
+ reframed_segmentation_mask, resized_mask = place_image_in_bounding_box(
382
+ torch.from_numpy(np.array(foreground_mask)),
383
+ (torch.from_numpy(np.array(target_mask)) / 255.0).to(dtype=bool)
384
+ )
385
+
386
+ reframed_segmentation_mask = reframed_segmentation_mask.numpy()
387
+ reframed_segmentation_mask_img = Image.fromarray(reframed_segmentation_mask)
388
+ #print("Placed Segmentation Mask")
389
+ #display(reframed_segmentation_mask_img)
390
+
391
+ # invert resized & padded image
392
+ # fg_noise = get_inverted_input_noise(cached_pipe, reframed_fg_img, num_steps=num_inversion_steps)
393
+
394
+ if photoshop_fg_noise:
395
+ # temporarily convert to apply mask
396
+ #print("Photoshopping FG IMAGE")
397
+ seg_mask_temp = torch.from_numpy(reframed_segmentation_mask).bool()
398
+ bg_temp = torch.tensor(np.array(background_image))
399
+ fg_temp = torch.tensor(np.array(reframed_fg_img))
400
+
401
+ photoshop_img = Image.fromarray(
402
+ (bg_temp * (~seg_mask_temp) + fg_temp * seg_mask_temp).numpy()
403
+ ).convert("RGB")
404
+ #display(photoshop_img)
405
+ fg_noise = get_inverted_input_noise(cached_pipe, photoshop_img, num_steps=num_inversion_steps)
406
+ else:
407
+ fg_noise = get_inverted_input_noise(cached_pipe, reframed_fg_img, num_steps=num_inversion_steps)
408
+
409
+
410
+ bg_noise = get_inverted_input_noise(cached_pipe, background_image, num_steps=num_inversion_steps)
411
+ bg_noise_init = bg_noise[-1].squeeze(0) if isinstance(bg_noise, list) else bg_noise
412
+ fg_noise_init = fg_noise[-1].squeeze(0) if isinstance(fg_noise, list) else fg_noise
413
+
414
+ # overwrite background in resized mask
415
+ # convert mask from 512x512x3 to 512x512 first
416
+ reframed_segmentation_mask = reframed_segmentation_mask[:, :, 0]
417
+ reframed_segmentation_mask = torch.from_numpy(reframed_segmentation_mask).to(dtype=bool)
418
+ latent_mask = resize_bounding_box(
419
+ reframed_segmentation_mask,
420
+ target_size=latents,
421
+ ).flatten().unsqueeze(-1).to("cuda")
422
+ bb_mask = resize_bounding_box(
423
+ resized_mask,
424
+ target_size=latents,
425
+ ).flatten().unsqueeze(-1).to("cuda")
426
+
427
+ # compose noise
428
+ composed_noise = bg_noise_init * (~latent_mask) + fg_noise_init * latent_mask
429
+
430
+ all_latent_masks = {
431
+ "latent_segmentation_mask": latent_mask,
432
+ # FIXME: handle bounding box better (making sure shapes are correct, especially when bg and fg images have different sizes, e.g. test image 69)
433
+ "bb_mask": bb_mask,
434
+ }
435
+ all_noise = {
436
+ "composed_noise": composed_noise,
437
+ "foreground_noise": fg_noise_init,
438
+ "background_noise": bg_noise_init,
439
+ "foreground_noise_list": fg_noise if isinstance(fg_noise, list) else None,
440
+ "background_noise_list": bg_noise if isinstance(bg_noise, list) else None,
441
+ }
442
+
443
+
444
+ elif option == "segmentation2":
445
+ # add random noise in the background
446
+
447
+ # segmented foreground image
448
+ segmented_fg_image = torch.tensor(
449
+ np.array(
450
+ foreground_mask.resize(foreground_image.size)
451
+ )).to(torch.bool).unsqueeze(-1) * torch.tensor(
452
+ np.array(foreground_image)
453
+ )
454
+
455
+ # resize and scale the image to the bounding box
456
+ reframed_fg_img, resized_mask = place_image_in_bounding_box(
457
+ segmented_fg_image,
458
+ (torch.from_numpy(np.array(target_mask)) / 255.0).to(dtype=bool)
459
+ )
460
+
461
+ #print("Segmented and Placed FG Image")
462
+ reframed_fg_img = Image.fromarray(reframed_fg_img.numpy())
463
+ #display(reframed_fg_img)
464
+
465
+ # resize and scale the mask itself
466
+ foreground_mask = foreground_mask.convert("RGB")
467
+ reframed_segmentation_mask, resized_mask = place_image_in_bounding_box(
468
+ torch.from_numpy(np.array(foreground_mask)),
469
+ (torch.from_numpy(np.array(target_mask)) / 255.0).to(dtype=bool)
470
+ )
471
+
472
+ reframed_segmentation_mask = reframed_segmentation_mask.numpy()
473
+ reframed_segmentation_mask_img = Image.fromarray(reframed_segmentation_mask)
474
+ #print("Reframed Segmentation Mask")
475
+ #display(reframed_segmentation_mask_img)
476
+
477
+ xor_mask = target_mask ^ np.array(reframed_segmentation_mask_img.convert("L"))
478
+ #print("XOR Mask")
479
+ #display(Image.fromarray(xor_mask))
480
+
481
+ # invert resized & padded image
482
+ # fg_noise = get_inverted_input_noise(cached_pipe, reframed_fg_img, num_steps=num_inversion_steps)
483
+ if photoshop_fg_noise:
484
+ #print("Photoshopping FG IMAGE")
485
+ # temporarily convert to apply mask
486
+ seg_mask_temp = torch.from_numpy(reframed_segmentation_mask).bool()
487
+ bg_temp = torch.tensor(np.array(background_image))
488
+ fg_temp = torch.tensor(np.array(reframed_fg_img))
489
+
490
+ photoshop_img = Image.fromarray(
491
+ (bg_temp * (~seg_mask_temp) + fg_temp * seg_mask_temp).numpy()
492
+ ).convert("RGB")
493
+ #display(photoshop_img)
494
+ fg_noise = get_inverted_input_noise(cached_pipe, photoshop_img, num_steps=num_inversion_steps)
495
+ else:
496
+ fg_noise = get_inverted_input_noise(cached_pipe, reframed_fg_img, num_steps=num_inversion_steps)
497
+ bg_noise = get_inverted_input_noise(cached_pipe, background_image, num_steps=num_inversion_steps)
498
+
499
+ # overwrite background in resized mask
500
+ # convert mask from 512x512x3 to 512x512
501
+ reframed_segmentation_mask = reframed_segmentation_mask[:, :, 0]
502
+ reframed_segmentation_mask = torch.from_numpy(reframed_segmentation_mask).to(dtype=bool)
503
+
504
+ # get all masks in latents and move to device
505
+ latent_seg_mask = resize_bounding_box(
506
+ reframed_segmentation_mask,
507
+ target_size=latents,
508
+ ).flatten().unsqueeze(-1).to("cuda")
509
+ print(latent_seg_mask.shape)
510
+
511
+
512
+ latent_xor_mask = resize_bounding_box(
513
+ torch.from_numpy(xor_mask),
514
+ target_size=latents,
515
+ ).flatten().unsqueeze(-1).to("cuda")
516
+
517
+
518
+ print(resized_mask.shape)
519
+ latent_target_mask = resize_bounding_box(
520
+ resized_mask,
521
+ target_size=latents,
522
+ ).flatten().unsqueeze(-1).to("cuda")
523
+
524
+ # implement x∗T = xrT ⊙Mseg +xmT ⊙(1−Muser)+z⊙(Muser ⊕Mseg)
525
+ bg_noise_init = bg_noise[-1].squeeze(0) if isinstance(bg_noise, list) else bg_noise
526
+ fg_noise_init = fg_noise[-1].squeeze(0) if isinstance(fg_noise, list) else fg_noise
527
+
528
+ bg = bg_noise_init[-1] * (~latent_target_mask)
529
+ fg = fg_noise_init[-1] * latent_seg_mask
530
+ boundary = latent_xor_mask * torch.randn(latent_xor_mask.shape).to("cuda")
531
+ composed_noise = bg + fg + boundary
532
+
533
+ all_latent_masks = {
534
+ "latent_target_mask": latent_target_mask,
535
+ "latent_segmentation_mask": latent_seg_mask,
536
+ "latent_xor_mask": latent_xor_mask,
537
+ }
538
+ all_noise = {
539
+ "composed_noise": composed_noise,
540
+ "foreground_noise": fg_noise_init,
541
+ "background_noise": bg_noise_init,
542
+ "foreground_noise_list": fg_noise if isinstance(fg_noise, list) else None,
543
+ "background_noise_list": bg_noise if isinstance(bg_noise, list) else None,
544
+ }
545
+
546
+ # always add latent bbox mask (for bg consistency or any other future application)
547
+ latent_bbox_mask = resize_bounding_box(
548
+ torch.from_numpy(np.array(target_mask.resize(background_image.size))), # reseize just to be sure
549
+ target_size=latents,
550
+ ).flatten().unsqueeze(-1).to("cuda")
551
+ all_latent_masks["latent_bbox_mask"] = latent_bbox_mask
552
+
553
+ # always add latent segmentation mkas
554
+ reframed_fg_img, resized_mask = place_image_in_bounding_box(
555
+ torch.from_numpy(np.array(foreground_image)),
556
+ (torch.from_numpy(np.array(target_mask)) / 255.0).to(dtype=bool)
557
+ )
558
+ bb_mask = resize_bounding_box(
559
+ resized_mask,
560
+ target_size=latents,
561
+ ).flatten().unsqueeze(-1).to("cuda")
562
+ all_latent_masks["latent_segmentation_mask"] = bb_mask
563
+
564
+ # output
565
+ return {
566
+ "noise": all_noise,
567
+ "latent_masks": all_latent_masks,
568
+ }
SDLens/cache_and_edit/metrics.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ from typing import Union
4
+ import torch
5
+ from transformers import CLIPProcessor, CLIPModel
6
+ import torch.nn.functional as F
7
+ from transformers import AutoModel, AutoImageProcessor
8
+
9
+
10
+
11
+ def masked_mse_tiled_mask(
12
+ image1_pil: Image.Image,
13
+ image2_pil: Image.Image,
14
+ tile_mask: Union[np.ndarray, torch.Tensor],
15
+ tile_size: int = 16
16
+ ) -> float:
17
+ # Convert images to float32 numpy arrays, normalized [0, 1]
18
+ img1 = np.asarray(image1_pil).astype(np.float32) / 255.0
19
+ img2 = np.asarray(image2_pil).astype(np.float32) / 255.0
20
+
21
+ # Convert mask to numpy if it's a torch tensor
22
+ if isinstance(tile_mask, torch.Tensor):
23
+ tile_mask = tile_mask.detach().cpu().numpy()
24
+
25
+ tile_mask = tile_mask.astype(np.float32)
26
+
27
+ # Upsample mask using np.kron to match image resolution
28
+ upsampled_mask = np.expand_dims(np.kron(tile_mask, np.ones((tile_size, tile_size), dtype=np.float32)), axis=-1)
29
+
30
+ # Invert mask: 1 = exclude → 0; 0 = include → 1
31
+ include_mask = 1.0 - upsampled_mask
32
+
33
+ # Compute squared difference
34
+ diff_squared = (img1 - img2) ** 2
35
+ masked_diff = diff_squared * include_mask
36
+
37
+ # Sum and normalize by valid (included) pixels
38
+ valid_pixel_count = np.sum(include_mask)
39
+ if valid_pixel_count == 0:
40
+ raise ValueError("All pixels are masked out. Cannot compute MSE.")
41
+
42
+ mse = np.sum(masked_diff) / valid_pixel_count
43
+ return float(mse)
44
+
45
+
46
+ def compute_clip_similarity(image: Image.Image, prompt: str) -> float:
47
+ """
48
+ Compute CLIP similarity between a PIL image and a text prompt.
49
+ Loads CLIP model only once and caches it.
50
+
51
+ Args:
52
+ image (PIL.Image.Image): Input image.
53
+ prompt (str): Text prompt.
54
+
55
+ Returns:
56
+ float: Cosine similarity between image and text.
57
+ """
58
+ if not hasattr(compute_clip_similarity, "model"):
59
+ compute_clip_similarity.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
60
+ compute_clip_similarity.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
61
+ compute_clip_similarity.model.eval()
62
+
63
+ model = compute_clip_similarity.model
64
+ processor = compute_clip_similarity.processor
65
+
66
+ image = image.convert("RGB")
67
+ image_inputs = processor(images=image, return_tensors="pt")
68
+ text_inputs = processor(text=[prompt], return_tensors="pt")
69
+
70
+
71
+ with torch.no_grad():
72
+ image_features = model.get_image_features(**image_inputs)
73
+ text_features = model.get_text_features(**text_inputs)
74
+
75
+ image_features = F.normalize(image_features, p=2, dim=-1)
76
+ text_features = F.normalize(text_features, p=2, dim=-1)
77
+
78
+ similarity = (image_features @ text_features.T).item()
79
+
80
+ return similarity
81
+
82
+
83
+ def compute_dinov2_similarity(image1: Image.Image, image2: Image.Image) -> float:
84
+ """
85
+ Compute perceptual similarity between two images using DINOv2 embeddings.
86
+
87
+ Args:
88
+ image1 (PIL.Image.Image): First image.
89
+ image2 (PIL.Image.Image): Second image.
90
+
91
+ Returns:
92
+ float: Cosine similarity between DINOv2 embeddings of the images.
93
+ """
94
+ # Load model and processor only once
95
+ if not hasattr(compute_dinov2_similarity, "model"):
96
+ compute_dinov2_similarity.processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
97
+ compute_dinov2_similarity.model = AutoModel.from_pretrained("facebook/dinov2-base")
98
+ compute_dinov2_similarity.model.eval()
99
+
100
+ processor = compute_dinov2_similarity.processor
101
+ model = compute_dinov2_similarity.model
102
+
103
+ # Preprocess both images
104
+ inputs = processor(images=[image1.convert("RGB"), image2.convert("RGB")], return_tensors="pt")
105
+
106
+ with torch.no_grad():
107
+ outputs = model(**inputs)
108
+ features = outputs.last_hidden_state.mean(dim=1) # [CLS] or mean-pooled features
109
+
110
+ # Normalize
111
+ features = F.normalize(features, p=2, dim=-1)
112
+
113
+ # Cosine similarity
114
+ similarity = (features[0] @ features[1].T).item()
115
+
116
+ return similarity
SDLens/cache_and_edit/qkv_cache.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Add parent directory to sys.path
2
+ from collections import defaultdict
3
+ import gc
4
+ import os, sys
5
+ from pathlib import Path
6
+
7
+ from SDLens.cache_and_edit.flux_pipeline import EditedFluxPipeline
8
+ parent_dir = Path.cwd().parent.resolve()
9
+ if str(parent_dir) not in sys.path:
10
+ sys.path.insert(0, str(parent_dir))
11
+
12
+ from typing import Dict, List, Literal, Optional, TypedDict, Type, Union
13
+ import torch
14
+ from diffusers.models.attention_processor import Attention
15
+ from diffusers.models.transformers import FluxTransformer2DModel
16
+ from diffusers import FluxPipeline
17
+ from diffusers.models.embeddings import apply_rotary_emb
18
+ from SDLens.cache_and_edit.hooks import locate_block
19
+ import torch.nn.functional as F
20
+ from diffusers.models.attention_processor import FluxAttnProcessor2_0
21
+
22
+ class QKVCache(TypedDict):
23
+ query: List[torch.Tensor]
24
+ key: List[torch.Tensor]
25
+ value: List[torch.Tensor]
26
+
27
+
28
+ class CachedFluxAttnProcessor2_0:
29
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
30
+
31
+ def __init__(self, external_cache: QKVCache,
32
+ inject_kv: Literal["image", "text", "both"]= None,
33
+ text_seq_length: int = 512):
34
+ """Constructor for Cached attention processor.
35
+
36
+ Args:
37
+ external_cache (QKVCache): cache to store/inject values.
38
+ inject_kv (Literal[&quot;image&quot;, &quot;text&quot;, &quot;both&quot;], optional): whether to inject image, text or both streams KV.
39
+ If None, it does not perform injection but the full cache is stored. Defaults to None.
40
+ """
41
+
42
+ if not hasattr(F, "scaled_dot_product_attention"):
43
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
44
+ self.cache = external_cache
45
+ self.inject_kv = inject_kv
46
+ self.text_seq_length = text_seq_length
47
+ assert all((cache_key in external_cache) for cache_key in {"query", "key", "value"}), "Cache has to contain 'query', 'key' and 'value' keys."
48
+
49
+ def __call__(
50
+ self,
51
+ attn: Attention,
52
+ hidden_states: torch.FloatTensor,
53
+ encoder_hidden_states: torch.FloatTensor = None,
54
+ attention_mask: Optional[torch.FloatTensor] = None,
55
+ image_rotary_emb: Optional[torch.Tensor] = None,
56
+ ) -> torch.FloatTensor:
57
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
58
+
59
+ # `sample` projections.
60
+ query = attn.to_q(hidden_states)
61
+ key = attn.to_k(hidden_states)
62
+ value = attn.to_v(hidden_states)
63
+
64
+ inner_dim = key.shape[-1]
65
+ head_dim = inner_dim // attn.heads
66
+
67
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
68
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
69
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
70
+
71
+ if attn.norm_q is not None:
72
+ query = attn.norm_q(query)
73
+ if attn.norm_k is not None:
74
+ key = attn.norm_k(key)
75
+
76
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
77
+ if encoder_hidden_states is not None:
78
+ # `context` projections.
79
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
80
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
81
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
82
+
83
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
84
+ batch_size, -1, attn.heads, head_dim
85
+ ).transpose(1, 2)
86
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
87
+ batch_size, -1, attn.heads, head_dim
88
+ ).transpose(1, 2)
89
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
90
+ batch_size, -1, attn.heads, head_dim
91
+ ).transpose(1, 2)
92
+
93
+ if attn.norm_added_q is not None:
94
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
95
+ if attn.norm_added_k is not None:
96
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
97
+
98
+ # attention
99
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
100
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
101
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
102
+
103
+ if image_rotary_emb is not None:
104
+ query = apply_rotary_emb(query, image_rotary_emb)
105
+ key = apply_rotary_emb(key, image_rotary_emb)
106
+
107
+ # Cache Q, K, V
108
+ if self.inject_kv == "image":
109
+ # NOTE: I am replacing key and values only for the image branch
110
+ # NOTE: in default settings, encoder_hidden_states_key_proh.shape[2] == 512
111
+ # the first element of the batch is the image whose key and value will be injected into all the other images
112
+ key[1:, :, self.text_seq_length:] = key[:1, :, self.text_seq_length:]
113
+ value[1:, :, self.text_seq_length:] = value[:1, :, self.text_seq_length:]
114
+ elif self.inject_kv == "text":
115
+ key[1:, :, :self.text_seq_length] = key[:1, :, :self.text_seq_length]
116
+ value[1:, :, :self.text_seq_length] = value[:1, :, :self.text_seq_length]
117
+ elif self.inject_kv == "both":
118
+ key[1:] = key[:1]
119
+ value[1:] = value[:1]
120
+ else: # Don't inject, store cache!
121
+ self.cache["query"].append(query)
122
+ self.cache["key"].append(key)
123
+ self.cache["value"].append(value)
124
+
125
+ hidden_states = F.scaled_dot_product_attention(
126
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
127
+ )
128
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
129
+ hidden_states = hidden_states.to(query.dtype)
130
+
131
+
132
+ if encoder_hidden_states is not None:
133
+ encoder_hidden_states, hidden_states = (
134
+ hidden_states[:, : encoder_hidden_states.shape[1]],
135
+ hidden_states[:, encoder_hidden_states.shape[1] :],
136
+ )
137
+
138
+ # linear proj
139
+ hidden_states = attn.to_out[0](hidden_states)
140
+ # dropout
141
+ hidden_states = attn.to_out[1](hidden_states)
142
+
143
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
144
+
145
+ return hidden_states, encoder_hidden_states
146
+ else:
147
+ return hidden_states
148
+
149
+
150
+ class CachedFluxAttnProcessor3_0:
151
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
152
+
153
+ def __init__(self, external_cache: QKVCache,
154
+ inject_kv: Literal["image", "text", "both"]= None,
155
+ inject_kv_foreground: bool = False,
156
+ text_seq_length: int = 512,
157
+ q_mask: Optional[torch.Tensor] = None,):
158
+ """Constructor for Cached attention processor.
159
+
160
+ Args:
161
+ external_cache (QKVCache): cache to store/inject values.
162
+ inject_kv (Literal[&quot;image&quot;, &quot;text&quot;, &quot;both&quot;], optional): whether to inject image, text or both streams KV.
163
+ If None, it does not perform injection but the full cache is stored. Defaults to None.
164
+ """
165
+
166
+ if not hasattr(F, "scaled_dot_product_attention"):
167
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
168
+ self.cache = external_cache
169
+ self.inject_kv = inject_kv
170
+ self.inject_kv_foreground = inject_kv_foreground
171
+ self.text_seq_length = text_seq_length
172
+ self.q_mask = q_mask
173
+ assert all((cache_key in external_cache) for cache_key in {"query", "key", "value"}), "Cache has to contain 'query', 'key' and 'value' keys."
174
+
175
+ def __call__(
176
+ self,
177
+ attn: Attention,
178
+ hidden_states: torch.FloatTensor,
179
+ encoder_hidden_states: torch.FloatTensor = None,
180
+ attention_mask: Optional[torch.FloatTensor] = None,
181
+ image_rotary_emb: Optional[torch.Tensor] = None,
182
+ ) -> torch.FloatTensor:
183
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
184
+
185
+ # `sample` projections.
186
+ query = attn.to_q(hidden_states)
187
+ key = attn.to_k(hidden_states)
188
+ value = attn.to_v(hidden_states)
189
+
190
+ inner_dim = key.shape[-1]
191
+ head_dim = inner_dim // attn.heads
192
+
193
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
194
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
195
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
196
+
197
+ if attn.norm_q is not None:
198
+ query = attn.norm_q(query)
199
+ if attn.norm_k is not None:
200
+ key = attn.norm_k(key)
201
+
202
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
203
+ if encoder_hidden_states is not None:
204
+ # `context` projections.
205
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
206
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
207
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
208
+
209
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
210
+ batch_size, -1, attn.heads, head_dim
211
+ ).transpose(1, 2)
212
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
213
+ batch_size, -1, attn.heads, head_dim
214
+ ).transpose(1, 2)
215
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
216
+ batch_size, -1, attn.heads, head_dim
217
+ ).transpose(1, 2)
218
+
219
+ if attn.norm_added_q is not None:
220
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
221
+ if attn.norm_added_k is not None:
222
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
223
+
224
+ # attention
225
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
226
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
227
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
228
+
229
+
230
+ # # Cache Q, K, V
231
+ # if self.inject_kv == "image":
232
+ # # NOTE: I am replacing key and values only for the image branch
233
+ # # NOTE: in default settings, encoder_hidden_states_key_proh.shape[2] == 512
234
+ # # the first element of the batch is the image whose key and value will be injected into all the other images
235
+ # key[1:, :, self.text_seq_length:] = key[:1, :, self.text_seq_length:]
236
+ # value[1:, :, self.text_seq_length:] = value[:1, :, self.text_seq_length:]
237
+ # elif self.inject_kv == "text":
238
+ # key[1:, :, :self.text_seq_length] = key[:1, :, :self.text_seq_length]
239
+ # value[1:, :, :self.text_seq_length] = value[:1, :, :self.text_seq_length]
240
+ # elif self.inject_kv == "both":
241
+ # key[1:] = key[:1]
242
+ # value[1:] = value[:1]
243
+ # else: # Don't inject, store cache!
244
+ # self.cache["query"].append(query)
245
+ # self.cache["key"].append(key)
246
+ # self.cache["value"].append(value)
247
+
248
+ # extend the mask to match key and values dimension:
249
+ # Shape of mask is: (num_image_tokens, 1)
250
+ mask = self.q_mask.permute(1, 0).unsqueeze(0).unsqueeze(-1) # Shape: (1, num_image_tokens, 1, 1)
251
+ # put mask on gpu
252
+ mask = mask.to(key.device)
253
+ # first check that we inject only kv in images:
254
+ if self.inject_kv is not None and self.inject_kv != "image":
255
+ raise NotImplementedError("Injecting is implemented only for images.")
256
+ # the second element of the batch is the number of heads
257
+ # The first element of the batch represents the background image, the second element of the batch
258
+ # represents the foreground image. The third element represents the image where we want to inject
259
+ # the key and value of the background image and foreground image according to the query mask.
260
+ # Inject from background (element 0) where mask is True
261
+
262
+ if image_rotary_emb is not None:
263
+ query = apply_rotary_emb(query, image_rotary_emb)
264
+ key = apply_rotary_emb(key, image_rotary_emb)
265
+
266
+ # Get the index range after the text tokens
267
+ start_idx = self.text_seq_length
268
+
269
+ if self.inject_kv_foreground and self.inject_kv == "image":
270
+ key[2:, :, start_idx:] = torch.where(mask, key[1:2, :, start_idx:], key[:1, :, start_idx:])
271
+ value[2:, :, start_idx:] = torch.where(mask, value[1:2, :, start_idx:], value[:1, :, start_idx:])
272
+ elif self.inject_kv == "image" and not self.inject_kv_foreground:
273
+ key[2:, :, start_idx:] = torch.where(mask, key[2:, :, start_idx:], key[:1, :, start_idx:])
274
+ value[2:, :, start_idx:] = torch.where(mask, value[2:, :, start_idx:], value[:1, :, start_idx:])
275
+ elif self.inject_kv is None and self.inject_kv_foreground:
276
+ key[2:, :, start_idx:] = torch.where(mask, key[1:2, :, start_idx:], key[2:, :, start_idx:])
277
+ value[2:, :, start_idx:] = torch.where(mask, value[1:2, :, start_idx:], value[2:, :, start_idx:])
278
+
279
+ hidden_states = F.scaled_dot_product_attention(
280
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
281
+ )
282
+ # mask hidden states from bg:
283
+ # hidden_states = hidden_states_fg[:, :, start_idx:] * mask + hidden_states_bg[:, :, start_idx:] * (~mask)
284
+
285
+ # concatenate the text
286
+ #hidden_states = torch.cat([hidden_states_bg[:, :, :start_idx], hidden_states], dim=2)
287
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
288
+ hidden_states = hidden_states.to(query.dtype)
289
+
290
+
291
+ if encoder_hidden_states is not None:
292
+ encoder_hidden_states, hidden_states = (
293
+ hidden_states[:, : encoder_hidden_states.shape[1]],
294
+ hidden_states[:, encoder_hidden_states.shape[1] :],
295
+ )
296
+
297
+ # linear proj
298
+ hidden_states = attn.to_out[0](hidden_states)
299
+ # dropout
300
+ hidden_states = attn.to_out[1](hidden_states)
301
+
302
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
303
+
304
+ return hidden_states, encoder_hidden_states
305
+ else:
306
+ return hidden_states
307
+
308
+
309
+ class QKVCacheFluxHandler:
310
+ """Used to cache queries, keys and values of a FluxPipeline.
311
+ """
312
+
313
+ def __init__(self, pipe: Union[FluxPipeline, EditedFluxPipeline],
314
+ positions_to_cache: List[str] = None,
315
+ positions_to_cache_foreground: List[str] = None,
316
+ inject_kv: Literal["image", "text", "both"] = None,
317
+ text_seq_length: int = 512,
318
+ q_mask: Optional[torch.Tensor] = None,
319
+ processor_class: Optional[Type] = CachedFluxAttnProcessor3_0
320
+ ):
321
+
322
+ print(type(pipe))
323
+ if not isinstance(pipe, FluxPipeline) and not isinstance(pipe, EditedFluxPipeline):
324
+ raise NotImplementedError(f"QKVCache not yet implemented for {type(pipe)}.")
325
+
326
+ self.pipe = pipe
327
+
328
+ if positions_to_cache is not None:
329
+ self.positions_to_cache = positions_to_cache
330
+ else:
331
+ # act on all transformer layers
332
+ self.positions_to_cache = []
333
+
334
+ if positions_to_cache_foreground is not None:
335
+ self.positions_to_cache_foreground = positions_to_cache_foreground
336
+ else:
337
+ self.positions_to_cache_foreground = []
338
+
339
+ self._cache = {"query": [], "key": [], "value": []}
340
+
341
+ # Set Cached Processor to perform editing
342
+
343
+ all_layers = [f"transformer.transformer_blocks.{i}" for i in range(19)] + \
344
+ [f"transformer.single_transformer_blocks.{i}" for i in range(38)]
345
+ for module_name in all_layers:
346
+
347
+ inject_kv = "image" if module_name in self.positions_to_cache else None
348
+ inject_kv_foreground = module_name in self.positions_to_cache_foreground
349
+
350
+
351
+ module = locate_block(pipe, module_name)
352
+ module.attn.set_processor(processor_class(external_cache=self._cache,
353
+ inject_kv=inject_kv,
354
+ inject_kv_foreground=inject_kv_foreground,
355
+ text_seq_length=text_seq_length,
356
+ q_mask=q_mask,
357
+ ))
358
+
359
+
360
+ @property
361
+ def cache(self) -> QKVCache:
362
+ """Returns a dictionary initialized as {"query": [], "key": [], "value": []}.
363
+ After calling a forward pass for pipe, queries, keys and values will be
364
+ appended in the respective list for each layer.
365
+
366
+ Returns:
367
+ Dict[str, List[torch.Tensor]]: cache dictionary containing 'query', 'key' and 'value'
368
+ """
369
+ return self._cache
370
+
371
+ def clear_cache(self) -> None:
372
+ # TODO: check if we have to force clean GPU memory
373
+ del(self._cache)
374
+ gc.collect() # force Python to clean up unreachable objects
375
+ torch.cuda.empty_cache() # tell PyTorch to release unused GPU memory from its cache
376
+ self._cache = {"query": [], "key": [], "value": []}
377
+
378
+ for module_name in self.positions_to_cache:
379
+ module = locate_block(self.pipe, module_name)
380
+ module.attn.set_processor(FluxAttnProcessor2_0())
381
+
382
+
383
+ class TFICONAttnProcessor:
384
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
385
+
386
+ def __init__(self,
387
+ external_cache: QKVCache,
388
+ inject_kv: Literal["image", "text", "both"]= None,
389
+ inject_kv_foreground: bool = False,
390
+ text_seq_length: int = 512,
391
+ q_mask: Optional[torch.Tensor] = None,
392
+ call_max_times = None,
393
+ inject_q = True,
394
+ inject_k = True,
395
+ inject_v = True,
396
+ ):
397
+ """Constructor for Cached attention processor.
398
+
399
+ Args:
400
+ external_cache (QKVCache): cache to store/inject values.
401
+ inject_kv (Literal[&quot;image&quot;, &quot;text&quot;, &quot;both&quot;], optional): whether to inject image, text or both streams KV.
402
+ If None, it does not perform injection but the full cache is stored. Defaults to None.
403
+ """
404
+
405
+ if not hasattr(F, "scaled_dot_product_attention"):
406
+ raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
407
+ self.cache = external_cache
408
+ self.inject_kv = inject_kv
409
+ self.inject_kv_foreground = inject_kv_foreground
410
+ self.text_seq_length = text_seq_length
411
+ self.q_mask = q_mask
412
+ self.inject_q = inject_q
413
+ self.inject_k = inject_k
414
+ self.inject_v = inject_v
415
+
416
+ self.call_max_times = call_max_times
417
+ if self.call_max_times is not None:
418
+ self.num_calls = call_max_times
419
+ else:
420
+ self.num_calls = None
421
+ assert all((cache_key in external_cache) for cache_key in {"query", "key", "value"}), "Cache has to contain 'query', 'key' and 'value' keys."
422
+
423
+ def __call__(
424
+ self,
425
+ attn: Attention,
426
+ hidden_states: torch.FloatTensor,
427
+ encoder_hidden_states: torch.FloatTensor = None,
428
+ attention_mask: Optional[torch.FloatTensor] = None,
429
+ image_rotary_emb: Optional[torch.Tensor] = None,
430
+ ) -> torch.FloatTensor:
431
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
432
+
433
+ # `sample` projections.
434
+ query = attn.to_q(hidden_states)
435
+ key = attn.to_k(hidden_states)
436
+ value = attn.to_v(hidden_states)
437
+
438
+ inner_dim = key.shape[-1]
439
+ head_dim = inner_dim // attn.heads
440
+
441
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
442
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
443
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
444
+
445
+ if attn.norm_q is not None:
446
+ query = attn.norm_q(query)
447
+ if attn.norm_k is not None:
448
+ key = attn.norm_k(key)
449
+
450
+ # hidden states are the image patches (B, 4096, hidden_dim)
451
+
452
+ # encoder_hidden_states are the text tokens (B, 512, hidden_dim)
453
+
454
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
455
+ if encoder_hidden_states is not None:
456
+ # `context` projections.
457
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
458
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
459
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
460
+
461
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
462
+ batch_size, -1, attn.heads, head_dim
463
+ ).transpose(1, 2)
464
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
465
+ batch_size, -1, attn.heads, head_dim
466
+ ).transpose(1, 2)
467
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
468
+ batch_size, -1, attn.heads, head_dim
469
+ ).transpose(1, 2)
470
+
471
+ if attn.norm_added_q is not None:
472
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
473
+ if attn.norm_added_k is not None:
474
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
475
+
476
+ # concat inputs for attention -> (B, num_heads, 512 + 4096, head_dim)
477
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
478
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
479
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
480
+
481
+ # TODO: try first witout mask
482
+ # Cache Q, K, V
483
+ # extend the mask to match key and values dimension:
484
+ # Shape of mask is: (num_image_tokens, 1)
485
+ mask = self.q_mask.permute(1, 0).unsqueeze(0).unsqueeze(-1) # Shape: (1, num_image_tokens, 1, 1)
486
+ # put mask on gpu
487
+ mask = mask.to(key.device)
488
+ # first check that we inject only kv in images:
489
+ if self.inject_kv is not None and self.inject_kv != "image":
490
+ raise NotImplementedError("Injecting is implemented only for images.")
491
+ # the second element of the batch is the number of heads
492
+ # The first element of the batch represents the background image, the second element of the batch
493
+ # represents the foreground image. The third element represents the image where we want to inject
494
+ # the key and value of the background image and foreground image according to the query mask.
495
+ # Inject from background (element 0) where mask is True
496
+
497
+ if image_rotary_emb is not None:
498
+ query = apply_rotary_emb(query, image_rotary_emb)
499
+ key = apply_rotary_emb(key, image_rotary_emb)
500
+
501
+ # Get the index range after the text tokens
502
+ start_idx = self.text_seq_length
503
+
504
+ # Batch is formed as follow:
505
+ # - background image (0)
506
+ # - foreground image (1)
507
+ # - composition(s) (2, 3, ...)
508
+ # Create the combined attention mask, by forming Q_comp and K_comp, taking the Q and K of the background image
509
+ # when outside of the mask, the one of the foreground image when inside the mask
510
+
511
+ if self.num_calls is None or self.num_calls > 0:
512
+ if self.inject_kv_foreground:
513
+ if self.inject_k:
514
+ key[2:, :, start_idx:] = torch.where(mask, key[1:2, :, start_idx:], key[0:1, :, start_idx:])
515
+ if self.inject_q:
516
+ query[2:, :, start_idx:] = torch.where(mask, query[1:2, :, start_idx:], query[0:1, :, start_idx:])
517
+ if self.inject_v:
518
+ value[2:, :, start_idx:] = torch.where(mask, value[1:2, :, start_idx:], value[0:1, :, start_idx:])
519
+ else:
520
+ if self.inject_k:
521
+ key[2:, :, start_idx:] = torch.where(mask, key[2:, :, start_idx:], key[0:1, :, start_idx:])
522
+ if self.inject_q:
523
+ query[2:, :, start_idx:] = torch.where(mask, query[2:, :, start_idx:], query[0:1, :, start_idx:])
524
+ if self.inject_v:
525
+ value[2:, :, start_idx:] = torch.where(mask, value[2:, :, start_idx:], value[0:1, :, start_idx:])
526
+
527
+ if self.num_calls is not None:
528
+ self.num_calls -= 1
529
+
530
+
531
+ # Use the combined attention map to compute attention using V from the composition image
532
+ hidden_states = F.scaled_dot_product_attention(
533
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
534
+ )
535
+
536
+ # hidden_states[2:, :, start_idx:] = torch.where(mask, weightage * hidden_states[1:2, :, start_idx:] + (1-weightage) * hidden_states[2:, :, start_idx:], hidden_states[2:, :, start_idx:])
537
+
538
+ # concatenate the text
539
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
540
+ hidden_states = hidden_states.to(query.dtype)
541
+
542
+ if encoder_hidden_states is not None:
543
+ encoder_hidden_states, hidden_states = (
544
+ hidden_states[:, : encoder_hidden_states.shape[1]],
545
+ hidden_states[:, encoder_hidden_states.shape[1] :],
546
+ )
547
+
548
+ # linear proj
549
+ hidden_states = attn.to_out[0](hidden_states)
550
+ # dropout
551
+ hidden_states = attn.to_out[1](hidden_states)
552
+
553
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
554
+
555
+ return hidden_states, encoder_hidden_states
556
+ else:
557
+ return hidden_states
SDLens/cache_and_edit/scheduler_inversion.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, Union
2
+ import torch
3
+ from diffusers.configuration_utils import register_to_config
4
+ from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteSchedulerOutput
5
+
6
+
7
+ class FlowMatchEulerDiscreteSchedulerForInversion(FlowMatchEulerDiscreteScheduler):
8
+
9
+ @register_to_config
10
+ def __init__(self, inverse: bool, **kwargs):
11
+ super().__init__(**kwargs)
12
+ self.inverse = inverse
13
+
14
+
15
+ def step(
16
+ self,
17
+ model_output: torch.FloatTensor,
18
+ timestep: Union[float, torch.FloatTensor],
19
+ sample: torch.FloatTensor,
20
+ s_churn: float = 0.0,
21
+ s_tmin: float = 0.0,
22
+ s_tmax: float = float("inf"),
23
+ s_noise: float = 1.0,
24
+ generator: Optional[torch.Generator] = None,
25
+ return_dict: bool = True,
26
+ ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
27
+ """
28
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
29
+ process from the learned model outputs (most often the predicted noise).
30
+
31
+ Args:
32
+ model_output (`torch.FloatTensor`):
33
+ The direct output from learned diffusion model.
34
+ timestep (`float`):
35
+ The current discrete timestep in the diffusion chain.
36
+ sample (`torch.FloatTensor`):
37
+ A current instance of a sample created by the diffusion process.
38
+ s_churn (`float`):
39
+ s_tmin (`float`):
40
+ s_tmax (`float`):
41
+ s_noise (`float`, defaults to 1.0):
42
+ Scaling factor for noise added to the sample.
43
+ generator (`torch.Generator`, *optional*):
44
+ A random number generator.
45
+ return_dict (`bool`):
46
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
47
+ tuple.
48
+
49
+ Returns:
50
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
51
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
52
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
53
+ """
54
+
55
+ if (
56
+ isinstance(timestep, int)
57
+ or isinstance(timestep, torch.IntTensor)
58
+ or isinstance(timestep, torch.LongTensor)
59
+ ):
60
+ raise ValueError(
61
+ (
62
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
63
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
64
+ " one of the `scheduler.timesteps` as a timestep."
65
+ ),
66
+ )
67
+
68
+ if self.step_index is None:
69
+ self._init_step_index(timestep)
70
+
71
+ # Upcast to avoid precision issues when computing prev_sample
72
+ sample = sample.to(torch.float32)
73
+
74
+ sigma = self.sigmas[self.step_index]
75
+ sigma_next = self.sigmas[self.step_index + 1]
76
+
77
+ if self.inverse:
78
+ next_sample = sample + (sigma - sigma_next) * model_output
79
+ # Cast sample back to model compatible dtype
80
+ next_sample = next_sample.to(model_output.dtype)
81
+ # upon completion increase step index by one
82
+ self._step_index -= 1
83
+
84
+ if not return_dict:
85
+ return (next_sample,)
86
+
87
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=next_sample)
88
+ else:
89
+ prev_sample = sample + (sigma_next - sigma) * model_output
90
+ # Cast sample back to model compatible dtype
91
+ prev_sample = prev_sample.to(model_output.dtype)
92
+ # upon completion increase step index by one
93
+ self._step_index += 1
94
+
95
+ if not return_dict:
96
+ return (prev_sample,)
97
+
98
+ return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
SDLens/hooked_scheduler.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DDPMScheduler
2
+ import torch
3
+
4
+ class HookedNoiseScheduler:
5
+ scheduler: DDPMScheduler
6
+ pre_hooks: list
7
+ post_hooks: list
8
+
9
+ def __init__(self, scheduler):
10
+ object.__setattr__(self, 'scheduler', scheduler)
11
+ object.__setattr__(self, 'pre_hooks', [])
12
+ object.__setattr__(self, 'post_hooks', [])
13
+
14
+ def step(
15
+ self,
16
+ model_output, timestep, sample, generator, return_dict
17
+ ):
18
+ assert return_dict == False, "return_dict == True is not implemented"
19
+ for hook in self.pre_hooks:
20
+ hook_output = hook(model_output, timestep, sample, generator)
21
+ if hook_output is not None:
22
+ model_output, timestep, sample, generator = hook_output
23
+
24
+ (pred_prev_sample, ) = self.scheduler.step(model_output, timestep, sample, generator, return_dict)
25
+
26
+ for hook in self.post_hooks:
27
+ hook_output = hook(pred_prev_sample)
28
+ if hook_output is not None:
29
+ pred_prev_sample = hook_output
30
+
31
+ return (pred_prev_sample, )
32
+
33
+ def __getattr__(self, name):
34
+ return getattr(self.scheduler, name)
35
+
36
+ def __setattr__(self, name, value):
37
+ if name in {'scheduler', 'pre_hooks', 'post_hooks'}:
38
+ object.__setattr__(self, name, value)
39
+ else:
40
+ setattr(self.scheduler, name, value)
SDLens/hooked_sd_pipeline.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ from diffusers import StableDiffusionXLPipeline, IFPipeline
3
+ from typing import List, Dict, Callable, Union
4
+ import torch
5
+ from .hooked_scheduler import HookedNoiseScheduler
6
+
7
+ def retrieve(io):
8
+ if isinstance(io, tuple):
9
+ if len(io) == 1:
10
+ return io[0]
11
+ else:
12
+ raise ValueError("A tuple should have length of 1")
13
+ elif isinstance(io, torch.Tensor):
14
+ return io
15
+ else:
16
+ raise ValueError("Input/Output must be a tensor, or 1-element tuple")
17
+
18
+
19
+ class HookedDiffusionAbstractPipeline:
20
+ parent_cls = None
21
+ pipe = None
22
+
23
+ def __init__(self, pipe: parent_cls, use_hooked_scheduler: bool = False):
24
+ if use_hooked_scheduler:
25
+ pipe.scheduler = HookedNoiseScheduler(pipe.scheduler)
26
+ self.__dict__['pipe'] = pipe
27
+ self.use_hooked_scheduler = use_hooked_scheduler
28
+
29
+ @classmethod
30
+ def from_pretrained(cls, *args, **kwargs):
31
+ return cls(cls.parent_cls.from_pretrained(*args, **kwargs))
32
+
33
+
34
+ def run_with_hooks(self,
35
+ *args,
36
+ position_hook_dict: Dict[str, Union[Callable, List[Callable]]],
37
+ **kwargs
38
+ ):
39
+ '''
40
+ Run the pipeline with hooks at specified positions.
41
+ Returns the final output.
42
+
43
+ Args:
44
+ *args: Arguments to pass to the pipeline.
45
+ position_hook_dict: A dictionary mapping positions to hooks.
46
+ The keys are positions in the pipeline where the hooks should be registered.
47
+ The values are either a single hook or a list of hooks to be registered at the specified position.
48
+ Each hook should be a callable that takes three arguments: (module, input, output).
49
+ **kwargs: Keyword arguments to pass to the pipeline.
50
+ '''
51
+ hooks = []
52
+ for position, hook in position_hook_dict.items():
53
+ if isinstance(hook, list):
54
+ for h in hook:
55
+ hooks.append(self._register_general_hook(position, h))
56
+ else:
57
+ hooks.append(self._register_general_hook(position, hook))
58
+
59
+ hooks = [hook for hook in hooks if hook is not None]
60
+
61
+ try:
62
+ output = self.pipe(*args, **kwargs)
63
+ finally:
64
+ for hook in hooks:
65
+ hook.remove()
66
+ if self.use_hooked_scheduler:
67
+ self.pipe.scheduler.pre_hooks = []
68
+ self.pipe.scheduler.post_hooks = []
69
+
70
+ return output
71
+
72
+ def run_with_cache(self,
73
+ *args,
74
+ positions_to_cache: List[str],
75
+ save_input: bool = False,
76
+ save_output: bool = True,
77
+ **kwargs
78
+ ):
79
+ '''
80
+ Run the pipeline with caching at specified positions.
81
+
82
+ This method allows you to cache the intermediate inputs and/or outputs of the pipeline
83
+ at certain positions. The final output of the pipeline and a dictionary of cached values
84
+ are returned.
85
+
86
+ Args:
87
+ *args: Arguments to pass to the pipeline.
88
+ positions_to_cache (List[str]): A list of positions in the pipeline where intermediate
89
+ inputs/outputs should be cached.
90
+ save_input (bool, optional): If True, caches the input at each specified position.
91
+ Defaults to False.
92
+ save_output (bool, optional): If True, caches the output at each specified position.
93
+ Defaults to True.
94
+ **kwargs: Keyword arguments to pass to the pipeline.
95
+
96
+ Returns:
97
+ final_output: The final output of the pipeline after execution.
98
+ cache_dict (Dict[str, Dict[str, Any]]): A dictionary where keys are the specified positions
99
+ and values are dictionaries containing the cached 'input' and/or 'output' at each position,
100
+ depending on the flags `save_input` and `save_output`.
101
+ '''
102
+ cache_input, cache_output = dict() if save_input else None, dict() if save_output else None
103
+ hooks = [
104
+ self._register_cache_hook(position, cache_input, cache_output) for position in positions_to_cache
105
+ ]
106
+ hooks = [hook for hook in hooks if hook is not None]
107
+ output = self.pipe(*args, **kwargs)
108
+ for hook in hooks:
109
+ hook.remove()
110
+ if self.use_hooked_scheduler:
111
+ self.pipe.scheduler.pre_hooks = []
112
+ self.pipe.scheduler.post_hooks = []
113
+
114
+ cache_dict = {}
115
+ if save_input:
116
+ for position, block in cache_input.items():
117
+ cache_input[position] = torch.stack(block, dim=1)
118
+ cache_dict['input'] = cache_input
119
+
120
+ if save_output:
121
+ for position, block in cache_output.items():
122
+ cache_output[position] = torch.stack(block, dim=1)
123
+ cache_dict['output'] = cache_output
124
+ return output, cache_dict
125
+
126
+ def run_with_hooks_and_cache(self,
127
+ *args,
128
+ position_hook_dict: Dict[str, Union[Callable, List[Callable]]],
129
+ positions_to_cache: List[str] = [],
130
+ save_input: bool = False,
131
+ save_output: bool = True,
132
+ **kwargs
133
+ ):
134
+ '''
135
+ Run the pipeline with hooks and caching at specified positions.
136
+
137
+ This method allows you to register hooks at certain positions in the pipeline and
138
+ cache intermediate inputs and/or outputs at specified positions. Hooks can be used
139
+ for inspecting or modifying the pipeline's execution, and caching stores intermediate
140
+ values for later inspection or use.
141
+
142
+ Args:
143
+ *args: Arguments to pass to the pipeline.
144
+ position_hook_dict Dict[str, Union[Callable, List[Callable]]]:
145
+ A dictionary where the keys are the positions in the pipeline, and the values
146
+ are hooks (either a single hook or a list of hooks) to be registered at those positions.
147
+ Each hook should be a callable that accepts three arguments: (module, input, output).
148
+ positions_to_cache (List[str], optional): A list of positions in the pipeline where
149
+ intermediate inputs/outputs should be cached. Defaults to an empty list.
150
+ save_input (bool, optional): If True, caches the input at each specified position.
151
+ Defaults to False.
152
+ save_output (bool, optional): If True, caches the output at each specified position.
153
+ Defaults to True.
154
+ **kwargs: Additional keyword arguments to pass to the pipeline.
155
+
156
+ Returns:
157
+ final_output: The final output of the pipeline after execution.
158
+ cache_dict (Dict[str, Dict[str, Any]]): A dictionary where keys are the specified positions
159
+ and values are dictionaries containing the cached 'input' and/or 'output' at each position,
160
+ depending on the flags `save_input` and `save_output`.
161
+ '''
162
+ cache_input, cache_output = dict() if save_input else None, dict() if save_output else None
163
+ hooks = [
164
+ self._register_cache_hook(position, cache_input, cache_output) for position in positions_to_cache
165
+ ]
166
+
167
+ for position, hook in position_hook_dict.items():
168
+ if isinstance(hook, list):
169
+ for h in hook:
170
+ hooks.append(self._register_general_hook(position, h))
171
+ else:
172
+ hooks.append(self._register_general_hook(position, hook))
173
+
174
+ hooks = [hook for hook in hooks if hook is not None]
175
+ output = self.pipe(*args, **kwargs)
176
+ for hook in hooks:
177
+ hook.remove()
178
+ if self.use_hooked_scheduler:
179
+ self.pipe.scheduler.pre_hooks = []
180
+ self.pipe.scheduler.post_hooks = []
181
+
182
+ cache_dict = {}
183
+ if save_input:
184
+ for position, block in cache_input.items():
185
+ cache_input[position] = torch.stack(block, dim=1)
186
+ cache_dict['input'] = cache_input
187
+
188
+ if save_output:
189
+ for position, block in cache_output.items():
190
+ cache_output[position] = torch.stack(block, dim=1)
191
+ cache_dict['output'] = cache_output
192
+
193
+ return output, cache_dict
194
+
195
+
196
+ def _locate_block(self, position: str):
197
+ '''
198
+ Locate the block at the specified position in the pipeline.
199
+ '''
200
+ block = self.pipe
201
+ for step in position.split('.'):
202
+ if step.isdigit():
203
+ step = int(step)
204
+ block = block[step]
205
+ else:
206
+ block = getattr(block, step)
207
+ return block
208
+
209
+
210
+ def _register_cache_hook(self, position: str, cache_input: Dict, cache_output: Dict):
211
+
212
+ if position.endswith('$self_attention') or position.endswith('$cross_attention'):
213
+ return self._register_cache_attention_hook(position, cache_output)
214
+
215
+ if position == 'noise':
216
+ def hook(model_output, timestep, sample, generator):
217
+ if position not in cache_output:
218
+ cache_output[position] = []
219
+ cache_output[position].append(sample)
220
+
221
+ if self.use_hooked_scheduler:
222
+ self.pipe.scheduler.post_hooks.append(hook)
223
+ else:
224
+ raise ValueError('Cannot cache noise without using hooked scheduler')
225
+ return
226
+
227
+ block = self._locate_block(position)
228
+
229
+ def hook(module, input, kwargs, output):
230
+ if cache_input is not None:
231
+ if position not in cache_input:
232
+ cache_input[position] = []
233
+ cache_input[position].append(retrieve(input))
234
+
235
+ if cache_output is not None:
236
+ if position not in cache_output:
237
+ cache_output[position] = []
238
+ cache_output[position].append(retrieve(output))
239
+
240
+ return block.register_forward_hook(hook, with_kwargs=True)
241
+
242
+ def _register_cache_attention_hook(self, position, cache):
243
+ attn_block = self._locate_block(position.split('$')[0])
244
+ if position.endswith('$self_attention'):
245
+ attn_block = attn_block.attn1
246
+ elif position.endswith('$cross_attention'):
247
+ attn_block = attn_block.attn2
248
+ else:
249
+ raise ValueError('Wrong attention type')
250
+
251
+ def hook(module, args, kwargs, output):
252
+ hidden_states = args[0]
253
+ encoder_hidden_states = kwargs['encoder_hidden_states']
254
+ attention_mask = kwargs['attention_mask']
255
+ batch_size, sequence_length, _ = hidden_states.shape
256
+ attention_mask = attn_block.prepare_attention_mask(attention_mask, sequence_length, batch_size)
257
+ query = attn_block.to_q(hidden_states)
258
+
259
+
260
+ if encoder_hidden_states is None:
261
+ encoder_hidden_states = hidden_states
262
+ elif attn_block.norm_cross is not None:
263
+ encoder_hidden_states = attn_block.norm_cross(encoder_hidden_states)
264
+
265
+ key = attn_block.to_k(encoder_hidden_states)
266
+ value = attn_block.to_v(encoder_hidden_states)
267
+
268
+ query = attn_block.head_to_batch_dim(query)
269
+ key = attn_block.head_to_batch_dim(key)
270
+ value = attn_block.head_to_batch_dim(value)
271
+
272
+ attention_probs = attn_block.get_attention_scores(query, key, attention_mask)
273
+ attention_probs = attention_probs.view(
274
+ batch_size,
275
+ attention_probs.shape[0] // batch_size,
276
+ attention_probs.shape[1],
277
+ attention_probs.shape[2]
278
+ )
279
+ if position not in cache:
280
+ cache[position] = []
281
+ cache[position].append(attention_probs)
282
+
283
+ return attn_block.register_forward_hook(hook, with_kwargs=True)
284
+
285
+ def _register_general_hook(self, position, hook):
286
+ if position == 'scheduler_pre':
287
+ if not self.use_hooked_scheduler:
288
+ raise ValueError('Cannot register hooks on scheduler without using hooked scheduler')
289
+ self.pipe.scheduler.pre_hooks.append(hook)
290
+ return
291
+ elif position == 'scheduler_post':
292
+ if not self.use_hooked_scheduler:
293
+ raise ValueError('Cannot register hooks on scheduler without using hooked scheduler')
294
+ self.pipe.scheduler.post_hooks.append(hook)
295
+ return
296
+
297
+ block = self._locate_block(position)
298
+ return block.register_forward_hook(hook)
299
+
300
+ def to(self, *args, **kwargs):
301
+ self.pipe = self.pipe.to(*args, **kwargs)
302
+ return self
303
+
304
+ def __getattr__(self, name):
305
+ return getattr(self.pipe, name)
306
+
307
+ def __setattr__(self, name, value):
308
+ return setattr(self.pipe, name, value)
309
+
310
+ def __call__(self, *args, **kwargs):
311
+ return self.pipe(*args, **kwargs)
312
+
313
+
314
+ class HookedStableDiffusionXLPipeline(HookedDiffusionAbstractPipeline):
315
+ parent_cls = StableDiffusionXLPipeline
316
+
317
+
318
+ class HookedIFPipeline(HookedDiffusionAbstractPipeline):
319
+ parent_cls = IFPipeline
app.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import json
3
+ import gradio as gr
4
+ import os
5
+
6
+ # environment
7
+ os.environ['HF_HOME'] = '/dlabscratch1/anmari'
8
+ os.environ['TRANSFORMERS_CACHE'] = '/dlabscratch1/anmari'
9
+ os.environ['HF_DATASETS_CACHE'] = '/dlabscratch1/anmari'
10
+ # os.environ["HF_TOKEN"] = ""
11
+ import torch
12
+ from PIL import Image
13
+ from SDLens import HookedStableDiffusionXLPipeline, CachedPipeline as CachedFLuxPipeline
14
+ from SDLens.cache_and_edit.flux_pipeline import EditedFluxPipeline
15
+ from SAE import SparseAutoencoder
16
+ from utils import TimedHook, add_feature_on_area_base, replace_with_feature_base, add_feature_on_area_turbo, replace_with_feature_turbo, add_feature_on_area_flux
17
+ import numpy as np
18
+ import matplotlib.pyplot as plt
19
+ from matplotlib.colors import ListedColormap
20
+ import threading
21
+ from einops import rearrange
22
+ import spaces
23
+ # from retrieval import FeatureRetriever
24
+
25
+
26
+ code_to_block_sd = {
27
+ "down.2.1": "unet.down_blocks.2.attentions.1",
28
+ "mid.0": "unet.mid_block.attentions.0",
29
+ "up.0.1": "unet.up_blocks.0.attentions.1",
30
+ "up.0.0": "unet.up_blocks.0.attentions.0"
31
+ }
32
+ code_to_block_flux = {"18": "transformer.transformer_blocks.18"}
33
+
34
+ FLUX_NAMES = ["black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-dev"]
35
+ MODELS_CONFIG = {
36
+ "stabilityai/stable-diffusion-xl-base-1.0": {
37
+ "steps": 25,
38
+ "guidance_scale": 8.0,
39
+ "choices": ["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
40
+ "value": "down.2.1 (composition)",
41
+ "code_to_block": code_to_block_sd,
42
+ "max_steps": 50,
43
+ "is_flux": False,
44
+ "downsample_factor": 16,
45
+ "add_feature_on_area": add_feature_on_area_base,
46
+ "num_features": 5120,
47
+
48
+ },
49
+ "stabilityai/sdxl-turbo": {
50
+ "steps": 1,
51
+ "guidance_scale": 0.0,
52
+ "choices": ["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
53
+ "value": "down.2.1 (composition)",
54
+ "code_to_block": code_to_block_sd,
55
+ "max_steps": 4,
56
+ "is_flux": False,
57
+ "downsample_factor": 32,
58
+ "add_feature_on_area": add_feature_on_area_turbo,
59
+ "num_features": 5120,
60
+ },
61
+ "black-forest-labs/FLUX.1-schnell": {
62
+ "steps": 1,
63
+ "guidance_scale": 0.0,
64
+ "choices": ["18"],
65
+ "value": "18",
66
+ "code_to_block": code_to_block_flux,
67
+ "max_steps": 4,
68
+ "is_flux": True,
69
+ "exclude_list": [2462, 2974, 1577, 786, 3188, 9986, 4693, 8472, 8248, 325, 9596, 2813, 10803, 11773, 11410, 1067, 2965, 10488, 4537, 2102],
70
+ "downsample_factor": 8,
71
+ "add_feature_on_area": add_feature_on_area_flux,
72
+ "num_features": 12288
73
+
74
+ },
75
+
76
+ "black-forest-labs/FLUX.1-dev": {
77
+ "steps": 25,
78
+ "guidance_scale": 0.0,
79
+ "choices": ["18"],
80
+ "value": "18",
81
+ "code_to_block": code_to_block_flux,
82
+ "max_steps": 50,
83
+ "is_flux": True,
84
+ "exclude_list": [2462, 2974, 1577, 786, 3188, 9986, 4693, 8472, 8248, 325, 9596, 2813, 10803, 11773, 11410, 1067, 2965, 10488, 4537, 2102],
85
+ "downsample_factor": 8,
86
+ "add_feature_on_area": add_feature_on_area_flux,
87
+ "num_features": 12288
88
+
89
+ }
90
+ }
91
+
92
+
93
+
94
+
95
+ lock = threading.Lock()
96
+
97
+
98
+
99
+
100
+
101
+ def process_cache(cache, saes_dict, model_config, timestep=None):
102
+
103
+ top_features_dict = {}
104
+ sparse_maps_dict = {}
105
+
106
+ for code in model_config['code_to_block'].keys():
107
+ block = model_config["code_to_block"][code]
108
+ sae = saes_dict[code]
109
+
110
+
111
+ if model_config["is_flux"]:
112
+
113
+ with torch.no_grad():
114
+ features = sae.encode(torch.stack(cache.image_activation)) # shape: [timestep, batch, seq_len, num_features]
115
+ features[..., model_config["exclude_list"]] = 0
116
+
117
+ if timestep is not None and timestep < features.shape[0]:
118
+ features = features[timestep:timestep+1]
119
+
120
+ # I want to get [batch, timestep, 64, 64, num_features]
121
+ sparse_maps = rearrange(features, "t b (w h) n -> b t w h n", w=64, h=64).squeeze(0).squeeze(0)
122
+
123
+ else:
124
+
125
+ diff = cache["output"][block] - cache["input"][block]
126
+ if diff.shape[0] == 2: # guidance is on and we need to select the second output
127
+ diff = diff[1].unsqueeze(0)
128
+
129
+ # If a specific timestep is provided, select that timestep from the cached activations
130
+ if timestep is not None and timestep < diff.shape[1]:
131
+ diff = diff[:, timestep:timestep+1]
132
+
133
+ diff = diff.permute(0, 1, 3, 4, 2).squeeze(0).squeeze(0)
134
+ with torch.no_grad():
135
+ sparse_maps = sae.encode(diff)
136
+
137
+ averages = torch.mean(sparse_maps, dim=(0, 1))
138
+
139
+ top_features = torch.topk(averages, 10).indices
140
+
141
+ top_features_dict[code] = top_features.cpu().tolist()
142
+ sparse_maps_dict[code] = sparse_maps.cpu().numpy()
143
+
144
+ return top_features_dict, sparse_maps_dict
145
+
146
+
147
+ def plot_image_heatmap(cache, block_select, radio, model_config):
148
+ code = block_select.split()[0]
149
+ feature = int(radio)
150
+
151
+ heatmap = cache["heatmaps"][code][:, :, feature]
152
+ scaling_factor = 16 if model_config["is_flux"] else 32
153
+ heatmap = np.kron(heatmap, np.ones((scaling_factor, scaling_factor)))
154
+ image = cache["image"].convert("RGBA")
155
+
156
+ jet = plt.cm.jet
157
+ cmap = jet(np.arange(jet.N))
158
+ cmap[:1, -1] = 0
159
+ cmap[1:, -1] = 0.6
160
+ cmap = ListedColormap(cmap)
161
+ heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))
162
+ heatmap_rgba = cmap(heatmap)
163
+ heatmap_image = Image.fromarray((heatmap_rgba * 255).astype(np.uint8))
164
+ heatmap_with_transparency = Image.alpha_composite(image, heatmap_image)
165
+
166
+ return heatmap_with_transparency
167
+
168
+
169
+ def create_prompt_part(pipe, saes_dict, demo):
170
+
171
+ model_config = MODELS_CONFIG[pipe.pipe.name_or_path]
172
+ @spaces.GPU
173
+ def image_gen(prompt, timestep=None, num_steps=None, guidance_scale=None):
174
+ lock.acquire()
175
+ try:
176
+ # Default values
177
+ default_n_steps = model_config["steps"]
178
+ default_guidance = model_config["guidance_scale"]
179
+
180
+ # Use provided values if available, otherwise use defaults
181
+ n_steps = default_n_steps if num_steps is None else int(num_steps)
182
+ guidance = default_guidance if guidance_scale is None else float(guidance_scale)
183
+
184
+ # Convert timestep to integer if it's not None
185
+ timestep_int = None if timestep is None else int(timestep)
186
+
187
+ if "FLUX" in pipe.pipe.name_or_path:
188
+ images = pipe.run(
189
+ prompt,
190
+ num_inference_steps=n_steps,
191
+ width=1024,
192
+ height=1024,
193
+ cache_activations=True,
194
+ guidance_scale=guidance,
195
+ positions_to_cache = list(model_config["code_to_block"].values()),
196
+ inverse=False,
197
+ )
198
+ cache = pipe.activation_cache
199
+
200
+ else:
201
+ images, cache = pipe.run_with_cache(
202
+ prompt,
203
+ positions_to_cache=list(model_config["code_to_block"].values()),
204
+ num_inference_steps=n_steps,
205
+ generator=torch.Generator(device="cpu").manual_seed(42),
206
+ guidance_scale=guidance,
207
+ save_input=True,
208
+ save_output=True
209
+ )
210
+ finally:
211
+ lock.release()
212
+
213
+ top_features_dict, top_sparse_maps_dict = process_cache(cache, saes_dict, model_config, timestep_int)
214
+ return images.images[0], {
215
+ "image": images.images[0],
216
+ "heatmaps": top_sparse_maps_dict,
217
+ "features": top_features_dict
218
+ }
219
+
220
+ def update_radio(cache, block_select):
221
+ code = block_select.split()[0]
222
+ return gr.update(choices=cache["features"][code])
223
+
224
+ def update_img(cache, block_select, radio):
225
+ new_img = plot_image_heatmap(cache, block_select, radio, model_config)
226
+ return new_img
227
+
228
+ with gr.Tab("Explore", elem_classes="tabs") as explore_tab:
229
+ cache = gr.State(value={
230
+ "image": None,
231
+ "heatmaps": None,
232
+ "features": []
233
+ })
234
+ with gr.Row():
235
+ with gr.Column(scale=7):
236
+ with gr.Row(equal_height=True):
237
+ prompt_field = gr.Textbox(lines=1, label="Enter prompt here", value="A cinematic shot of a professor sloth wearing a tuxedo at a BBQ party and eathing a dish with peas.")
238
+ button = gr.Button("Generate", elem_classes="generate_button1")
239
+
240
+ with gr.Row():
241
+ image = gr.Image(width=512, height=512, image_mode="RGB", label="Generated image")
242
+
243
+ with gr.Column(scale=4):
244
+ block_select = gr.Dropdown(
245
+ choices=model_config["choices"], # replace this for flux
246
+ value=model_config["value"],
247
+ label="Select block",
248
+ elem_id="block_select",
249
+ interactive=True
250
+ )
251
+
252
+ with gr.Group() as sdxl_base_controls:
253
+ steps_slider = gr.Slider(
254
+ minimum=1,
255
+ maximum=model_config["max_steps"],
256
+ value= model_config["steps"],
257
+ step=1,
258
+ label="Number of steps",
259
+ elem_id="steps_slider",
260
+ interactive=True,
261
+ visible=True
262
+ )
263
+
264
+ # Add timestep selector
265
+ # TODO: check this
266
+ timestep_selector = gr.Slider(
267
+ minimum=0,
268
+ maximum=model_config["max_steps"]-1,
269
+ value=None,
270
+ step=1,
271
+ label="Timestep (leave empty for average across all steps)",
272
+ elem_id="timestep_selector",
273
+ interactive=True,
274
+ visible=True,
275
+ )
276
+ recompute_button = gr.Button("Recompute", elem_id="recompute_button")
277
+ # Update max timestep when steps change
278
+ steps_slider.change(lambda s: gr.update(maximum=s-1), [steps_slider], [timestep_selector])
279
+
280
+ radio = gr.Radio(choices=[], label="Select a feature", interactive=True)
281
+
282
+ button.click(image_gen, [prompt_field, timestep_selector, steps_slider], outputs=[image, cache])
283
+ cache.change(update_radio, [cache, block_select], outputs=[radio])
284
+ block_select.select(update_radio, [cache, block_select], outputs=[radio])
285
+ radio.select(update_img, [cache, block_select, radio], outputs=[image])
286
+ recompute_button.click(image_gen, [prompt_field, timestep_selector, steps_slider], outputs=[image, cache])
287
+ demo.load(image_gen, [prompt_field, timestep_selector, steps_slider], outputs=[image, cache])
288
+
289
+ return explore_tab
290
+
291
+ def downsample_mask(image, factor):
292
+ downsampled = image.reshape(
293
+ (image.shape[0] // factor, factor,
294
+ image.shape[1] // factor, factor)
295
+ )
296
+ downsampled = downsampled.mean(axis=(1, 3))
297
+ return downsampled
298
+
299
+ def create_intervene_part(pipe: HookedStableDiffusionXLPipeline, saes_dict, means_dict, demo):
300
+ model_config = MODELS_CONFIG[pipe.pipe.name_or_path]
301
+
302
+ @spaces.GPU
303
+ def image_gen(prompt, num_steps, guidance_scale=None):
304
+ lock.acquire()
305
+ guidance = model_config["guidance_scale"] if guidance_scale is None else float(guidance_scale)
306
+ try:
307
+
308
+ if "FLUX" in pipe.pipe.name_or_path:
309
+ images = pipe.run(
310
+ prompt,
311
+ num_inference_steps=int(num_steps),
312
+ width=1024,
313
+ height=1024,
314
+ cache_activations=False,
315
+ guidance_scale=guidance,
316
+ inverse=False,
317
+ )
318
+ else:
319
+ images = pipe.run_with_hooks(
320
+ prompt,
321
+ position_hook_dict={},
322
+ num_inference_steps=int(num_steps),
323
+ generator=torch.Generator(device="cpu").manual_seed(42),
324
+ guidance_scale=guidance,
325
+ )
326
+ finally:
327
+ lock.release()
328
+ if images.images[0].size == (1024, 1024):
329
+ return images.images[0].resize((512, 512))
330
+ else:
331
+ return images.images[0]
332
+
333
+ @spaces.GPU
334
+ def image_mod(prompt, block_str, brush_index, strength, num_steps, input_image, guidance_scale=None, start_index=None, end_index=None):
335
+ block = block_str.split(" ")[0]
336
+
337
+ mask = (input_image["layers"][0] > 0)[:, :, -1].astype(float)
338
+ mask = downsample_mask(mask, model_config["downsample_factor"])
339
+ mask = torch.tensor(mask, dtype=torch.float32, device="cuda")
340
+
341
+ if mask.sum() == 0:
342
+ gr.Info("No mask selected, please draw on the input image")
343
+
344
+
345
+ # Set default values for start_index and end_index if not provided
346
+ if start_index is None:
347
+ start_index = 0
348
+ if end_index is None:
349
+ end_index = int(num_steps)
350
+
351
+ # Ensure start_index and end_index are within valid ranges
352
+ start_index = max(0, min(int(start_index), int(num_steps)))
353
+ end_index = max(0, min(int(end_index), int(num_steps)))
354
+
355
+ # Ensure start_index is less than end_index
356
+ if start_index >= end_index:
357
+ start_index = max(0, end_index - 1)
358
+
359
+
360
+ def myhook(module, input, output):
361
+ return model_config["add_feature_on_area"](
362
+ saes_dict[block],
363
+ brush_index,
364
+ mask * means_dict[block][brush_index] * strength,
365
+ module,
366
+ input,
367
+ output)
368
+ hook = TimedHook(myhook, int(num_steps), np.arange(start_index, end_index))
369
+
370
+ lock.acquire()
371
+ guidance = model_config["guidance_scale"] if guidance_scale is None else float(guidance_scale)
372
+
373
+ try:
374
+
375
+ if model_config["is_flux"]:
376
+ image = pipe.run_with_edit(
377
+ prompt,
378
+ seed=42,
379
+ num_inference_steps=int(num_steps),
380
+ edit_fn= lambda input, output: hook(None, input, output),
381
+ layers_for_edit_fn=[i for i in range(18, 57)],
382
+ stream="image").images[0]
383
+ else:
384
+
385
+ image = pipe.run_with_hooks(
386
+ prompt,
387
+ position_hook_dict={model_config["code_to_block"][block]: hook},
388
+ num_inference_steps=int(num_steps),
389
+ generator=torch.Generator(device="cpu").manual_seed(42),
390
+ guidance_scale=guidance
391
+ ).images[0]
392
+ finally:
393
+ lock.release()
394
+ return image
395
+
396
+ def feature_icon(block_str, brush_index, guidance_scale=None):
397
+ block = block_str.split(" ")[0]
398
+ if block in ["mid.0", "up.0.0"]:
399
+ gr.Info("Note that Feature Icon works best with down.2.1 and up.0.1 blocks but feel free to explore", duration=3)
400
+
401
+ def hook(module, input, output):
402
+ if is_base_model:
403
+ return replace_with_feature_base(
404
+ saes_dict[block],
405
+ brush_index,
406
+ means_dict[block][brush_index] * saes_dict[block].k,
407
+ module,
408
+ input,
409
+ output
410
+ )
411
+ else:
412
+ return replace_with_feature_turbo(
413
+ saes_dict[block],
414
+ brush_index,
415
+ means_dict[block][brush_index] * saes_dict[block].k,
416
+ module,
417
+ input,
418
+ output)
419
+ lock.acquire()
420
+ guidance = model_config["guidance_scale"] if guidance_scale is None else float(guidance_scale)
421
+
422
+ try:
423
+ image = pipe.run_with_hooks(
424
+ "",
425
+ position_hook_dict={model_config["code_to_block"][block]: hook},
426
+ num_inference_steps=model_config["steps"],
427
+ generator=torch.Generator(device="cpu").manual_seed(42),
428
+ guidance_scale=guidance,
429
+ ).images[0]
430
+ finally:
431
+ lock.release()
432
+ return image
433
+
434
+ with gr.Tab("Paint!", elem_classes="tabs") as intervene_tab:
435
+ image_state = gr.State(value=None)
436
+ with gr.Row():
437
+ with gr.Column(scale=3):
438
+ # Generation column
439
+ with gr.Row():
440
+ # prompt and num_steps
441
+ prompt_field = gr.Textbox(lines=1, label="Enter prompt here", value="A dog plays with a ball, cartoon", elem_id="prompt_input")
442
+
443
+ with gr.Row():
444
+ num_steps = gr.Number(value=model_config["steps"], label="Number of steps", minimum=1, maximum=model_config["max_steps"], elem_id="num_steps", precision=0)
445
+
446
+ with gr.Row():
447
+ # Generate button
448
+ button_generate = gr.Button("Generate", elem_id="generate_button")
449
+ with gr.Column(scale=3):
450
+ # Intervention column
451
+ with gr.Row():
452
+ # dropdowns and number inputs
453
+ with gr.Column(scale=7):
454
+ with gr.Row():
455
+ block_select = gr.Dropdown(
456
+ choices=model_config["choices"],
457
+ value=model_config["value"],
458
+ label="Select block",
459
+ elem_id="block_select"
460
+ )
461
+ brush_index = gr.Number(value=0, label="Brush index", minimum=0, maximum=model_config["num_features"]-1, elem_id="brush_index", precision=0)
462
+ # with gr.Row():
463
+ # button_icon = gr.Button('Feature Icon', elem_id="feature_icon_button")
464
+ with gr.Row():
465
+ gr.Markdown("**TimedHook Range** (which steps to apply the feature)", visible=True)
466
+ with gr.Row():
467
+ start_index = gr.Number(value=0, label="Start index", minimum=0, maximum=model_config["max_steps"], elem_id="start_index", precision=0, visible=True)
468
+ end_index = gr.Number(value=model_config["steps"], label="End index", minimum=0, maximum=model_config["max_steps"], elem_id="end_index", precision=0, visible=True)
469
+ with gr.Column(scale=3):
470
+ with gr.Row():
471
+ strength = gr.Number(value=10, label="Strength", minimum=-40, maximum=40, elem_id="strength", precision=2)
472
+ with gr.Row():
473
+ button = gr.Button('Apply', elem_id="apply_button")
474
+
475
+ with gr.Row():
476
+ with gr.Column():
477
+ # Input image
478
+ i_image = gr.Sketchpad(
479
+ height=610,
480
+ layers=False, transforms=[], placeholder="Generate and paint!",
481
+ brush=gr.Brush(default_size=64, color_mode="fixed", colors=['black']),
482
+ container=False,
483
+ canvas_size=(512, 512),
484
+ label="Input Image")
485
+ clear_button = gr.Button("Clear")
486
+ clear_button.click(lambda x: x, [image_state], [i_image])
487
+ # Output image
488
+ o_image = gr.Image(width=512, height=512, label="Output Image")
489
+
490
+ # Set up the click events
491
+ button_generate.click(image_gen, inputs=[prompt_field, num_steps], outputs=[image_state])
492
+ image_state.change(lambda x: x, [image_state], [i_image])
493
+
494
+ # Update max values for start_index and end_index when num_steps changes
495
+ def update_index_maxes(steps):
496
+ return gr.update(maximum=steps), gr.update(maximum=steps)
497
+
498
+ num_steps.change(update_index_maxes, [num_steps], [start_index, end_index])
499
+
500
+ button.click(image_mod,
501
+ inputs=[prompt_field, block_select, brush_index, strength, num_steps, i_image, start_index, end_index],
502
+ outputs=o_image)
503
+ # button_icon.click(feature_icon, inputs=[block_select, brush_index], outputs=o_image)
504
+ demo.load(image_gen, [prompt_field, num_steps], outputs=[image_state])
505
+
506
+
507
+ return intervene_tab
508
+
509
+
510
+
511
+ def create_top_images_part(demo, pipe):
512
+
513
+ model_config = MODELS_CONFIG[pipe.pipe.name_or_path]
514
+
515
+ if isinstance(pipe, HookedStableDiffusionXLPipeline):
516
+ is_flux = False
517
+ elif isinstance(pipe, CachedFLuxPipeline):
518
+ is_flux = True
519
+ else:
520
+ raise AssertionError(f"Unknown pipe class: {type(pipe)}")
521
+
522
+ def update_top_images(block_select, brush_index):
523
+ block = block_select.split(" ")[0]
524
+ # Define path for fetching image
525
+ if is_flux:
526
+ part = 1 if brush_index <= 7000 else 2
527
+ url = f"https://huggingface.co/datasets/antoniomari/flux_sae_images/resolve/main/{block}/part{part}/{brush_index}.jpg"
528
+ else:
529
+ url = f"https://huggingface.co/surokpro2/sdxl_sae_images/resolve/main/{block}/{brush_index}.jpg"
530
+ return url
531
+
532
+ with gr.Tab("Top Images", elem_classes="tabs") as top_images_tab:
533
+ with gr.Row():
534
+ block_select = gr.Dropdown(
535
+ choices=["flux_18"] if is_flux else ["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
536
+ value="flux_18" if is_flux else "down.2.1 (composition)",
537
+ label="Select block"
538
+ )
539
+ brush_index = gr.Number(value=0, label="Brush index", minimum=0, maximum=model_config["num_features"]-1, precision=0)
540
+ with gr.Row():
541
+ image = gr.Image(width=600, height=600, label="Top Images")
542
+
543
+ block_select.select(update_top_images, [block_select, brush_index], outputs=[image])
544
+ brush_index.change(update_top_images, [block_select, brush_index], outputs=[image])
545
+ demo.load(update_top_images, [block_select, brush_index], outputs=[image])
546
+ return top_images_tab
547
+
548
+
549
+ def create_top_images_plus_search_part(retriever, demo, pipe):
550
+
551
+ model_config = MODELS_CONFIG[pipe.pipe.name_or_path]
552
+
553
+
554
+
555
+ if isinstance(pipe, HookedStableDiffusionXLPipeline):
556
+ is_flux = False
557
+ elif isinstance(pipe, CachedFLuxPipeline):
558
+ is_flux = True
559
+ else:
560
+ raise AssertionError(f"Unknown pipe class: {type(pipe)}")
561
+
562
+ def update_cache(block_select, search_by_text, search_by_index):
563
+ if search_by_text == "":
564
+ top_indices = []
565
+ index = search_by_index
566
+ block = block_select.split(" ")[0]
567
+
568
+ # Define path for fetching image
569
+ if is_flux:
570
+ part = 1 if index <= 7000 else 2
571
+ url = f"https://huggingface.co/antoniomari/flux_sae_images/resolve/main/{block}/part{part}/{index}.jpg"
572
+ else:
573
+ url = f"https://huggingface.co/surokpro2/sdxl_sae_images/resolve/main/{block}/{index}.jpg"
574
+ return url, {"image": url, "feature_idx": index, "features": top_indices}
575
+ else:
576
+ # TODO
577
+ if retriever is None:
578
+ raise ValueError("Feature retrieval is not enabled")
579
+ lock.acquire()
580
+ try:
581
+ top_indices = list(retriever.query_text(search_by_text, block_select.split(" ")[0]).keys())
582
+ finally:
583
+ lock.release()
584
+ block = block_select.split(" ")[0]
585
+ top_indices = list(map(int, top_indices))
586
+ index = top_indices[0]
587
+ url = f"https://huggingface.co/surokpro2/sdxl_sae_images/resolve/main/{block}/{index}.jpg"
588
+ return url, {"image": url, "feature_idx": index, "features": top_indices[:20]}
589
+
590
+ def update_radio(cache):
591
+ return gr.update(choices=cache["features"], value=cache["feature_idx"])
592
+
593
+ def update_img(cache, block_select, index):
594
+ block = block_select.split(" ")[0]
595
+ url = f"https://huggingface.co/surokpro2/sdxl_sae_images/resolve/main/{block}/{index}.jpg"
596
+ return url
597
+
598
+ with gr.Tab("Top Images", elem_classes="tabs") as explore_tab:
599
+ cache = gr.State(value={
600
+ "image": None,
601
+ "feature_idx": None,
602
+ "features": []
603
+ })
604
+ with gr.Row():
605
+ with gr.Column(scale=7):
606
+ with gr.Row():
607
+ # top images
608
+ image = gr.Image(width=600, height=600, image_mode="RGB", label="Top images")
609
+
610
+ with gr.Column(scale=4):
611
+ block_select = gr.Dropdown(
612
+ choices=["flux_18"] if is_flux else ["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
613
+ value="flux_18" if is_flux else "down.2.1 (composition)",
614
+ label="Select block",
615
+ elem_id="block_select",
616
+ interactive=True
617
+ )
618
+ search_by_index = gr.Number(value=0, label="Search by index", minimum=0, maximum=model_config["num_features"]-1, precision=0)
619
+ search_by_text = gr.Textbox(lines=1, label="Search by text", value="", visible=False)
620
+ radio = gr.Radio(choices=[], label="Select a feature", interactive=True, visible=False)
621
+
622
+
623
+ search_by_text.change(update_cache,
624
+ [block_select, search_by_text, search_by_index],
625
+ outputs=[image, cache])
626
+ block_select.select(update_cache,
627
+ [block_select, search_by_text, search_by_index],
628
+ outputs=[image, cache])
629
+ cache.change(update_radio, [cache], outputs=[radio])
630
+ radio.select(update_img, [cache, block_select, radio], outputs=[image])
631
+ search_by_index.change(update_img, [cache, block_select, search_by_index], outputs=[image])
632
+ demo.load(update_img,
633
+ [cache, block_select, search_by_index],
634
+ outputs=[image])
635
+
636
+ return explore_tab
637
+
638
+
639
+ def create_intro_part():
640
+ with gr.Tab("Instructions", elem_classes="tabs") as intro_tab:
641
+ gr.Markdown(
642
+ '''# Unpacking SDXL Turbo with Sparse Autoencoders
643
+ ## Demo Overview
644
+ This demo showcases the use of Sparse Autoencoders (SAEs) to understand the features learned by the Stable Diffusion XL Turbo model.
645
+
646
+ ## How to Use
647
+ ### Explore
648
+ * Enter a prompt in the text box and click on the "Generate" button to generate an image.
649
+ * You can observe the active features in different blocks plot on top of the generated image.
650
+ ### Top Images
651
+ * For each feature, you can view the top images that activate the feature the most.
652
+ ### Paint!
653
+ * Generate an image using the prompt.
654
+ * Paint on the generated image to apply interventions.
655
+ * Use the "Feature Icon" button to understand how the selected brush functions.
656
+
657
+ ### Remarks
658
+ * Not all brushes mix well with all images. Experiment with different brushes and strengths.
659
+ * Feature Icon works best with `down.2.1 (composition)` and `up.0.1 (style)` blocks.
660
+ * This demo is provided for research purposes only. We do not take responsibility for the content generated by the demo.
661
+
662
+ ### Interesting features to try
663
+ To get started, try the following features:
664
+ - down.2.1 (composition): 2301 (evil) 3747 (image frame) 4998 (cartoon)
665
+ - up.0.1 (style): 4977 (tiger stripes) 90 (fur) 2615 (twilight blur)
666
+ '''
667
+ )
668
+
669
+ return intro_tab
670
+
671
+
672
+ def create_demo(pipe, saes_dict, means_dict, use_retrieval=True):
673
+ custom_css = """
674
+ .tabs button {
675
+ font-size: 20px !important; /* Adjust font size for tab text */
676
+ padding: 10px !important; /* Adjust padding to make the tabs bigger */
677
+ font-weight: bold !important; /* Adjust font weight to make the text bold */
678
+ }
679
+ .generate_button1 {
680
+ max-width: 160px !important;
681
+ margin-top: 20px !important;
682
+ margin-bottom: 20px !important;
683
+ }
684
+ """
685
+ if use_retrieval:
686
+ retriever = None # FeatureRetriever()
687
+ else:
688
+ retriever = None
689
+
690
+ with gr.Blocks(css=custom_css) as demo:
691
+ # with create_intro_part():
692
+ # pass
693
+ with create_prompt_part(pipe, saes_dict, demo):
694
+ pass
695
+ with create_top_images_part(demo, pipe):
696
+ pass
697
+ with create_intervene_part(pipe, saes_dict, means_dict, demo):
698
+ pass
699
+
700
+ return demo
701
+
702
+
703
+ if __name__ == "__main__":
704
+ import os
705
+ import gradio as gr
706
+ import torch
707
+ from SDLens import HookedStableDiffusionXLPipeline
708
+ from SAE import SparseAutoencoder
709
+ from huggingface_hub import hf_hub_download
710
+
711
+ dtype = torch.float16
712
+ pipe = EditedFluxPipeline.from_pretrained(
713
+ "black-forest-labs/FLUX.1-schnell",
714
+ device_map="balanced",
715
+ torch_dtype=dtype
716
+ )
717
+ pipe.set_progress_bar_config(disable=True)
718
+ pipe = CachedFLuxPipeline(pipe)
719
+
720
+ # Parameters
721
+ DEVICE = "cuda"
722
+
723
+ # Hugging Face repo setup
724
+ HF_REPO_ID = "antoniomari/SAE_flux_18"
725
+ HF_BRANCH = "main"
726
+
727
+ # Command-line arguments
728
+ block_code = "18"
729
+ block_name = code_to_block_flux[block_code]
730
+
731
+ saes_dict = {}
732
+ means_dict = {}
733
+
734
+ # Download files from the root of the repo
735
+ state_dict_path = hf_hub_download(
736
+ repo_id=HF_REPO_ID,
737
+ filename="state_dict.pth",
738
+ revision=HF_BRANCH
739
+ )
740
+
741
+ config_path = hf_hub_download(
742
+ repo_id=HF_REPO_ID,
743
+ filename="config.json",
744
+ revision=HF_BRANCH
745
+ )
746
+
747
+ mean_path = hf_hub_download(
748
+ repo_id=HF_REPO_ID,
749
+ filename="mean.pt",
750
+ revision=HF_BRANCH
751
+ )
752
+
753
+ # Load config and model
754
+ with open(config_path, "r") as f:
755
+ config = json.load(f)
756
+
757
+ sae = SparseAutoencoder(**config)
758
+ checkpoint = torch.load(state_dict_path, map_location=DEVICE)
759
+ state_dict = checkpoint["state_dict"]
760
+ sae.load_state_dict(state_dict)
761
+ sae = sae.to(DEVICE, dtype=torch.float16).eval()
762
+ means = torch.load(mean_path, map_location=DEVICE).to(dtype)
763
+
764
+ saes_dict[block_code] = sae
765
+ means_dict[block_code] = means
766
+
767
+ demo = create_demo(pipe, saes_dict, means_dict)
768
+ demo.launch()
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:387f2b6f8c4e4a6f1227921f28f00dfa4beb2bd4e422b7eb592cd8627af0e58f
3
+ size 21581
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39e3c6d17aa572a53368ca8ba8f82757947a3caf14fe654e84b175d0dc0a4650
3
+ size 52497831
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6ca694c9504a7a8aa827004d3fdec5c1cb8fcf3904acc3562d1861fc6e65c19
3
+ size 21576
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80790481d0e56ac3fa36599703cee7a05cfb4cc078db57c8f9180e860c330e1d
3
+ size 21581
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:49d38d9178c2a2780e04a5482a2feb9548c6e9a636ed1bf85291acf42e0ffa34
3
+ size 52497831
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bb6bfc7ce5e596f8aa048ab262ca56841868c222bf07eb2ed35b6e4f7094fea6
3
+ size 21576
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de036d0fb9ee663f7bdf60e4a5d89d038516dae637531676b53ff75d05eab46b
3
+ size 21581
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14c45efd9cce0258f014c49babdcd0e9ce8b266fe31eed72db1a45b990a1a0f8
3
+ size 52497831
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb9c04499ccae041987cc262894e254c2f04288857a8a0470cfb1b86a8ecfa09
3
+ size 21576
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96dbf6fffe9d62c3b3352f8e4fe48c54dfd69906cf8ad6828d5ce93db9a5f0dc
3
+ size 21581
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8eed82f4bcb2f010ae9075f10a1ece801ee3dec46dba7fadccc35f6c0a7836b
3
+ size 52497831
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe5c5be0c4c2d2b57e7888319053cb64929559f947c8ce445ddd6a397302afab
3
+ size 21576
colab_requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.29.2
2
+ gradio==5.23.2
3
+ numpy
4
+ matplotlib
5
+ pillow
6
+ einops
7
+ transformers
8
+ huggingface_hub
example.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diffusers==0.29.2
2
+ gradio==4.44.1
3
+ torch>=2.4.0
4
+ numpy
5
+ matplotlib
6
+ pillow
7
+ wandb
8
+ einops
9
+ transformers
10
+ accelerate
11
+ huggingface_hub
12
+ git+https://github.com/wendlerc/clip-retrieval.git
resourses/image.png ADDED

Git LFS Details

  • SHA256: 86594c5876d61a3eac5238b739eeec41418995c7696b6453d70b4e683ebd82df
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
retrieval.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import snapshot_download
2
+ from clip_retrieval.clip_back import load_clip_indices, KnnService, ClipOptions
3
+ from collections import defaultdict
4
+ import os
5
+ import glob
6
+ import shutil
7
+ import random
8
+
9
+ class FeatureRetriever:
10
+ def __init__(self,
11
+ num_images=50,
12
+ imgs_per_dir=15,
13
+ force_download=False):
14
+
15
+ if force_download or not os.path.exists("./clip"):
16
+ print("Downloading clip resources")
17
+ rand_num = random.randint(0, 100000)
18
+ tmp_dir = f"./tmp_{rand_num}"
19
+ snapshot_download(repo_type="dataset", repo_id="wendlerc/sdxl-unbox-clip-indices", cache_dir=tmp_dir)
20
+ clip_dirs = glob.glob(f"{tmp_dir}/**/down_10_5120", recursive=True)
21
+ if len(clip_dirs) > 0:
22
+ shutil.copytree(clip_dirs[0].replace("down_10_5120", ""), "./clip", dirs_exist_ok=True)
23
+ shutil.rmtree(tmp_dir)
24
+ else:
25
+ ValueError("Could not find clip indices in the downloaded repo.")
26
+
27
+ # Initialize CLIP service
28
+ clip_options = ClipOptions(
29
+ indice_folder="currently unused by knn.query()",
30
+ clip_model="ViT-B/32", #"open_clip:ViT-H-14",
31
+ enable_hdf5=False,
32
+ enable_faiss_memory_mapping=True,
33
+ columns_to_return=["image_path", "similarity"],
34
+ reorder_metadata_by_ivf_index=False,
35
+ enable_mclip_option=False,
36
+ use_jit=False,
37
+ use_arrow=False,
38
+ provide_safety_model=False,
39
+ provide_violence_detector=False,
40
+ provide_aesthetic_embeddings=False,
41
+ )
42
+ self.names = ["down.2.1", "mid.0", "up.0.0", "up.0.1"]
43
+ self.paths = ["./clip/down_10_5120/indices_paths.json",
44
+ "./clip/mid_10_5120/indices_paths.json",
45
+ "./clip/up0_10_5120/indices_paths.json",
46
+ "./clip/up_10_5120/indices_paths.json",]
47
+ self.knn_service = {}
48
+ for name, path in zip(self.names, self.paths):
49
+ resources = load_clip_indices(path, clip_options)
50
+ self.knn_service[name] = KnnService(clip_resources=resources)
51
+ self.num_images = num_images
52
+ self.imgs_per_dir = imgs_per_dir
53
+
54
+ def query_text(self, query, block):
55
+ if block not in self.names:
56
+ raise ValueError(f"Block must be one of {self.names}")
57
+ results = self.knn_service[block].query(
58
+ text_input=query,
59
+ num_images=self.num_images,
60
+ num_result_ids=self.num_images,
61
+ deduplicate=True,
62
+ )
63
+ feat_sims = defaultdict(list)
64
+ feat_scores = {}
65
+ for result in results:
66
+ feature_id = result["image_path"].split("/")[-2]
67
+ feat_sims[feature_id] += [result["similarity"]]
68
+ for fid, sims in feat_sims.items():
69
+ feat_scores[fid] = (sum(sims) / len(sims)) * (len(sims)/self.imgs_per_dir)
70
+
71
+ return dict(sorted(feat_scores.items(), key=lambda item: -item[1]))
scripts/collect_latents_dataset.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import io
4
+ import tarfile
5
+ import torch
6
+ import webdataset as wds
7
+ import numpy as np
8
+
9
+ from tqdm import tqdm
10
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
11
+ from SDLens.hooked_sd_pipeline import HookedStableDiffusionXLPipeline
12
+
13
+ import datetime
14
+ from datasets import load_dataset
15
+ from torch.utils.data import DataLoader
16
+ import diffusers
17
+ import fire
18
+
19
+ def main(save_path, start_at=0, finish_at=30000, dataset_batch_size=50):
20
+ blocks_to_save = [
21
+ 'unet.down_blocks.2.attentions.1',
22
+ 'unet.mid_block.attentions.0',
23
+ 'unet.up_blocks.0.attentions.0',
24
+ 'unet.up_blocks.0.attentions.1',
25
+ ]
26
+
27
+ # Initialization
28
+ dataset = load_dataset("guangyil/laion-coco-aesthetic", split="train", columns=["caption"], streaming=True).shuffle(seed=42)
29
+ pipe = HookedStableDiffusionXLPipeline.from_pretrained('stabilityai/sdxl-turbo')
30
+ pipe.to('cuda')
31
+ pipe.set_progress_bar_config(disable=True)
32
+ dataloader = DataLoader(dataset, batch_size=dataset_batch_size)
33
+
34
+ ct = datetime.datetime.now()
35
+ save_path = os.path.join(save_path, str(ct))
36
+ # Collecting dataset
37
+ os.makedirs(save_path, exist_ok=True)
38
+
39
+ writers = {
40
+ block: wds.TarWriter(f'{save_path}/{block}.tar') for block in blocks_to_save
41
+ }
42
+
43
+ writers.update({'images': wds.TarWriter(f'{save_path}/images.tar')})
44
+
45
+ def to_kwargs(kwargs_to_save):
46
+ kwargs = kwargs_to_save.copy()
47
+ seed = kwargs['seed']
48
+ del kwargs['seed']
49
+ kwargs['generator'] = torch.Generator(device="cpu").manual_seed(num_document)
50
+ return kwargs
51
+
52
+ dataloader_iter = iter(dataloader)
53
+ for num_document, batch in tqdm(enumerate(dataloader)):
54
+ if num_document < start_at:
55
+ continue
56
+
57
+ if num_document >= finish_at:
58
+ break
59
+
60
+ kwargs_to_save = {
61
+ 'prompt': batch['caption'],
62
+ 'positions_to_cache': blocks_to_save,
63
+ 'save_input': True,
64
+ 'save_output': True,
65
+ 'num_inference_steps': 1,
66
+ 'guidance_scale': 0.0,
67
+ 'seed': num_document,
68
+ 'output_type': 'pil'
69
+ }
70
+
71
+ kwargs = to_kwargs(kwargs_to_save)
72
+
73
+ output, cache = pipe.run_with_cache(
74
+ **kwargs
75
+ )
76
+
77
+ blocks = cache['input'].keys()
78
+ for block in blocks:
79
+ sample = {
80
+ "__key__": f"sample_{num_document}",
81
+ "output.pth": cache['output'][block],
82
+ "diff.pth": cache['output'][block] - cache['input'][block],
83
+ "gen_args.json": kwargs_to_save
84
+ }
85
+
86
+ writers[block].write(sample)
87
+ writers['images'].write({
88
+ "__key__": f"sample_{num_document}",
89
+ "images.npy": np.stack(output.images)
90
+ })
91
+
92
+ for block, writer in writers.items():
93
+ writer.close()
94
+
95
+ if __name__ == '__main__':
96
+ fire.Fire(main)
scripts/train_sae.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Adapted from
3
+ https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/train.py
4
+ '''
5
+
6
+
7
+ import os
8
+ import sys
9
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
10
+ from typing import Callable, Iterable, Iterator
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch.distributed import ReduceOp
17
+ from SAE.dataset_iterator import ActivationsDataloader
18
+ from SAE.sae import SparseAutoencoder, unit_norm_decoder_, unit_norm_decoder_grad_adjustment_
19
+ from SAE.sae_utils import SAETrainingConfig, Config
20
+
21
+ from types import SimpleNamespace
22
+ from typing import Optional, List
23
+ import json
24
+
25
+ import tqdm
26
+
27
+ def weighted_average(points: torch.Tensor, weights: torch.Tensor):
28
+ weights = weights / weights.sum()
29
+ return (points * weights.view(-1, 1)).sum(dim=0)
30
+
31
+
32
+ @torch.no_grad()
33
+ def geometric_median_objective(
34
+ median: torch.Tensor, points: torch.Tensor, weights: torch.Tensor
35
+ ) -> torch.Tensor:
36
+
37
+ norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
38
+
39
+ return (norms * weights).sum()
40
+
41
+
42
+ def compute_geometric_median(
43
+ points: torch.Tensor,
44
+ weights: Optional[torch.Tensor] = None,
45
+ eps: float = 1e-6,
46
+ maxiter: int = 100,
47
+ ftol: float = 1e-20,
48
+ do_log: bool = False,
49
+ ):
50
+ """
51
+ :param points: ``torch.Tensor`` of shape ``(n, d)``
52
+ :param weights: Optional ``torch.Tensor`` of shape :math:``(n,)``.
53
+ :param eps: Smallest allowed value of denominator, to avoid divide by zero.
54
+ Equivalently, this is a smoothing parameter. Default 1e-6.
55
+ :param maxiter: Maximum number of Weiszfeld iterations. Default 100
56
+ :param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
57
+ :param do_log: If true will return a log of function values encountered through the course of the algorithm
58
+ :return: SimpleNamespace object with fields
59
+ - `median`: estimate of the geometric median, which is a ``torch.Tensor`` object of shape :math:``(d,)``
60
+ - `termination`: string explaining how the algorithm terminated.
61
+ - `logs`: function values encountered through the course of the algorithm in a list (None if do_log is false).
62
+ """
63
+ with torch.no_grad():
64
+
65
+ if weights is None:
66
+ weights = torch.ones((points.shape[0],), device=points.device)
67
+ # initialize median estimate at mean
68
+ new_weights = weights
69
+ median = weighted_average(points, weights)
70
+ objective_value = geometric_median_objective(median, points, weights)
71
+ if do_log:
72
+ logs = [objective_value]
73
+ else:
74
+ logs = None
75
+
76
+ # Weiszfeld iterations
77
+ early_termination = False
78
+ pbar = tqdm.tqdm(range(maxiter))
79
+ for _ in pbar:
80
+ prev_obj_value = objective_value
81
+
82
+ norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
83
+ new_weights = weights / torch.clamp(norms, min=eps)
84
+ median = weighted_average(points, new_weights)
85
+ objective_value = geometric_median_objective(median, points, weights)
86
+
87
+ if logs is not None:
88
+ logs.append(objective_value)
89
+ if abs(prev_obj_value - objective_value) <= ftol * objective_value:
90
+ early_termination = True
91
+ break
92
+
93
+ pbar.set_description(f"Objective value: {objective_value:.4f}")
94
+
95
+ median = weighted_average(points, new_weights) # allow autodiff to track it
96
+ return SimpleNamespace(
97
+ median=median,
98
+ new_weights=new_weights,
99
+ termination=(
100
+ "function value converged within tolerance"
101
+ if early_termination
102
+ else "maximum iterations reached"
103
+ ),
104
+ logs=logs,
105
+ )
106
+
107
+ def maybe_transpose(x):
108
+ return x.T if not x.is_contiguous() and x.T.is_contiguous() else x
109
+
110
+ import wandb
111
+
112
+ RANK = 0
113
+
114
+ class Logger:
115
+ def __init__(self, sae_name, **kws):
116
+ self.vals = {}
117
+ self.enabled = (RANK == 0) and not kws.pop("dummy", False)
118
+ self.sae_name = sae_name
119
+
120
+ def logkv(self, k, v):
121
+ if self.enabled:
122
+ self.vals[f'{self.sae_name}/{k}'] = v.detach() if isinstance(v, torch.Tensor) else v
123
+ return v
124
+
125
+ def dumpkvs(self, step):
126
+ if self.enabled:
127
+ wandb.log(self.vals, step=step)
128
+ self.vals = {}
129
+
130
+
131
+ class FeaturesStats:
132
+ def __init__(self, dim, logger):
133
+ self.dim = dim
134
+ self.logger = logger
135
+ self.reinit()
136
+
137
+ def reinit(self):
138
+ self.n_activated = torch.zeros(self.dim, dtype=torch.long, device="cuda")
139
+ self.n = 0
140
+
141
+ def update(self, inds):
142
+ self.n += inds.shape[0]
143
+ inds = inds.flatten().detach()
144
+ self.n_activated.scatter_add_(0, inds, torch.ones_like(inds))
145
+
146
+ def log(self):
147
+ self.logger.logkv('activated', (self.n_activated / self.n + 1e-9).log10().cpu().numpy())
148
+
149
+ def training_loop_(
150
+ aes,
151
+ train_acts_iter,
152
+ loss_fn,
153
+ log_interval,
154
+ save_interval,
155
+ loggers,
156
+ sae_cfgs,
157
+ ):
158
+ sae_packs = []
159
+ for ae, cfg, logger in zip(aes, sae_cfgs, loggers):
160
+ pbar = tqdm.tqdm(unit=" steps", desc="Training Loss: ")
161
+ fstats = FeaturesStats(ae.n_dirs, logger)
162
+ opt = torch.optim.Adam(ae.parameters(), lr=cfg.lr, eps=cfg.eps, fused=True)
163
+ sae_packs.append((ae, cfg, logger, pbar, fstats, opt))
164
+
165
+ for i, flat_acts_train_batch in enumerate(train_acts_iter):
166
+ flat_acts_train_batch = flat_acts_train_batch.cuda()
167
+
168
+ for ae, cfg, logger, pbar, fstats, opt in sae_packs:
169
+ recons, info = ae(flat_acts_train_batch)
170
+ loss = loss_fn(ae, cfg, flat_acts_train_batch, recons, info, logger)
171
+
172
+ fstats.update(info['inds'])
173
+
174
+ bs = flat_acts_train_batch.shape[0]
175
+ logger.logkv('not-activated 1e4', (ae.stats_last_nonzero > 1e4 / bs).mean(dtype=float).item())
176
+ logger.logkv('not-activated 1e6', (ae.stats_last_nonzero > 1e6 / bs).mean(dtype=float).item())
177
+ logger.logkv('not-activated 1e7', (ae.stats_last_nonzero > 1e7 / bs).mean(dtype=float).item())
178
+
179
+ logger.logkv('explained variance', explained_variance(recons, flat_acts_train_batch))
180
+ logger.logkv('l2_div', (torch.linalg.norm(recons, dim=1) / torch.linalg.norm(flat_acts_train_batch, dim=1)).mean())
181
+
182
+ if (i + 1) % log_interval == 0:
183
+ fstats.log()
184
+ fstats.reinit()
185
+
186
+ if (i + 1) % save_interval == 0:
187
+ ae.save_to_disk(f"{cfg.save_path}/{i + 1}")
188
+
189
+ loss.backward()
190
+
191
+ unit_norm_decoder_(ae)
192
+ unit_norm_decoder_grad_adjustment_(ae)
193
+
194
+ opt.step()
195
+ opt.zero_grad()
196
+ logger.dumpkvs(i)
197
+
198
+ pbar.set_description(f"Training Loss {loss.item():.4f}")
199
+ pbar.update(1)
200
+
201
+
202
+ for ae, cfg, logger, pbar, fstats, opt in sae_packs:
203
+ pbar.close()
204
+ ae.save_to_disk(f"{cfg.save_path}/final")
205
+
206
+
207
+ def init_from_data_(ae, stats_acts_sample):
208
+ ae.pre_bias.data = (
209
+ compute_geometric_median(stats_acts_sample[:32768].float().cpu()).median.cuda().float()
210
+ )
211
+
212
+
213
+ def mse(recons, x):
214
+ # return ((recons - x) ** 2).sum(dim=-1).mean()
215
+ return ((recons - x) ** 2).mean()
216
+
217
+ def normalized_mse(recon: torch.Tensor, xs: torch.Tensor) -> torch.Tensor:
218
+ # only used for auxk
219
+ xs_mu = xs.mean(dim=0)
220
+
221
+ loss = mse(recon, xs) / mse(
222
+ xs_mu[None, :].broadcast_to(xs.shape), xs
223
+ )
224
+
225
+ return loss
226
+
227
+ def explained_variance(recons, x):
228
+ # Compute the variance of the difference
229
+ diff = x - recons
230
+ diff_var = torch.var(diff, dim=0, unbiased=False)
231
+
232
+ # Compute the variance of the original tensor
233
+ x_var = torch.var(x, dim=0, unbiased=False)
234
+
235
+ # Avoid division by zero
236
+ explained_var = 1 - diff_var / (x_var + 1e-8)
237
+
238
+ return explained_var.mean()
239
+
240
+
241
+ def main():
242
+ cfg = Config(json.load(open('SAE/config.json')))
243
+
244
+ dataloader = ActivationsDataloader(cfg.paths_to_latents, cfg.block_name, cfg.bs)
245
+
246
+ acts_iter = dataloader.iterate()
247
+ stats_acts_sample = torch.cat([
248
+ next(acts_iter).cpu() for _ in range(10)
249
+ ], dim=0)
250
+
251
+ aes = [
252
+ SparseAutoencoder(
253
+ n_dirs_local=sae.n_dirs,
254
+ d_model=sae.d_model,
255
+ k=sae.k,
256
+ auxk=sae.auxk,
257
+ dead_steps_threshold=sae.dead_toks_threshold // cfg.bs,
258
+ ).cuda()
259
+ for sae in cfg.saes
260
+ ]
261
+
262
+ for ae in aes:
263
+ init_from_data_(ae, stats_acts_sample)
264
+
265
+ mse_scale = (
266
+ 1 / ((stats_acts_sample.float().mean(dim=0) - stats_acts_sample.float()) ** 2).mean()
267
+ )
268
+ mse_scale = mse_scale.item()
269
+ del stats_acts_sample
270
+
271
+ wandb.init(
272
+ project=cfg.wandb_project,
273
+ name=cfg.wandb_name,
274
+ )
275
+
276
+ loggers = [Logger(
277
+ sae_name=cfg_sae.sae_name,
278
+ dummy=False,
279
+ ) for cfg_sae in cfg.saes]
280
+
281
+ training_loop_(
282
+ aes,
283
+ acts_iter,
284
+ lambda ae, cfg_sae, flat_acts_train_batch, recons, info, logger: (
285
+ # MSE
286
+ logger.logkv("train_recons", mse_scale * mse(recons, flat_acts_train_batch))
287
+ # AuxK
288
+ + logger.logkv(
289
+ "train_maxk_recons",
290
+ cfg_sae.auxk_coef
291
+ * normalized_mse(
292
+ ae.decode_sparse(
293
+ info["auxk_inds"],
294
+ info["auxk_vals"],
295
+ ),
296
+ flat_acts_train_batch - recons.detach() + ae.pre_bias.detach(),
297
+ ).nan_to_num(0),
298
+ )
299
+ ),
300
+ sae_cfgs = cfg.saes,
301
+ loggers=loggers,
302
+ log_interval=cfg.log_interval,
303
+ save_interval=cfg.save_interval,
304
+ )
305
+
306
+
307
+ if __name__ == "__main__":
308
+ main()
utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .hooks import *
utils/hooks.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Optional
2
+ import torch
3
+
4
+ class TimedHook:
5
+ def __init__(self, hook_fn, total_steps, apply_at_steps=None):
6
+ self.hook_fn = hook_fn
7
+ self.total_steps = total_steps
8
+ self.apply_at_steps = apply_at_steps
9
+ self.current_step = 0
10
+
11
+ def identity(self, module, input, output):
12
+ return output
13
+
14
+ def __call__(self, module, input, output):
15
+ if self.apply_at_steps is not None:
16
+ if self.current_step in self.apply_at_steps:
17
+ self.__increment()
18
+ return self.hook_fn(module, input, output)
19
+ else:
20
+ self.__increment()
21
+ return self.identity(module, input, output)
22
+
23
+ return self.identity(module, input, output)
24
+
25
+ def __increment(self):
26
+ if self.current_step < self.total_steps:
27
+ self.current_step += 1
28
+ else:
29
+ self.current_step = 0
30
+
31
+ @torch.no_grad()
32
+ def add_feature(sae, feature_idx, value, module, input, output):
33
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
34
+ activated = sae.encode(diff)
35
+ mask = torch.zeros_like(activated, device=diff.device)
36
+ mask[..., feature_idx] = value
37
+ to_add = mask @ sae.decoder.weight.T
38
+ return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
39
+
40
+ @torch.no_grad()
41
+ def add_feature_on_area_base(sae, feature_idx, activation_map, module, input, output):
42
+ return add_feature_on_area_base_both(sae, feature_idx, activation_map, module, input, output)
43
+
44
+ @torch.no_grad()
45
+ def add_feature_on_area_base_both(sae, feature_idx, activation_map, module, input, output):
46
+ # add the feature to cond and subtract from uncond
47
+ # this assumes diff.shape[0] == 2
48
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
49
+ activated = sae.encode(diff)
50
+ mask = torch.zeros_like(activated, device=diff.device)
51
+ if len(activation_map) == 2:
52
+ activation_map = activation_map.unsqueeze(0)
53
+ mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device)
54
+ to_add = mask @ sae.decoder.weight.T
55
+ to_add = to_add.chunk(2)
56
+ output[0][0] -= to_add[0].permute(0, 3, 1, 2).to(output[0].device)[0]
57
+ output[0][1] += to_add[1].permute(0, 3, 1, 2).to(output[0].device)[0]
58
+ return output
59
+
60
+
61
+ @torch.no_grad()
62
+ def add_feature_on_area_base_cond(sae, feature_idx, activation_map, module, input, output):
63
+ # add the feature to cond
64
+ # this assumes diff.shape[0] == 2
65
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
66
+ diff_uncond, diff_cond = diff.chunk(2)
67
+ activated = sae.encode(diff_cond)
68
+ mask = torch.zeros_like(activated, device=diff_cond.device)
69
+ if len(activation_map) == 2:
70
+ activation_map = activation_map.unsqueeze(0)
71
+ mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device)
72
+ to_add = mask @ sae.decoder.weight.T
73
+ output[0][1] += to_add.permute(0, 3, 1, 2).to(output[0].device)[0]
74
+ return output
75
+
76
+
77
+ @torch.no_grad()
78
+ def replace_with_feature_base(sae, feature_idx, value, module, input, output):
79
+ # this assumes diff.shape[0] == 2
80
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
81
+ diff_uncond, diff_cond = diff.chunk(2)
82
+ activated = sae.encode(diff_cond)
83
+ mask = torch.zeros_like(activated, device=diff_cond.device)
84
+ mask[..., feature_idx] = value
85
+ to_add = mask @ sae.decoder.weight.T
86
+ input[0][1] += to_add.permute(0, 3, 1, 2).to(output[0].device)[0]
87
+ return input
88
+
89
+
90
+ @torch.no_grad()
91
+ def add_feature_on_area_turbo(sae, feature_idx, activation_map, module, input, output):
92
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
93
+ activated = sae.encode(diff)
94
+ mask = torch.zeros_like(activated, device=diff.device)
95
+ if len(activation_map) == 2:
96
+ activation_map = activation_map.unsqueeze(0)
97
+ mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device)
98
+ to_add = mask @ sae.decoder.weight.T
99
+ return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
100
+
101
+ @torch.no_grad
102
+ def add_feature_on_area_flux(
103
+ sae,
104
+ feature_idx,
105
+ activation_map,
106
+ module,
107
+ input: torch.Tensor,
108
+ output: torch.Tensor,
109
+ ):
110
+
111
+ diff = (output - input).to(sae.device)
112
+ activated = sae.encode(diff)
113
+
114
+ # TODO: check
115
+ if len(activation_map) == 2:
116
+ activation_map = activation_map.unsqueeze(0)
117
+ mask = torch.zeros_like(activated, device=diff.device)
118
+ activation_map = activation_map.flatten()
119
+ mask[..., feature_idx] = activation_map.to(mask.device)
120
+ to_add = mask @ sae.decoder.weight.T
121
+ return output + to_add.to(output.device, output.dtype)
122
+
123
+
124
+
125
+ @torch.no_grad()
126
+ def replace_with_feature_turbo(sae, feature_idx, value, module, input, output):
127
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
128
+ activated = sae.encode(diff)
129
+ mask = torch.zeros_like(activated, device=diff.device)
130
+ mask[..., feature_idx] = value
131
+ to_add = mask @ sae.decoder.weight.T
132
+ return (input[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
133
+
134
+
135
+ @torch.no_grad()
136
+ def reconstruct_sae_hook(sae, module, input, output):
137
+ diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
138
+ activated = sae.encode(diff)
139
+ reconstructed = sae.decoder(activated) + sae.pre_bias
140
+ return (input[0] + reconstructed.permute(0, 3, 1, 2).to(output[0].device),)
141
+
142
+
143
+ @torch.no_grad()
144
+ def ablate_block(module, input, output):
145
+ return input