Upload 47 files
Browse files- .gitattributes +1 -0
- LICENSE +21 -0
- README.MD +82 -0
- SAE/__init__.py +1 -0
- SAE/config.json +23 -0
- SAE/dataset_iterator.py +53 -0
- SAE/sae.py +216 -0
- SAE/sae_utils.py +48 -0
- SDLens/__init__.py +2 -0
- SDLens/cache_and_edit/__init__.py +1 -0
- SDLens/cache_and_edit/activation_cache.py +147 -0
- SDLens/cache_and_edit/cached_pipeline.py +342 -0
- SDLens/cache_and_edit/edits.py +223 -0
- SDLens/cache_and_edit/flux_pipeline.py +998 -0
- SDLens/cache_and_edit/hooks.py +108 -0
- SDLens/cache_and_edit/inversion.py +568 -0
- SDLens/cache_and_edit/metrics.py +116 -0
- SDLens/cache_and_edit/qkv_cache.py +557 -0
- SDLens/cache_and_edit/scheduler_inversion.py +98 -0
- SDLens/hooked_scheduler.py +40 -0
- SDLens/hooked_sd_pipeline.py +319 -0
- app.ipynb +0 -0
- app.py +768 -0
- checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
- checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
- checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
- checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
- checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
- checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
- checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
- checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
- checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
- checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
- checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
- checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
- checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json +1 -0
- checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt +3 -0
- checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth +3 -0
- checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt +3 -0
- colab_requirements.txt +8 -0
- example.ipynb +0 -0
- requirements.txt +12 -0
- resourses/image.png +3 -0
- retrieval.py +71 -0
- scripts/collect_latents_dataset.py +96 -0
- scripts/train_sae.py +308 -0
- utils/__init__.py +1 -0
- utils/hooks.py +145 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
resourses/image.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Viacheslav Surkov
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.MD
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Unpacking SDXL Turbo: Interpreting Text-to-Image Models with Sparse Autoencoders
|
2 |
+
|
3 |
+
[](https://arxiv.org/abs/2410.22366)
|
4 |
+
[](https://huggingface.co/spaces/surokpro2/Unboxing_SDXL_with_SAEs)
|
5 |
+
[](https://colab.research.google.com/drive/1lWZ2yCRwCf4iuykvb-91QYUNkuzIwI3k?usp=sharing)
|
6 |
+
|
7 |
+
|
8 |
+

|
9 |
+
|
10 |
+
This repository contains code to reproduce results from our paper on using sparse autoencoders (SAEs) to analyze and interpret the internal representations of text-to-image diffusion models, specifically SDXL Turbo.
|
11 |
+
|
12 |
+
|
13 |
+
## Repository Structure
|
14 |
+
|
15 |
+
```
|
16 |
+
|-- SAE/ # Core sparse autoencoder implementation
|
17 |
+
|-- SDLens/ # Tools for analyzing diffusion models
|
18 |
+
| `-- hooked_sd_pipeline.py # Modified stable diffusion pipeline
|
19 |
+
|-- scripts/
|
20 |
+
| |-- collect_latents_dataset.py # Generate training data
|
21 |
+
| `-- train_sae.py # Train SAE models
|
22 |
+
|-- utils/
|
23 |
+
| `-- hooks.py # Hook utility functions
|
24 |
+
|-- checkpoints/ # Pretrained SAE model checkpoints
|
25 |
+
|-- app.py # Demo application
|
26 |
+
|-- app.ipynb # Interactive notebook demo
|
27 |
+
|-- example.ipynb # Usage examples
|
28 |
+
`-- requirements.txt # Python dependencies
|
29 |
+
```
|
30 |
+
|
31 |
+
## Installation
|
32 |
+
|
33 |
+
```bash
|
34 |
+
pip install -r requirements.txt
|
35 |
+
```
|
36 |
+
|
37 |
+
## Demo Application
|
38 |
+
|
39 |
+
You can try our gradio demo application (`app.ipynb`) to browse and experiment with 20K+ features of our trained SAEs out-of-the-box. You can find the same notebook on [Google Colab](https://colab.research.google.com/drive/1lWZ2yCRwCf4iuykvb-91QYUNkuzIwI3k?usp=sharing).
|
40 |
+
|
41 |
+
## Usage
|
42 |
+
|
43 |
+
1. Collect latent data from SDXL Turbo:
|
44 |
+
```bash
|
45 |
+
python scripts/collect_latents_dataset.py --save_path={your_save_path}
|
46 |
+
```
|
47 |
+
|
48 |
+
2. Train sparse autoencoders:
|
49 |
+
|
50 |
+
2.1. Insert the path of stored latents and directory to store checkpoints in `SAE/config.json`
|
51 |
+
|
52 |
+
2.2. Run the training script:
|
53 |
+
|
54 |
+
```bash
|
55 |
+
python scripts/train_sae.py
|
56 |
+
```
|
57 |
+
|
58 |
+
## Pretrained Models
|
59 |
+
|
60 |
+
We provide pretrained SAE checkpoints for 4 key transformer blocks in SDXL Turbo's U-Net in the `checkpoints` folder. See `example.ipynb` for analysis examples and visualization of learned features. More pretrained SAEs with different parameters are accessible through [HuggingFace repo](https://huggingface.co/surokpro2/sdxl-saes/tree/main).
|
61 |
+
|
62 |
+
|
63 |
+
## Citation
|
64 |
+
|
65 |
+
If you find this code useful in your research, please cite our paper:
|
66 |
+
|
67 |
+
```bibtex
|
68 |
+
@misc{surkov2024unpackingsdxlturbointerpreting,
|
69 |
+
title={Unpacking SDXL Turbo: Interpreting Text-to-Image Models with Sparse Autoencoders},
|
70 |
+
author={Viacheslav Surkov and Chris Wendler and Mikhail Terekhov and Justin Deschenaux and Robert West and Caglar Gulcehre},
|
71 |
+
year={2024},
|
72 |
+
eprint={2410.22366},
|
73 |
+
archivePrefix={arXiv},
|
74 |
+
primaryClass={cs.LG},
|
75 |
+
url={https://arxiv.org/abs/2410.22366},
|
76 |
+
}
|
77 |
+
```
|
78 |
+
|
79 |
+
## Acknowledgements
|
80 |
+
|
81 |
+
The SAE component was implemented based on [`openai/sparse_autoencoder`](https://github.com/openai/sparse_autoencoder) repository.
|
82 |
+
|
SAE/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .sae import SparseAutoencoder
|
SAE/config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"sae_configs": [
|
3 |
+
{
|
4 |
+
"d_model": 1280,
|
5 |
+
"n_dirs": 5120,
|
6 |
+
"k": 20
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"d_model": 1280,
|
10 |
+
"n_dirs": 640,
|
11 |
+
"k": 20
|
12 |
+
}
|
13 |
+
],
|
14 |
+
"bs": 4096,
|
15 |
+
"log_interval": 500,
|
16 |
+
"save_interval": 5000,
|
17 |
+
|
18 |
+
"paths_to_latents": [
|
19 |
+
"PASS YOUR PATHS HERE. Example /home/username/latents/<timestamp>. It should contain tar archives with latents."
|
20 |
+
],
|
21 |
+
"save_path_base": "<Your SAE save path>",
|
22 |
+
"block_name": "unet.down_blocks.2.attentions.1"
|
23 |
+
}
|
SAE/dataset_iterator.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import webdataset as wds
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
|
5 |
+
class ActivationsDataloader:
|
6 |
+
def __init__(self, paths_to_datasets, block_name, batch_size, output_or_diff='diff', num_in_buffer=50):
|
7 |
+
assert output_or_diff in ['diff', 'output'], "Provide 'output' or 'diff'"
|
8 |
+
|
9 |
+
self.dataset = wds.WebDataset(
|
10 |
+
[os.path.join(path_to_dataset, f"{block_name}.tar")
|
11 |
+
for path_to_dataset in paths_to_datasets]
|
12 |
+
).decode("torch")
|
13 |
+
self.iter = iter(self.dataset)
|
14 |
+
self.buffer = None
|
15 |
+
self.pointer = 0
|
16 |
+
self.num_in_buffer = num_in_buffer
|
17 |
+
self.output_or_diff = output_or_diff
|
18 |
+
self.batch_size = batch_size
|
19 |
+
self.one_size = None
|
20 |
+
|
21 |
+
def renew_buffer(self, to_retrieve):
|
22 |
+
to_merge = []
|
23 |
+
if self.buffer is not None and self.buffer.shape[0] > self.pointer:
|
24 |
+
to_merge = [self.buffer[self.pointer:].clone()]
|
25 |
+
del self.buffer
|
26 |
+
for _ in range(to_retrieve):
|
27 |
+
sample = next(self.iter)
|
28 |
+
latents = sample['output.pth'] if self.output_or_diff == 'output' else sample['diff.pth']
|
29 |
+
latents = latents.permute((0, 1, 3, 4, 2))
|
30 |
+
latents = latents.reshape((-1, latents.shape[-1]))
|
31 |
+
to_merge.append(latents.to('cuda'))
|
32 |
+
self.one_size = latents.shape[0]
|
33 |
+
self.buffer = torch.cat(to_merge, dim=0)
|
34 |
+
shuffled_indices = torch.randperm(self.buffer.shape[0])
|
35 |
+
self.buffer = self.buffer[shuffled_indices]
|
36 |
+
self.pointer = 0
|
37 |
+
|
38 |
+
def iterate(self):
|
39 |
+
while True:
|
40 |
+
if self.buffer == None or self.buffer.shape[0] - self.pointer < self.num_in_buffer * self.one_size * 4 // 5:
|
41 |
+
try:
|
42 |
+
to_retrieve = self.num_in_buffer if self.buffer is None else self.num_in_buffer // 5
|
43 |
+
self.renew_buffer(to_retrieve)
|
44 |
+
except StopIteration:
|
45 |
+
break
|
46 |
+
|
47 |
+
batch = self.buffer[self.pointer: self.pointer + self.batch_size]
|
48 |
+
self.pointer += self.batch_size
|
49 |
+
|
50 |
+
assert batch.shape[0] == self.batch_size
|
51 |
+
yield batch
|
52 |
+
|
53 |
+
|
SAE/sae.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Adapted from
|
3 |
+
https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/model.py
|
4 |
+
'''
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import os
|
9 |
+
import json
|
10 |
+
|
11 |
+
class SparseAutoencoder(nn.Module):
|
12 |
+
"""
|
13 |
+
Top-K Autoencoder with sparse kernels. Implements:
|
14 |
+
|
15 |
+
latents = relu(topk(encoder(x - pre_bias) + latent_bias))
|
16 |
+
recons = decoder(latents) + pre_bias
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
n_dirs_local: int,
|
22 |
+
d_model: int,
|
23 |
+
k: int,
|
24 |
+
auxk: int | None,
|
25 |
+
dead_steps_threshold: int,
|
26 |
+
):
|
27 |
+
super().__init__()
|
28 |
+
self.n_dirs_local = n_dirs_local
|
29 |
+
self.d_model = d_model
|
30 |
+
self.k = k
|
31 |
+
self.auxk = auxk
|
32 |
+
self.dead_steps_threshold = dead_steps_threshold
|
33 |
+
|
34 |
+
self.encoder = nn.Linear(d_model, n_dirs_local, bias=False)
|
35 |
+
self.decoder = nn.Linear(n_dirs_local, d_model, bias=False)
|
36 |
+
|
37 |
+
self.pre_bias = nn.Parameter(torch.zeros(d_model))
|
38 |
+
self.latent_bias = nn.Parameter(torch.zeros(n_dirs_local))
|
39 |
+
|
40 |
+
self.stats_last_nonzero: torch.Tensor
|
41 |
+
self.register_buffer("stats_last_nonzero", torch.zeros(n_dirs_local, dtype=torch.long))
|
42 |
+
|
43 |
+
def auxk_mask_fn(x):
|
44 |
+
dead_mask = self.stats_last_nonzero > dead_steps_threshold
|
45 |
+
x.data *= dead_mask # inplace to save memory
|
46 |
+
return x
|
47 |
+
|
48 |
+
self.auxk_mask_fn = auxk_mask_fn
|
49 |
+
|
50 |
+
## initialization
|
51 |
+
|
52 |
+
# "tied" init
|
53 |
+
self.decoder.weight.data = self.encoder.weight.data.T.clone()
|
54 |
+
|
55 |
+
# store decoder in column major layout for kernel
|
56 |
+
self.decoder.weight.data = self.decoder.weight.data.T.contiguous().T
|
57 |
+
|
58 |
+
unit_norm_decoder_(self)
|
59 |
+
|
60 |
+
def save_to_disk(self, path: str):
|
61 |
+
PATH_TO_CFG = 'config.json'
|
62 |
+
PATH_TO_WEIGHTS = 'state_dict.pth'
|
63 |
+
|
64 |
+
cfg = {
|
65 |
+
"n_dirs_local": self.n_dirs_local,
|
66 |
+
"d_model": self.d_model,
|
67 |
+
"k": self.k,
|
68 |
+
"auxk": self.auxk,
|
69 |
+
"dead_steps_threshold": self.dead_steps_threshold,
|
70 |
+
}
|
71 |
+
|
72 |
+
os.makedirs(path, exist_ok=True)
|
73 |
+
|
74 |
+
with open(os.path.join(path, PATH_TO_CFG), 'w') as f:
|
75 |
+
json.dump(cfg, f)
|
76 |
+
|
77 |
+
|
78 |
+
torch.save({
|
79 |
+
"state_dict": self.state_dict(),
|
80 |
+
}, os.path.join(path, PATH_TO_WEIGHTS))
|
81 |
+
|
82 |
+
|
83 |
+
@classmethod
|
84 |
+
def load_from_disk(cls, path: str):
|
85 |
+
PATH_TO_CFG = 'config.json'
|
86 |
+
PATH_TO_WEIGHTS = 'state_dict.pth'
|
87 |
+
|
88 |
+
with open(os.path.join(path, PATH_TO_CFG), 'r') as f:
|
89 |
+
cfg = json.load(f)
|
90 |
+
|
91 |
+
ae = cls(
|
92 |
+
n_dirs_local=cfg["n_dirs_local"],
|
93 |
+
d_model=cfg["d_model"],
|
94 |
+
k=cfg["k"],
|
95 |
+
auxk=cfg["auxk"],
|
96 |
+
dead_steps_threshold=cfg["dead_steps_threshold"],
|
97 |
+
)
|
98 |
+
|
99 |
+
state_dict = torch.load(os.path.join(path, PATH_TO_WEIGHTS))["state_dict"]
|
100 |
+
ae.load_state_dict(state_dict)
|
101 |
+
|
102 |
+
return ae
|
103 |
+
|
104 |
+
@property
|
105 |
+
def n_dirs(self):
|
106 |
+
return self.n_dirs_local
|
107 |
+
|
108 |
+
def encode(self, x):
|
109 |
+
x = x - self.pre_bias
|
110 |
+
latents_pre_act = self.encoder(x) + self.latent_bias
|
111 |
+
|
112 |
+
vals, inds = torch.topk(
|
113 |
+
latents_pre_act,
|
114 |
+
k=self.k,
|
115 |
+
dim=-1
|
116 |
+
)
|
117 |
+
|
118 |
+
latents = torch.zeros_like(latents_pre_act)
|
119 |
+
latents.scatter_(-1, inds, torch.relu(vals))
|
120 |
+
|
121 |
+
return latents
|
122 |
+
|
123 |
+
def forward(self, x):
|
124 |
+
x = x - self.pre_bias
|
125 |
+
latents_pre_act = self.encoder(x) + self.latent_bias
|
126 |
+
vals, inds = torch.topk(
|
127 |
+
latents_pre_act,
|
128 |
+
k=self.k,
|
129 |
+
dim=-1
|
130 |
+
)
|
131 |
+
|
132 |
+
## set num nonzero stat ##
|
133 |
+
tmp = torch.zeros_like(self.stats_last_nonzero)
|
134 |
+
tmp.scatter_add_(
|
135 |
+
0,
|
136 |
+
inds.reshape(-1),
|
137 |
+
(vals > 1e-3).to(tmp.dtype).reshape(-1),
|
138 |
+
)
|
139 |
+
self.stats_last_nonzero *= 1 - tmp.clamp(max=1)
|
140 |
+
self.stats_last_nonzero += 1
|
141 |
+
## end stats ##
|
142 |
+
|
143 |
+
## auxk
|
144 |
+
if self.auxk is not None: # for auxk
|
145 |
+
# IMPORTANT: has to go after stats update!
|
146 |
+
# WARN: auxk_mask_fn can mutate latents_pre_act!
|
147 |
+
auxk_vals, auxk_inds = torch.topk(
|
148 |
+
self.auxk_mask_fn(latents_pre_act),
|
149 |
+
k=self.auxk,
|
150 |
+
dim=-1
|
151 |
+
)
|
152 |
+
else:
|
153 |
+
auxk_inds = None
|
154 |
+
auxk_vals = None
|
155 |
+
|
156 |
+
## end auxk
|
157 |
+
|
158 |
+
vals = torch.relu(vals)
|
159 |
+
if auxk_vals is not None:
|
160 |
+
auxk_vals = torch.relu(auxk_vals)
|
161 |
+
|
162 |
+
|
163 |
+
rows, cols = latents_pre_act.size()
|
164 |
+
row_indices = torch.arange(rows).unsqueeze(1).expand(-1, self.k).reshape(-1)
|
165 |
+
vals = vals.reshape(-1)
|
166 |
+
inds = inds.reshape(-1)
|
167 |
+
|
168 |
+
indices = torch.stack([row_indices.to(inds.device), inds])
|
169 |
+
|
170 |
+
sparse_tensor = torch.sparse_coo_tensor(indices, vals, torch.Size([rows, cols]))
|
171 |
+
|
172 |
+
recons = torch.sparse.mm(sparse_tensor, self.decoder.weight.T) + self.pre_bias
|
173 |
+
|
174 |
+
|
175 |
+
return recons, {
|
176 |
+
"inds": inds,
|
177 |
+
"vals": vals,
|
178 |
+
"auxk_inds": auxk_inds,
|
179 |
+
"auxk_vals": auxk_vals,
|
180 |
+
}
|
181 |
+
|
182 |
+
|
183 |
+
def decode_sparse(self, inds, vals):
|
184 |
+
rows, cols = inds.shape[0], self.n_dirs
|
185 |
+
|
186 |
+
row_indices = torch.arange(rows).unsqueeze(1).expand(-1, inds.shape[1]).reshape(-1)
|
187 |
+
vals = vals.reshape(-1)
|
188 |
+
inds = inds.reshape(-1)
|
189 |
+
|
190 |
+
indices = torch.stack([row_indices.to(inds.device), inds])
|
191 |
+
|
192 |
+
sparse_tensor = torch.sparse_coo_tensor(indices, vals, torch.Size([rows, cols]))
|
193 |
+
|
194 |
+
recons = torch.sparse.mm(sparse_tensor, self.decoder.weight.T) + self.pre_bias
|
195 |
+
return recons
|
196 |
+
|
197 |
+
@property
|
198 |
+
def device(self):
|
199 |
+
return next(self.parameters()).device
|
200 |
+
|
201 |
+
|
202 |
+
def unit_norm_decoder_(autoencoder: SparseAutoencoder) -> None:
|
203 |
+
"""
|
204 |
+
Unit normalize the decoder weights of an autoencoder.
|
205 |
+
"""
|
206 |
+
autoencoder.decoder.weight.data /= autoencoder.decoder.weight.data.norm(dim=0)
|
207 |
+
|
208 |
+
|
209 |
+
def unit_norm_decoder_grad_adjustment_(autoencoder) -> None:
|
210 |
+
"""project out gradient information parallel to the dictionary vectors - assumes that the decoder is already unit normed"""
|
211 |
+
|
212 |
+
assert autoencoder.decoder.weight.grad is not None
|
213 |
+
|
214 |
+
autoencoder.decoder.weight.grad +=\
|
215 |
+
torch.einsum("bn,bn->n", autoencoder.decoder.weight.data, autoencoder.decoder.weight.grad) *\
|
216 |
+
autoencoder.decoder.weight.data * -1
|
SAE/sae_utils.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
import os
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class SAETrainingConfig:
|
7 |
+
d_model: int
|
8 |
+
n_dirs: int
|
9 |
+
k: int
|
10 |
+
block_name: str
|
11 |
+
bs: int
|
12 |
+
save_path_base: str
|
13 |
+
auxk: int = 256
|
14 |
+
lr: float = 1e-4
|
15 |
+
eps: float = 6.25e-10
|
16 |
+
dead_toks_threshold: int = 10_000_000
|
17 |
+
auxk_coef: float = 1/32
|
18 |
+
|
19 |
+
@property
|
20 |
+
def sae_name(self):
|
21 |
+
return f'{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}'
|
22 |
+
|
23 |
+
@property
|
24 |
+
def save_path(self):
|
25 |
+
return os.path.join(save_path_base, f'{self.block_name}_k{self.k}_hidden{self.n_dirs}_auxk{self.auxk}_bs{self.bs}_lr{self.lr}')
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class Config:
|
30 |
+
saes: list[SAETrainingConfig]
|
31 |
+
paths_to_latents: list[str]
|
32 |
+
log_interval: int
|
33 |
+
save_interval: int
|
34 |
+
bs: int
|
35 |
+
block_name: str
|
36 |
+
wandb_project: str = 'sdxl_sae_train'
|
37 |
+
wandb_name: str = 'multiple_sae'
|
38 |
+
|
39 |
+
def __init__(self, cfg_json):
|
40 |
+
self.saes = [SAETrainingConfig(**sae_cfg, block_name=cfg_json['block_name'], bs=cfg_json['bs'], save_path_base=cfg_json['save_path_base'])
|
41 |
+
for sae_cfg in cfg_json['sae_configs']]
|
42 |
+
|
43 |
+
self.save_path_base = cfg_json['save_path_base']
|
44 |
+
self.paths_to_latents = cfg_json['paths_to_latents']
|
45 |
+
self.log_interval = cfg_json['log_interval']
|
46 |
+
self.save_interval = cfg_json['save_interval']
|
47 |
+
self.bs = cfg_json['bs']
|
48 |
+
self.block_name = cfg_json['block_name']
|
SDLens/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .hooked_sd_pipeline import HookedIFPipeline, HookedStableDiffusionXLPipeline
|
2 |
+
from .cache_and_edit import CachedPipeline
|
SDLens/cache_and_edit/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .cached_pipeline import CachedPipeline
|
SDLens/cache_and_edit/activation_cache.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from collections import defaultdict
|
3 |
+
from typing import List
|
4 |
+
from diffusers.models.transformers.transformer_flux import FluxTransformerBlock, FluxSingleTransformerBlock
|
5 |
+
from SDLens.cache_and_edit.hooks import fix_inf_values_hook, register_general_hook
|
6 |
+
import torch
|
7 |
+
|
8 |
+
class ModelActivationCache(ABC):
|
9 |
+
"""
|
10 |
+
Cache for inference pass of a Diffusion Transformer.
|
11 |
+
Used to cache residual-streams and activations.
|
12 |
+
"""
|
13 |
+
def __init__(self):
|
14 |
+
|
15 |
+
# Initialize caches for "double transformer" blocks using the subclass-defined NUM_TRANSFORMER_BLOCKS
|
16 |
+
if hasattr(self, 'NUM_TRANSFORMER_BLOCKS'):
|
17 |
+
self.image_residual = []
|
18 |
+
self.image_activation = []
|
19 |
+
self.text_residual = []
|
20 |
+
self.text_activation = []
|
21 |
+
|
22 |
+
# Initialize caches for "single transformer" blocks if defined (using NUM_SINGLE_TRANSFORMER_BLOCKS)
|
23 |
+
if hasattr(self, 'NUM_SINGLE_TRANSFORMER_BLOCKS'):
|
24 |
+
self.text_image_residual = []
|
25 |
+
self.text_image_activation = []
|
26 |
+
|
27 |
+
def __str__(self):
|
28 |
+
lines = [f"{self.__class__.__name__}:"]
|
29 |
+
for attr_name, value in self.__dict__.items():
|
30 |
+
if isinstance(value, list) and all(isinstance(v, torch.Tensor) for v in value):
|
31 |
+
shapes = value[0].shape
|
32 |
+
lines.append(f" {attr_name}: len={len(value)}, shapes={shapes}")
|
33 |
+
else:
|
34 |
+
lines.append(f" {attr_name}: {type(value)}")
|
35 |
+
return "\n".join(lines)
|
36 |
+
|
37 |
+
def _repr_pretty_(self, p, cycle):
|
38 |
+
p.text(str(self))
|
39 |
+
|
40 |
+
@abstractmethod
|
41 |
+
def get_cache_info(self):
|
42 |
+
"""
|
43 |
+
Return details about the cache configuration.
|
44 |
+
Subclasses must implement this to provide info on their transformer block counts.
|
45 |
+
"""
|
46 |
+
pass
|
47 |
+
|
48 |
+
|
49 |
+
class FluxActivationCache(ModelActivationCache):
|
50 |
+
# Define number of blocks for double and single transformer caches
|
51 |
+
NUM_TRANSFORMER_BLOCKS = 19
|
52 |
+
NUM_SINGLE_TRANSFORMER_BLOCKS = 38
|
53 |
+
|
54 |
+
def __init__(self):
|
55 |
+
super().__init__()
|
56 |
+
|
57 |
+
def get_cache_info(self):
|
58 |
+
return {
|
59 |
+
"transformer_blocks": self.NUM_TRANSFORMER_BLOCKS,
|
60 |
+
"single_transformer_blocks": self.NUM_SINGLE_TRANSFORMER_BLOCKS,
|
61 |
+
}
|
62 |
+
|
63 |
+
def __getitem__(self, key):
|
64 |
+
return getattr(self, key)
|
65 |
+
|
66 |
+
|
67 |
+
class PixartActivationCache(ModelActivationCache):
|
68 |
+
# Define number of blocks for the double transformer cache only
|
69 |
+
NUM_TRANSFORMER_BLOCKS = 28
|
70 |
+
|
71 |
+
def __init__(self):
|
72 |
+
super().__init__()
|
73 |
+
|
74 |
+
def get_cache_info(self):
|
75 |
+
return {
|
76 |
+
"double_transformer_blocks": self.NUM_TRANSFORMER_BLOCKS,
|
77 |
+
}
|
78 |
+
|
79 |
+
|
80 |
+
class ActivationCacheHandler:
|
81 |
+
""" Used to manage ModelActivationCache of a Diffusion Transformer.
|
82 |
+
"""
|
83 |
+
|
84 |
+
def __init__(self, cache: ModelActivationCache, positions_to_cache: List[str] = None):
|
85 |
+
"""Constructor.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
cache (ModelActivationCache): cache to be used to store tensors.
|
89 |
+
positions_to_cache (List[str], optional): name of modules to cached.
|
90 |
+
If None, all modules as specified in `cache.get_cache_info()` will be cached. Defaults to None.
|
91 |
+
|
92 |
+
Raises:
|
93 |
+
NotImplementedError: _description_
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
_type_: _description_
|
97 |
+
"""
|
98 |
+
self.cache = cache
|
99 |
+
self.positions_to_cache = positions_to_cache
|
100 |
+
|
101 |
+
@torch.no_grad()
|
102 |
+
def cache_residual_and_activation_hook(self, *args):
|
103 |
+
"""
|
104 |
+
To be used as a forward hook on a Transformer Block.
|
105 |
+
It caches both residual_stream and activation (defined as output - residual_stream).
|
106 |
+
"""
|
107 |
+
|
108 |
+
if len(args) == 3:
|
109 |
+
module, input, output = args
|
110 |
+
elif len(args) == 4:
|
111 |
+
module, input, kwinput, output = args
|
112 |
+
|
113 |
+
if isinstance(module, FluxTransformerBlock):
|
114 |
+
encoder_hidden_states = output[0]
|
115 |
+
hidden_states = output[1]
|
116 |
+
|
117 |
+
self.cache.image_activation.append(hidden_states - kwinput["hidden_states"])
|
118 |
+
self.cache.text_activation.append(encoder_hidden_states - kwinput["encoder_hidden_states"])
|
119 |
+
self.cache.image_residual.append(kwinput["hidden_states"])
|
120 |
+
self.cache.text_residual.append(kwinput["encoder_hidden_states"])
|
121 |
+
|
122 |
+
elif isinstance(module, FluxSingleTransformerBlock):
|
123 |
+
self.cache.text_image_activation.append(output - kwinput["hidden_states"])
|
124 |
+
self.cache.text_image_residual.append(kwinput["hidden_states"])
|
125 |
+
else:
|
126 |
+
raise NotImplementedError(f"Caching not implemented for {type(module)}")
|
127 |
+
|
128 |
+
|
129 |
+
@property
|
130 |
+
def forward_hooks_dict(self):
|
131 |
+
|
132 |
+
# insert cache storing in dict
|
133 |
+
hooks = defaultdict(list)
|
134 |
+
|
135 |
+
if self.positions_to_cache is None:
|
136 |
+
for block_type, num_layers in self.cache.get_cache_info().items():
|
137 |
+
for i in range(num_layers):
|
138 |
+
module_name: str = f"transformer.{block_type}.{i}"
|
139 |
+
hooks[module_name].append(fix_inf_values_hook)
|
140 |
+
hooks[module_name].append(self.cache_residual_and_activation_hook)
|
141 |
+
else:
|
142 |
+
for module_name in self.positions_to_cache:
|
143 |
+
hooks[module_name].append(fix_inf_values_hook)
|
144 |
+
hooks[module_name].append(self.cache_residual_and_activation_hook)
|
145 |
+
|
146 |
+
return hooks
|
147 |
+
|
SDLens/cache_and_edit/cached_pipeline.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from functools import partial
|
3 |
+
import gc
|
4 |
+
from typing import Callable, Dict, List, Literal, Union, Optional, Type, Union
|
5 |
+
import torch
|
6 |
+
from SDLens.cache_and_edit.activation_cache import FluxActivationCache, ModelActivationCache, PixartActivationCache, ActivationCacheHandler
|
7 |
+
from diffusers.models.transformers.transformer_flux import FluxTransformerBlock, FluxSingleTransformerBlock
|
8 |
+
from SDLens.cache_and_edit.hooks import locate_block, register_general_hook, fix_inf_values_hook, edit_streams_hook
|
9 |
+
from SDLens.cache_and_edit.qkv_cache import QKVCacheFluxHandler, QKVCache, CachedFluxAttnProcessor3_0
|
10 |
+
from SDLens.cache_and_edit.scheduler_inversion import FlowMatchEulerDiscreteSchedulerForInversion
|
11 |
+
from SDLens.cache_and_edit.flux_pipeline import EditedFluxPipeline
|
12 |
+
|
13 |
+
from diffusers.pipelines import FluxPipeline
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
class CachedPipeline:
|
18 |
+
|
19 |
+
def __init__(self, pipe: EditedFluxPipeline, text_seq_length: int = 512):
|
20 |
+
|
21 |
+
assert isinstance(pipe, EditedFluxPipeline) or isinstance(pipe, FluxPipeline), "Use EditedFluxPipeline class in `cache_and_edit/flux_pipeline.py`"
|
22 |
+
self.pipe = pipe
|
23 |
+
self.text_seq_length = text_seq_length
|
24 |
+
|
25 |
+
# Cache handlers
|
26 |
+
self.activation_cache_handler = None
|
27 |
+
self.qkv_cache_handler = None
|
28 |
+
# keeps references to all registered hooks
|
29 |
+
self.registered_hooks = []
|
30 |
+
|
31 |
+
|
32 |
+
def setup_cache(self, use_activation_cache = True,
|
33 |
+
use_qkv_cache = False,
|
34 |
+
positions_to_cache: List[str] = None,
|
35 |
+
positions_to_cache_foreground: List[str] = None,
|
36 |
+
qkv_to_inject: QKVCache = None,
|
37 |
+
inject_kv_mode: Literal["image", "text", "both"] = None,
|
38 |
+
q_mask=None,
|
39 |
+
processor_class: Optional[Type] = CachedFluxAttnProcessor3_0
|
40 |
+
) -> None:
|
41 |
+
"""
|
42 |
+
Sets up activation_cache and/or qkv_cache, setting the required hooks.
|
43 |
+
If positions_to_cache is None, then all modules will be cached.
|
44 |
+
If inject_kv_mode is None, then qkv cache will be stored, otherwise qkv_to_inject will be injected.
|
45 |
+
"""
|
46 |
+
|
47 |
+
if use_activation_cache:
|
48 |
+
if isinstance(self.pipe, EditedFluxPipeline) or isinstance(self.pipe, FluxPipeline):
|
49 |
+
activation_cache = FluxActivationCache()
|
50 |
+
else:
|
51 |
+
raise AssertionError(f"activation cache not implemented for {type(self.pipe)}")
|
52 |
+
|
53 |
+
self.activation_cache_handler = ActivationCacheHandler(activation_cache, positions_to_cache)
|
54 |
+
# register hooks crated by activation_cache
|
55 |
+
self._set_hooks(position_hook_dict=self.activation_cache_handler.forward_hooks_dict,
|
56 |
+
with_kwargs=True)
|
57 |
+
|
58 |
+
if use_qkv_cache:
|
59 |
+
if isinstance(self.pipe, EditedFluxPipeline) or isinstance(self.pipe, FluxPipeline):
|
60 |
+
self.qkv_cache_handler = QKVCacheFluxHandler(self.pipe,
|
61 |
+
positions_to_cache,
|
62 |
+
positions_to_cache_foreground,
|
63 |
+
inject_kv=inject_kv_mode,
|
64 |
+
text_seq_length=self.text_seq_length,
|
65 |
+
q_mask=q_mask,
|
66 |
+
processor_class=processor_class,
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
raise AssertionError(f"QKV cache not implemented for {type(self.pipe)}")
|
70 |
+
|
71 |
+
# qkv_cache does not use hooks
|
72 |
+
|
73 |
+
|
74 |
+
@property
|
75 |
+
def activation_cache(self) -> ModelActivationCache:
|
76 |
+
return self.activation_cache_handler.cache if hasattr(self, "activation_cache_handler") and self.activation_cache_handler else None
|
77 |
+
|
78 |
+
|
79 |
+
@property
|
80 |
+
def qkv_cache(self) -> QKVCache:
|
81 |
+
return self.qkv_cache_handler.cache if hasattr(self, "qkv_cache_handler") and self.qkv_cache_handler else None
|
82 |
+
|
83 |
+
|
84 |
+
@torch.no_grad
|
85 |
+
def run(self,
|
86 |
+
prompt: Union[str, List[str]],
|
87 |
+
num_inference_steps: int = 1,
|
88 |
+
seed: int = 42,
|
89 |
+
width=1024,
|
90 |
+
height=1024,
|
91 |
+
cache_activations: bool = False,
|
92 |
+
cache_qkv: bool = False,
|
93 |
+
guidance_scale: float = 0.0,
|
94 |
+
positions_to_cache: List[str] = None,
|
95 |
+
empty_clip_embeddings: bool = True,
|
96 |
+
inverse: bool = False,
|
97 |
+
**kwargs):
|
98 |
+
"""run the pipeline, possibly cachine activations or QKV.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
prompt (str): Prompt to run the pipeline (NOTE: for Flux, parameters passed are prompt='' and prompt2=prompt)
|
102 |
+
num_inference_steps (int, optional): Num steps for inference. Defaults to 1.
|
103 |
+
seed (int, optional): seed for generators. Defaults to 42.
|
104 |
+
cache_activations (bool, optional): Whether to cache activations. Defaults to True.
|
105 |
+
cache_qkv (bool, optional): Whether to cache queries, keys, values. Defaults to False.
|
106 |
+
positions_to_cache (List[str], optional): list of blocks to cache.
|
107 |
+
If None, all transformer blocks will be cached. Defaults to None.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
_type_: same output as wrapped pipeline.
|
111 |
+
"""
|
112 |
+
|
113 |
+
# First, clear all registered hooks
|
114 |
+
self.clear_all_hooks()
|
115 |
+
|
116 |
+
# Delete cache already present
|
117 |
+
if self.activation_cache or self.qkv_cache:
|
118 |
+
|
119 |
+
if self.activation_cache:
|
120 |
+
del(self.activation_cache_handler.cache)
|
121 |
+
del(self.activation_cache_handler)
|
122 |
+
|
123 |
+
if self.qkv_cache:
|
124 |
+
# Necessary to delete the old cache.
|
125 |
+
self.qkv_cache_handler.clear_cache()
|
126 |
+
del(self.qkv_cache_handler)
|
127 |
+
|
128 |
+
gc.collect() # force Python to clean up unreachable objects
|
129 |
+
torch.cuda.empty_cache() # tell PyTorch to release unused GPU memory from its cache
|
130 |
+
|
131 |
+
# Setup cache again for the current inference pass
|
132 |
+
self.setup_cache(cache_activations, cache_qkv, positions_to_cache, inject_kv_mode=None)
|
133 |
+
|
134 |
+
assert isinstance(seed, int)
|
135 |
+
|
136 |
+
if isinstance(prompt, str):
|
137 |
+
empty_prompt = [""]
|
138 |
+
prompt = [prompt]
|
139 |
+
else:
|
140 |
+
empty_prompt = [""] * len(prompt)
|
141 |
+
|
142 |
+
gen = [torch.Generator(device="cpu").manual_seed(seed) for _ in range(len(prompt))]
|
143 |
+
|
144 |
+
if inverse:
|
145 |
+
# maybe create scheduler for inversion
|
146 |
+
if not hasattr(self, "inversion_scheduler"):
|
147 |
+
self.inversion_scheduler = FlowMatchEulerDiscreteSchedulerForInversion.from_config(
|
148 |
+
self.pipe.scheduler.config,
|
149 |
+
inverse=True
|
150 |
+
)
|
151 |
+
self.og_scheduler = self.pipe.scheduler
|
152 |
+
|
153 |
+
self.pipe.scheduler = self.inversion_scheduler
|
154 |
+
|
155 |
+
output = self.pipe(
|
156 |
+
prompt=empty_prompt if empty_clip_embeddings else prompt,
|
157 |
+
prompt_2=prompt,
|
158 |
+
num_inference_steps=num_inference_steps,
|
159 |
+
guidance_scale=guidance_scale,
|
160 |
+
generator=gen,
|
161 |
+
width=width,
|
162 |
+
height=height,
|
163 |
+
**kwargs
|
164 |
+
)
|
165 |
+
|
166 |
+
# Restore original scheduler
|
167 |
+
if inverse:
|
168 |
+
self.pipe.scheduler = self.og_scheduler
|
169 |
+
|
170 |
+
return output
|
171 |
+
|
172 |
+
@torch.no_grad
|
173 |
+
def run_inject_qkv(self,
|
174 |
+
prompt: Union[str, List[str]],
|
175 |
+
positions_to_inject: List[str] = None,
|
176 |
+
positions_to_inject_foreground: List[str] = None,
|
177 |
+
inject_kv_mode: Literal["image", "text", "both"] = "image",
|
178 |
+
num_inference_steps: int = 1,
|
179 |
+
guidance_scale: float = 0.0,
|
180 |
+
seed: int = 42,
|
181 |
+
empty_clip_embeddings: bool = True,
|
182 |
+
q_mask=None,
|
183 |
+
width: int = 1024,
|
184 |
+
height: int = 1024,
|
185 |
+
processor_class: Optional[Type] = CachedFluxAttnProcessor3_0,
|
186 |
+
**kwargs):
|
187 |
+
"""run the pipeline, possibly cachine activations or QKV.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
prompt (str): Prompt to run the pipeline (NOTE: for Flux, parameters passed are prompt='' and prompt2=prompt)
|
191 |
+
num_inference_steps (int, optional): Num steps for inference. Defaults to 1.
|
192 |
+
seed (int, optional): seed for generators. Defaults to 42.
|
193 |
+
cache_activations (bool, optional): Whether to cache activations. Defaults to True.
|
194 |
+
cache_qkv (bool, optional): Whether to cache queries, keys, values. Defaults to False.
|
195 |
+
positions_to_cache (List[str], optional): list of blocks to cache.
|
196 |
+
If None, all transformer blocks will be cached. Defaults to None.
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
_type_: same output as wrapped pipeline.
|
200 |
+
"""
|
201 |
+
|
202 |
+
# First, clear all registered hooks
|
203 |
+
self.clear_all_hooks()
|
204 |
+
|
205 |
+
# Delete previous QKVCache
|
206 |
+
if hasattr(self, "qkv_cache_handler") and self.qkv_cache_handler is not None:
|
207 |
+
self.qkv_cache_handler.clear_cache()
|
208 |
+
del(self.qkv_cache_handler)
|
209 |
+
gc.collect() # force Python to clean up unreachable objects
|
210 |
+
torch.cuda.empty_cache() # tell PyTorch to release unused GPU memory from its cache
|
211 |
+
|
212 |
+
# Will setup existing QKV cache to be injected
|
213 |
+
self.setup_cache(use_activation_cache=False,
|
214 |
+
use_qkv_cache=True,
|
215 |
+
positions_to_cache=positions_to_inject,
|
216 |
+
positions_to_cache_foreground=positions_to_inject_foreground,
|
217 |
+
inject_kv_mode=inject_kv_mode,
|
218 |
+
q_mask=q_mask,
|
219 |
+
processor_class=processor_class,
|
220 |
+
)
|
221 |
+
|
222 |
+
self.qkv_cache_handler
|
223 |
+
|
224 |
+
assert isinstance(seed, int)
|
225 |
+
|
226 |
+
if isinstance(prompt, str):
|
227 |
+
empty_prompt = [""]
|
228 |
+
prompt = [prompt]
|
229 |
+
else:
|
230 |
+
empty_prompt = [""] * len(prompt)
|
231 |
+
|
232 |
+
gen = [torch.Generator(device="cpu").manual_seed(seed) for _ in range(len(prompt))]
|
233 |
+
|
234 |
+
output = self.pipe(
|
235 |
+
prompt=empty_prompt if empty_clip_embeddings else prompt,
|
236 |
+
prompt_2=prompt,
|
237 |
+
num_inference_steps=num_inference_steps,
|
238 |
+
guidance_scale=guidance_scale,
|
239 |
+
generator=gen,
|
240 |
+
width=width,
|
241 |
+
height=height,
|
242 |
+
**kwargs
|
243 |
+
)
|
244 |
+
|
245 |
+
|
246 |
+
|
247 |
+
return output
|
248 |
+
|
249 |
+
|
250 |
+
def clear_all_hooks(self):
|
251 |
+
|
252 |
+
# 1. Clear all registered hooks
|
253 |
+
for hook in self.registered_hooks:
|
254 |
+
hook.remove()
|
255 |
+
self.registered_hooks = []
|
256 |
+
|
257 |
+
# 2. Eventually clear other hooks registered in the pipeline but not present here
|
258 |
+
# TODO: make it general for other models
|
259 |
+
for i in range(len(locate_block(self.pipe, "transformer.transformer_blocks"))):
|
260 |
+
locate_block(self.pipe, f"transformer.transformer_blocks.{i}")._forward_hooks.clear()
|
261 |
+
|
262 |
+
for i in range(len(locate_block(self.pipe, "transformer.single_transformer_blocks"))):
|
263 |
+
locate_block(self.pipe, f"transformer.single_transformer_blocks.{i}")._forward_hooks.clear()
|
264 |
+
|
265 |
+
|
266 |
+
def _set_hooks(self,
|
267 |
+
position_hook_dict: Dict[str, List[Callable]] = {},
|
268 |
+
position_pre_hook_dict: Dict[str, List[Callable]] = {},
|
269 |
+
with_kwargs=False
|
270 |
+
):
|
271 |
+
'''
|
272 |
+
Set hooks at specified positions and register them.
|
273 |
+
Args:
|
274 |
+
position_hook_dict: A dictionary mapping positions to hooks.
|
275 |
+
The keys are positions in the pipeline where the hooks should be registered.
|
276 |
+
The values are either a single hook or a list of hooks to be registered at the specified position.
|
277 |
+
Each hook should be a callable that takes three arguments: (module, input, output).
|
278 |
+
**kwargs: Keyword arguments to pass to the pipeline.
|
279 |
+
'''
|
280 |
+
|
281 |
+
# Register hooks
|
282 |
+
for is_pre_hook, hook_dict in [(True, position_pre_hook_dict), (False, position_hook_dict)]:
|
283 |
+
for position, hook in hook_dict.items():
|
284 |
+
assert isinstance(hook, list)
|
285 |
+
for h in hook:
|
286 |
+
self.registered_hooks.append(register_general_hook(self.pipe, position, h, with_kwargs, is_pre_hook))
|
287 |
+
|
288 |
+
|
289 |
+
def run_with_edit(self,
|
290 |
+
prompt: str,
|
291 |
+
edit_fn: callable,
|
292 |
+
layers_for_edit_fn: List[int],
|
293 |
+
stream: Literal['text', 'image', 'both'],
|
294 |
+
guidance_scale: float = 0.0,
|
295 |
+
seed=42,
|
296 |
+
num_inference_steps=1,
|
297 |
+
empty_clip_embeddings: bool = True,
|
298 |
+
width: int = 1024,
|
299 |
+
height: int = 1024,
|
300 |
+
**kwargs,
|
301 |
+
):
|
302 |
+
|
303 |
+
assert isinstance(seed, int)
|
304 |
+
|
305 |
+
self.clear_all_hooks()
|
306 |
+
|
307 |
+
|
308 |
+
# Setup hooks for edit_fn at the specified layers
|
309 |
+
# NOTE: edit_fn_hooks has to be Dict[str, List[Callable]]
|
310 |
+
edit_fn_hooks = {f"transformer.transformer_blocks.{layer}": [lambda *args: edit_streams_hook(*args, recompute_fn=edit_fn, stream=stream)]
|
311 |
+
for layer in layers_for_edit_fn if layer < 19}
|
312 |
+
edit_fn_hooks.update({f"transformer.single_transformer_blocks.{layer - 19}": [lambda *args: edit_streams_hook(*args, recompute_fn=edit_fn, stream=stream)]
|
313 |
+
for layer in layers_for_edit_fn if layer >= 19})
|
314 |
+
|
315 |
+
|
316 |
+
# register hooks in the pipe
|
317 |
+
self._set_hooks(position_hook_dict=edit_fn_hooks, with_kwargs=True)
|
318 |
+
|
319 |
+
# Create generators
|
320 |
+
|
321 |
+
if isinstance(prompt, str):
|
322 |
+
empty_prompt = [""]
|
323 |
+
prompt = [prompt]
|
324 |
+
else:
|
325 |
+
empty_prompt = [""] * len(prompt)
|
326 |
+
|
327 |
+
gen = [torch.Generator(device="cpu").manual_seed(seed) for _ in range(len(prompt))]
|
328 |
+
|
329 |
+
with torch.no_grad():
|
330 |
+
output = self.pipe(
|
331 |
+
prompt=empty_prompt if empty_clip_embeddings else prompt,
|
332 |
+
prompt_2=prompt,
|
333 |
+
num_inference_steps=num_inference_steps,
|
334 |
+
guidance_scale=guidance_scale,
|
335 |
+
generator=gen,
|
336 |
+
width=width,
|
337 |
+
height=height,
|
338 |
+
**kwargs
|
339 |
+
)
|
340 |
+
|
341 |
+
return output
|
342 |
+
|
SDLens/cache_and_edit/edits.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class Edit:
|
3 |
+
|
4 |
+
def __init__(self, ablator, vanilla_pre_forward_dict: Callable[[str, int], dict],
|
5 |
+
vanilla_forward_dict: Callable[[str, int], dict],
|
6 |
+
ablated_pre_forward_dict: Callable[[str, int], dict],
|
7 |
+
ablated_forward_dict: Callable[[str, int], dict],):
|
8 |
+
self.ablator=ablator
|
9 |
+
self.vanilla_seed = 42
|
10 |
+
self.vanilla_pre_forward_dict = vanilla_pre_forward_dict
|
11 |
+
self.vanilla_forward_dict = vanilla_forward_dict
|
12 |
+
|
13 |
+
self.ablated_seed = 42
|
14 |
+
self.ablated_pre_forward_dict = ablated_pre_forward_dict
|
15 |
+
self.ablated_forward_dict = ablated_forward_dict
|
16 |
+
|
17 |
+
|
18 |
+
def get_edit(name: str):
|
19 |
+
|
20 |
+
if name == "edit_streams":
|
21 |
+
ablator = TransformerActivationCache()
|
22 |
+
stream: str = kwargs["stream"]
|
23 |
+
layers = kwargs["layers"]
|
24 |
+
edit_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = kwargs["edit_fn"]
|
25 |
+
|
26 |
+
interventions = {f"transformer.transformer_blocks.{layer}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer < 19}
|
27 |
+
interventions.update({f"transformer.single_transformer_blocks.{layer - 19}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer >= 19})
|
28 |
+
|
29 |
+
return Ablation(ablator,
|
30 |
+
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
|
31 |
+
vanilla_forward_dict=lambda block_type, layer_num: {},
|
32 |
+
ablated_pre_forward_dict=lambda block_type, layer_num: {},
|
33 |
+
ablated_forward_dict=lambda block_type, layer_num: interventions,
|
34 |
+
)
|
35 |
+
|
36 |
+
|
37 |
+
"""
|
38 |
+
def get_ablation(name: str, **kwargs):
|
39 |
+
|
40 |
+
if name == "intermediate_text_stream_to_input":
|
41 |
+
|
42 |
+
ablator = TransformerActivationCache()
|
43 |
+
return Ablation(ablator,
|
44 |
+
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
|
45 |
+
vanilla_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.cache_attention_activation(*args, full_output=True)},
|
46 |
+
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.transformer_blocks.0": lambda *args: ablator.replace_stream_input(*args, stream="text")},
|
47 |
+
ablated_forward_dict=lambda block_type, layer_num: {})
|
48 |
+
elif name == "input_to_intermediate_text_stream":
|
49 |
+
ablator = TransformerActivationCache()
|
50 |
+
return Ablation(ablator,
|
51 |
+
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
|
52 |
+
vanilla_forward_dict=lambda block_type, layer_num: {f"transformer.transformer_blocks.0": lambda *args: ablator.cache_attention_activation(*args, full_output=True)},
|
53 |
+
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.replace_stream_input(*args, stream="text")},
|
54 |
+
ablated_forward_dict=lambda block_type, layer_num: {})
|
55 |
+
|
56 |
+
elif name == "set_input_text":
|
57 |
+
|
58 |
+
tensor: torch.Tensor = kwargs["tensor"]
|
59 |
+
|
60 |
+
ablator = TransformerActivationCache()
|
61 |
+
return Ablation(ablator,
|
62 |
+
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
|
63 |
+
vanilla_forward_dict=lambda block_type, layer_num: {},
|
64 |
+
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.replace_stream_input(*args, use_tensor=tensor, stream="text")},
|
65 |
+
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.clamp_output(*args)})
|
66 |
+
|
67 |
+
elif name == "replace_text_stream_activation":
|
68 |
+
ablator = AttentionAblationCacheHook()
|
69 |
+
weight = kwargs["weight"] if "weight" in kwargs else 1.0
|
70 |
+
|
71 |
+
|
72 |
+
return Ablation(ablator,
|
73 |
+
vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_text_stream},
|
74 |
+
vanilla_forward_dict=lambda block_type, layer_num: {},
|
75 |
+
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_and_inject_pre_forward},
|
76 |
+
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.set_ablated_attention(*args, weight=weight)})
|
77 |
+
|
78 |
+
elif name == "replace_text_stream":
|
79 |
+
ablator = TransformerActivationCache()
|
80 |
+
weight = kwargs["weight"] if "weight" in kwargs else 1.0
|
81 |
+
|
82 |
+
return Ablation(ablator,
|
83 |
+
vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_text_stream},
|
84 |
+
vanilla_forward_dict=lambda block_type, layer_num: {},
|
85 |
+
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_and_inject_pre_forward},
|
86 |
+
ablated_forward_dict=lambda block_type, layer_num: {})
|
87 |
+
|
88 |
+
|
89 |
+
elif name == "input=output":
|
90 |
+
return Ablation(None,
|
91 |
+
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
|
92 |
+
vanilla_forward_dict=lambda block_type, layer_num: {},
|
93 |
+
ablated_pre_forward_dict=lambda block_type, layer_num: {},
|
94 |
+
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablate_block(*args)})
|
95 |
+
|
96 |
+
elif name == "reweight_text_stream":
|
97 |
+
ablator = TransformerActivationCache()
|
98 |
+
|
99 |
+
residual_w=kwargs["residual_w"]
|
100 |
+
activation_w=kwargs["activation_w"]
|
101 |
+
|
102 |
+
return Ablation(ablator,
|
103 |
+
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
|
104 |
+
vanilla_forward_dict=lambda block_type, layer_num: {},
|
105 |
+
ablated_pre_forward_dict=lambda block_type, layer_num: {},
|
106 |
+
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.reweight_text_stream(*args, residual_w=residual_w, activation_w=activation_w)})
|
107 |
+
|
108 |
+
elif name == "add_input_text":
|
109 |
+
|
110 |
+
tensor: torch.Tensor = kwargs["tensor"]
|
111 |
+
|
112 |
+
ablator = TransformerActivationCache()
|
113 |
+
return Ablation(ablator,
|
114 |
+
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
|
115 |
+
vanilla_forward_dict=lambda block_type, layer_num: {},
|
116 |
+
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.add_text_stream_input(*args, use_tensor=tensor)},
|
117 |
+
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.0": lambda *args: ablator.clamp_output(*args)})
|
118 |
+
|
119 |
+
elif name == "nothing":
|
120 |
+
ablator = TransformerActivationCache()
|
121 |
+
return Ablation(ablator,
|
122 |
+
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
|
123 |
+
vanilla_forward_dict=lambda block_type, layer_num: {},
|
124 |
+
ablated_pre_forward_dict=lambda block_type, layer_num: {},
|
125 |
+
ablated_forward_dict=lambda block_type, layer_num: {})
|
126 |
+
|
127 |
+
elif name == "reweight_image_stream":
|
128 |
+
ablator = TransformerActivationCache()
|
129 |
+
residual_w=kwargs["residual_w"]
|
130 |
+
activation_w=kwargs["activation_w"]
|
131 |
+
|
132 |
+
return Ablation(ablator,
|
133 |
+
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
|
134 |
+
vanilla_forward_dict=lambda block_type, layer_num: {},
|
135 |
+
ablated_pre_forward_dict=lambda block_type, layer_num: {},
|
136 |
+
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.reweight_image_stream(*args, residual_w=residual_w, activation_w=activation_w)})
|
137 |
+
|
138 |
+
if name == "intermediate_image_stream_to_input":
|
139 |
+
|
140 |
+
ablator = TransformerActivationCache()
|
141 |
+
return Ablation(ablator,
|
142 |
+
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
|
143 |
+
vanilla_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": lambda *args: ablator.cache_attention_activation(*args, full_output=True)},
|
144 |
+
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.transformer_blocks.0": lambda *args: ablator.replace_stream_input(*args, stream='image')},
|
145 |
+
ablated_forward_dict=lambda block_type, layer_num: {})
|
146 |
+
|
147 |
+
|
148 |
+
elif name == "replace_text_stream_one_layer":
|
149 |
+
ablator = AttentionAblationCacheHook()
|
150 |
+
weight = kwargs["weight"] if "weight" in kwargs else 1.0
|
151 |
+
|
152 |
+
|
153 |
+
return Ablation(ablator,
|
154 |
+
vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_text_stream},
|
155 |
+
vanilla_forward_dict=lambda block_type, layer_num: {},
|
156 |
+
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.cache_and_inject_pre_forward},
|
157 |
+
ablated_forward_dict=lambda block_type, layer_num: {f"transformer.{block_type}.{layer_num}": ablator.restore_text_stream})
|
158 |
+
|
159 |
+
elif name == "replace_intermediate_representation":
|
160 |
+
ablator = TransformerActivationCache()
|
161 |
+
tensor: torch.Tensor = kwargs["tensor"]
|
162 |
+
|
163 |
+
return Ablation(ablator,
|
164 |
+
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
|
165 |
+
vanilla_forward_dict=lambda block_type, layer_num: {},
|
166 |
+
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.0": lambda *args: ablator.replace_stream_input(*args, use_tensor=tensor, stream='text_image')},
|
167 |
+
ablated_forward_dict=lambda block_type, layer_num: {})
|
168 |
+
|
169 |
+
elif name == "destroy_registers":
|
170 |
+
ablator = TransformerActivationCache()
|
171 |
+
layers: List[int] = kwargs['layers']
|
172 |
+
k: float = kwargs["k"]
|
173 |
+
stream: str = kwargs['stream']
|
174 |
+
random: bool = kwargs["random"] if "random" in kwargs else False
|
175 |
+
lowest_norm: bool = kwargs["lowest_norm"] if "lowest_norm" in kwargs else False
|
176 |
+
|
177 |
+
return Ablation(ablator,
|
178 |
+
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
|
179 |
+
vanilla_forward_dict=lambda block_type, layer_num: {},
|
180 |
+
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.{i}": lambda *args: ablator.destroy_registers(*args, k=k, stream=stream, random_ablation=random, lowest_norm=lowest_norm) for i in layers},
|
181 |
+
ablated_forward_dict=lambda block_type, layer_num: {})
|
182 |
+
|
183 |
+
elif name == "patch_registers":
|
184 |
+
ablator = TransformerActivationCache()
|
185 |
+
layers: List[int] = kwargs['layers']
|
186 |
+
k: float = kwargs["k"]
|
187 |
+
stream: str = kwargs['stream']
|
188 |
+
random: bool = kwargs["random"] if "random" in kwargs else False
|
189 |
+
lowest_norm: bool = kwargs["lowest_norm"] if "lowest_norm" in kwargs else False
|
190 |
+
|
191 |
+
return Ablation(ablator,
|
192 |
+
vanilla_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.{i}": lambda *args: ablator.destroy_registers(*args, k=k, stream=stream, random_ablation=random, lowest_norm=lowest_norm) for i in layers},
|
193 |
+
vanilla_forward_dict=lambda block_type, layer_num: {},
|
194 |
+
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer.single_transformer_blocks.{i}": lambda *args: ablator.set_cached_registers(*args, k=k, stream=stream, random_ablation=random, lowest_norm=lowest_norm) for i in layers},
|
195 |
+
ablated_forward_dict=lambda block_type, layer_num: {})
|
196 |
+
|
197 |
+
elif name == "add_registers":
|
198 |
+
ablator = TransformerActivationCache()
|
199 |
+
num_registers: int = kwargs["num_registers"]
|
200 |
+
|
201 |
+
return Ablation(ablator,
|
202 |
+
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
|
203 |
+
vanilla_forward_dict=lambda block_type, layer_num: {},
|
204 |
+
ablated_pre_forward_dict=lambda block_type, layer_num: {f"transformer": lambda *args: insert_extra_registers(*args, num_registers=num_registers)},
|
205 |
+
ablated_forward_dict=lambda block_type, layer_num: {f"transformer": lambda *args: discard_extra_registers(*args, num_registers=num_registers)},)
|
206 |
+
|
207 |
+
|
208 |
+
elif name == "edit_streams":
|
209 |
+
ablator = TransformerActivationCache()
|
210 |
+
stream: str = kwargs["stream"]
|
211 |
+
layers = kwargs["layers"]
|
212 |
+
edit_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = kwargs["edit_fn"]
|
213 |
+
|
214 |
+
interventions = {f"transformer.transformer_blocks.{layer}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer < 19}
|
215 |
+
interventions.update({f"transformer.single_transformer_blocks.{layer - 19}": lambda *args: ablator.edit_streams(*args, recompute_fn=partial(edit_fn, layer=layer), stream=stream) for layer in layers if layer >= 19})
|
216 |
+
|
217 |
+
return Ablation(ablator,
|
218 |
+
vanilla_pre_forward_dict=lambda block_type, layer_num: {},
|
219 |
+
vanilla_forward_dict=lambda block_type, layer_num: {},
|
220 |
+
ablated_pre_forward_dict=lambda block_type, layer_num: {},
|
221 |
+
ablated_forward_dict=lambda block_type, layer_num: interventions,
|
222 |
+
)
|
223 |
+
"""
|
SDLens/cache_and_edit/flux_pipeline.py
ADDED
@@ -0,0 +1,998 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
from transformers import (
|
21 |
+
CLIPImageProcessor,
|
22 |
+
CLIPTextModel,
|
23 |
+
CLIPTokenizer,
|
24 |
+
CLIPVisionModelWithProjection,
|
25 |
+
T5EncoderModel,
|
26 |
+
T5TokenizerFast,
|
27 |
+
)
|
28 |
+
|
29 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
30 |
+
from diffusers.loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
|
31 |
+
from diffusers.models.autoencoders import AutoencoderKL
|
32 |
+
from diffusers.models.transformers import FluxTransformer2DModel
|
33 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
34 |
+
from diffusers.utils import (
|
35 |
+
USE_PEFT_BACKEND,
|
36 |
+
is_torch_xla_available,
|
37 |
+
logging,
|
38 |
+
replace_example_docstring,
|
39 |
+
scale_lora_layers,
|
40 |
+
unscale_lora_layers,
|
41 |
+
)
|
42 |
+
from diffusers.utils.torch_utils import randn_tensor
|
43 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
44 |
+
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
45 |
+
|
46 |
+
|
47 |
+
if is_torch_xla_available():
|
48 |
+
import torch_xla.core.xla_model as xm
|
49 |
+
|
50 |
+
XLA_AVAILABLE = True
|
51 |
+
else:
|
52 |
+
XLA_AVAILABLE = False
|
53 |
+
|
54 |
+
|
55 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
56 |
+
|
57 |
+
EXAMPLE_DOC_STRING = """
|
58 |
+
Examples:
|
59 |
+
```py
|
60 |
+
>>> import torch
|
61 |
+
>>> from diffusers import FluxPipeline
|
62 |
+
|
63 |
+
>>> pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
|
64 |
+
>>> pipe.to("cuda")
|
65 |
+
>>> prompt = "A cat holding a sign that says hello world"
|
66 |
+
>>> # Depending on the variant being used, the pipeline call will slightly vary.
|
67 |
+
>>> # Refer to the pipeline documentation for more details.
|
68 |
+
>>> image = pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
|
69 |
+
>>> image.save("flux.png")
|
70 |
+
```
|
71 |
+
"""
|
72 |
+
|
73 |
+
|
74 |
+
def calculate_shift(
|
75 |
+
image_seq_len,
|
76 |
+
base_seq_len: int = 256,
|
77 |
+
max_seq_len: int = 4096,
|
78 |
+
base_shift: float = 0.5,
|
79 |
+
max_shift: float = 1.16,
|
80 |
+
):
|
81 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
82 |
+
b = base_shift - m * base_seq_len
|
83 |
+
mu = image_seq_len * m + b
|
84 |
+
return mu
|
85 |
+
|
86 |
+
|
87 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
88 |
+
def retrieve_timesteps(
|
89 |
+
scheduler,
|
90 |
+
num_inference_steps: Optional[int] = None,
|
91 |
+
device: Optional[Union[str, torch.device]] = None,
|
92 |
+
timesteps: Optional[List[int]] = None,
|
93 |
+
sigmas: Optional[List[float]] = None,
|
94 |
+
**kwargs,
|
95 |
+
):
|
96 |
+
r"""
|
97 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
98 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
scheduler (`SchedulerMixin`):
|
102 |
+
The scheduler to get timesteps from.
|
103 |
+
num_inference_steps (`int`):
|
104 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
105 |
+
must be `None`.
|
106 |
+
device (`str` or `torch.device`, *optional*):
|
107 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
108 |
+
timesteps (`List[int]`, *optional*):
|
109 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
110 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
111 |
+
sigmas (`List[float]`, *optional*):
|
112 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
113 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
117 |
+
second element is the number of inference steps.
|
118 |
+
"""
|
119 |
+
if timesteps is not None and sigmas is not None:
|
120 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
121 |
+
if timesteps is not None:
|
122 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
123 |
+
if not accepts_timesteps:
|
124 |
+
raise ValueError(
|
125 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
126 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
127 |
+
)
|
128 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
129 |
+
timesteps = scheduler.timesteps
|
130 |
+
num_inference_steps = len(timesteps)
|
131 |
+
elif sigmas is not None:
|
132 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
133 |
+
if not accept_sigmas:
|
134 |
+
raise ValueError(
|
135 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
136 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
137 |
+
)
|
138 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
139 |
+
timesteps = scheduler.timesteps
|
140 |
+
num_inference_steps = len(timesteps)
|
141 |
+
else:
|
142 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
143 |
+
timesteps = scheduler.timesteps
|
144 |
+
return timesteps, num_inference_steps
|
145 |
+
|
146 |
+
|
147 |
+
class EditedFluxPipeline(
|
148 |
+
DiffusionPipeline,
|
149 |
+
FluxLoraLoaderMixin,
|
150 |
+
FromSingleFileMixin,
|
151 |
+
TextualInversionLoaderMixin,
|
152 |
+
FluxIPAdapterMixin,
|
153 |
+
):
|
154 |
+
r"""
|
155 |
+
The Flux pipeline for text-to-image generation.
|
156 |
+
|
157 |
+
Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
|
158 |
+
|
159 |
+
Args:
|
160 |
+
transformer ([`FluxTransformer2DModel`]):
|
161 |
+
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
162 |
+
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
163 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
164 |
+
vae ([`AutoencoderKL`]):
|
165 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
166 |
+
text_encoder ([`CLIPTextModel`]):
|
167 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
168 |
+
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
169 |
+
text_encoder_2 ([`T5EncoderModel`]):
|
170 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
171 |
+
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
172 |
+
tokenizer (`CLIPTokenizer`):
|
173 |
+
Tokenizer of class
|
174 |
+
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
|
175 |
+
tokenizer_2 (`T5TokenizerFast`):
|
176 |
+
Second Tokenizer of class
|
177 |
+
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
|
178 |
+
"""
|
179 |
+
|
180 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
|
181 |
+
_optional_components = ["image_encoder", "feature_extractor"]
|
182 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
183 |
+
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
187 |
+
vae: AutoencoderKL,
|
188 |
+
text_encoder: CLIPTextModel,
|
189 |
+
tokenizer: CLIPTokenizer,
|
190 |
+
text_encoder_2: T5EncoderModel,
|
191 |
+
tokenizer_2: T5TokenizerFast,
|
192 |
+
transformer: FluxTransformer2DModel,
|
193 |
+
image_encoder: CLIPVisionModelWithProjection = None,
|
194 |
+
feature_extractor: CLIPImageProcessor = None,
|
195 |
+
):
|
196 |
+
super().__init__()
|
197 |
+
|
198 |
+
self.register_modules(
|
199 |
+
vae=vae,
|
200 |
+
text_encoder=text_encoder,
|
201 |
+
text_encoder_2=text_encoder_2,
|
202 |
+
tokenizer=tokenizer,
|
203 |
+
tokenizer_2=tokenizer_2,
|
204 |
+
transformer=transformer,
|
205 |
+
scheduler=scheduler,
|
206 |
+
image_encoder=image_encoder,
|
207 |
+
feature_extractor=feature_extractor,
|
208 |
+
)
|
209 |
+
self.vae_scale_factor = (
|
210 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
211 |
+
)
|
212 |
+
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
213 |
+
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
214 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
215 |
+
self.tokenizer_max_length = (
|
216 |
+
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
217 |
+
)
|
218 |
+
self.default_sample_size = 128
|
219 |
+
|
220 |
+
def _get_t5_prompt_embeds(
|
221 |
+
self,
|
222 |
+
prompt: Union[str, List[str]] = None,
|
223 |
+
num_images_per_prompt: int = 1,
|
224 |
+
max_sequence_length: int = 512,
|
225 |
+
device: Optional[torch.device] = None,
|
226 |
+
dtype: Optional[torch.dtype] = None,
|
227 |
+
):
|
228 |
+
device = device or self._execution_device
|
229 |
+
dtype = dtype or self.text_encoder.dtype
|
230 |
+
|
231 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
232 |
+
batch_size = len(prompt)
|
233 |
+
|
234 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
235 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
|
236 |
+
|
237 |
+
text_inputs = self.tokenizer_2(
|
238 |
+
prompt,
|
239 |
+
padding="max_length",
|
240 |
+
max_length=max_sequence_length,
|
241 |
+
truncation=True,
|
242 |
+
return_length=False,
|
243 |
+
return_overflowing_tokens=False,
|
244 |
+
return_tensors="pt",
|
245 |
+
)
|
246 |
+
text_input_ids = text_inputs.input_ids
|
247 |
+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
|
248 |
+
|
249 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
250 |
+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
251 |
+
logger.warning(
|
252 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
253 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
254 |
+
)
|
255 |
+
|
256 |
+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
257 |
+
|
258 |
+
dtype = self.text_encoder_2.dtype
|
259 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
260 |
+
|
261 |
+
_, seq_len, _ = prompt_embeds.shape
|
262 |
+
|
263 |
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
264 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
265 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
266 |
+
|
267 |
+
return prompt_embeds
|
268 |
+
|
269 |
+
def _get_clip_prompt_embeds(
|
270 |
+
self,
|
271 |
+
prompt: Union[str, List[str]],
|
272 |
+
num_images_per_prompt: int = 1,
|
273 |
+
device: Optional[torch.device] = None,
|
274 |
+
):
|
275 |
+
device = device or self._execution_device
|
276 |
+
|
277 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
278 |
+
batch_size = len(prompt)
|
279 |
+
|
280 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
281 |
+
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
282 |
+
|
283 |
+
text_inputs = self.tokenizer(
|
284 |
+
prompt,
|
285 |
+
padding="max_length",
|
286 |
+
max_length=self.tokenizer_max_length,
|
287 |
+
truncation=True,
|
288 |
+
return_overflowing_tokens=False,
|
289 |
+
return_length=False,
|
290 |
+
return_tensors="pt",
|
291 |
+
)
|
292 |
+
|
293 |
+
text_input_ids = text_inputs.input_ids
|
294 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
295 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
296 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
|
297 |
+
logger.warning(
|
298 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
299 |
+
f" {self.tokenizer_max_length} tokens: {removed_text}"
|
300 |
+
)
|
301 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
|
302 |
+
|
303 |
+
# Use pooled output of CLIPTextModel
|
304 |
+
prompt_embeds = prompt_embeds.pooler_output
|
305 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
306 |
+
|
307 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
308 |
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
|
309 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
310 |
+
|
311 |
+
return prompt_embeds
|
312 |
+
|
313 |
+
def encode_prompt(
|
314 |
+
self,
|
315 |
+
prompt: Union[str, List[str]],
|
316 |
+
prompt_2: Union[str, List[str]],
|
317 |
+
device: Optional[torch.device] = None,
|
318 |
+
num_images_per_prompt: int = 1,
|
319 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
320 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
321 |
+
max_sequence_length: int = 512,
|
322 |
+
lora_scale: Optional[float] = None,
|
323 |
+
):
|
324 |
+
r"""
|
325 |
+
|
326 |
+
Args:
|
327 |
+
prompt (`str` or `List[str]`, *optional*):
|
328 |
+
prompt to be encoded
|
329 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
330 |
+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
331 |
+
used in all text-encoders
|
332 |
+
device: (`torch.device`):
|
333 |
+
torch device
|
334 |
+
num_images_per_prompt (`int`):
|
335 |
+
number of images that should be generated per prompt
|
336 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
337 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
338 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
339 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
340 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
341 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
342 |
+
lora_scale (`float`, *optional*):
|
343 |
+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
344 |
+
"""
|
345 |
+
device = device or self._execution_device
|
346 |
+
|
347 |
+
# set lora scale so that monkey patched LoRA
|
348 |
+
# function of text encoder can correctly access it
|
349 |
+
if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
|
350 |
+
self._lora_scale = lora_scale
|
351 |
+
|
352 |
+
# dynamically adjust the LoRA scale
|
353 |
+
if self.text_encoder is not None and USE_PEFT_BACKEND:
|
354 |
+
scale_lora_layers(self.text_encoder, lora_scale)
|
355 |
+
if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
|
356 |
+
scale_lora_layers(self.text_encoder_2, lora_scale)
|
357 |
+
|
358 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
359 |
+
|
360 |
+
if prompt_embeds is None:
|
361 |
+
prompt_2 = prompt_2 or prompt
|
362 |
+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
363 |
+
|
364 |
+
# We only use the pooled prompt output from the CLIPTextModel
|
365 |
+
pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
366 |
+
prompt=prompt,
|
367 |
+
device=device,
|
368 |
+
num_images_per_prompt=num_images_per_prompt,
|
369 |
+
)
|
370 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
371 |
+
prompt=prompt_2,
|
372 |
+
num_images_per_prompt=num_images_per_prompt,
|
373 |
+
max_sequence_length=max_sequence_length,
|
374 |
+
device=device,
|
375 |
+
)
|
376 |
+
|
377 |
+
if self.text_encoder is not None:
|
378 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
379 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
380 |
+
unscale_lora_layers(self.text_encoder, lora_scale)
|
381 |
+
|
382 |
+
if self.text_encoder_2 is not None:
|
383 |
+
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
384 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
385 |
+
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
386 |
+
|
387 |
+
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
388 |
+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
389 |
+
|
390 |
+
return prompt_embeds, pooled_prompt_embeds, text_ids
|
391 |
+
|
392 |
+
def encode_image(self, image, device, num_images_per_prompt):
|
393 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
394 |
+
|
395 |
+
if not isinstance(image, torch.Tensor):
|
396 |
+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
|
397 |
+
|
398 |
+
image = image.to(device=device, dtype=dtype)
|
399 |
+
image_embeds = self.image_encoder(image).image_embeds
|
400 |
+
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
|
401 |
+
return image_embeds
|
402 |
+
|
403 |
+
def prepare_ip_adapter_image_embeds(
|
404 |
+
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
|
405 |
+
):
|
406 |
+
image_embeds = []
|
407 |
+
if ip_adapter_image_embeds is None:
|
408 |
+
if not isinstance(ip_adapter_image, list):
|
409 |
+
ip_adapter_image = [ip_adapter_image]
|
410 |
+
|
411 |
+
if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
|
412 |
+
raise ValueError(
|
413 |
+
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
414 |
+
)
|
415 |
+
|
416 |
+
for single_ip_adapter_image, image_proj_layer in zip(
|
417 |
+
ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
|
418 |
+
):
|
419 |
+
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
|
420 |
+
|
421 |
+
image_embeds.append(single_image_embeds[None, :])
|
422 |
+
else:
|
423 |
+
for single_image_embeds in ip_adapter_image_embeds:
|
424 |
+
image_embeds.append(single_image_embeds)
|
425 |
+
|
426 |
+
ip_adapter_image_embeds = []
|
427 |
+
for i, single_image_embeds in enumerate(image_embeds):
|
428 |
+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
429 |
+
single_image_embeds = single_image_embeds.to(device=device)
|
430 |
+
ip_adapter_image_embeds.append(single_image_embeds)
|
431 |
+
|
432 |
+
return ip_adapter_image_embeds
|
433 |
+
|
434 |
+
def check_inputs(
|
435 |
+
self,
|
436 |
+
prompt,
|
437 |
+
prompt_2,
|
438 |
+
height,
|
439 |
+
width,
|
440 |
+
negative_prompt=None,
|
441 |
+
negative_prompt_2=None,
|
442 |
+
prompt_embeds=None,
|
443 |
+
negative_prompt_embeds=None,
|
444 |
+
pooled_prompt_embeds=None,
|
445 |
+
negative_pooled_prompt_embeds=None,
|
446 |
+
callback_on_step_end_tensor_inputs=None,
|
447 |
+
max_sequence_length=None,
|
448 |
+
):
|
449 |
+
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
|
450 |
+
logger.warning(
|
451 |
+
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
452 |
+
)
|
453 |
+
|
454 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
455 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
456 |
+
):
|
457 |
+
raise ValueError(
|
458 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
459 |
+
)
|
460 |
+
|
461 |
+
if prompt is not None and prompt_embeds is not None:
|
462 |
+
raise ValueError(
|
463 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
464 |
+
" only forward one of the two."
|
465 |
+
)
|
466 |
+
elif prompt_2 is not None and prompt_embeds is not None:
|
467 |
+
raise ValueError(
|
468 |
+
f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
469 |
+
" only forward one of the two."
|
470 |
+
)
|
471 |
+
elif prompt is None and prompt_embeds is None:
|
472 |
+
raise ValueError(
|
473 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
474 |
+
)
|
475 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
476 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
477 |
+
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
478 |
+
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
479 |
+
|
480 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
481 |
+
raise ValueError(
|
482 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
483 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
484 |
+
)
|
485 |
+
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
|
486 |
+
raise ValueError(
|
487 |
+
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
|
488 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
489 |
+
)
|
490 |
+
|
491 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
492 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
493 |
+
raise ValueError(
|
494 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
495 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
496 |
+
f" {negative_prompt_embeds.shape}."
|
497 |
+
)
|
498 |
+
|
499 |
+
if prompt_embeds is not None and pooled_prompt_embeds is None:
|
500 |
+
raise ValueError(
|
501 |
+
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
502 |
+
)
|
503 |
+
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
|
504 |
+
raise ValueError(
|
505 |
+
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
|
506 |
+
)
|
507 |
+
|
508 |
+
if max_sequence_length is not None and max_sequence_length > 512:
|
509 |
+
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
510 |
+
|
511 |
+
@staticmethod
|
512 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
513 |
+
latent_image_ids = torch.zeros(height, width, 3)
|
514 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
515 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
516 |
+
|
517 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
518 |
+
|
519 |
+
latent_image_ids = latent_image_ids.reshape(
|
520 |
+
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
521 |
+
)
|
522 |
+
|
523 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
524 |
+
|
525 |
+
@staticmethod
|
526 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
527 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
528 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
529 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
530 |
+
|
531 |
+
return latents
|
532 |
+
|
533 |
+
@staticmethod
|
534 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
535 |
+
batch_size, num_patches, channels = latents.shape
|
536 |
+
|
537 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
538 |
+
# latent height and width to be divisible by 2.
|
539 |
+
height = 2 * (int(height) // (vae_scale_factor * 2))
|
540 |
+
width = 2 * (int(width) // (vae_scale_factor * 2))
|
541 |
+
|
542 |
+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
|
543 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
544 |
+
|
545 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
|
546 |
+
|
547 |
+
return latents
|
548 |
+
|
549 |
+
def enable_vae_slicing(self):
|
550 |
+
r"""
|
551 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
552 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
553 |
+
"""
|
554 |
+
self.vae.enable_slicing()
|
555 |
+
|
556 |
+
def disable_vae_slicing(self):
|
557 |
+
r"""
|
558 |
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
559 |
+
computing decoding in one step.
|
560 |
+
"""
|
561 |
+
self.vae.disable_slicing()
|
562 |
+
|
563 |
+
def enable_vae_tiling(self):
|
564 |
+
r"""
|
565 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
566 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
567 |
+
processing larger images.
|
568 |
+
"""
|
569 |
+
self.vae.enable_tiling()
|
570 |
+
|
571 |
+
def disable_vae_tiling(self):
|
572 |
+
r"""
|
573 |
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
574 |
+
computing decoding in one step.
|
575 |
+
"""
|
576 |
+
self.vae.disable_tiling()
|
577 |
+
|
578 |
+
|
579 |
+
def prepare_latents(
|
580 |
+
self,
|
581 |
+
batch_size,
|
582 |
+
num_channels_latents,
|
583 |
+
height,
|
584 |
+
width,
|
585 |
+
dtype,
|
586 |
+
device,
|
587 |
+
generator,
|
588 |
+
latents=None,
|
589 |
+
):
|
590 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
591 |
+
# latent height and width to be divisible by 2.
|
592 |
+
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
593 |
+
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
594 |
+
|
595 |
+
shape = (batch_size, num_channels_latents, height, width)
|
596 |
+
|
597 |
+
if latents is not None:
|
598 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
599 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
600 |
+
|
601 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
602 |
+
raise ValueError(
|
603 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
604 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
605 |
+
)
|
606 |
+
|
607 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
608 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
609 |
+
|
610 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
611 |
+
|
612 |
+
return latents, latent_image_ids
|
613 |
+
|
614 |
+
@property
|
615 |
+
def guidance_scale(self):
|
616 |
+
return self._guidance_scale
|
617 |
+
|
618 |
+
@property
|
619 |
+
def joint_attention_kwargs(self):
|
620 |
+
return self._joint_attention_kwargs
|
621 |
+
|
622 |
+
@property
|
623 |
+
def num_timesteps(self):
|
624 |
+
return self._num_timesteps
|
625 |
+
|
626 |
+
@property
|
627 |
+
def interrupt(self):
|
628 |
+
return self._interrupt
|
629 |
+
|
630 |
+
@torch.no_grad()
|
631 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
632 |
+
def __call__(
|
633 |
+
self,
|
634 |
+
prompt: Union[str, List[str]] = None,
|
635 |
+
prompt_2: Optional[Union[str, List[str]]] = None,
|
636 |
+
negative_prompt: Union[str, List[str]] = None,
|
637 |
+
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
638 |
+
true_cfg_scale: float = 1.0,
|
639 |
+
height: Optional[int] = None,
|
640 |
+
width: Optional[int] = None,
|
641 |
+
num_inference_steps: int = 28,
|
642 |
+
sigmas: Optional[List[float]] = None,
|
643 |
+
guidance_scale: float = 3.5,
|
644 |
+
num_images_per_prompt: Optional[int] = 1,
|
645 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
646 |
+
latents: Optional[torch.FloatTensor] = None,
|
647 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
648 |
+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
649 |
+
ip_adapter_image: Optional[PipelineImageInput] = None,
|
650 |
+
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
651 |
+
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
|
652 |
+
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
|
653 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
654 |
+
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
655 |
+
output_type: Optional[str] = "pil",
|
656 |
+
return_dict: bool = True,
|
657 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
658 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
659 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
660 |
+
max_sequence_length: int = 512,
|
661 |
+
is_inverted_generation: bool = False,
|
662 |
+
inverted_latents_list: List[torch.Tensor] = None,
|
663 |
+
tau_b: Optional[float] = None,
|
664 |
+
bg_consistency_mask: Optional[torch.Tensor] = None,
|
665 |
+
):
|
666 |
+
r"""
|
667 |
+
Function invoked when calling the pipeline for generation.
|
668 |
+
|
669 |
+
Args:
|
670 |
+
prompt (`str` or `List[str]`, *optional*):
|
671 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
672 |
+
instead.
|
673 |
+
prompt_2 (`str` or `List[str]`, *optional*):
|
674 |
+
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
|
675 |
+
will be used instead
|
676 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
677 |
+
The height in pixels of the generated image. This is set to 1024 by default for the best results.
|
678 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
679 |
+
The width in pixels of the generated image. This is set to 1024 by default for the best results.
|
680 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
681 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
682 |
+
expense of slower inference.
|
683 |
+
sigmas (`List[float]`, *optional*):
|
684 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
685 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
686 |
+
will be used.
|
687 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
688 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
689 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
690 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
691 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
692 |
+
usually at the expense of lower image quality.
|
693 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
694 |
+
The number of images to generate per prompt.
|
695 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
696 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
697 |
+
to make generation deterministic.
|
698 |
+
latents (`torch.FloatTensor`, *optional*):
|
699 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
700 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
701 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
702 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
703 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
704 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
705 |
+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
|
706 |
+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
|
707 |
+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
|
708 |
+
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
709 |
+
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
710 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
711 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
712 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
713 |
+
negative_ip_adapter_image:
|
714 |
+
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
|
715 |
+
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
|
716 |
+
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
|
717 |
+
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
|
718 |
+
provided, embeddings are computed from the `ip_adapter_image` input argument.
|
719 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
720 |
+
The output format of the generate image. Choose between
|
721 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
722 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
723 |
+
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
|
724 |
+
joint_attention_kwargs (`dict`, *optional*):
|
725 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
726 |
+
`self.processor` in
|
727 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
728 |
+
callback_on_step_end (`Callable`, *optional*):
|
729 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
730 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
731 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
732 |
+
`callback_on_step_end_tensor_inputs`.
|
733 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
734 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
735 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
736 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
737 |
+
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
|
738 |
+
tau_b (`float`, *optional*): Proportion of steps during which the background consistency is applied.
|
739 |
+
bg_consistency_mask (`torch.Tensor`, *optional*): Mask to use when applying background consistency. The mask
|
740 |
+
background consistency will be applied to the areas outside of the mask.
|
741 |
+
|
742 |
+
Examples:
|
743 |
+
|
744 |
+
Returns:
|
745 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
746 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
747 |
+
images.
|
748 |
+
"""
|
749 |
+
|
750 |
+
height = height or self.default_sample_size * self.vae_scale_factor
|
751 |
+
width = width or self.default_sample_size * self.vae_scale_factor
|
752 |
+
|
753 |
+
# 1. Check inputs. Raise error if not correct
|
754 |
+
self.check_inputs(
|
755 |
+
prompt,
|
756 |
+
prompt_2,
|
757 |
+
height,
|
758 |
+
width,
|
759 |
+
negative_prompt=negative_prompt,
|
760 |
+
negative_prompt_2=negative_prompt_2,
|
761 |
+
prompt_embeds=prompt_embeds,
|
762 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
763 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
764 |
+
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
765 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
766 |
+
max_sequence_length=max_sequence_length,
|
767 |
+
)
|
768 |
+
|
769 |
+
self._guidance_scale = guidance_scale
|
770 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
771 |
+
self._interrupt = False
|
772 |
+
|
773 |
+
# 2. Define call parameters
|
774 |
+
if prompt is not None and isinstance(prompt, str):
|
775 |
+
batch_size = 1
|
776 |
+
elif prompt is not None and isinstance(prompt, list):
|
777 |
+
batch_size = len(prompt)
|
778 |
+
else:
|
779 |
+
batch_size = prompt_embeds.shape[0]
|
780 |
+
|
781 |
+
device = self._execution_device
|
782 |
+
|
783 |
+
lora_scale = (
|
784 |
+
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
|
785 |
+
)
|
786 |
+
do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
|
787 |
+
(
|
788 |
+
prompt_embeds,
|
789 |
+
pooled_prompt_embeds,
|
790 |
+
text_ids,
|
791 |
+
) = self.encode_prompt(
|
792 |
+
prompt=prompt,
|
793 |
+
prompt_2=prompt_2,
|
794 |
+
prompt_embeds=prompt_embeds,
|
795 |
+
pooled_prompt_embeds=pooled_prompt_embeds,
|
796 |
+
device=device,
|
797 |
+
num_images_per_prompt=num_images_per_prompt,
|
798 |
+
max_sequence_length=max_sequence_length,
|
799 |
+
lora_scale=lora_scale,
|
800 |
+
)
|
801 |
+
if do_true_cfg:
|
802 |
+
(
|
803 |
+
negative_prompt_embeds,
|
804 |
+
negative_pooled_prompt_embeds,
|
805 |
+
_,
|
806 |
+
) = self.encode_prompt(
|
807 |
+
prompt=negative_prompt,
|
808 |
+
prompt_2=negative_prompt_2,
|
809 |
+
prompt_embeds=negative_prompt_embeds,
|
810 |
+
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
811 |
+
device=device,
|
812 |
+
num_images_per_prompt=num_images_per_prompt,
|
813 |
+
max_sequence_length=max_sequence_length,
|
814 |
+
lora_scale=lora_scale,
|
815 |
+
)
|
816 |
+
|
817 |
+
# 4. Prepare latent variables
|
818 |
+
num_channels_latents = self.transformer.config.in_channels // 4
|
819 |
+
latents, latent_image_ids = self.prepare_latents(
|
820 |
+
batch_size * num_images_per_prompt,
|
821 |
+
num_channels_latents,
|
822 |
+
height,
|
823 |
+
width,
|
824 |
+
prompt_embeds.dtype,
|
825 |
+
device,
|
826 |
+
generator,
|
827 |
+
latents,
|
828 |
+
)
|
829 |
+
|
830 |
+
# 5. Prepare timesteps
|
831 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
832 |
+
image_seq_len = latents.shape[1]
|
833 |
+
mu = calculate_shift(
|
834 |
+
image_seq_len,
|
835 |
+
self.scheduler.config.base_image_seq_len,
|
836 |
+
self.scheduler.config.max_image_seq_len,
|
837 |
+
self.scheduler.config.base_shift,
|
838 |
+
self.scheduler.config.max_shift,
|
839 |
+
)
|
840 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
841 |
+
self.scheduler,
|
842 |
+
num_inference_steps,
|
843 |
+
device,
|
844 |
+
sigmas=sigmas,
|
845 |
+
mu=mu,
|
846 |
+
)
|
847 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
848 |
+
self._num_timesteps = len(timesteps)
|
849 |
+
|
850 |
+
if is_inverted_generation:
|
851 |
+
timesteps = reversed(timesteps)
|
852 |
+
|
853 |
+
|
854 |
+
# handle guidance
|
855 |
+
if self.transformer.config.guidance_embeds:
|
856 |
+
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
857 |
+
guidance = guidance.expand(latents.shape[0])
|
858 |
+
else:
|
859 |
+
guidance = None
|
860 |
+
|
861 |
+
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
|
862 |
+
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
|
863 |
+
):
|
864 |
+
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
865 |
+
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
|
866 |
+
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
|
867 |
+
):
|
868 |
+
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
|
869 |
+
|
870 |
+
if self.joint_attention_kwargs is None:
|
871 |
+
self._joint_attention_kwargs = {}
|
872 |
+
|
873 |
+
image_embeds = None
|
874 |
+
negative_image_embeds = None
|
875 |
+
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
|
876 |
+
image_embeds = self.prepare_ip_adapter_image_embeds(
|
877 |
+
ip_adapter_image,
|
878 |
+
ip_adapter_image_embeds,
|
879 |
+
device,
|
880 |
+
batch_size * num_images_per_prompt,
|
881 |
+
)
|
882 |
+
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
|
883 |
+
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
|
884 |
+
negative_ip_adapter_image,
|
885 |
+
negative_ip_adapter_image_embeds,
|
886 |
+
device,
|
887 |
+
batch_size * num_images_per_prompt,
|
888 |
+
)
|
889 |
+
|
890 |
+
# 6. Denoising loop
|
891 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
892 |
+
for i, t in enumerate(timesteps):
|
893 |
+
|
894 |
+
if self.interrupt:
|
895 |
+
continue
|
896 |
+
|
897 |
+
if image_embeds is not None:
|
898 |
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
|
899 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
900 |
+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
901 |
+
|
902 |
+
noise_pred = self.transformer(
|
903 |
+
hidden_states=latents,
|
904 |
+
timestep=timestep / 1000,
|
905 |
+
guidance=guidance,
|
906 |
+
pooled_projections=pooled_prompt_embeds,
|
907 |
+
encoder_hidden_states=prompt_embeds,
|
908 |
+
txt_ids=text_ids,
|
909 |
+
img_ids=latent_image_ids,
|
910 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
911 |
+
return_dict=False,
|
912 |
+
)[0]
|
913 |
+
|
914 |
+
if do_true_cfg:
|
915 |
+
if negative_image_embeds is not None:
|
916 |
+
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
|
917 |
+
neg_noise_pred = self.transformer(
|
918 |
+
hidden_states=latents,
|
919 |
+
timestep=timestep / 1000,
|
920 |
+
guidance=guidance,
|
921 |
+
pooled_projections=negative_pooled_prompt_embeds,
|
922 |
+
encoder_hidden_states=negative_prompt_embeds,
|
923 |
+
txt_ids=text_ids,
|
924 |
+
img_ids=latent_image_ids,
|
925 |
+
joint_attention_kwargs=self.joint_attention_kwargs,
|
926 |
+
return_dict=False,
|
927 |
+
)[0]
|
928 |
+
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
|
929 |
+
|
930 |
+
# compute the previous noisy sample x_t -> x_t-1
|
931 |
+
latents_dtype = latents.dtype
|
932 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
933 |
+
|
934 |
+
if tau_b:
|
935 |
+
if bg_consistency_mask is None:
|
936 |
+
raise ValueError("if tau_b is set, bg_consistency_mask must be provided for background consistency to work.")
|
937 |
+
|
938 |
+
assert latents.shape[0] >= 3, "Three processes are required for background consistency injection (being background, foreground and composed process)."
|
939 |
+
assert latents.shape[1] == bg_consistency_mask.shape[0], f"Latents and segmentation mask must have the same number of timesteps. Got {latents.shape[1]} and {bg_consistency_mask.shape[0]}."
|
940 |
+
|
941 |
+
bg_consistency_mask = bg_consistency_mask.to(device=latents.device, dtype=torch.int32)
|
942 |
+
|
943 |
+
# TF-ICON background consistency: if we're in the first tau_b part of the de-noising process,
|
944 |
+
# overwrite the latents of the composed image with those of the background process (only outside the segmentation mask)
|
945 |
+
if i <= tau_b * num_inference_steps:
|
946 |
+
latents[2, :, :] = latents[0, :, :] * (1 - bg_consistency_mask) + latents[2, :, :] * bg_consistency_mask
|
947 |
+
|
948 |
+
# NOTE: this was the added part for inversion
|
949 |
+
if is_inverted_generation:
|
950 |
+
inverted_latents_list.append(latents)
|
951 |
+
else:
|
952 |
+
if inverted_latents_list is not None:
|
953 |
+
if isinstance(inverted_latents_list[0], torch.Tensor):
|
954 |
+
latents[0] = inverted_latents_list[-i][0]
|
955 |
+
else:
|
956 |
+
assert isinstance(inverted_latents_list[0], tuple)
|
957 |
+
for j, tensor_tuple in enumerate(inverted_latents_list[-i]):
|
958 |
+
latents[j] = tensor_tuple
|
959 |
+
|
960 |
+
|
961 |
+
|
962 |
+
if latents.dtype != latents_dtype:
|
963 |
+
if torch.backends.mps.is_available():
|
964 |
+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
965 |
+
latents = latents.to(latents_dtype)
|
966 |
+
|
967 |
+
if callback_on_step_end is not None:
|
968 |
+
callback_kwargs = {}
|
969 |
+
for k in callback_on_step_end_tensor_inputs:
|
970 |
+
callback_kwargs[k] = locals()[k]
|
971 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
972 |
+
|
973 |
+
latents = callback_outputs.pop("latents", latents)
|
974 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
975 |
+
|
976 |
+
# call the callback, if provided
|
977 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
978 |
+
progress_bar.update()
|
979 |
+
|
980 |
+
if XLA_AVAILABLE:
|
981 |
+
xm.mark_step()
|
982 |
+
|
983 |
+
if output_type == "latent":
|
984 |
+
image = latents
|
985 |
+
|
986 |
+
else:
|
987 |
+
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
|
988 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
989 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
990 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
991 |
+
|
992 |
+
# Offload all models
|
993 |
+
self.maybe_free_model_hooks()
|
994 |
+
|
995 |
+
if not return_dict:
|
996 |
+
return (image,)
|
997 |
+
|
998 |
+
return FluxPipelineOutput(images=image)
|
SDLens/cache_and_edit/hooks.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, Literal
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from diffusers.models.transformers.transformer_flux import FluxTransformerBlock, FluxSingleTransformerBlock
|
5 |
+
|
6 |
+
|
7 |
+
def register_general_hook(pipe, position, hook, with_kwargs=False, is_pre_hook=False):
|
8 |
+
"""Registers a forward hook in a module of the pipeline specified with 'position'
|
9 |
+
|
10 |
+
Args:
|
11 |
+
pipe (_type_): _description_
|
12 |
+
position (_type_): _description_
|
13 |
+
hook (_type_): _description_
|
14 |
+
with_kwargs (bool, optional): _description_. Defaults to False.
|
15 |
+
is_pre_hook (bool, optional): _description_. Defaults to False.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
_type_: _description_
|
19 |
+
"""
|
20 |
+
|
21 |
+
block: nn.Module = locate_block(pipe, position)
|
22 |
+
|
23 |
+
if is_pre_hook:
|
24 |
+
return block.register_forward_pre_hook(hook, with_kwargs=with_kwargs)
|
25 |
+
else:
|
26 |
+
return block.register_forward_hook(hook, with_kwargs=with_kwargs)
|
27 |
+
|
28 |
+
|
29 |
+
def locate_block(pipe, position: str) -> nn.Module:
|
30 |
+
'''
|
31 |
+
Locate the block at the specified position in the pipeline.
|
32 |
+
'''
|
33 |
+
block = pipe
|
34 |
+
for step in position.split('.'):
|
35 |
+
if step.isdigit():
|
36 |
+
step = int(step)
|
37 |
+
block = block[step]
|
38 |
+
else:
|
39 |
+
block = getattr(block, step)
|
40 |
+
return block
|
41 |
+
|
42 |
+
|
43 |
+
def _safe_clip(x: torch.Tensor):
|
44 |
+
if x.dtype == torch.float16:
|
45 |
+
x[torch.isposinf(x)] = 65504
|
46 |
+
x[torch.isneginf(x)] = -65504
|
47 |
+
return x
|
48 |
+
|
49 |
+
|
50 |
+
@torch.no_grad()
|
51 |
+
def fix_inf_values_hook(*args):
|
52 |
+
|
53 |
+
# Case 1: no kwards are passed to the module
|
54 |
+
if len(args) == 3:
|
55 |
+
module, input, output = args
|
56 |
+
# Case 2: when kwargs are passed to the model as input
|
57 |
+
elif len(args) == 4:
|
58 |
+
module, input, kwinput, output = args
|
59 |
+
|
60 |
+
if isinstance(module, FluxTransformerBlock):
|
61 |
+
return _safe_clip(output[0]), _safe_clip(output[1])
|
62 |
+
|
63 |
+
elif isinstance(module, FluxSingleTransformerBlock):
|
64 |
+
return _safe_clip(output)
|
65 |
+
|
66 |
+
|
67 |
+
@torch.no_grad()
|
68 |
+
def edit_streams_hook(*args,
|
69 |
+
recompute_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
70 |
+
stream: Literal["text", "image", "both"]):
|
71 |
+
"""
|
72 |
+
recompute_fn will get as input the input tensor and the output tensor for such stream
|
73 |
+
and returns what should be the new modified output
|
74 |
+
"""
|
75 |
+
|
76 |
+
# Case 1: no kwards are passed to the module
|
77 |
+
if len(args) == 3:
|
78 |
+
module, input, output = args
|
79 |
+
# Case 2: when kwargs are passed to the model as input
|
80 |
+
elif len(args) == 4:
|
81 |
+
module, input, kwinput, output = args
|
82 |
+
else:
|
83 |
+
raise AssertionError(f'Weird len(args):{len(args)}')
|
84 |
+
|
85 |
+
if isinstance(module, FluxTransformerBlock):
|
86 |
+
|
87 |
+
if stream == 'text':
|
88 |
+
output_text = recompute_fn(kwinput["encoder_hidden_states"], output[0])
|
89 |
+
output_image = output[1]
|
90 |
+
elif stream == 'image':
|
91 |
+
output_image = recompute_fn(kwinput["hidden_states"], output[1])
|
92 |
+
output_text = output[0]
|
93 |
+
else:
|
94 |
+
raise AssertionError("Branch not supported for this layer.")
|
95 |
+
|
96 |
+
return _safe_clip(output_text), _safe_clip(output_image)
|
97 |
+
|
98 |
+
elif isinstance(module, FluxSingleTransformerBlock):
|
99 |
+
|
100 |
+
if stream == 'text':
|
101 |
+
output[:, :512] = recompute_fn(kwinput["hidden_states"][:, :512], output[:, :512])
|
102 |
+
elif stream == 'image':
|
103 |
+
output[:, 512:] = recompute_fn(kwinput["hidden_states"][:, 512:], output[:, 512:])
|
104 |
+
else:
|
105 |
+
output = recompute_fn(kwinput["hidden_states"], output)
|
106 |
+
|
107 |
+
return _safe_clip(output)
|
108 |
+
|
SDLens/cache_and_edit/inversion.py
ADDED
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Tuple
|
2 |
+
import torch
|
3 |
+
import torchvision.transforms.functional as TF
|
4 |
+
from PIL import Image
|
5 |
+
from cache_and_edit import CachedPipeline
|
6 |
+
import numpy as np
|
7 |
+
from IPython.display import display
|
8 |
+
|
9 |
+
from cache_and_edit.flux_pipeline import EditedFluxPipeline
|
10 |
+
|
11 |
+
def image2latent(pipe, image, latent_nudging_scalar = 1.15):
|
12 |
+
image = pipe.image_processor.preprocess(image).type(pipe.vae.dtype).to("cuda")
|
13 |
+
latents = pipe.vae.encode(image)["latent_dist"].mean
|
14 |
+
latents = (latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
|
15 |
+
latents = latents * latent_nudging_scalar
|
16 |
+
|
17 |
+
latents = pipe._pack_latents(
|
18 |
+
latents=latents,
|
19 |
+
batch_size=1,
|
20 |
+
num_channels_latents=16,
|
21 |
+
height=image.size(2) // 8,
|
22 |
+
width= image.size(3) // 8
|
23 |
+
)
|
24 |
+
|
25 |
+
return latents
|
26 |
+
|
27 |
+
|
28 |
+
def get_inverted_input_noise(pipe: CachedPipeline,
|
29 |
+
image,
|
30 |
+
prompt: str = "",
|
31 |
+
num_steps: int = 28,
|
32 |
+
latent_nudging_scalar: int = 1.15):
|
33 |
+
"""_summary_
|
34 |
+
|
35 |
+
Args:
|
36 |
+
pipe (CachedPipeline): _description_
|
37 |
+
image (_type_): _description_
|
38 |
+
num_steps (int, optional): _description_. Defaults to 28.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
_type_: _description_
|
42 |
+
"""
|
43 |
+
|
44 |
+
width, height = image.size
|
45 |
+
inverted_latents_list = []
|
46 |
+
|
47 |
+
if isinstance(pipe.pipe, EditedFluxPipeline):
|
48 |
+
|
49 |
+
_ = pipe.run(
|
50 |
+
prompt,
|
51 |
+
num_inference_steps=num_steps,
|
52 |
+
seed=42,
|
53 |
+
guidance_scale=1,
|
54 |
+
output_type="latent",
|
55 |
+
latents=image2latent(pipe.pipe, image, latent_nudging_scalar=latent_nudging_scalar),
|
56 |
+
empty_clip_embeddings=False,
|
57 |
+
inverse=True,
|
58 |
+
width=width,
|
59 |
+
height=height,
|
60 |
+
is_inverted_generation=True,
|
61 |
+
inverted_latents_list=inverted_latents_list
|
62 |
+
).images[0]
|
63 |
+
|
64 |
+
return inverted_latents_list
|
65 |
+
|
66 |
+
|
67 |
+
else:
|
68 |
+
noise = pipe.run(
|
69 |
+
prompt,
|
70 |
+
num_inference_steps=num_steps,
|
71 |
+
seed=42,
|
72 |
+
guidance_scale=1,
|
73 |
+
output_type="latent",
|
74 |
+
latents=image2latent(pipe.pipe, image, latent_nudging_scalar=latent_nudging_scalar),
|
75 |
+
empty_clip_embeddings=False,
|
76 |
+
inverse=True,
|
77 |
+
width=width,
|
78 |
+
height=height
|
79 |
+
).images[0]
|
80 |
+
|
81 |
+
return noise
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
def resize_bounding_box(
|
87 |
+
bb_mask: torch.Tensor,
|
88 |
+
target_size: Tuple[int, int] = (64, 64),
|
89 |
+
) -> torch.Tensor:
|
90 |
+
"""
|
91 |
+
Given a bounding box mask, patches it into a mask with the target size.
|
92 |
+
The mask is a 2D tensor of shape (H, W) where each element is either 0 or 1.
|
93 |
+
Any patch that contains at least one 1 in the original mask will be set to 1 in the output mask.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
bb_mask (torch.Tensor): The bounding box mask as a boolean tensor of shape (H, W).
|
97 |
+
target_size (Tuple[int, int]): The size of the target mask as a tuple (H, W).
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
torch.Tensor: The resized bounding box mask as a boolean tensor of shape (H, W).
|
101 |
+
"""
|
102 |
+
|
103 |
+
w_mask, h_mask = bb_mask.shape[-2:]
|
104 |
+
w_target, h_target = target_size
|
105 |
+
|
106 |
+
# Make sure the sizes are compatible
|
107 |
+
if w_mask % w_target != 0 or h_mask % h_target != 0:
|
108 |
+
raise ValueError(
|
109 |
+
f"Mask size {bb_mask.shape[-2:]} is not compatible with target size {target_size}"
|
110 |
+
)
|
111 |
+
|
112 |
+
# Compute the size of a patch
|
113 |
+
patch_size = (w_mask // w_target, h_mask // h_target)
|
114 |
+
|
115 |
+
# Iterate over the mask, one patch at a time, and save a 0 patch if the patch is empty or a 1 patch if the patch is not empty
|
116 |
+
out_mask = torch.zeros((w_target, h_target), dtype=bb_mask.dtype, device=bb_mask.device)
|
117 |
+
for i in range(w_target):
|
118 |
+
for j in range(h_target):
|
119 |
+
patch = bb_mask[
|
120 |
+
i * patch_size[0] : (i + 1) * patch_size[0],
|
121 |
+
j * patch_size[1] : (j + 1) * patch_size[1],
|
122 |
+
]
|
123 |
+
if torch.sum(patch) > 0:
|
124 |
+
out_mask[i, j] = 1
|
125 |
+
else:
|
126 |
+
out_mask[i, j] = 0
|
127 |
+
|
128 |
+
return out_mask
|
129 |
+
|
130 |
+
|
131 |
+
def place_image_in_bounding_box(
|
132 |
+
image_tensor_whc: torch.Tensor,
|
133 |
+
mask_tensor_wh: torch.Tensor
|
134 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
135 |
+
"""
|
136 |
+
Resizes an input image to fit within a bounding box (from a mask)
|
137 |
+
preserving aspect ratio, and places it centered on a new canvas.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
image_tensor_whc: Input image tensor, shape [width, height, channels].
|
141 |
+
mask_tensor_wh: Bounding box mask, shape [width, height]. Defines canvas size
|
142 |
+
and contains a rectangle of 1s for the BB.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
A tuple:
|
146 |
+
- output_image_whc (torch.Tensor): Canvas with the resized image placed.
|
147 |
+
Shape [canvas_width, canvas_height, channels].
|
148 |
+
- new_mask_wh (torch.Tensor): Mask showing the actual placement of the image.
|
149 |
+
Shape [canvas_width, canvas_height].
|
150 |
+
"""
|
151 |
+
|
152 |
+
# Validate input image dimensions
|
153 |
+
if not (image_tensor_whc.ndim == 3 and image_tensor_whc.shape[0] > 0 and image_tensor_whc.shape[1] > 0):
|
154 |
+
raise ValueError(
|
155 |
+
"Input image_tensor_whc must be a 3D tensor [width, height, channels] "
|
156 |
+
"with width > 0 and height > 0."
|
157 |
+
)
|
158 |
+
img_orig_w, img_orig_h, num_channels = image_tensor_whc.shape
|
159 |
+
|
160 |
+
# Validate mask tensor dimensions
|
161 |
+
if not (mask_tensor_wh.ndim == 2):
|
162 |
+
raise ValueError("Input mask_tensor_wh must be a 2D tensor [width, height].")
|
163 |
+
canvas_w, canvas_h = mask_tensor_wh.shape
|
164 |
+
|
165 |
+
# Prepare default empty outputs for early exit scenarios
|
166 |
+
empty_output_image = torch.zeros(
|
167 |
+
canvas_w, canvas_h, num_channels,
|
168 |
+
dtype=image_tensor_whc.dtype, device=image_tensor_whc.device
|
169 |
+
)
|
170 |
+
empty_new_mask = torch.zeros(
|
171 |
+
canvas_w, canvas_h,
|
172 |
+
dtype=mask_tensor_wh.dtype, device=mask_tensor_wh.device
|
173 |
+
)
|
174 |
+
|
175 |
+
# 1. Find Bounding Box (BB) coordinates from the input mask_tensor_wh
|
176 |
+
# fg_coords shape: [N, 2], where N is num_nonzero. Each row: [x_coord, y_coord].
|
177 |
+
fg_coords = torch.nonzero(mask_tensor_wh, as_tuple=False)
|
178 |
+
|
179 |
+
if fg_coords.numel() == 0: # No bounding box found in mask
|
180 |
+
return empty_output_image, empty_new_mask
|
181 |
+
|
182 |
+
# Determine min/max extents of the bounding box
|
183 |
+
x_min_bb, y_min_bb = fg_coords[:, 0].min(), fg_coords[:, 1].min()
|
184 |
+
x_max_bb, y_max_bb = fg_coords[:, 0].max(), fg_coords[:, 1].max()
|
185 |
+
|
186 |
+
bb_target_w = x_max_bb - x_min_bb + 1
|
187 |
+
bb_target_h = y_max_bb - y_min_bb + 1
|
188 |
+
|
189 |
+
if bb_target_w <= 0 or bb_target_h <= 0: # Should not happen if fg_coords not empty
|
190 |
+
return empty_output_image, empty_new_mask
|
191 |
+
|
192 |
+
# 2. Prepare image for resizing: TF.resize expects [C, H, W]
|
193 |
+
# Input image_tensor_whc is [W, H, C]. Permute to [C, H_orig, W_orig].
|
194 |
+
image_tensor_chw = image_tensor_whc.permute(2, 1, 0)
|
195 |
+
|
196 |
+
# 3. Calculate new dimensions for the image to fit in BB, preserving aspect ratio
|
197 |
+
scale_factor_w = bb_target_w / img_orig_w
|
198 |
+
scale_factor_h = bb_target_h / img_orig_h
|
199 |
+
scale = min(scale_factor_w, scale_factor_h) # Fit entirely within BB
|
200 |
+
|
201 |
+
resized_img_w = int(img_orig_w * scale)
|
202 |
+
resized_img_h = int(img_orig_h * scale)
|
203 |
+
|
204 |
+
if resized_img_w == 0 or resized_img_h == 0: # Image scaled to nothing
|
205 |
+
return empty_output_image, empty_new_mask
|
206 |
+
|
207 |
+
# 4. Resize the image. TF.resize expects size as [H, W].
|
208 |
+
try:
|
209 |
+
# antialias=True for better quality (requires torchvision >= 0.8.0 approx)
|
210 |
+
resized_image_chw = TF.resize(image_tensor_chw, [resized_img_h, resized_img_w], antialias=True)
|
211 |
+
except TypeError: # Fallback for older torchvision versions
|
212 |
+
resized_image_chw = TF.resize(image_tensor_chw, [resized_img_h, resized_img_w])
|
213 |
+
|
214 |
+
# Permute resized image back to [W, H, C] format
|
215 |
+
resized_image_whc = resized_image_chw.permute(2, 1, 0)
|
216 |
+
|
217 |
+
# 5. Create the output canvas image (initialized to zeros)
|
218 |
+
output_image_whc = torch.zeros(
|
219 |
+
canvas_w, canvas_h, num_channels,
|
220 |
+
dtype=image_tensor_whc.dtype, device=image_tensor_whc.device
|
221 |
+
)
|
222 |
+
|
223 |
+
# 6. Calculate pasting coordinates to center the resized image within the original BB
|
224 |
+
offset_x = (bb_target_w - resized_img_w) // 2
|
225 |
+
offset_y = (bb_target_h - resized_img_h) // 2
|
226 |
+
|
227 |
+
paste_x_start = x_min_bb + offset_x
|
228 |
+
paste_y_start = y_min_bb + offset_y
|
229 |
+
|
230 |
+
paste_x_end = paste_x_start + resized_img_w
|
231 |
+
paste_y_end = paste_y_start + resized_img_h
|
232 |
+
|
233 |
+
# Place the resized image onto the canvas
|
234 |
+
output_image_whc[paste_x_start:paste_x_end, paste_y_start:paste_y_end, :] = resized_image_whc
|
235 |
+
|
236 |
+
# 7. Create the new mask representing where the image was actually placed
|
237 |
+
new_mask_wh = torch.zeros(
|
238 |
+
canvas_w, canvas_h,
|
239 |
+
dtype=mask_tensor_wh.dtype, device=mask_tensor_wh.device
|
240 |
+
)
|
241 |
+
new_mask_wh[paste_x_start:paste_x_end, paste_y_start:paste_y_end] = 1
|
242 |
+
|
243 |
+
return output_image_whc, new_mask_wh
|
244 |
+
|
245 |
+
|
246 |
+
|
247 |
+
### Function to cut image and put it in bounding box (either cut or not cut)
|
248 |
+
def compose_noise_masks(cached_pipe,
|
249 |
+
foreground_image: Image,
|
250 |
+
background_image: Image,
|
251 |
+
target_mask: torch.Tensor,
|
252 |
+
foreground_mask: torch.Tensor,
|
253 |
+
option: str = "bg", # bg, bg_fg, segmentation1, tf_icon
|
254 |
+
photoshop_fg_noise: bool = False,
|
255 |
+
num_inversion_steps: int = 100,
|
256 |
+
):
|
257 |
+
|
258 |
+
"""
|
259 |
+
Composes noise masks for image generation using different strategies.
|
260 |
+
This function composes noise masks for stable diffusion inversion, with several composition strategies:
|
261 |
+
- "bg": Uses only background noise
|
262 |
+
- "bg_fg": Combines background and foreground noise using a target mask
|
263 |
+
- "segmentation1": Uses segmentation mask to compose foreground and background noise
|
264 |
+
- "segmentation2": Implements advanced composition with additional boundary noise
|
265 |
+
Parameters:
|
266 |
+
----------
|
267 |
+
cached_pipe : object
|
268 |
+
The cached stable diffusion pipeline used for noise inversion
|
269 |
+
foreground_image : PIL.Image
|
270 |
+
The foreground image to be placed in the background
|
271 |
+
background_image : PIL.Image
|
272 |
+
The background image
|
273 |
+
target_mask : torch.Tensor
|
274 |
+
Target mask indicating the position where the foreground should be placed
|
275 |
+
foreground_mask : torch.Tensor
|
276 |
+
Segmentation mask of the foreground object
|
277 |
+
option : str, default="bg"
|
278 |
+
Composition strategy: "bg", "bg_fg", "segmentation1", or "segmentation2"
|
279 |
+
photoshop_fg_noise : bool, default=False
|
280 |
+
Whether to generate noise from a photoshopped composition of foreground and background
|
281 |
+
num_inversion_steps : int, default=100
|
282 |
+
Number of steps for the inversion process
|
283 |
+
Returns:
|
284 |
+
-------
|
285 |
+
dict
|
286 |
+
A dictionary containing:
|
287 |
+
- "noise": Dictionary of generated noises (composed_noise, foreground_noise, background_noise)
|
288 |
+
- "latent_masks": Dictionary of latent masks used for composition
|
289 |
+
"""
|
290 |
+
|
291 |
+
# assert options
|
292 |
+
assert option in ["bg", "bg_fg", "segmentation1", "segmentation2"], f"Invalid option: {option}"
|
293 |
+
|
294 |
+
# calculate size of latent noise for mask resizing
|
295 |
+
PATCH_SIZE = 16
|
296 |
+
latent_size = background_image.size[0] // PATCH_SIZE
|
297 |
+
latents = (latent_size, latent_size)
|
298 |
+
|
299 |
+
# process the options
|
300 |
+
if option == "bg":
|
301 |
+
# only background noise
|
302 |
+
bg_noise = get_inverted_input_noise(cached_pipe, background_image, num_steps=num_inversion_steps)
|
303 |
+
composed_noise = bg_noise
|
304 |
+
|
305 |
+
all_noise = {
|
306 |
+
"composed_noise": composed_noise,
|
307 |
+
"background_noise": bg_noise,
|
308 |
+
}
|
309 |
+
all_latent_masks = {}
|
310 |
+
|
311 |
+
|
312 |
+
elif option == "bg_fg":
|
313 |
+
|
314 |
+
# resize and scale the image to the bounding box
|
315 |
+
reframed_fg_img, resized_mask = place_image_in_bounding_box(
|
316 |
+
torch.from_numpy(np.array(foreground_image)),
|
317 |
+
(torch.from_numpy(np.array(target_mask)) / 255.0).to(dtype=bool)
|
318 |
+
)
|
319 |
+
|
320 |
+
#print("Placed Foreground Image")
|
321 |
+
reframed_fg_img = Image.fromarray(reframed_fg_img.numpy())
|
322 |
+
#display(reframed_fg_img)
|
323 |
+
|
324 |
+
#print("Placed Mask")
|
325 |
+
resized_mask_img = Image.fromarray((resized_mask.numpy() * 255).astype(np.uint8))
|
326 |
+
#display(resized_mask_img)
|
327 |
+
|
328 |
+
# invert resized & padded image
|
329 |
+
if photoshop_fg_noise:
|
330 |
+
#print("Photoshopping FG IMAGE")
|
331 |
+
photoshop_img = Image.fromarray(
|
332 |
+
(torch.tensor(np.array(background_image)) * ~resized_mask.cpu().unsqueeze(-1) + torch.tensor(np.array(reframed_fg_img)) * resized_mask.cpu().unsqueeze(-1)).numpy()
|
333 |
+
)
|
334 |
+
#display(photoshop_img)
|
335 |
+
fg_noise = get_inverted_input_noise(cached_pipe, photoshop_img, num_steps=num_inversion_steps)
|
336 |
+
else:
|
337 |
+
fg_noise = get_inverted_input_noise(cached_pipe, reframed_fg_img, num_steps=num_inversion_steps)
|
338 |
+
bg_noise = get_inverted_input_noise(cached_pipe, background_image, num_steps=num_inversion_steps)
|
339 |
+
|
340 |
+
# overwrite get masked in latent space
|
341 |
+
latent_mask = resize_bounding_box(
|
342 |
+
resized_mask,
|
343 |
+
target_size=latents,
|
344 |
+
).flatten().unsqueeze(-1).to("cuda")
|
345 |
+
|
346 |
+
# compose the noise
|
347 |
+
composed_noise = bg_noise * (~latent_mask) + fg_noise * latent_mask
|
348 |
+
all_latent_masks = {
|
349 |
+
"latent_mask": latent_mask,
|
350 |
+
}
|
351 |
+
all_noise = {
|
352 |
+
"composed_noise": composed_noise,
|
353 |
+
"foreground_noise": fg_noise,
|
354 |
+
"background_noise": bg_noise,
|
355 |
+
}
|
356 |
+
|
357 |
+
elif option == "segmentation1":
|
358 |
+
# cut out the object and compose it with the background noise
|
359 |
+
|
360 |
+
# segmented foreground image
|
361 |
+
segmented_fg_image = torch.tensor(
|
362 |
+
np.array(
|
363 |
+
foreground_mask.resize(foreground_image.size)
|
364 |
+
)).to(torch.bool).unsqueeze(-1) * torch.tensor(
|
365 |
+
np.array(foreground_image)
|
366 |
+
)
|
367 |
+
|
368 |
+
# resize and scale the image to the bounding box
|
369 |
+
reframed_fg_img, resized_mask = place_image_in_bounding_box(
|
370 |
+
segmented_fg_image,
|
371 |
+
(torch.from_numpy(np.array(target_mask)) / 255.0).to(dtype=bool)
|
372 |
+
)
|
373 |
+
|
374 |
+
reframed_fg_img = Image.fromarray(reframed_fg_img.numpy())
|
375 |
+
#display(reframed_fg_img)
|
376 |
+
|
377 |
+
resized_mask_img = Image.fromarray((resized_mask.numpy() * 255).astype(np.uint8))
|
378 |
+
|
379 |
+
# resize and scale the mask itself
|
380 |
+
foreground_mask = foreground_mask.convert("RGB") # to avoid extraction of contours and make work with function
|
381 |
+
reframed_segmentation_mask, resized_mask = place_image_in_bounding_box(
|
382 |
+
torch.from_numpy(np.array(foreground_mask)),
|
383 |
+
(torch.from_numpy(np.array(target_mask)) / 255.0).to(dtype=bool)
|
384 |
+
)
|
385 |
+
|
386 |
+
reframed_segmentation_mask = reframed_segmentation_mask.numpy()
|
387 |
+
reframed_segmentation_mask_img = Image.fromarray(reframed_segmentation_mask)
|
388 |
+
#print("Placed Segmentation Mask")
|
389 |
+
#display(reframed_segmentation_mask_img)
|
390 |
+
|
391 |
+
# invert resized & padded image
|
392 |
+
# fg_noise = get_inverted_input_noise(cached_pipe, reframed_fg_img, num_steps=num_inversion_steps)
|
393 |
+
|
394 |
+
if photoshop_fg_noise:
|
395 |
+
# temporarily convert to apply mask
|
396 |
+
#print("Photoshopping FG IMAGE")
|
397 |
+
seg_mask_temp = torch.from_numpy(reframed_segmentation_mask).bool()
|
398 |
+
bg_temp = torch.tensor(np.array(background_image))
|
399 |
+
fg_temp = torch.tensor(np.array(reframed_fg_img))
|
400 |
+
|
401 |
+
photoshop_img = Image.fromarray(
|
402 |
+
(bg_temp * (~seg_mask_temp) + fg_temp * seg_mask_temp).numpy()
|
403 |
+
).convert("RGB")
|
404 |
+
#display(photoshop_img)
|
405 |
+
fg_noise = get_inverted_input_noise(cached_pipe, photoshop_img, num_steps=num_inversion_steps)
|
406 |
+
else:
|
407 |
+
fg_noise = get_inverted_input_noise(cached_pipe, reframed_fg_img, num_steps=num_inversion_steps)
|
408 |
+
|
409 |
+
|
410 |
+
bg_noise = get_inverted_input_noise(cached_pipe, background_image, num_steps=num_inversion_steps)
|
411 |
+
bg_noise_init = bg_noise[-1].squeeze(0) if isinstance(bg_noise, list) else bg_noise
|
412 |
+
fg_noise_init = fg_noise[-1].squeeze(0) if isinstance(fg_noise, list) else fg_noise
|
413 |
+
|
414 |
+
# overwrite background in resized mask
|
415 |
+
# convert mask from 512x512x3 to 512x512 first
|
416 |
+
reframed_segmentation_mask = reframed_segmentation_mask[:, :, 0]
|
417 |
+
reframed_segmentation_mask = torch.from_numpy(reframed_segmentation_mask).to(dtype=bool)
|
418 |
+
latent_mask = resize_bounding_box(
|
419 |
+
reframed_segmentation_mask,
|
420 |
+
target_size=latents,
|
421 |
+
).flatten().unsqueeze(-1).to("cuda")
|
422 |
+
bb_mask = resize_bounding_box(
|
423 |
+
resized_mask,
|
424 |
+
target_size=latents,
|
425 |
+
).flatten().unsqueeze(-1).to("cuda")
|
426 |
+
|
427 |
+
# compose noise
|
428 |
+
composed_noise = bg_noise_init * (~latent_mask) + fg_noise_init * latent_mask
|
429 |
+
|
430 |
+
all_latent_masks = {
|
431 |
+
"latent_segmentation_mask": latent_mask,
|
432 |
+
# FIXME: handle bounding box better (making sure shapes are correct, especially when bg and fg images have different sizes, e.g. test image 69)
|
433 |
+
"bb_mask": bb_mask,
|
434 |
+
}
|
435 |
+
all_noise = {
|
436 |
+
"composed_noise": composed_noise,
|
437 |
+
"foreground_noise": fg_noise_init,
|
438 |
+
"background_noise": bg_noise_init,
|
439 |
+
"foreground_noise_list": fg_noise if isinstance(fg_noise, list) else None,
|
440 |
+
"background_noise_list": bg_noise if isinstance(bg_noise, list) else None,
|
441 |
+
}
|
442 |
+
|
443 |
+
|
444 |
+
elif option == "segmentation2":
|
445 |
+
# add random noise in the background
|
446 |
+
|
447 |
+
# segmented foreground image
|
448 |
+
segmented_fg_image = torch.tensor(
|
449 |
+
np.array(
|
450 |
+
foreground_mask.resize(foreground_image.size)
|
451 |
+
)).to(torch.bool).unsqueeze(-1) * torch.tensor(
|
452 |
+
np.array(foreground_image)
|
453 |
+
)
|
454 |
+
|
455 |
+
# resize and scale the image to the bounding box
|
456 |
+
reframed_fg_img, resized_mask = place_image_in_bounding_box(
|
457 |
+
segmented_fg_image,
|
458 |
+
(torch.from_numpy(np.array(target_mask)) / 255.0).to(dtype=bool)
|
459 |
+
)
|
460 |
+
|
461 |
+
#print("Segmented and Placed FG Image")
|
462 |
+
reframed_fg_img = Image.fromarray(reframed_fg_img.numpy())
|
463 |
+
#display(reframed_fg_img)
|
464 |
+
|
465 |
+
# resize and scale the mask itself
|
466 |
+
foreground_mask = foreground_mask.convert("RGB")
|
467 |
+
reframed_segmentation_mask, resized_mask = place_image_in_bounding_box(
|
468 |
+
torch.from_numpy(np.array(foreground_mask)),
|
469 |
+
(torch.from_numpy(np.array(target_mask)) / 255.0).to(dtype=bool)
|
470 |
+
)
|
471 |
+
|
472 |
+
reframed_segmentation_mask = reframed_segmentation_mask.numpy()
|
473 |
+
reframed_segmentation_mask_img = Image.fromarray(reframed_segmentation_mask)
|
474 |
+
#print("Reframed Segmentation Mask")
|
475 |
+
#display(reframed_segmentation_mask_img)
|
476 |
+
|
477 |
+
xor_mask = target_mask ^ np.array(reframed_segmentation_mask_img.convert("L"))
|
478 |
+
#print("XOR Mask")
|
479 |
+
#display(Image.fromarray(xor_mask))
|
480 |
+
|
481 |
+
# invert resized & padded image
|
482 |
+
# fg_noise = get_inverted_input_noise(cached_pipe, reframed_fg_img, num_steps=num_inversion_steps)
|
483 |
+
if photoshop_fg_noise:
|
484 |
+
#print("Photoshopping FG IMAGE")
|
485 |
+
# temporarily convert to apply mask
|
486 |
+
seg_mask_temp = torch.from_numpy(reframed_segmentation_mask).bool()
|
487 |
+
bg_temp = torch.tensor(np.array(background_image))
|
488 |
+
fg_temp = torch.tensor(np.array(reframed_fg_img))
|
489 |
+
|
490 |
+
photoshop_img = Image.fromarray(
|
491 |
+
(bg_temp * (~seg_mask_temp) + fg_temp * seg_mask_temp).numpy()
|
492 |
+
).convert("RGB")
|
493 |
+
#display(photoshop_img)
|
494 |
+
fg_noise = get_inverted_input_noise(cached_pipe, photoshop_img, num_steps=num_inversion_steps)
|
495 |
+
else:
|
496 |
+
fg_noise = get_inverted_input_noise(cached_pipe, reframed_fg_img, num_steps=num_inversion_steps)
|
497 |
+
bg_noise = get_inverted_input_noise(cached_pipe, background_image, num_steps=num_inversion_steps)
|
498 |
+
|
499 |
+
# overwrite background in resized mask
|
500 |
+
# convert mask from 512x512x3 to 512x512
|
501 |
+
reframed_segmentation_mask = reframed_segmentation_mask[:, :, 0]
|
502 |
+
reframed_segmentation_mask = torch.from_numpy(reframed_segmentation_mask).to(dtype=bool)
|
503 |
+
|
504 |
+
# get all masks in latents and move to device
|
505 |
+
latent_seg_mask = resize_bounding_box(
|
506 |
+
reframed_segmentation_mask,
|
507 |
+
target_size=latents,
|
508 |
+
).flatten().unsqueeze(-1).to("cuda")
|
509 |
+
print(latent_seg_mask.shape)
|
510 |
+
|
511 |
+
|
512 |
+
latent_xor_mask = resize_bounding_box(
|
513 |
+
torch.from_numpy(xor_mask),
|
514 |
+
target_size=latents,
|
515 |
+
).flatten().unsqueeze(-1).to("cuda")
|
516 |
+
|
517 |
+
|
518 |
+
print(resized_mask.shape)
|
519 |
+
latent_target_mask = resize_bounding_box(
|
520 |
+
resized_mask,
|
521 |
+
target_size=latents,
|
522 |
+
).flatten().unsqueeze(-1).to("cuda")
|
523 |
+
|
524 |
+
# implement x∗T = xrT ⊙Mseg +xmT ⊙(1−Muser)+z⊙(Muser ⊕Mseg)
|
525 |
+
bg_noise_init = bg_noise[-1].squeeze(0) if isinstance(bg_noise, list) else bg_noise
|
526 |
+
fg_noise_init = fg_noise[-1].squeeze(0) if isinstance(fg_noise, list) else fg_noise
|
527 |
+
|
528 |
+
bg = bg_noise_init[-1] * (~latent_target_mask)
|
529 |
+
fg = fg_noise_init[-1] * latent_seg_mask
|
530 |
+
boundary = latent_xor_mask * torch.randn(latent_xor_mask.shape).to("cuda")
|
531 |
+
composed_noise = bg + fg + boundary
|
532 |
+
|
533 |
+
all_latent_masks = {
|
534 |
+
"latent_target_mask": latent_target_mask,
|
535 |
+
"latent_segmentation_mask": latent_seg_mask,
|
536 |
+
"latent_xor_mask": latent_xor_mask,
|
537 |
+
}
|
538 |
+
all_noise = {
|
539 |
+
"composed_noise": composed_noise,
|
540 |
+
"foreground_noise": fg_noise_init,
|
541 |
+
"background_noise": bg_noise_init,
|
542 |
+
"foreground_noise_list": fg_noise if isinstance(fg_noise, list) else None,
|
543 |
+
"background_noise_list": bg_noise if isinstance(bg_noise, list) else None,
|
544 |
+
}
|
545 |
+
|
546 |
+
# always add latent bbox mask (for bg consistency or any other future application)
|
547 |
+
latent_bbox_mask = resize_bounding_box(
|
548 |
+
torch.from_numpy(np.array(target_mask.resize(background_image.size))), # reseize just to be sure
|
549 |
+
target_size=latents,
|
550 |
+
).flatten().unsqueeze(-1).to("cuda")
|
551 |
+
all_latent_masks["latent_bbox_mask"] = latent_bbox_mask
|
552 |
+
|
553 |
+
# always add latent segmentation mkas
|
554 |
+
reframed_fg_img, resized_mask = place_image_in_bounding_box(
|
555 |
+
torch.from_numpy(np.array(foreground_image)),
|
556 |
+
(torch.from_numpy(np.array(target_mask)) / 255.0).to(dtype=bool)
|
557 |
+
)
|
558 |
+
bb_mask = resize_bounding_box(
|
559 |
+
resized_mask,
|
560 |
+
target_size=latents,
|
561 |
+
).flatten().unsqueeze(-1).to("cuda")
|
562 |
+
all_latent_masks["latent_segmentation_mask"] = bb_mask
|
563 |
+
|
564 |
+
# output
|
565 |
+
return {
|
566 |
+
"noise": all_noise,
|
567 |
+
"latent_masks": all_latent_masks,
|
568 |
+
}
|
SDLens/cache_and_edit/metrics.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
from typing import Union
|
4 |
+
import torch
|
5 |
+
from transformers import CLIPProcessor, CLIPModel
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from transformers import AutoModel, AutoImageProcessor
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
def masked_mse_tiled_mask(
|
12 |
+
image1_pil: Image.Image,
|
13 |
+
image2_pil: Image.Image,
|
14 |
+
tile_mask: Union[np.ndarray, torch.Tensor],
|
15 |
+
tile_size: int = 16
|
16 |
+
) -> float:
|
17 |
+
# Convert images to float32 numpy arrays, normalized [0, 1]
|
18 |
+
img1 = np.asarray(image1_pil).astype(np.float32) / 255.0
|
19 |
+
img2 = np.asarray(image2_pil).astype(np.float32) / 255.0
|
20 |
+
|
21 |
+
# Convert mask to numpy if it's a torch tensor
|
22 |
+
if isinstance(tile_mask, torch.Tensor):
|
23 |
+
tile_mask = tile_mask.detach().cpu().numpy()
|
24 |
+
|
25 |
+
tile_mask = tile_mask.astype(np.float32)
|
26 |
+
|
27 |
+
# Upsample mask using np.kron to match image resolution
|
28 |
+
upsampled_mask = np.expand_dims(np.kron(tile_mask, np.ones((tile_size, tile_size), dtype=np.float32)), axis=-1)
|
29 |
+
|
30 |
+
# Invert mask: 1 = exclude → 0; 0 = include → 1
|
31 |
+
include_mask = 1.0 - upsampled_mask
|
32 |
+
|
33 |
+
# Compute squared difference
|
34 |
+
diff_squared = (img1 - img2) ** 2
|
35 |
+
masked_diff = diff_squared * include_mask
|
36 |
+
|
37 |
+
# Sum and normalize by valid (included) pixels
|
38 |
+
valid_pixel_count = np.sum(include_mask)
|
39 |
+
if valid_pixel_count == 0:
|
40 |
+
raise ValueError("All pixels are masked out. Cannot compute MSE.")
|
41 |
+
|
42 |
+
mse = np.sum(masked_diff) / valid_pixel_count
|
43 |
+
return float(mse)
|
44 |
+
|
45 |
+
|
46 |
+
def compute_clip_similarity(image: Image.Image, prompt: str) -> float:
|
47 |
+
"""
|
48 |
+
Compute CLIP similarity between a PIL image and a text prompt.
|
49 |
+
Loads CLIP model only once and caches it.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
image (PIL.Image.Image): Input image.
|
53 |
+
prompt (str): Text prompt.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
float: Cosine similarity between image and text.
|
57 |
+
"""
|
58 |
+
if not hasattr(compute_clip_similarity, "model"):
|
59 |
+
compute_clip_similarity.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
60 |
+
compute_clip_similarity.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
61 |
+
compute_clip_similarity.model.eval()
|
62 |
+
|
63 |
+
model = compute_clip_similarity.model
|
64 |
+
processor = compute_clip_similarity.processor
|
65 |
+
|
66 |
+
image = image.convert("RGB")
|
67 |
+
image_inputs = processor(images=image, return_tensors="pt")
|
68 |
+
text_inputs = processor(text=[prompt], return_tensors="pt")
|
69 |
+
|
70 |
+
|
71 |
+
with torch.no_grad():
|
72 |
+
image_features = model.get_image_features(**image_inputs)
|
73 |
+
text_features = model.get_text_features(**text_inputs)
|
74 |
+
|
75 |
+
image_features = F.normalize(image_features, p=2, dim=-1)
|
76 |
+
text_features = F.normalize(text_features, p=2, dim=-1)
|
77 |
+
|
78 |
+
similarity = (image_features @ text_features.T).item()
|
79 |
+
|
80 |
+
return similarity
|
81 |
+
|
82 |
+
|
83 |
+
def compute_dinov2_similarity(image1: Image.Image, image2: Image.Image) -> float:
|
84 |
+
"""
|
85 |
+
Compute perceptual similarity between two images using DINOv2 embeddings.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
image1 (PIL.Image.Image): First image.
|
89 |
+
image2 (PIL.Image.Image): Second image.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
float: Cosine similarity between DINOv2 embeddings of the images.
|
93 |
+
"""
|
94 |
+
# Load model and processor only once
|
95 |
+
if not hasattr(compute_dinov2_similarity, "model"):
|
96 |
+
compute_dinov2_similarity.processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
|
97 |
+
compute_dinov2_similarity.model = AutoModel.from_pretrained("facebook/dinov2-base")
|
98 |
+
compute_dinov2_similarity.model.eval()
|
99 |
+
|
100 |
+
processor = compute_dinov2_similarity.processor
|
101 |
+
model = compute_dinov2_similarity.model
|
102 |
+
|
103 |
+
# Preprocess both images
|
104 |
+
inputs = processor(images=[image1.convert("RGB"), image2.convert("RGB")], return_tensors="pt")
|
105 |
+
|
106 |
+
with torch.no_grad():
|
107 |
+
outputs = model(**inputs)
|
108 |
+
features = outputs.last_hidden_state.mean(dim=1) # [CLS] or mean-pooled features
|
109 |
+
|
110 |
+
# Normalize
|
111 |
+
features = F.normalize(features, p=2, dim=-1)
|
112 |
+
|
113 |
+
# Cosine similarity
|
114 |
+
similarity = (features[0] @ features[1].T).item()
|
115 |
+
|
116 |
+
return similarity
|
SDLens/cache_and_edit/qkv_cache.py
ADDED
@@ -0,0 +1,557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Add parent directory to sys.path
|
2 |
+
from collections import defaultdict
|
3 |
+
import gc
|
4 |
+
import os, sys
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
from SDLens.cache_and_edit.flux_pipeline import EditedFluxPipeline
|
8 |
+
parent_dir = Path.cwd().parent.resolve()
|
9 |
+
if str(parent_dir) not in sys.path:
|
10 |
+
sys.path.insert(0, str(parent_dir))
|
11 |
+
|
12 |
+
from typing import Dict, List, Literal, Optional, TypedDict, Type, Union
|
13 |
+
import torch
|
14 |
+
from diffusers.models.attention_processor import Attention
|
15 |
+
from diffusers.models.transformers import FluxTransformer2DModel
|
16 |
+
from diffusers import FluxPipeline
|
17 |
+
from diffusers.models.embeddings import apply_rotary_emb
|
18 |
+
from SDLens.cache_and_edit.hooks import locate_block
|
19 |
+
import torch.nn.functional as F
|
20 |
+
from diffusers.models.attention_processor import FluxAttnProcessor2_0
|
21 |
+
|
22 |
+
class QKVCache(TypedDict):
|
23 |
+
query: List[torch.Tensor]
|
24 |
+
key: List[torch.Tensor]
|
25 |
+
value: List[torch.Tensor]
|
26 |
+
|
27 |
+
|
28 |
+
class CachedFluxAttnProcessor2_0:
|
29 |
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
30 |
+
|
31 |
+
def __init__(self, external_cache: QKVCache,
|
32 |
+
inject_kv: Literal["image", "text", "both"]= None,
|
33 |
+
text_seq_length: int = 512):
|
34 |
+
"""Constructor for Cached attention processor.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
external_cache (QKVCache): cache to store/inject values.
|
38 |
+
inject_kv (Literal["image", "text", "both"], optional): whether to inject image, text or both streams KV.
|
39 |
+
If None, it does not perform injection but the full cache is stored. Defaults to None.
|
40 |
+
"""
|
41 |
+
|
42 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
43 |
+
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
44 |
+
self.cache = external_cache
|
45 |
+
self.inject_kv = inject_kv
|
46 |
+
self.text_seq_length = text_seq_length
|
47 |
+
assert all((cache_key in external_cache) for cache_key in {"query", "key", "value"}), "Cache has to contain 'query', 'key' and 'value' keys."
|
48 |
+
|
49 |
+
def __call__(
|
50 |
+
self,
|
51 |
+
attn: Attention,
|
52 |
+
hidden_states: torch.FloatTensor,
|
53 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
54 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
55 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
56 |
+
) -> torch.FloatTensor:
|
57 |
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
58 |
+
|
59 |
+
# `sample` projections.
|
60 |
+
query = attn.to_q(hidden_states)
|
61 |
+
key = attn.to_k(hidden_states)
|
62 |
+
value = attn.to_v(hidden_states)
|
63 |
+
|
64 |
+
inner_dim = key.shape[-1]
|
65 |
+
head_dim = inner_dim // attn.heads
|
66 |
+
|
67 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
68 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
69 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
70 |
+
|
71 |
+
if attn.norm_q is not None:
|
72 |
+
query = attn.norm_q(query)
|
73 |
+
if attn.norm_k is not None:
|
74 |
+
key = attn.norm_k(key)
|
75 |
+
|
76 |
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
77 |
+
if encoder_hidden_states is not None:
|
78 |
+
# `context` projections.
|
79 |
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
80 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
81 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
82 |
+
|
83 |
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
84 |
+
batch_size, -1, attn.heads, head_dim
|
85 |
+
).transpose(1, 2)
|
86 |
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
87 |
+
batch_size, -1, attn.heads, head_dim
|
88 |
+
).transpose(1, 2)
|
89 |
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
90 |
+
batch_size, -1, attn.heads, head_dim
|
91 |
+
).transpose(1, 2)
|
92 |
+
|
93 |
+
if attn.norm_added_q is not None:
|
94 |
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
95 |
+
if attn.norm_added_k is not None:
|
96 |
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
97 |
+
|
98 |
+
# attention
|
99 |
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
100 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
101 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
102 |
+
|
103 |
+
if image_rotary_emb is not None:
|
104 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
105 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
106 |
+
|
107 |
+
# Cache Q, K, V
|
108 |
+
if self.inject_kv == "image":
|
109 |
+
# NOTE: I am replacing key and values only for the image branch
|
110 |
+
# NOTE: in default settings, encoder_hidden_states_key_proh.shape[2] == 512
|
111 |
+
# the first element of the batch is the image whose key and value will be injected into all the other images
|
112 |
+
key[1:, :, self.text_seq_length:] = key[:1, :, self.text_seq_length:]
|
113 |
+
value[1:, :, self.text_seq_length:] = value[:1, :, self.text_seq_length:]
|
114 |
+
elif self.inject_kv == "text":
|
115 |
+
key[1:, :, :self.text_seq_length] = key[:1, :, :self.text_seq_length]
|
116 |
+
value[1:, :, :self.text_seq_length] = value[:1, :, :self.text_seq_length]
|
117 |
+
elif self.inject_kv == "both":
|
118 |
+
key[1:] = key[:1]
|
119 |
+
value[1:] = value[:1]
|
120 |
+
else: # Don't inject, store cache!
|
121 |
+
self.cache["query"].append(query)
|
122 |
+
self.cache["key"].append(key)
|
123 |
+
self.cache["value"].append(value)
|
124 |
+
|
125 |
+
hidden_states = F.scaled_dot_product_attention(
|
126 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
127 |
+
)
|
128 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
129 |
+
hidden_states = hidden_states.to(query.dtype)
|
130 |
+
|
131 |
+
|
132 |
+
if encoder_hidden_states is not None:
|
133 |
+
encoder_hidden_states, hidden_states = (
|
134 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
135 |
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
136 |
+
)
|
137 |
+
|
138 |
+
# linear proj
|
139 |
+
hidden_states = attn.to_out[0](hidden_states)
|
140 |
+
# dropout
|
141 |
+
hidden_states = attn.to_out[1](hidden_states)
|
142 |
+
|
143 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
144 |
+
|
145 |
+
return hidden_states, encoder_hidden_states
|
146 |
+
else:
|
147 |
+
return hidden_states
|
148 |
+
|
149 |
+
|
150 |
+
class CachedFluxAttnProcessor3_0:
|
151 |
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
152 |
+
|
153 |
+
def __init__(self, external_cache: QKVCache,
|
154 |
+
inject_kv: Literal["image", "text", "both"]= None,
|
155 |
+
inject_kv_foreground: bool = False,
|
156 |
+
text_seq_length: int = 512,
|
157 |
+
q_mask: Optional[torch.Tensor] = None,):
|
158 |
+
"""Constructor for Cached attention processor.
|
159 |
+
|
160 |
+
Args:
|
161 |
+
external_cache (QKVCache): cache to store/inject values.
|
162 |
+
inject_kv (Literal["image", "text", "both"], optional): whether to inject image, text or both streams KV.
|
163 |
+
If None, it does not perform injection but the full cache is stored. Defaults to None.
|
164 |
+
"""
|
165 |
+
|
166 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
167 |
+
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
168 |
+
self.cache = external_cache
|
169 |
+
self.inject_kv = inject_kv
|
170 |
+
self.inject_kv_foreground = inject_kv_foreground
|
171 |
+
self.text_seq_length = text_seq_length
|
172 |
+
self.q_mask = q_mask
|
173 |
+
assert all((cache_key in external_cache) for cache_key in {"query", "key", "value"}), "Cache has to contain 'query', 'key' and 'value' keys."
|
174 |
+
|
175 |
+
def __call__(
|
176 |
+
self,
|
177 |
+
attn: Attention,
|
178 |
+
hidden_states: torch.FloatTensor,
|
179 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
180 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
181 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
182 |
+
) -> torch.FloatTensor:
|
183 |
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
184 |
+
|
185 |
+
# `sample` projections.
|
186 |
+
query = attn.to_q(hidden_states)
|
187 |
+
key = attn.to_k(hidden_states)
|
188 |
+
value = attn.to_v(hidden_states)
|
189 |
+
|
190 |
+
inner_dim = key.shape[-1]
|
191 |
+
head_dim = inner_dim // attn.heads
|
192 |
+
|
193 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
194 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
195 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
196 |
+
|
197 |
+
if attn.norm_q is not None:
|
198 |
+
query = attn.norm_q(query)
|
199 |
+
if attn.norm_k is not None:
|
200 |
+
key = attn.norm_k(key)
|
201 |
+
|
202 |
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
203 |
+
if encoder_hidden_states is not None:
|
204 |
+
# `context` projections.
|
205 |
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
206 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
207 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
208 |
+
|
209 |
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
210 |
+
batch_size, -1, attn.heads, head_dim
|
211 |
+
).transpose(1, 2)
|
212 |
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
213 |
+
batch_size, -1, attn.heads, head_dim
|
214 |
+
).transpose(1, 2)
|
215 |
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
216 |
+
batch_size, -1, attn.heads, head_dim
|
217 |
+
).transpose(1, 2)
|
218 |
+
|
219 |
+
if attn.norm_added_q is not None:
|
220 |
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
221 |
+
if attn.norm_added_k is not None:
|
222 |
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
223 |
+
|
224 |
+
# attention
|
225 |
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
226 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
227 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
228 |
+
|
229 |
+
|
230 |
+
# # Cache Q, K, V
|
231 |
+
# if self.inject_kv == "image":
|
232 |
+
# # NOTE: I am replacing key and values only for the image branch
|
233 |
+
# # NOTE: in default settings, encoder_hidden_states_key_proh.shape[2] == 512
|
234 |
+
# # the first element of the batch is the image whose key and value will be injected into all the other images
|
235 |
+
# key[1:, :, self.text_seq_length:] = key[:1, :, self.text_seq_length:]
|
236 |
+
# value[1:, :, self.text_seq_length:] = value[:1, :, self.text_seq_length:]
|
237 |
+
# elif self.inject_kv == "text":
|
238 |
+
# key[1:, :, :self.text_seq_length] = key[:1, :, :self.text_seq_length]
|
239 |
+
# value[1:, :, :self.text_seq_length] = value[:1, :, :self.text_seq_length]
|
240 |
+
# elif self.inject_kv == "both":
|
241 |
+
# key[1:] = key[:1]
|
242 |
+
# value[1:] = value[:1]
|
243 |
+
# else: # Don't inject, store cache!
|
244 |
+
# self.cache["query"].append(query)
|
245 |
+
# self.cache["key"].append(key)
|
246 |
+
# self.cache["value"].append(value)
|
247 |
+
|
248 |
+
# extend the mask to match key and values dimension:
|
249 |
+
# Shape of mask is: (num_image_tokens, 1)
|
250 |
+
mask = self.q_mask.permute(1, 0).unsqueeze(0).unsqueeze(-1) # Shape: (1, num_image_tokens, 1, 1)
|
251 |
+
# put mask on gpu
|
252 |
+
mask = mask.to(key.device)
|
253 |
+
# first check that we inject only kv in images:
|
254 |
+
if self.inject_kv is not None and self.inject_kv != "image":
|
255 |
+
raise NotImplementedError("Injecting is implemented only for images.")
|
256 |
+
# the second element of the batch is the number of heads
|
257 |
+
# The first element of the batch represents the background image, the second element of the batch
|
258 |
+
# represents the foreground image. The third element represents the image where we want to inject
|
259 |
+
# the key and value of the background image and foreground image according to the query mask.
|
260 |
+
# Inject from background (element 0) where mask is True
|
261 |
+
|
262 |
+
if image_rotary_emb is not None:
|
263 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
264 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
265 |
+
|
266 |
+
# Get the index range after the text tokens
|
267 |
+
start_idx = self.text_seq_length
|
268 |
+
|
269 |
+
if self.inject_kv_foreground and self.inject_kv == "image":
|
270 |
+
key[2:, :, start_idx:] = torch.where(mask, key[1:2, :, start_idx:], key[:1, :, start_idx:])
|
271 |
+
value[2:, :, start_idx:] = torch.where(mask, value[1:2, :, start_idx:], value[:1, :, start_idx:])
|
272 |
+
elif self.inject_kv == "image" and not self.inject_kv_foreground:
|
273 |
+
key[2:, :, start_idx:] = torch.where(mask, key[2:, :, start_idx:], key[:1, :, start_idx:])
|
274 |
+
value[2:, :, start_idx:] = torch.where(mask, value[2:, :, start_idx:], value[:1, :, start_idx:])
|
275 |
+
elif self.inject_kv is None and self.inject_kv_foreground:
|
276 |
+
key[2:, :, start_idx:] = torch.where(mask, key[1:2, :, start_idx:], key[2:, :, start_idx:])
|
277 |
+
value[2:, :, start_idx:] = torch.where(mask, value[1:2, :, start_idx:], value[2:, :, start_idx:])
|
278 |
+
|
279 |
+
hidden_states = F.scaled_dot_product_attention(
|
280 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
281 |
+
)
|
282 |
+
# mask hidden states from bg:
|
283 |
+
# hidden_states = hidden_states_fg[:, :, start_idx:] * mask + hidden_states_bg[:, :, start_idx:] * (~mask)
|
284 |
+
|
285 |
+
# concatenate the text
|
286 |
+
#hidden_states = torch.cat([hidden_states_bg[:, :, :start_idx], hidden_states], dim=2)
|
287 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
288 |
+
hidden_states = hidden_states.to(query.dtype)
|
289 |
+
|
290 |
+
|
291 |
+
if encoder_hidden_states is not None:
|
292 |
+
encoder_hidden_states, hidden_states = (
|
293 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
294 |
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
295 |
+
)
|
296 |
+
|
297 |
+
# linear proj
|
298 |
+
hidden_states = attn.to_out[0](hidden_states)
|
299 |
+
# dropout
|
300 |
+
hidden_states = attn.to_out[1](hidden_states)
|
301 |
+
|
302 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
303 |
+
|
304 |
+
return hidden_states, encoder_hidden_states
|
305 |
+
else:
|
306 |
+
return hidden_states
|
307 |
+
|
308 |
+
|
309 |
+
class QKVCacheFluxHandler:
|
310 |
+
"""Used to cache queries, keys and values of a FluxPipeline.
|
311 |
+
"""
|
312 |
+
|
313 |
+
def __init__(self, pipe: Union[FluxPipeline, EditedFluxPipeline],
|
314 |
+
positions_to_cache: List[str] = None,
|
315 |
+
positions_to_cache_foreground: List[str] = None,
|
316 |
+
inject_kv: Literal["image", "text", "both"] = None,
|
317 |
+
text_seq_length: int = 512,
|
318 |
+
q_mask: Optional[torch.Tensor] = None,
|
319 |
+
processor_class: Optional[Type] = CachedFluxAttnProcessor3_0
|
320 |
+
):
|
321 |
+
|
322 |
+
print(type(pipe))
|
323 |
+
if not isinstance(pipe, FluxPipeline) and not isinstance(pipe, EditedFluxPipeline):
|
324 |
+
raise NotImplementedError(f"QKVCache not yet implemented for {type(pipe)}.")
|
325 |
+
|
326 |
+
self.pipe = pipe
|
327 |
+
|
328 |
+
if positions_to_cache is not None:
|
329 |
+
self.positions_to_cache = positions_to_cache
|
330 |
+
else:
|
331 |
+
# act on all transformer layers
|
332 |
+
self.positions_to_cache = []
|
333 |
+
|
334 |
+
if positions_to_cache_foreground is not None:
|
335 |
+
self.positions_to_cache_foreground = positions_to_cache_foreground
|
336 |
+
else:
|
337 |
+
self.positions_to_cache_foreground = []
|
338 |
+
|
339 |
+
self._cache = {"query": [], "key": [], "value": []}
|
340 |
+
|
341 |
+
# Set Cached Processor to perform editing
|
342 |
+
|
343 |
+
all_layers = [f"transformer.transformer_blocks.{i}" for i in range(19)] + \
|
344 |
+
[f"transformer.single_transformer_blocks.{i}" for i in range(38)]
|
345 |
+
for module_name in all_layers:
|
346 |
+
|
347 |
+
inject_kv = "image" if module_name in self.positions_to_cache else None
|
348 |
+
inject_kv_foreground = module_name in self.positions_to_cache_foreground
|
349 |
+
|
350 |
+
|
351 |
+
module = locate_block(pipe, module_name)
|
352 |
+
module.attn.set_processor(processor_class(external_cache=self._cache,
|
353 |
+
inject_kv=inject_kv,
|
354 |
+
inject_kv_foreground=inject_kv_foreground,
|
355 |
+
text_seq_length=text_seq_length,
|
356 |
+
q_mask=q_mask,
|
357 |
+
))
|
358 |
+
|
359 |
+
|
360 |
+
@property
|
361 |
+
def cache(self) -> QKVCache:
|
362 |
+
"""Returns a dictionary initialized as {"query": [], "key": [], "value": []}.
|
363 |
+
After calling a forward pass for pipe, queries, keys and values will be
|
364 |
+
appended in the respective list for each layer.
|
365 |
+
|
366 |
+
Returns:
|
367 |
+
Dict[str, List[torch.Tensor]]: cache dictionary containing 'query', 'key' and 'value'
|
368 |
+
"""
|
369 |
+
return self._cache
|
370 |
+
|
371 |
+
def clear_cache(self) -> None:
|
372 |
+
# TODO: check if we have to force clean GPU memory
|
373 |
+
del(self._cache)
|
374 |
+
gc.collect() # force Python to clean up unreachable objects
|
375 |
+
torch.cuda.empty_cache() # tell PyTorch to release unused GPU memory from its cache
|
376 |
+
self._cache = {"query": [], "key": [], "value": []}
|
377 |
+
|
378 |
+
for module_name in self.positions_to_cache:
|
379 |
+
module = locate_block(self.pipe, module_name)
|
380 |
+
module.attn.set_processor(FluxAttnProcessor2_0())
|
381 |
+
|
382 |
+
|
383 |
+
class TFICONAttnProcessor:
|
384 |
+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
385 |
+
|
386 |
+
def __init__(self,
|
387 |
+
external_cache: QKVCache,
|
388 |
+
inject_kv: Literal["image", "text", "both"]= None,
|
389 |
+
inject_kv_foreground: bool = False,
|
390 |
+
text_seq_length: int = 512,
|
391 |
+
q_mask: Optional[torch.Tensor] = None,
|
392 |
+
call_max_times = None,
|
393 |
+
inject_q = True,
|
394 |
+
inject_k = True,
|
395 |
+
inject_v = True,
|
396 |
+
):
|
397 |
+
"""Constructor for Cached attention processor.
|
398 |
+
|
399 |
+
Args:
|
400 |
+
external_cache (QKVCache): cache to store/inject values.
|
401 |
+
inject_kv (Literal["image", "text", "both"], optional): whether to inject image, text or both streams KV.
|
402 |
+
If None, it does not perform injection but the full cache is stored. Defaults to None.
|
403 |
+
"""
|
404 |
+
|
405 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
406 |
+
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
407 |
+
self.cache = external_cache
|
408 |
+
self.inject_kv = inject_kv
|
409 |
+
self.inject_kv_foreground = inject_kv_foreground
|
410 |
+
self.text_seq_length = text_seq_length
|
411 |
+
self.q_mask = q_mask
|
412 |
+
self.inject_q = inject_q
|
413 |
+
self.inject_k = inject_k
|
414 |
+
self.inject_v = inject_v
|
415 |
+
|
416 |
+
self.call_max_times = call_max_times
|
417 |
+
if self.call_max_times is not None:
|
418 |
+
self.num_calls = call_max_times
|
419 |
+
else:
|
420 |
+
self.num_calls = None
|
421 |
+
assert all((cache_key in external_cache) for cache_key in {"query", "key", "value"}), "Cache has to contain 'query', 'key' and 'value' keys."
|
422 |
+
|
423 |
+
def __call__(
|
424 |
+
self,
|
425 |
+
attn: Attention,
|
426 |
+
hidden_states: torch.FloatTensor,
|
427 |
+
encoder_hidden_states: torch.FloatTensor = None,
|
428 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
429 |
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
430 |
+
) -> torch.FloatTensor:
|
431 |
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
432 |
+
|
433 |
+
# `sample` projections.
|
434 |
+
query = attn.to_q(hidden_states)
|
435 |
+
key = attn.to_k(hidden_states)
|
436 |
+
value = attn.to_v(hidden_states)
|
437 |
+
|
438 |
+
inner_dim = key.shape[-1]
|
439 |
+
head_dim = inner_dim // attn.heads
|
440 |
+
|
441 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
442 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
443 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
444 |
+
|
445 |
+
if attn.norm_q is not None:
|
446 |
+
query = attn.norm_q(query)
|
447 |
+
if attn.norm_k is not None:
|
448 |
+
key = attn.norm_k(key)
|
449 |
+
|
450 |
+
# hidden states are the image patches (B, 4096, hidden_dim)
|
451 |
+
|
452 |
+
# encoder_hidden_states are the text tokens (B, 512, hidden_dim)
|
453 |
+
|
454 |
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
455 |
+
if encoder_hidden_states is not None:
|
456 |
+
# `context` projections.
|
457 |
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
458 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
459 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
460 |
+
|
461 |
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
462 |
+
batch_size, -1, attn.heads, head_dim
|
463 |
+
).transpose(1, 2)
|
464 |
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
465 |
+
batch_size, -1, attn.heads, head_dim
|
466 |
+
).transpose(1, 2)
|
467 |
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
468 |
+
batch_size, -1, attn.heads, head_dim
|
469 |
+
).transpose(1, 2)
|
470 |
+
|
471 |
+
if attn.norm_added_q is not None:
|
472 |
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
473 |
+
if attn.norm_added_k is not None:
|
474 |
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
475 |
+
|
476 |
+
# concat inputs for attention -> (B, num_heads, 512 + 4096, head_dim)
|
477 |
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
478 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
479 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
480 |
+
|
481 |
+
# TODO: try first witout mask
|
482 |
+
# Cache Q, K, V
|
483 |
+
# extend the mask to match key and values dimension:
|
484 |
+
# Shape of mask is: (num_image_tokens, 1)
|
485 |
+
mask = self.q_mask.permute(1, 0).unsqueeze(0).unsqueeze(-1) # Shape: (1, num_image_tokens, 1, 1)
|
486 |
+
# put mask on gpu
|
487 |
+
mask = mask.to(key.device)
|
488 |
+
# first check that we inject only kv in images:
|
489 |
+
if self.inject_kv is not None and self.inject_kv != "image":
|
490 |
+
raise NotImplementedError("Injecting is implemented only for images.")
|
491 |
+
# the second element of the batch is the number of heads
|
492 |
+
# The first element of the batch represents the background image, the second element of the batch
|
493 |
+
# represents the foreground image. The third element represents the image where we want to inject
|
494 |
+
# the key and value of the background image and foreground image according to the query mask.
|
495 |
+
# Inject from background (element 0) where mask is True
|
496 |
+
|
497 |
+
if image_rotary_emb is not None:
|
498 |
+
query = apply_rotary_emb(query, image_rotary_emb)
|
499 |
+
key = apply_rotary_emb(key, image_rotary_emb)
|
500 |
+
|
501 |
+
# Get the index range after the text tokens
|
502 |
+
start_idx = self.text_seq_length
|
503 |
+
|
504 |
+
# Batch is formed as follow:
|
505 |
+
# - background image (0)
|
506 |
+
# - foreground image (1)
|
507 |
+
# - composition(s) (2, 3, ...)
|
508 |
+
# Create the combined attention mask, by forming Q_comp and K_comp, taking the Q and K of the background image
|
509 |
+
# when outside of the mask, the one of the foreground image when inside the mask
|
510 |
+
|
511 |
+
if self.num_calls is None or self.num_calls > 0:
|
512 |
+
if self.inject_kv_foreground:
|
513 |
+
if self.inject_k:
|
514 |
+
key[2:, :, start_idx:] = torch.where(mask, key[1:2, :, start_idx:], key[0:1, :, start_idx:])
|
515 |
+
if self.inject_q:
|
516 |
+
query[2:, :, start_idx:] = torch.where(mask, query[1:2, :, start_idx:], query[0:1, :, start_idx:])
|
517 |
+
if self.inject_v:
|
518 |
+
value[2:, :, start_idx:] = torch.where(mask, value[1:2, :, start_idx:], value[0:1, :, start_idx:])
|
519 |
+
else:
|
520 |
+
if self.inject_k:
|
521 |
+
key[2:, :, start_idx:] = torch.where(mask, key[2:, :, start_idx:], key[0:1, :, start_idx:])
|
522 |
+
if self.inject_q:
|
523 |
+
query[2:, :, start_idx:] = torch.where(mask, query[2:, :, start_idx:], query[0:1, :, start_idx:])
|
524 |
+
if self.inject_v:
|
525 |
+
value[2:, :, start_idx:] = torch.where(mask, value[2:, :, start_idx:], value[0:1, :, start_idx:])
|
526 |
+
|
527 |
+
if self.num_calls is not None:
|
528 |
+
self.num_calls -= 1
|
529 |
+
|
530 |
+
|
531 |
+
# Use the combined attention map to compute attention using V from the composition image
|
532 |
+
hidden_states = F.scaled_dot_product_attention(
|
533 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
534 |
+
)
|
535 |
+
|
536 |
+
# hidden_states[2:, :, start_idx:] = torch.where(mask, weightage * hidden_states[1:2, :, start_idx:] + (1-weightage) * hidden_states[2:, :, start_idx:], hidden_states[2:, :, start_idx:])
|
537 |
+
|
538 |
+
# concatenate the text
|
539 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
540 |
+
hidden_states = hidden_states.to(query.dtype)
|
541 |
+
|
542 |
+
if encoder_hidden_states is not None:
|
543 |
+
encoder_hidden_states, hidden_states = (
|
544 |
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
545 |
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
546 |
+
)
|
547 |
+
|
548 |
+
# linear proj
|
549 |
+
hidden_states = attn.to_out[0](hidden_states)
|
550 |
+
# dropout
|
551 |
+
hidden_states = attn.to_out[1](hidden_states)
|
552 |
+
|
553 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
554 |
+
|
555 |
+
return hidden_states, encoder_hidden_states
|
556 |
+
else:
|
557 |
+
return hidden_states
|
SDLens/cache_and_edit/scheduler_inversion.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple, Union
|
2 |
+
import torch
|
3 |
+
from diffusers.configuration_utils import register_to_config
|
4 |
+
from diffusers.schedulers.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler, FlowMatchEulerDiscreteSchedulerOutput
|
5 |
+
|
6 |
+
|
7 |
+
class FlowMatchEulerDiscreteSchedulerForInversion(FlowMatchEulerDiscreteScheduler):
|
8 |
+
|
9 |
+
@register_to_config
|
10 |
+
def __init__(self, inverse: bool, **kwargs):
|
11 |
+
super().__init__(**kwargs)
|
12 |
+
self.inverse = inverse
|
13 |
+
|
14 |
+
|
15 |
+
def step(
|
16 |
+
self,
|
17 |
+
model_output: torch.FloatTensor,
|
18 |
+
timestep: Union[float, torch.FloatTensor],
|
19 |
+
sample: torch.FloatTensor,
|
20 |
+
s_churn: float = 0.0,
|
21 |
+
s_tmin: float = 0.0,
|
22 |
+
s_tmax: float = float("inf"),
|
23 |
+
s_noise: float = 1.0,
|
24 |
+
generator: Optional[torch.Generator] = None,
|
25 |
+
return_dict: bool = True,
|
26 |
+
) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]:
|
27 |
+
"""
|
28 |
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
29 |
+
process from the learned model outputs (most often the predicted noise).
|
30 |
+
|
31 |
+
Args:
|
32 |
+
model_output (`torch.FloatTensor`):
|
33 |
+
The direct output from learned diffusion model.
|
34 |
+
timestep (`float`):
|
35 |
+
The current discrete timestep in the diffusion chain.
|
36 |
+
sample (`torch.FloatTensor`):
|
37 |
+
A current instance of a sample created by the diffusion process.
|
38 |
+
s_churn (`float`):
|
39 |
+
s_tmin (`float`):
|
40 |
+
s_tmax (`float`):
|
41 |
+
s_noise (`float`, defaults to 1.0):
|
42 |
+
Scaling factor for noise added to the sample.
|
43 |
+
generator (`torch.Generator`, *optional*):
|
44 |
+
A random number generator.
|
45 |
+
return_dict (`bool`):
|
46 |
+
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
|
47 |
+
tuple.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
|
51 |
+
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
|
52 |
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
53 |
+
"""
|
54 |
+
|
55 |
+
if (
|
56 |
+
isinstance(timestep, int)
|
57 |
+
or isinstance(timestep, torch.IntTensor)
|
58 |
+
or isinstance(timestep, torch.LongTensor)
|
59 |
+
):
|
60 |
+
raise ValueError(
|
61 |
+
(
|
62 |
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
63 |
+
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
64 |
+
" one of the `scheduler.timesteps` as a timestep."
|
65 |
+
),
|
66 |
+
)
|
67 |
+
|
68 |
+
if self.step_index is None:
|
69 |
+
self._init_step_index(timestep)
|
70 |
+
|
71 |
+
# Upcast to avoid precision issues when computing prev_sample
|
72 |
+
sample = sample.to(torch.float32)
|
73 |
+
|
74 |
+
sigma = self.sigmas[self.step_index]
|
75 |
+
sigma_next = self.sigmas[self.step_index + 1]
|
76 |
+
|
77 |
+
if self.inverse:
|
78 |
+
next_sample = sample + (sigma - sigma_next) * model_output
|
79 |
+
# Cast sample back to model compatible dtype
|
80 |
+
next_sample = next_sample.to(model_output.dtype)
|
81 |
+
# upon completion increase step index by one
|
82 |
+
self._step_index -= 1
|
83 |
+
|
84 |
+
if not return_dict:
|
85 |
+
return (next_sample,)
|
86 |
+
|
87 |
+
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=next_sample)
|
88 |
+
else:
|
89 |
+
prev_sample = sample + (sigma_next - sigma) * model_output
|
90 |
+
# Cast sample back to model compatible dtype
|
91 |
+
prev_sample = prev_sample.to(model_output.dtype)
|
92 |
+
# upon completion increase step index by one
|
93 |
+
self._step_index += 1
|
94 |
+
|
95 |
+
if not return_dict:
|
96 |
+
return (prev_sample,)
|
97 |
+
|
98 |
+
return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample)
|
SDLens/hooked_scheduler.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import DDPMScheduler
|
2 |
+
import torch
|
3 |
+
|
4 |
+
class HookedNoiseScheduler:
|
5 |
+
scheduler: DDPMScheduler
|
6 |
+
pre_hooks: list
|
7 |
+
post_hooks: list
|
8 |
+
|
9 |
+
def __init__(self, scheduler):
|
10 |
+
object.__setattr__(self, 'scheduler', scheduler)
|
11 |
+
object.__setattr__(self, 'pre_hooks', [])
|
12 |
+
object.__setattr__(self, 'post_hooks', [])
|
13 |
+
|
14 |
+
def step(
|
15 |
+
self,
|
16 |
+
model_output, timestep, sample, generator, return_dict
|
17 |
+
):
|
18 |
+
assert return_dict == False, "return_dict == True is not implemented"
|
19 |
+
for hook in self.pre_hooks:
|
20 |
+
hook_output = hook(model_output, timestep, sample, generator)
|
21 |
+
if hook_output is not None:
|
22 |
+
model_output, timestep, sample, generator = hook_output
|
23 |
+
|
24 |
+
(pred_prev_sample, ) = self.scheduler.step(model_output, timestep, sample, generator, return_dict)
|
25 |
+
|
26 |
+
for hook in self.post_hooks:
|
27 |
+
hook_output = hook(pred_prev_sample)
|
28 |
+
if hook_output is not None:
|
29 |
+
pred_prev_sample = hook_output
|
30 |
+
|
31 |
+
return (pred_prev_sample, )
|
32 |
+
|
33 |
+
def __getattr__(self, name):
|
34 |
+
return getattr(self.scheduler, name)
|
35 |
+
|
36 |
+
def __setattr__(self, name, value):
|
37 |
+
if name in {'scheduler', 'pre_hooks', 'post_hooks'}:
|
38 |
+
object.__setattr__(self, name, value)
|
39 |
+
else:
|
40 |
+
setattr(self.scheduler, name, value)
|
SDLens/hooked_sd_pipeline.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import einops
|
2 |
+
from diffusers import StableDiffusionXLPipeline, IFPipeline
|
3 |
+
from typing import List, Dict, Callable, Union
|
4 |
+
import torch
|
5 |
+
from .hooked_scheduler import HookedNoiseScheduler
|
6 |
+
|
7 |
+
def retrieve(io):
|
8 |
+
if isinstance(io, tuple):
|
9 |
+
if len(io) == 1:
|
10 |
+
return io[0]
|
11 |
+
else:
|
12 |
+
raise ValueError("A tuple should have length of 1")
|
13 |
+
elif isinstance(io, torch.Tensor):
|
14 |
+
return io
|
15 |
+
else:
|
16 |
+
raise ValueError("Input/Output must be a tensor, or 1-element tuple")
|
17 |
+
|
18 |
+
|
19 |
+
class HookedDiffusionAbstractPipeline:
|
20 |
+
parent_cls = None
|
21 |
+
pipe = None
|
22 |
+
|
23 |
+
def __init__(self, pipe: parent_cls, use_hooked_scheduler: bool = False):
|
24 |
+
if use_hooked_scheduler:
|
25 |
+
pipe.scheduler = HookedNoiseScheduler(pipe.scheduler)
|
26 |
+
self.__dict__['pipe'] = pipe
|
27 |
+
self.use_hooked_scheduler = use_hooked_scheduler
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def from_pretrained(cls, *args, **kwargs):
|
31 |
+
return cls(cls.parent_cls.from_pretrained(*args, **kwargs))
|
32 |
+
|
33 |
+
|
34 |
+
def run_with_hooks(self,
|
35 |
+
*args,
|
36 |
+
position_hook_dict: Dict[str, Union[Callable, List[Callable]]],
|
37 |
+
**kwargs
|
38 |
+
):
|
39 |
+
'''
|
40 |
+
Run the pipeline with hooks at specified positions.
|
41 |
+
Returns the final output.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
*args: Arguments to pass to the pipeline.
|
45 |
+
position_hook_dict: A dictionary mapping positions to hooks.
|
46 |
+
The keys are positions in the pipeline where the hooks should be registered.
|
47 |
+
The values are either a single hook or a list of hooks to be registered at the specified position.
|
48 |
+
Each hook should be a callable that takes three arguments: (module, input, output).
|
49 |
+
**kwargs: Keyword arguments to pass to the pipeline.
|
50 |
+
'''
|
51 |
+
hooks = []
|
52 |
+
for position, hook in position_hook_dict.items():
|
53 |
+
if isinstance(hook, list):
|
54 |
+
for h in hook:
|
55 |
+
hooks.append(self._register_general_hook(position, h))
|
56 |
+
else:
|
57 |
+
hooks.append(self._register_general_hook(position, hook))
|
58 |
+
|
59 |
+
hooks = [hook for hook in hooks if hook is not None]
|
60 |
+
|
61 |
+
try:
|
62 |
+
output = self.pipe(*args, **kwargs)
|
63 |
+
finally:
|
64 |
+
for hook in hooks:
|
65 |
+
hook.remove()
|
66 |
+
if self.use_hooked_scheduler:
|
67 |
+
self.pipe.scheduler.pre_hooks = []
|
68 |
+
self.pipe.scheduler.post_hooks = []
|
69 |
+
|
70 |
+
return output
|
71 |
+
|
72 |
+
def run_with_cache(self,
|
73 |
+
*args,
|
74 |
+
positions_to_cache: List[str],
|
75 |
+
save_input: bool = False,
|
76 |
+
save_output: bool = True,
|
77 |
+
**kwargs
|
78 |
+
):
|
79 |
+
'''
|
80 |
+
Run the pipeline with caching at specified positions.
|
81 |
+
|
82 |
+
This method allows you to cache the intermediate inputs and/or outputs of the pipeline
|
83 |
+
at certain positions. The final output of the pipeline and a dictionary of cached values
|
84 |
+
are returned.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
*args: Arguments to pass to the pipeline.
|
88 |
+
positions_to_cache (List[str]): A list of positions in the pipeline where intermediate
|
89 |
+
inputs/outputs should be cached.
|
90 |
+
save_input (bool, optional): If True, caches the input at each specified position.
|
91 |
+
Defaults to False.
|
92 |
+
save_output (bool, optional): If True, caches the output at each specified position.
|
93 |
+
Defaults to True.
|
94 |
+
**kwargs: Keyword arguments to pass to the pipeline.
|
95 |
+
|
96 |
+
Returns:
|
97 |
+
final_output: The final output of the pipeline after execution.
|
98 |
+
cache_dict (Dict[str, Dict[str, Any]]): A dictionary where keys are the specified positions
|
99 |
+
and values are dictionaries containing the cached 'input' and/or 'output' at each position,
|
100 |
+
depending on the flags `save_input` and `save_output`.
|
101 |
+
'''
|
102 |
+
cache_input, cache_output = dict() if save_input else None, dict() if save_output else None
|
103 |
+
hooks = [
|
104 |
+
self._register_cache_hook(position, cache_input, cache_output) for position in positions_to_cache
|
105 |
+
]
|
106 |
+
hooks = [hook for hook in hooks if hook is not None]
|
107 |
+
output = self.pipe(*args, **kwargs)
|
108 |
+
for hook in hooks:
|
109 |
+
hook.remove()
|
110 |
+
if self.use_hooked_scheduler:
|
111 |
+
self.pipe.scheduler.pre_hooks = []
|
112 |
+
self.pipe.scheduler.post_hooks = []
|
113 |
+
|
114 |
+
cache_dict = {}
|
115 |
+
if save_input:
|
116 |
+
for position, block in cache_input.items():
|
117 |
+
cache_input[position] = torch.stack(block, dim=1)
|
118 |
+
cache_dict['input'] = cache_input
|
119 |
+
|
120 |
+
if save_output:
|
121 |
+
for position, block in cache_output.items():
|
122 |
+
cache_output[position] = torch.stack(block, dim=1)
|
123 |
+
cache_dict['output'] = cache_output
|
124 |
+
return output, cache_dict
|
125 |
+
|
126 |
+
def run_with_hooks_and_cache(self,
|
127 |
+
*args,
|
128 |
+
position_hook_dict: Dict[str, Union[Callable, List[Callable]]],
|
129 |
+
positions_to_cache: List[str] = [],
|
130 |
+
save_input: bool = False,
|
131 |
+
save_output: bool = True,
|
132 |
+
**kwargs
|
133 |
+
):
|
134 |
+
'''
|
135 |
+
Run the pipeline with hooks and caching at specified positions.
|
136 |
+
|
137 |
+
This method allows you to register hooks at certain positions in the pipeline and
|
138 |
+
cache intermediate inputs and/or outputs at specified positions. Hooks can be used
|
139 |
+
for inspecting or modifying the pipeline's execution, and caching stores intermediate
|
140 |
+
values for later inspection or use.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
*args: Arguments to pass to the pipeline.
|
144 |
+
position_hook_dict Dict[str, Union[Callable, List[Callable]]]:
|
145 |
+
A dictionary where the keys are the positions in the pipeline, and the values
|
146 |
+
are hooks (either a single hook or a list of hooks) to be registered at those positions.
|
147 |
+
Each hook should be a callable that accepts three arguments: (module, input, output).
|
148 |
+
positions_to_cache (List[str], optional): A list of positions in the pipeline where
|
149 |
+
intermediate inputs/outputs should be cached. Defaults to an empty list.
|
150 |
+
save_input (bool, optional): If True, caches the input at each specified position.
|
151 |
+
Defaults to False.
|
152 |
+
save_output (bool, optional): If True, caches the output at each specified position.
|
153 |
+
Defaults to True.
|
154 |
+
**kwargs: Additional keyword arguments to pass to the pipeline.
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
final_output: The final output of the pipeline after execution.
|
158 |
+
cache_dict (Dict[str, Dict[str, Any]]): A dictionary where keys are the specified positions
|
159 |
+
and values are dictionaries containing the cached 'input' and/or 'output' at each position,
|
160 |
+
depending on the flags `save_input` and `save_output`.
|
161 |
+
'''
|
162 |
+
cache_input, cache_output = dict() if save_input else None, dict() if save_output else None
|
163 |
+
hooks = [
|
164 |
+
self._register_cache_hook(position, cache_input, cache_output) for position in positions_to_cache
|
165 |
+
]
|
166 |
+
|
167 |
+
for position, hook in position_hook_dict.items():
|
168 |
+
if isinstance(hook, list):
|
169 |
+
for h in hook:
|
170 |
+
hooks.append(self._register_general_hook(position, h))
|
171 |
+
else:
|
172 |
+
hooks.append(self._register_general_hook(position, hook))
|
173 |
+
|
174 |
+
hooks = [hook for hook in hooks if hook is not None]
|
175 |
+
output = self.pipe(*args, **kwargs)
|
176 |
+
for hook in hooks:
|
177 |
+
hook.remove()
|
178 |
+
if self.use_hooked_scheduler:
|
179 |
+
self.pipe.scheduler.pre_hooks = []
|
180 |
+
self.pipe.scheduler.post_hooks = []
|
181 |
+
|
182 |
+
cache_dict = {}
|
183 |
+
if save_input:
|
184 |
+
for position, block in cache_input.items():
|
185 |
+
cache_input[position] = torch.stack(block, dim=1)
|
186 |
+
cache_dict['input'] = cache_input
|
187 |
+
|
188 |
+
if save_output:
|
189 |
+
for position, block in cache_output.items():
|
190 |
+
cache_output[position] = torch.stack(block, dim=1)
|
191 |
+
cache_dict['output'] = cache_output
|
192 |
+
|
193 |
+
return output, cache_dict
|
194 |
+
|
195 |
+
|
196 |
+
def _locate_block(self, position: str):
|
197 |
+
'''
|
198 |
+
Locate the block at the specified position in the pipeline.
|
199 |
+
'''
|
200 |
+
block = self.pipe
|
201 |
+
for step in position.split('.'):
|
202 |
+
if step.isdigit():
|
203 |
+
step = int(step)
|
204 |
+
block = block[step]
|
205 |
+
else:
|
206 |
+
block = getattr(block, step)
|
207 |
+
return block
|
208 |
+
|
209 |
+
|
210 |
+
def _register_cache_hook(self, position: str, cache_input: Dict, cache_output: Dict):
|
211 |
+
|
212 |
+
if position.endswith('$self_attention') or position.endswith('$cross_attention'):
|
213 |
+
return self._register_cache_attention_hook(position, cache_output)
|
214 |
+
|
215 |
+
if position == 'noise':
|
216 |
+
def hook(model_output, timestep, sample, generator):
|
217 |
+
if position not in cache_output:
|
218 |
+
cache_output[position] = []
|
219 |
+
cache_output[position].append(sample)
|
220 |
+
|
221 |
+
if self.use_hooked_scheduler:
|
222 |
+
self.pipe.scheduler.post_hooks.append(hook)
|
223 |
+
else:
|
224 |
+
raise ValueError('Cannot cache noise without using hooked scheduler')
|
225 |
+
return
|
226 |
+
|
227 |
+
block = self._locate_block(position)
|
228 |
+
|
229 |
+
def hook(module, input, kwargs, output):
|
230 |
+
if cache_input is not None:
|
231 |
+
if position not in cache_input:
|
232 |
+
cache_input[position] = []
|
233 |
+
cache_input[position].append(retrieve(input))
|
234 |
+
|
235 |
+
if cache_output is not None:
|
236 |
+
if position not in cache_output:
|
237 |
+
cache_output[position] = []
|
238 |
+
cache_output[position].append(retrieve(output))
|
239 |
+
|
240 |
+
return block.register_forward_hook(hook, with_kwargs=True)
|
241 |
+
|
242 |
+
def _register_cache_attention_hook(self, position, cache):
|
243 |
+
attn_block = self._locate_block(position.split('$')[0])
|
244 |
+
if position.endswith('$self_attention'):
|
245 |
+
attn_block = attn_block.attn1
|
246 |
+
elif position.endswith('$cross_attention'):
|
247 |
+
attn_block = attn_block.attn2
|
248 |
+
else:
|
249 |
+
raise ValueError('Wrong attention type')
|
250 |
+
|
251 |
+
def hook(module, args, kwargs, output):
|
252 |
+
hidden_states = args[0]
|
253 |
+
encoder_hidden_states = kwargs['encoder_hidden_states']
|
254 |
+
attention_mask = kwargs['attention_mask']
|
255 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
256 |
+
attention_mask = attn_block.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
257 |
+
query = attn_block.to_q(hidden_states)
|
258 |
+
|
259 |
+
|
260 |
+
if encoder_hidden_states is None:
|
261 |
+
encoder_hidden_states = hidden_states
|
262 |
+
elif attn_block.norm_cross is not None:
|
263 |
+
encoder_hidden_states = attn_block.norm_cross(encoder_hidden_states)
|
264 |
+
|
265 |
+
key = attn_block.to_k(encoder_hidden_states)
|
266 |
+
value = attn_block.to_v(encoder_hidden_states)
|
267 |
+
|
268 |
+
query = attn_block.head_to_batch_dim(query)
|
269 |
+
key = attn_block.head_to_batch_dim(key)
|
270 |
+
value = attn_block.head_to_batch_dim(value)
|
271 |
+
|
272 |
+
attention_probs = attn_block.get_attention_scores(query, key, attention_mask)
|
273 |
+
attention_probs = attention_probs.view(
|
274 |
+
batch_size,
|
275 |
+
attention_probs.shape[0] // batch_size,
|
276 |
+
attention_probs.shape[1],
|
277 |
+
attention_probs.shape[2]
|
278 |
+
)
|
279 |
+
if position not in cache:
|
280 |
+
cache[position] = []
|
281 |
+
cache[position].append(attention_probs)
|
282 |
+
|
283 |
+
return attn_block.register_forward_hook(hook, with_kwargs=True)
|
284 |
+
|
285 |
+
def _register_general_hook(self, position, hook):
|
286 |
+
if position == 'scheduler_pre':
|
287 |
+
if not self.use_hooked_scheduler:
|
288 |
+
raise ValueError('Cannot register hooks on scheduler without using hooked scheduler')
|
289 |
+
self.pipe.scheduler.pre_hooks.append(hook)
|
290 |
+
return
|
291 |
+
elif position == 'scheduler_post':
|
292 |
+
if not self.use_hooked_scheduler:
|
293 |
+
raise ValueError('Cannot register hooks on scheduler without using hooked scheduler')
|
294 |
+
self.pipe.scheduler.post_hooks.append(hook)
|
295 |
+
return
|
296 |
+
|
297 |
+
block = self._locate_block(position)
|
298 |
+
return block.register_forward_hook(hook)
|
299 |
+
|
300 |
+
def to(self, *args, **kwargs):
|
301 |
+
self.pipe = self.pipe.to(*args, **kwargs)
|
302 |
+
return self
|
303 |
+
|
304 |
+
def __getattr__(self, name):
|
305 |
+
return getattr(self.pipe, name)
|
306 |
+
|
307 |
+
def __setattr__(self, name, value):
|
308 |
+
return setattr(self.pipe, name, value)
|
309 |
+
|
310 |
+
def __call__(self, *args, **kwargs):
|
311 |
+
return self.pipe(*args, **kwargs)
|
312 |
+
|
313 |
+
|
314 |
+
class HookedStableDiffusionXLPipeline(HookedDiffusionAbstractPipeline):
|
315 |
+
parent_cls = StableDiffusionXLPipeline
|
316 |
+
|
317 |
+
|
318 |
+
class HookedIFPipeline(HookedDiffusionAbstractPipeline):
|
319 |
+
parent_cls = IFPipeline
|
app.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app.py
ADDED
@@ -0,0 +1,768 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
import json
|
3 |
+
import gradio as gr
|
4 |
+
import os
|
5 |
+
|
6 |
+
# environment
|
7 |
+
os.environ['HF_HOME'] = '/dlabscratch1/anmari'
|
8 |
+
os.environ['TRANSFORMERS_CACHE'] = '/dlabscratch1/anmari'
|
9 |
+
os.environ['HF_DATASETS_CACHE'] = '/dlabscratch1/anmari'
|
10 |
+
# os.environ["HF_TOKEN"] = ""
|
11 |
+
import torch
|
12 |
+
from PIL import Image
|
13 |
+
from SDLens import HookedStableDiffusionXLPipeline, CachedPipeline as CachedFLuxPipeline
|
14 |
+
from SDLens.cache_and_edit.flux_pipeline import EditedFluxPipeline
|
15 |
+
from SAE import SparseAutoencoder
|
16 |
+
from utils import TimedHook, add_feature_on_area_base, replace_with_feature_base, add_feature_on_area_turbo, replace_with_feature_turbo, add_feature_on_area_flux
|
17 |
+
import numpy as np
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
from matplotlib.colors import ListedColormap
|
20 |
+
import threading
|
21 |
+
from einops import rearrange
|
22 |
+
import spaces
|
23 |
+
# from retrieval import FeatureRetriever
|
24 |
+
|
25 |
+
|
26 |
+
code_to_block_sd = {
|
27 |
+
"down.2.1": "unet.down_blocks.2.attentions.1",
|
28 |
+
"mid.0": "unet.mid_block.attentions.0",
|
29 |
+
"up.0.1": "unet.up_blocks.0.attentions.1",
|
30 |
+
"up.0.0": "unet.up_blocks.0.attentions.0"
|
31 |
+
}
|
32 |
+
code_to_block_flux = {"18": "transformer.transformer_blocks.18"}
|
33 |
+
|
34 |
+
FLUX_NAMES = ["black-forest-labs/FLUX.1-schnell", "black-forest-labs/FLUX.1-dev"]
|
35 |
+
MODELS_CONFIG = {
|
36 |
+
"stabilityai/stable-diffusion-xl-base-1.0": {
|
37 |
+
"steps": 25,
|
38 |
+
"guidance_scale": 8.0,
|
39 |
+
"choices": ["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
|
40 |
+
"value": "down.2.1 (composition)",
|
41 |
+
"code_to_block": code_to_block_sd,
|
42 |
+
"max_steps": 50,
|
43 |
+
"is_flux": False,
|
44 |
+
"downsample_factor": 16,
|
45 |
+
"add_feature_on_area": add_feature_on_area_base,
|
46 |
+
"num_features": 5120,
|
47 |
+
|
48 |
+
},
|
49 |
+
"stabilityai/sdxl-turbo": {
|
50 |
+
"steps": 1,
|
51 |
+
"guidance_scale": 0.0,
|
52 |
+
"choices": ["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
|
53 |
+
"value": "down.2.1 (composition)",
|
54 |
+
"code_to_block": code_to_block_sd,
|
55 |
+
"max_steps": 4,
|
56 |
+
"is_flux": False,
|
57 |
+
"downsample_factor": 32,
|
58 |
+
"add_feature_on_area": add_feature_on_area_turbo,
|
59 |
+
"num_features": 5120,
|
60 |
+
},
|
61 |
+
"black-forest-labs/FLUX.1-schnell": {
|
62 |
+
"steps": 1,
|
63 |
+
"guidance_scale": 0.0,
|
64 |
+
"choices": ["18"],
|
65 |
+
"value": "18",
|
66 |
+
"code_to_block": code_to_block_flux,
|
67 |
+
"max_steps": 4,
|
68 |
+
"is_flux": True,
|
69 |
+
"exclude_list": [2462, 2974, 1577, 786, 3188, 9986, 4693, 8472, 8248, 325, 9596, 2813, 10803, 11773, 11410, 1067, 2965, 10488, 4537, 2102],
|
70 |
+
"downsample_factor": 8,
|
71 |
+
"add_feature_on_area": add_feature_on_area_flux,
|
72 |
+
"num_features": 12288
|
73 |
+
|
74 |
+
},
|
75 |
+
|
76 |
+
"black-forest-labs/FLUX.1-dev": {
|
77 |
+
"steps": 25,
|
78 |
+
"guidance_scale": 0.0,
|
79 |
+
"choices": ["18"],
|
80 |
+
"value": "18",
|
81 |
+
"code_to_block": code_to_block_flux,
|
82 |
+
"max_steps": 50,
|
83 |
+
"is_flux": True,
|
84 |
+
"exclude_list": [2462, 2974, 1577, 786, 3188, 9986, 4693, 8472, 8248, 325, 9596, 2813, 10803, 11773, 11410, 1067, 2965, 10488, 4537, 2102],
|
85 |
+
"downsample_factor": 8,
|
86 |
+
"add_feature_on_area": add_feature_on_area_flux,
|
87 |
+
"num_features": 12288
|
88 |
+
|
89 |
+
}
|
90 |
+
}
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
lock = threading.Lock()
|
96 |
+
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
def process_cache(cache, saes_dict, model_config, timestep=None):
|
102 |
+
|
103 |
+
top_features_dict = {}
|
104 |
+
sparse_maps_dict = {}
|
105 |
+
|
106 |
+
for code in model_config['code_to_block'].keys():
|
107 |
+
block = model_config["code_to_block"][code]
|
108 |
+
sae = saes_dict[code]
|
109 |
+
|
110 |
+
|
111 |
+
if model_config["is_flux"]:
|
112 |
+
|
113 |
+
with torch.no_grad():
|
114 |
+
features = sae.encode(torch.stack(cache.image_activation)) # shape: [timestep, batch, seq_len, num_features]
|
115 |
+
features[..., model_config["exclude_list"]] = 0
|
116 |
+
|
117 |
+
if timestep is not None and timestep < features.shape[0]:
|
118 |
+
features = features[timestep:timestep+1]
|
119 |
+
|
120 |
+
# I want to get [batch, timestep, 64, 64, num_features]
|
121 |
+
sparse_maps = rearrange(features, "t b (w h) n -> b t w h n", w=64, h=64).squeeze(0).squeeze(0)
|
122 |
+
|
123 |
+
else:
|
124 |
+
|
125 |
+
diff = cache["output"][block] - cache["input"][block]
|
126 |
+
if diff.shape[0] == 2: # guidance is on and we need to select the second output
|
127 |
+
diff = diff[1].unsqueeze(0)
|
128 |
+
|
129 |
+
# If a specific timestep is provided, select that timestep from the cached activations
|
130 |
+
if timestep is not None and timestep < diff.shape[1]:
|
131 |
+
diff = diff[:, timestep:timestep+1]
|
132 |
+
|
133 |
+
diff = diff.permute(0, 1, 3, 4, 2).squeeze(0).squeeze(0)
|
134 |
+
with torch.no_grad():
|
135 |
+
sparse_maps = sae.encode(diff)
|
136 |
+
|
137 |
+
averages = torch.mean(sparse_maps, dim=(0, 1))
|
138 |
+
|
139 |
+
top_features = torch.topk(averages, 10).indices
|
140 |
+
|
141 |
+
top_features_dict[code] = top_features.cpu().tolist()
|
142 |
+
sparse_maps_dict[code] = sparse_maps.cpu().numpy()
|
143 |
+
|
144 |
+
return top_features_dict, sparse_maps_dict
|
145 |
+
|
146 |
+
|
147 |
+
def plot_image_heatmap(cache, block_select, radio, model_config):
|
148 |
+
code = block_select.split()[0]
|
149 |
+
feature = int(radio)
|
150 |
+
|
151 |
+
heatmap = cache["heatmaps"][code][:, :, feature]
|
152 |
+
scaling_factor = 16 if model_config["is_flux"] else 32
|
153 |
+
heatmap = np.kron(heatmap, np.ones((scaling_factor, scaling_factor)))
|
154 |
+
image = cache["image"].convert("RGBA")
|
155 |
+
|
156 |
+
jet = plt.cm.jet
|
157 |
+
cmap = jet(np.arange(jet.N))
|
158 |
+
cmap[:1, -1] = 0
|
159 |
+
cmap[1:, -1] = 0.6
|
160 |
+
cmap = ListedColormap(cmap)
|
161 |
+
heatmap = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap))
|
162 |
+
heatmap_rgba = cmap(heatmap)
|
163 |
+
heatmap_image = Image.fromarray((heatmap_rgba * 255).astype(np.uint8))
|
164 |
+
heatmap_with_transparency = Image.alpha_composite(image, heatmap_image)
|
165 |
+
|
166 |
+
return heatmap_with_transparency
|
167 |
+
|
168 |
+
|
169 |
+
def create_prompt_part(pipe, saes_dict, demo):
|
170 |
+
|
171 |
+
model_config = MODELS_CONFIG[pipe.pipe.name_or_path]
|
172 |
+
@spaces.GPU
|
173 |
+
def image_gen(prompt, timestep=None, num_steps=None, guidance_scale=None):
|
174 |
+
lock.acquire()
|
175 |
+
try:
|
176 |
+
# Default values
|
177 |
+
default_n_steps = model_config["steps"]
|
178 |
+
default_guidance = model_config["guidance_scale"]
|
179 |
+
|
180 |
+
# Use provided values if available, otherwise use defaults
|
181 |
+
n_steps = default_n_steps if num_steps is None else int(num_steps)
|
182 |
+
guidance = default_guidance if guidance_scale is None else float(guidance_scale)
|
183 |
+
|
184 |
+
# Convert timestep to integer if it's not None
|
185 |
+
timestep_int = None if timestep is None else int(timestep)
|
186 |
+
|
187 |
+
if "FLUX" in pipe.pipe.name_or_path:
|
188 |
+
images = pipe.run(
|
189 |
+
prompt,
|
190 |
+
num_inference_steps=n_steps,
|
191 |
+
width=1024,
|
192 |
+
height=1024,
|
193 |
+
cache_activations=True,
|
194 |
+
guidance_scale=guidance,
|
195 |
+
positions_to_cache = list(model_config["code_to_block"].values()),
|
196 |
+
inverse=False,
|
197 |
+
)
|
198 |
+
cache = pipe.activation_cache
|
199 |
+
|
200 |
+
else:
|
201 |
+
images, cache = pipe.run_with_cache(
|
202 |
+
prompt,
|
203 |
+
positions_to_cache=list(model_config["code_to_block"].values()),
|
204 |
+
num_inference_steps=n_steps,
|
205 |
+
generator=torch.Generator(device="cpu").manual_seed(42),
|
206 |
+
guidance_scale=guidance,
|
207 |
+
save_input=True,
|
208 |
+
save_output=True
|
209 |
+
)
|
210 |
+
finally:
|
211 |
+
lock.release()
|
212 |
+
|
213 |
+
top_features_dict, top_sparse_maps_dict = process_cache(cache, saes_dict, model_config, timestep_int)
|
214 |
+
return images.images[0], {
|
215 |
+
"image": images.images[0],
|
216 |
+
"heatmaps": top_sparse_maps_dict,
|
217 |
+
"features": top_features_dict
|
218 |
+
}
|
219 |
+
|
220 |
+
def update_radio(cache, block_select):
|
221 |
+
code = block_select.split()[0]
|
222 |
+
return gr.update(choices=cache["features"][code])
|
223 |
+
|
224 |
+
def update_img(cache, block_select, radio):
|
225 |
+
new_img = plot_image_heatmap(cache, block_select, radio, model_config)
|
226 |
+
return new_img
|
227 |
+
|
228 |
+
with gr.Tab("Explore", elem_classes="tabs") as explore_tab:
|
229 |
+
cache = gr.State(value={
|
230 |
+
"image": None,
|
231 |
+
"heatmaps": None,
|
232 |
+
"features": []
|
233 |
+
})
|
234 |
+
with gr.Row():
|
235 |
+
with gr.Column(scale=7):
|
236 |
+
with gr.Row(equal_height=True):
|
237 |
+
prompt_field = gr.Textbox(lines=1, label="Enter prompt here", value="A cinematic shot of a professor sloth wearing a tuxedo at a BBQ party and eathing a dish with peas.")
|
238 |
+
button = gr.Button("Generate", elem_classes="generate_button1")
|
239 |
+
|
240 |
+
with gr.Row():
|
241 |
+
image = gr.Image(width=512, height=512, image_mode="RGB", label="Generated image")
|
242 |
+
|
243 |
+
with gr.Column(scale=4):
|
244 |
+
block_select = gr.Dropdown(
|
245 |
+
choices=model_config["choices"], # replace this for flux
|
246 |
+
value=model_config["value"],
|
247 |
+
label="Select block",
|
248 |
+
elem_id="block_select",
|
249 |
+
interactive=True
|
250 |
+
)
|
251 |
+
|
252 |
+
with gr.Group() as sdxl_base_controls:
|
253 |
+
steps_slider = gr.Slider(
|
254 |
+
minimum=1,
|
255 |
+
maximum=model_config["max_steps"],
|
256 |
+
value= model_config["steps"],
|
257 |
+
step=1,
|
258 |
+
label="Number of steps",
|
259 |
+
elem_id="steps_slider",
|
260 |
+
interactive=True,
|
261 |
+
visible=True
|
262 |
+
)
|
263 |
+
|
264 |
+
# Add timestep selector
|
265 |
+
# TODO: check this
|
266 |
+
timestep_selector = gr.Slider(
|
267 |
+
minimum=0,
|
268 |
+
maximum=model_config["max_steps"]-1,
|
269 |
+
value=None,
|
270 |
+
step=1,
|
271 |
+
label="Timestep (leave empty for average across all steps)",
|
272 |
+
elem_id="timestep_selector",
|
273 |
+
interactive=True,
|
274 |
+
visible=True,
|
275 |
+
)
|
276 |
+
recompute_button = gr.Button("Recompute", elem_id="recompute_button")
|
277 |
+
# Update max timestep when steps change
|
278 |
+
steps_slider.change(lambda s: gr.update(maximum=s-1), [steps_slider], [timestep_selector])
|
279 |
+
|
280 |
+
radio = gr.Radio(choices=[], label="Select a feature", interactive=True)
|
281 |
+
|
282 |
+
button.click(image_gen, [prompt_field, timestep_selector, steps_slider], outputs=[image, cache])
|
283 |
+
cache.change(update_radio, [cache, block_select], outputs=[radio])
|
284 |
+
block_select.select(update_radio, [cache, block_select], outputs=[radio])
|
285 |
+
radio.select(update_img, [cache, block_select, radio], outputs=[image])
|
286 |
+
recompute_button.click(image_gen, [prompt_field, timestep_selector, steps_slider], outputs=[image, cache])
|
287 |
+
demo.load(image_gen, [prompt_field, timestep_selector, steps_slider], outputs=[image, cache])
|
288 |
+
|
289 |
+
return explore_tab
|
290 |
+
|
291 |
+
def downsample_mask(image, factor):
|
292 |
+
downsampled = image.reshape(
|
293 |
+
(image.shape[0] // factor, factor,
|
294 |
+
image.shape[1] // factor, factor)
|
295 |
+
)
|
296 |
+
downsampled = downsampled.mean(axis=(1, 3))
|
297 |
+
return downsampled
|
298 |
+
|
299 |
+
def create_intervene_part(pipe: HookedStableDiffusionXLPipeline, saes_dict, means_dict, demo):
|
300 |
+
model_config = MODELS_CONFIG[pipe.pipe.name_or_path]
|
301 |
+
|
302 |
+
@spaces.GPU
|
303 |
+
def image_gen(prompt, num_steps, guidance_scale=None):
|
304 |
+
lock.acquire()
|
305 |
+
guidance = model_config["guidance_scale"] if guidance_scale is None else float(guidance_scale)
|
306 |
+
try:
|
307 |
+
|
308 |
+
if "FLUX" in pipe.pipe.name_or_path:
|
309 |
+
images = pipe.run(
|
310 |
+
prompt,
|
311 |
+
num_inference_steps=int(num_steps),
|
312 |
+
width=1024,
|
313 |
+
height=1024,
|
314 |
+
cache_activations=False,
|
315 |
+
guidance_scale=guidance,
|
316 |
+
inverse=False,
|
317 |
+
)
|
318 |
+
else:
|
319 |
+
images = pipe.run_with_hooks(
|
320 |
+
prompt,
|
321 |
+
position_hook_dict={},
|
322 |
+
num_inference_steps=int(num_steps),
|
323 |
+
generator=torch.Generator(device="cpu").manual_seed(42),
|
324 |
+
guidance_scale=guidance,
|
325 |
+
)
|
326 |
+
finally:
|
327 |
+
lock.release()
|
328 |
+
if images.images[0].size == (1024, 1024):
|
329 |
+
return images.images[0].resize((512, 512))
|
330 |
+
else:
|
331 |
+
return images.images[0]
|
332 |
+
|
333 |
+
@spaces.GPU
|
334 |
+
def image_mod(prompt, block_str, brush_index, strength, num_steps, input_image, guidance_scale=None, start_index=None, end_index=None):
|
335 |
+
block = block_str.split(" ")[0]
|
336 |
+
|
337 |
+
mask = (input_image["layers"][0] > 0)[:, :, -1].astype(float)
|
338 |
+
mask = downsample_mask(mask, model_config["downsample_factor"])
|
339 |
+
mask = torch.tensor(mask, dtype=torch.float32, device="cuda")
|
340 |
+
|
341 |
+
if mask.sum() == 0:
|
342 |
+
gr.Info("No mask selected, please draw on the input image")
|
343 |
+
|
344 |
+
|
345 |
+
# Set default values for start_index and end_index if not provided
|
346 |
+
if start_index is None:
|
347 |
+
start_index = 0
|
348 |
+
if end_index is None:
|
349 |
+
end_index = int(num_steps)
|
350 |
+
|
351 |
+
# Ensure start_index and end_index are within valid ranges
|
352 |
+
start_index = max(0, min(int(start_index), int(num_steps)))
|
353 |
+
end_index = max(0, min(int(end_index), int(num_steps)))
|
354 |
+
|
355 |
+
# Ensure start_index is less than end_index
|
356 |
+
if start_index >= end_index:
|
357 |
+
start_index = max(0, end_index - 1)
|
358 |
+
|
359 |
+
|
360 |
+
def myhook(module, input, output):
|
361 |
+
return model_config["add_feature_on_area"](
|
362 |
+
saes_dict[block],
|
363 |
+
brush_index,
|
364 |
+
mask * means_dict[block][brush_index] * strength,
|
365 |
+
module,
|
366 |
+
input,
|
367 |
+
output)
|
368 |
+
hook = TimedHook(myhook, int(num_steps), np.arange(start_index, end_index))
|
369 |
+
|
370 |
+
lock.acquire()
|
371 |
+
guidance = model_config["guidance_scale"] if guidance_scale is None else float(guidance_scale)
|
372 |
+
|
373 |
+
try:
|
374 |
+
|
375 |
+
if model_config["is_flux"]:
|
376 |
+
image = pipe.run_with_edit(
|
377 |
+
prompt,
|
378 |
+
seed=42,
|
379 |
+
num_inference_steps=int(num_steps),
|
380 |
+
edit_fn= lambda input, output: hook(None, input, output),
|
381 |
+
layers_for_edit_fn=[i for i in range(18, 57)],
|
382 |
+
stream="image").images[0]
|
383 |
+
else:
|
384 |
+
|
385 |
+
image = pipe.run_with_hooks(
|
386 |
+
prompt,
|
387 |
+
position_hook_dict={model_config["code_to_block"][block]: hook},
|
388 |
+
num_inference_steps=int(num_steps),
|
389 |
+
generator=torch.Generator(device="cpu").manual_seed(42),
|
390 |
+
guidance_scale=guidance
|
391 |
+
).images[0]
|
392 |
+
finally:
|
393 |
+
lock.release()
|
394 |
+
return image
|
395 |
+
|
396 |
+
def feature_icon(block_str, brush_index, guidance_scale=None):
|
397 |
+
block = block_str.split(" ")[0]
|
398 |
+
if block in ["mid.0", "up.0.0"]:
|
399 |
+
gr.Info("Note that Feature Icon works best with down.2.1 and up.0.1 blocks but feel free to explore", duration=3)
|
400 |
+
|
401 |
+
def hook(module, input, output):
|
402 |
+
if is_base_model:
|
403 |
+
return replace_with_feature_base(
|
404 |
+
saes_dict[block],
|
405 |
+
brush_index,
|
406 |
+
means_dict[block][brush_index] * saes_dict[block].k,
|
407 |
+
module,
|
408 |
+
input,
|
409 |
+
output
|
410 |
+
)
|
411 |
+
else:
|
412 |
+
return replace_with_feature_turbo(
|
413 |
+
saes_dict[block],
|
414 |
+
brush_index,
|
415 |
+
means_dict[block][brush_index] * saes_dict[block].k,
|
416 |
+
module,
|
417 |
+
input,
|
418 |
+
output)
|
419 |
+
lock.acquire()
|
420 |
+
guidance = model_config["guidance_scale"] if guidance_scale is None else float(guidance_scale)
|
421 |
+
|
422 |
+
try:
|
423 |
+
image = pipe.run_with_hooks(
|
424 |
+
"",
|
425 |
+
position_hook_dict={model_config["code_to_block"][block]: hook},
|
426 |
+
num_inference_steps=model_config["steps"],
|
427 |
+
generator=torch.Generator(device="cpu").manual_seed(42),
|
428 |
+
guidance_scale=guidance,
|
429 |
+
).images[0]
|
430 |
+
finally:
|
431 |
+
lock.release()
|
432 |
+
return image
|
433 |
+
|
434 |
+
with gr.Tab("Paint!", elem_classes="tabs") as intervene_tab:
|
435 |
+
image_state = gr.State(value=None)
|
436 |
+
with gr.Row():
|
437 |
+
with gr.Column(scale=3):
|
438 |
+
# Generation column
|
439 |
+
with gr.Row():
|
440 |
+
# prompt and num_steps
|
441 |
+
prompt_field = gr.Textbox(lines=1, label="Enter prompt here", value="A dog plays with a ball, cartoon", elem_id="prompt_input")
|
442 |
+
|
443 |
+
with gr.Row():
|
444 |
+
num_steps = gr.Number(value=model_config["steps"], label="Number of steps", minimum=1, maximum=model_config["max_steps"], elem_id="num_steps", precision=0)
|
445 |
+
|
446 |
+
with gr.Row():
|
447 |
+
# Generate button
|
448 |
+
button_generate = gr.Button("Generate", elem_id="generate_button")
|
449 |
+
with gr.Column(scale=3):
|
450 |
+
# Intervention column
|
451 |
+
with gr.Row():
|
452 |
+
# dropdowns and number inputs
|
453 |
+
with gr.Column(scale=7):
|
454 |
+
with gr.Row():
|
455 |
+
block_select = gr.Dropdown(
|
456 |
+
choices=model_config["choices"],
|
457 |
+
value=model_config["value"],
|
458 |
+
label="Select block",
|
459 |
+
elem_id="block_select"
|
460 |
+
)
|
461 |
+
brush_index = gr.Number(value=0, label="Brush index", minimum=0, maximum=model_config["num_features"]-1, elem_id="brush_index", precision=0)
|
462 |
+
# with gr.Row():
|
463 |
+
# button_icon = gr.Button('Feature Icon', elem_id="feature_icon_button")
|
464 |
+
with gr.Row():
|
465 |
+
gr.Markdown("**TimedHook Range** (which steps to apply the feature)", visible=True)
|
466 |
+
with gr.Row():
|
467 |
+
start_index = gr.Number(value=0, label="Start index", minimum=0, maximum=model_config["max_steps"], elem_id="start_index", precision=0, visible=True)
|
468 |
+
end_index = gr.Number(value=model_config["steps"], label="End index", minimum=0, maximum=model_config["max_steps"], elem_id="end_index", precision=0, visible=True)
|
469 |
+
with gr.Column(scale=3):
|
470 |
+
with gr.Row():
|
471 |
+
strength = gr.Number(value=10, label="Strength", minimum=-40, maximum=40, elem_id="strength", precision=2)
|
472 |
+
with gr.Row():
|
473 |
+
button = gr.Button('Apply', elem_id="apply_button")
|
474 |
+
|
475 |
+
with gr.Row():
|
476 |
+
with gr.Column():
|
477 |
+
# Input image
|
478 |
+
i_image = gr.Sketchpad(
|
479 |
+
height=610,
|
480 |
+
layers=False, transforms=[], placeholder="Generate and paint!",
|
481 |
+
brush=gr.Brush(default_size=64, color_mode="fixed", colors=['black']),
|
482 |
+
container=False,
|
483 |
+
canvas_size=(512, 512),
|
484 |
+
label="Input Image")
|
485 |
+
clear_button = gr.Button("Clear")
|
486 |
+
clear_button.click(lambda x: x, [image_state], [i_image])
|
487 |
+
# Output image
|
488 |
+
o_image = gr.Image(width=512, height=512, label="Output Image")
|
489 |
+
|
490 |
+
# Set up the click events
|
491 |
+
button_generate.click(image_gen, inputs=[prompt_field, num_steps], outputs=[image_state])
|
492 |
+
image_state.change(lambda x: x, [image_state], [i_image])
|
493 |
+
|
494 |
+
# Update max values for start_index and end_index when num_steps changes
|
495 |
+
def update_index_maxes(steps):
|
496 |
+
return gr.update(maximum=steps), gr.update(maximum=steps)
|
497 |
+
|
498 |
+
num_steps.change(update_index_maxes, [num_steps], [start_index, end_index])
|
499 |
+
|
500 |
+
button.click(image_mod,
|
501 |
+
inputs=[prompt_field, block_select, brush_index, strength, num_steps, i_image, start_index, end_index],
|
502 |
+
outputs=o_image)
|
503 |
+
# button_icon.click(feature_icon, inputs=[block_select, brush_index], outputs=o_image)
|
504 |
+
demo.load(image_gen, [prompt_field, num_steps], outputs=[image_state])
|
505 |
+
|
506 |
+
|
507 |
+
return intervene_tab
|
508 |
+
|
509 |
+
|
510 |
+
|
511 |
+
def create_top_images_part(demo, pipe):
|
512 |
+
|
513 |
+
model_config = MODELS_CONFIG[pipe.pipe.name_or_path]
|
514 |
+
|
515 |
+
if isinstance(pipe, HookedStableDiffusionXLPipeline):
|
516 |
+
is_flux = False
|
517 |
+
elif isinstance(pipe, CachedFLuxPipeline):
|
518 |
+
is_flux = True
|
519 |
+
else:
|
520 |
+
raise AssertionError(f"Unknown pipe class: {type(pipe)}")
|
521 |
+
|
522 |
+
def update_top_images(block_select, brush_index):
|
523 |
+
block = block_select.split(" ")[0]
|
524 |
+
# Define path for fetching image
|
525 |
+
if is_flux:
|
526 |
+
part = 1 if brush_index <= 7000 else 2
|
527 |
+
url = f"https://huggingface.co/datasets/antoniomari/flux_sae_images/resolve/main/{block}/part{part}/{brush_index}.jpg"
|
528 |
+
else:
|
529 |
+
url = f"https://huggingface.co/surokpro2/sdxl_sae_images/resolve/main/{block}/{brush_index}.jpg"
|
530 |
+
return url
|
531 |
+
|
532 |
+
with gr.Tab("Top Images", elem_classes="tabs") as top_images_tab:
|
533 |
+
with gr.Row():
|
534 |
+
block_select = gr.Dropdown(
|
535 |
+
choices=["flux_18"] if is_flux else ["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
|
536 |
+
value="flux_18" if is_flux else "down.2.1 (composition)",
|
537 |
+
label="Select block"
|
538 |
+
)
|
539 |
+
brush_index = gr.Number(value=0, label="Brush index", minimum=0, maximum=model_config["num_features"]-1, precision=0)
|
540 |
+
with gr.Row():
|
541 |
+
image = gr.Image(width=600, height=600, label="Top Images")
|
542 |
+
|
543 |
+
block_select.select(update_top_images, [block_select, brush_index], outputs=[image])
|
544 |
+
brush_index.change(update_top_images, [block_select, brush_index], outputs=[image])
|
545 |
+
demo.load(update_top_images, [block_select, brush_index], outputs=[image])
|
546 |
+
return top_images_tab
|
547 |
+
|
548 |
+
|
549 |
+
def create_top_images_plus_search_part(retriever, demo, pipe):
|
550 |
+
|
551 |
+
model_config = MODELS_CONFIG[pipe.pipe.name_or_path]
|
552 |
+
|
553 |
+
|
554 |
+
|
555 |
+
if isinstance(pipe, HookedStableDiffusionXLPipeline):
|
556 |
+
is_flux = False
|
557 |
+
elif isinstance(pipe, CachedFLuxPipeline):
|
558 |
+
is_flux = True
|
559 |
+
else:
|
560 |
+
raise AssertionError(f"Unknown pipe class: {type(pipe)}")
|
561 |
+
|
562 |
+
def update_cache(block_select, search_by_text, search_by_index):
|
563 |
+
if search_by_text == "":
|
564 |
+
top_indices = []
|
565 |
+
index = search_by_index
|
566 |
+
block = block_select.split(" ")[0]
|
567 |
+
|
568 |
+
# Define path for fetching image
|
569 |
+
if is_flux:
|
570 |
+
part = 1 if index <= 7000 else 2
|
571 |
+
url = f"https://huggingface.co/antoniomari/flux_sae_images/resolve/main/{block}/part{part}/{index}.jpg"
|
572 |
+
else:
|
573 |
+
url = f"https://huggingface.co/surokpro2/sdxl_sae_images/resolve/main/{block}/{index}.jpg"
|
574 |
+
return url, {"image": url, "feature_idx": index, "features": top_indices}
|
575 |
+
else:
|
576 |
+
# TODO
|
577 |
+
if retriever is None:
|
578 |
+
raise ValueError("Feature retrieval is not enabled")
|
579 |
+
lock.acquire()
|
580 |
+
try:
|
581 |
+
top_indices = list(retriever.query_text(search_by_text, block_select.split(" ")[0]).keys())
|
582 |
+
finally:
|
583 |
+
lock.release()
|
584 |
+
block = block_select.split(" ")[0]
|
585 |
+
top_indices = list(map(int, top_indices))
|
586 |
+
index = top_indices[0]
|
587 |
+
url = f"https://huggingface.co/surokpro2/sdxl_sae_images/resolve/main/{block}/{index}.jpg"
|
588 |
+
return url, {"image": url, "feature_idx": index, "features": top_indices[:20]}
|
589 |
+
|
590 |
+
def update_radio(cache):
|
591 |
+
return gr.update(choices=cache["features"], value=cache["feature_idx"])
|
592 |
+
|
593 |
+
def update_img(cache, block_select, index):
|
594 |
+
block = block_select.split(" ")[0]
|
595 |
+
url = f"https://huggingface.co/surokpro2/sdxl_sae_images/resolve/main/{block}/{index}.jpg"
|
596 |
+
return url
|
597 |
+
|
598 |
+
with gr.Tab("Top Images", elem_classes="tabs") as explore_tab:
|
599 |
+
cache = gr.State(value={
|
600 |
+
"image": None,
|
601 |
+
"feature_idx": None,
|
602 |
+
"features": []
|
603 |
+
})
|
604 |
+
with gr.Row():
|
605 |
+
with gr.Column(scale=7):
|
606 |
+
with gr.Row():
|
607 |
+
# top images
|
608 |
+
image = gr.Image(width=600, height=600, image_mode="RGB", label="Top images")
|
609 |
+
|
610 |
+
with gr.Column(scale=4):
|
611 |
+
block_select = gr.Dropdown(
|
612 |
+
choices=["flux_18"] if is_flux else ["up.0.1 (style)", "down.2.1 (composition)", "up.0.0 (details)", "mid.0"],
|
613 |
+
value="flux_18" if is_flux else "down.2.1 (composition)",
|
614 |
+
label="Select block",
|
615 |
+
elem_id="block_select",
|
616 |
+
interactive=True
|
617 |
+
)
|
618 |
+
search_by_index = gr.Number(value=0, label="Search by index", minimum=0, maximum=model_config["num_features"]-1, precision=0)
|
619 |
+
search_by_text = gr.Textbox(lines=1, label="Search by text", value="", visible=False)
|
620 |
+
radio = gr.Radio(choices=[], label="Select a feature", interactive=True, visible=False)
|
621 |
+
|
622 |
+
|
623 |
+
search_by_text.change(update_cache,
|
624 |
+
[block_select, search_by_text, search_by_index],
|
625 |
+
outputs=[image, cache])
|
626 |
+
block_select.select(update_cache,
|
627 |
+
[block_select, search_by_text, search_by_index],
|
628 |
+
outputs=[image, cache])
|
629 |
+
cache.change(update_radio, [cache], outputs=[radio])
|
630 |
+
radio.select(update_img, [cache, block_select, radio], outputs=[image])
|
631 |
+
search_by_index.change(update_img, [cache, block_select, search_by_index], outputs=[image])
|
632 |
+
demo.load(update_img,
|
633 |
+
[cache, block_select, search_by_index],
|
634 |
+
outputs=[image])
|
635 |
+
|
636 |
+
return explore_tab
|
637 |
+
|
638 |
+
|
639 |
+
def create_intro_part():
|
640 |
+
with gr.Tab("Instructions", elem_classes="tabs") as intro_tab:
|
641 |
+
gr.Markdown(
|
642 |
+
'''# Unpacking SDXL Turbo with Sparse Autoencoders
|
643 |
+
## Demo Overview
|
644 |
+
This demo showcases the use of Sparse Autoencoders (SAEs) to understand the features learned by the Stable Diffusion XL Turbo model.
|
645 |
+
|
646 |
+
## How to Use
|
647 |
+
### Explore
|
648 |
+
* Enter a prompt in the text box and click on the "Generate" button to generate an image.
|
649 |
+
* You can observe the active features in different blocks plot on top of the generated image.
|
650 |
+
### Top Images
|
651 |
+
* For each feature, you can view the top images that activate the feature the most.
|
652 |
+
### Paint!
|
653 |
+
* Generate an image using the prompt.
|
654 |
+
* Paint on the generated image to apply interventions.
|
655 |
+
* Use the "Feature Icon" button to understand how the selected brush functions.
|
656 |
+
|
657 |
+
### Remarks
|
658 |
+
* Not all brushes mix well with all images. Experiment with different brushes and strengths.
|
659 |
+
* Feature Icon works best with `down.2.1 (composition)` and `up.0.1 (style)` blocks.
|
660 |
+
* This demo is provided for research purposes only. We do not take responsibility for the content generated by the demo.
|
661 |
+
|
662 |
+
### Interesting features to try
|
663 |
+
To get started, try the following features:
|
664 |
+
- down.2.1 (composition): 2301 (evil) 3747 (image frame) 4998 (cartoon)
|
665 |
+
- up.0.1 (style): 4977 (tiger stripes) 90 (fur) 2615 (twilight blur)
|
666 |
+
'''
|
667 |
+
)
|
668 |
+
|
669 |
+
return intro_tab
|
670 |
+
|
671 |
+
|
672 |
+
def create_demo(pipe, saes_dict, means_dict, use_retrieval=True):
|
673 |
+
custom_css = """
|
674 |
+
.tabs button {
|
675 |
+
font-size: 20px !important; /* Adjust font size for tab text */
|
676 |
+
padding: 10px !important; /* Adjust padding to make the tabs bigger */
|
677 |
+
font-weight: bold !important; /* Adjust font weight to make the text bold */
|
678 |
+
}
|
679 |
+
.generate_button1 {
|
680 |
+
max-width: 160px !important;
|
681 |
+
margin-top: 20px !important;
|
682 |
+
margin-bottom: 20px !important;
|
683 |
+
}
|
684 |
+
"""
|
685 |
+
if use_retrieval:
|
686 |
+
retriever = None # FeatureRetriever()
|
687 |
+
else:
|
688 |
+
retriever = None
|
689 |
+
|
690 |
+
with gr.Blocks(css=custom_css) as demo:
|
691 |
+
# with create_intro_part():
|
692 |
+
# pass
|
693 |
+
with create_prompt_part(pipe, saes_dict, demo):
|
694 |
+
pass
|
695 |
+
with create_top_images_part(demo, pipe):
|
696 |
+
pass
|
697 |
+
with create_intervene_part(pipe, saes_dict, means_dict, demo):
|
698 |
+
pass
|
699 |
+
|
700 |
+
return demo
|
701 |
+
|
702 |
+
|
703 |
+
if __name__ == "__main__":
|
704 |
+
import os
|
705 |
+
import gradio as gr
|
706 |
+
import torch
|
707 |
+
from SDLens import HookedStableDiffusionXLPipeline
|
708 |
+
from SAE import SparseAutoencoder
|
709 |
+
from huggingface_hub import hf_hub_download
|
710 |
+
|
711 |
+
dtype = torch.float16
|
712 |
+
pipe = EditedFluxPipeline.from_pretrained(
|
713 |
+
"black-forest-labs/FLUX.1-schnell",
|
714 |
+
device_map="balanced",
|
715 |
+
torch_dtype=dtype
|
716 |
+
)
|
717 |
+
pipe.set_progress_bar_config(disable=True)
|
718 |
+
pipe = CachedFLuxPipeline(pipe)
|
719 |
+
|
720 |
+
# Parameters
|
721 |
+
DEVICE = "cuda"
|
722 |
+
|
723 |
+
# Hugging Face repo setup
|
724 |
+
HF_REPO_ID = "antoniomari/SAE_flux_18"
|
725 |
+
HF_BRANCH = "main"
|
726 |
+
|
727 |
+
# Command-line arguments
|
728 |
+
block_code = "18"
|
729 |
+
block_name = code_to_block_flux[block_code]
|
730 |
+
|
731 |
+
saes_dict = {}
|
732 |
+
means_dict = {}
|
733 |
+
|
734 |
+
# Download files from the root of the repo
|
735 |
+
state_dict_path = hf_hub_download(
|
736 |
+
repo_id=HF_REPO_ID,
|
737 |
+
filename="state_dict.pth",
|
738 |
+
revision=HF_BRANCH
|
739 |
+
)
|
740 |
+
|
741 |
+
config_path = hf_hub_download(
|
742 |
+
repo_id=HF_REPO_ID,
|
743 |
+
filename="config.json",
|
744 |
+
revision=HF_BRANCH
|
745 |
+
)
|
746 |
+
|
747 |
+
mean_path = hf_hub_download(
|
748 |
+
repo_id=HF_REPO_ID,
|
749 |
+
filename="mean.pt",
|
750 |
+
revision=HF_BRANCH
|
751 |
+
)
|
752 |
+
|
753 |
+
# Load config and model
|
754 |
+
with open(config_path, "r") as f:
|
755 |
+
config = json.load(f)
|
756 |
+
|
757 |
+
sae = SparseAutoencoder(**config)
|
758 |
+
checkpoint = torch.load(state_dict_path, map_location=DEVICE)
|
759 |
+
state_dict = checkpoint["state_dict"]
|
760 |
+
sae.load_state_dict(state_dict)
|
761 |
+
sae = sae.to(DEVICE, dtype=torch.float16).eval()
|
762 |
+
means = torch.load(mean_path, map_location=DEVICE).to(dtype)
|
763 |
+
|
764 |
+
saes_dict[block_code] = sae
|
765 |
+
means_dict[block_code] = means
|
766 |
+
|
767 |
+
demo = create_demo(pipe, saes_dict, means_dict)
|
768 |
+
demo.launch()
|
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
|
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:387f2b6f8c4e4a6f1227921f28f00dfa4beb2bd4e422b7eb592cd8627af0e58f
|
3 |
+
size 21581
|
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:39e3c6d17aa572a53368ca8ba8f82757947a3caf14fe654e84b175d0dc0a4650
|
3 |
+
size 52497831
|
checkpoints/unet.down_blocks.2.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c6ca694c9504a7a8aa827004d3fdec5c1cb8fcf3904acc3562d1861fc6e65c19
|
3 |
+
size 21576
|
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
|
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:80790481d0e56ac3fa36599703cee7a05cfb4cc078db57c8f9180e860c330e1d
|
3 |
+
size 21581
|
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:49d38d9178c2a2780e04a5482a2feb9548c6e9a636ed1bf85291acf42e0ffa34
|
3 |
+
size 52497831
|
checkpoints/unet.mid_block.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bb6bfc7ce5e596f8aa048ab262ca56841868c222bf07eb2ed35b6e4f7094fea6
|
3 |
+
size 21576
|
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
|
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:de036d0fb9ee663f7bdf60e4a5d89d038516dae637531676b53ff75d05eab46b
|
3 |
+
size 21581
|
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:14c45efd9cce0258f014c49babdcd0e9ce8b266fe31eed72db1a45b990a1a0f8
|
3 |
+
size 52497831
|
checkpoints/unet.up_blocks.0.attentions.0_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cb9c04499ccae041987cc262894e254c2f04288857a8a0470cfb1b86a8ecfa09
|
3 |
+
size 21576
|
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"n_dirs_local": 5120, "d_model": 1280, "k": 10, "auxk": 256, "dead_steps_threshold": 2441}
|
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/mean.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:96dbf6fffe9d62c3b3352f8e4fe48c54dfd69906cf8ad6828d5ce93db9a5f0dc
|
3 |
+
size 21581
|
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/state_dict.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f8eed82f4bcb2f010ae9075f10a1ece801ee3dec46dba7fadccc35f6c0a7836b
|
3 |
+
size 52497831
|
checkpoints/unet.up_blocks.0.attentions.1_k10_hidden5120_auxk256_bs4096_lr0.0001/final/std.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fe5c5be0c4c2d2b57e7888319053cb64929559f947c8ce445ddd6a397302afab
|
3 |
+
size 21576
|
colab_requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.29.2
|
2 |
+
gradio==5.23.2
|
3 |
+
numpy
|
4 |
+
matplotlib
|
5 |
+
pillow
|
6 |
+
einops
|
7 |
+
transformers
|
8 |
+
huggingface_hub
|
example.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.29.2
|
2 |
+
gradio==4.44.1
|
3 |
+
torch>=2.4.0
|
4 |
+
numpy
|
5 |
+
matplotlib
|
6 |
+
pillow
|
7 |
+
wandb
|
8 |
+
einops
|
9 |
+
transformers
|
10 |
+
accelerate
|
11 |
+
huggingface_hub
|
12 |
+
git+https://github.com/wendlerc/clip-retrieval.git
|
resourses/image.png
ADDED
![]() |
Git LFS Details
|
retrieval.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import snapshot_download
|
2 |
+
from clip_retrieval.clip_back import load_clip_indices, KnnService, ClipOptions
|
3 |
+
from collections import defaultdict
|
4 |
+
import os
|
5 |
+
import glob
|
6 |
+
import shutil
|
7 |
+
import random
|
8 |
+
|
9 |
+
class FeatureRetriever:
|
10 |
+
def __init__(self,
|
11 |
+
num_images=50,
|
12 |
+
imgs_per_dir=15,
|
13 |
+
force_download=False):
|
14 |
+
|
15 |
+
if force_download or not os.path.exists("./clip"):
|
16 |
+
print("Downloading clip resources")
|
17 |
+
rand_num = random.randint(0, 100000)
|
18 |
+
tmp_dir = f"./tmp_{rand_num}"
|
19 |
+
snapshot_download(repo_type="dataset", repo_id="wendlerc/sdxl-unbox-clip-indices", cache_dir=tmp_dir)
|
20 |
+
clip_dirs = glob.glob(f"{tmp_dir}/**/down_10_5120", recursive=True)
|
21 |
+
if len(clip_dirs) > 0:
|
22 |
+
shutil.copytree(clip_dirs[0].replace("down_10_5120", ""), "./clip", dirs_exist_ok=True)
|
23 |
+
shutil.rmtree(tmp_dir)
|
24 |
+
else:
|
25 |
+
ValueError("Could not find clip indices in the downloaded repo.")
|
26 |
+
|
27 |
+
# Initialize CLIP service
|
28 |
+
clip_options = ClipOptions(
|
29 |
+
indice_folder="currently unused by knn.query()",
|
30 |
+
clip_model="ViT-B/32", #"open_clip:ViT-H-14",
|
31 |
+
enable_hdf5=False,
|
32 |
+
enable_faiss_memory_mapping=True,
|
33 |
+
columns_to_return=["image_path", "similarity"],
|
34 |
+
reorder_metadata_by_ivf_index=False,
|
35 |
+
enable_mclip_option=False,
|
36 |
+
use_jit=False,
|
37 |
+
use_arrow=False,
|
38 |
+
provide_safety_model=False,
|
39 |
+
provide_violence_detector=False,
|
40 |
+
provide_aesthetic_embeddings=False,
|
41 |
+
)
|
42 |
+
self.names = ["down.2.1", "mid.0", "up.0.0", "up.0.1"]
|
43 |
+
self.paths = ["./clip/down_10_5120/indices_paths.json",
|
44 |
+
"./clip/mid_10_5120/indices_paths.json",
|
45 |
+
"./clip/up0_10_5120/indices_paths.json",
|
46 |
+
"./clip/up_10_5120/indices_paths.json",]
|
47 |
+
self.knn_service = {}
|
48 |
+
for name, path in zip(self.names, self.paths):
|
49 |
+
resources = load_clip_indices(path, clip_options)
|
50 |
+
self.knn_service[name] = KnnService(clip_resources=resources)
|
51 |
+
self.num_images = num_images
|
52 |
+
self.imgs_per_dir = imgs_per_dir
|
53 |
+
|
54 |
+
def query_text(self, query, block):
|
55 |
+
if block not in self.names:
|
56 |
+
raise ValueError(f"Block must be one of {self.names}")
|
57 |
+
results = self.knn_service[block].query(
|
58 |
+
text_input=query,
|
59 |
+
num_images=self.num_images,
|
60 |
+
num_result_ids=self.num_images,
|
61 |
+
deduplicate=True,
|
62 |
+
)
|
63 |
+
feat_sims = defaultdict(list)
|
64 |
+
feat_scores = {}
|
65 |
+
for result in results:
|
66 |
+
feature_id = result["image_path"].split("/")[-2]
|
67 |
+
feat_sims[feature_id] += [result["similarity"]]
|
68 |
+
for fid, sims in feat_sims.items():
|
69 |
+
feat_scores[fid] = (sum(sims) / len(sims)) * (len(sims)/self.imgs_per_dir)
|
70 |
+
|
71 |
+
return dict(sorted(feat_scores.items(), key=lambda item: -item[1]))
|
scripts/collect_latents_dataset.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import io
|
4 |
+
import tarfile
|
5 |
+
import torch
|
6 |
+
import webdataset as wds
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
11 |
+
from SDLens.hooked_sd_pipeline import HookedStableDiffusionXLPipeline
|
12 |
+
|
13 |
+
import datetime
|
14 |
+
from datasets import load_dataset
|
15 |
+
from torch.utils.data import DataLoader
|
16 |
+
import diffusers
|
17 |
+
import fire
|
18 |
+
|
19 |
+
def main(save_path, start_at=0, finish_at=30000, dataset_batch_size=50):
|
20 |
+
blocks_to_save = [
|
21 |
+
'unet.down_blocks.2.attentions.1',
|
22 |
+
'unet.mid_block.attentions.0',
|
23 |
+
'unet.up_blocks.0.attentions.0',
|
24 |
+
'unet.up_blocks.0.attentions.1',
|
25 |
+
]
|
26 |
+
|
27 |
+
# Initialization
|
28 |
+
dataset = load_dataset("guangyil/laion-coco-aesthetic", split="train", columns=["caption"], streaming=True).shuffle(seed=42)
|
29 |
+
pipe = HookedStableDiffusionXLPipeline.from_pretrained('stabilityai/sdxl-turbo')
|
30 |
+
pipe.to('cuda')
|
31 |
+
pipe.set_progress_bar_config(disable=True)
|
32 |
+
dataloader = DataLoader(dataset, batch_size=dataset_batch_size)
|
33 |
+
|
34 |
+
ct = datetime.datetime.now()
|
35 |
+
save_path = os.path.join(save_path, str(ct))
|
36 |
+
# Collecting dataset
|
37 |
+
os.makedirs(save_path, exist_ok=True)
|
38 |
+
|
39 |
+
writers = {
|
40 |
+
block: wds.TarWriter(f'{save_path}/{block}.tar') for block in blocks_to_save
|
41 |
+
}
|
42 |
+
|
43 |
+
writers.update({'images': wds.TarWriter(f'{save_path}/images.tar')})
|
44 |
+
|
45 |
+
def to_kwargs(kwargs_to_save):
|
46 |
+
kwargs = kwargs_to_save.copy()
|
47 |
+
seed = kwargs['seed']
|
48 |
+
del kwargs['seed']
|
49 |
+
kwargs['generator'] = torch.Generator(device="cpu").manual_seed(num_document)
|
50 |
+
return kwargs
|
51 |
+
|
52 |
+
dataloader_iter = iter(dataloader)
|
53 |
+
for num_document, batch in tqdm(enumerate(dataloader)):
|
54 |
+
if num_document < start_at:
|
55 |
+
continue
|
56 |
+
|
57 |
+
if num_document >= finish_at:
|
58 |
+
break
|
59 |
+
|
60 |
+
kwargs_to_save = {
|
61 |
+
'prompt': batch['caption'],
|
62 |
+
'positions_to_cache': blocks_to_save,
|
63 |
+
'save_input': True,
|
64 |
+
'save_output': True,
|
65 |
+
'num_inference_steps': 1,
|
66 |
+
'guidance_scale': 0.0,
|
67 |
+
'seed': num_document,
|
68 |
+
'output_type': 'pil'
|
69 |
+
}
|
70 |
+
|
71 |
+
kwargs = to_kwargs(kwargs_to_save)
|
72 |
+
|
73 |
+
output, cache = pipe.run_with_cache(
|
74 |
+
**kwargs
|
75 |
+
)
|
76 |
+
|
77 |
+
blocks = cache['input'].keys()
|
78 |
+
for block in blocks:
|
79 |
+
sample = {
|
80 |
+
"__key__": f"sample_{num_document}",
|
81 |
+
"output.pth": cache['output'][block],
|
82 |
+
"diff.pth": cache['output'][block] - cache['input'][block],
|
83 |
+
"gen_args.json": kwargs_to_save
|
84 |
+
}
|
85 |
+
|
86 |
+
writers[block].write(sample)
|
87 |
+
writers['images'].write({
|
88 |
+
"__key__": f"sample_{num_document}",
|
89 |
+
"images.npy": np.stack(output.images)
|
90 |
+
})
|
91 |
+
|
92 |
+
for block, writer in writers.items():
|
93 |
+
writer.close()
|
94 |
+
|
95 |
+
if __name__ == '__main__':
|
96 |
+
fire.Fire(main)
|
scripts/train_sae.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Adapted from
|
3 |
+
https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/train.py
|
4 |
+
'''
|
5 |
+
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
10 |
+
from typing import Callable, Iterable, Iterator
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.distributed as dist
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch.distributed import ReduceOp
|
17 |
+
from SAE.dataset_iterator import ActivationsDataloader
|
18 |
+
from SAE.sae import SparseAutoencoder, unit_norm_decoder_, unit_norm_decoder_grad_adjustment_
|
19 |
+
from SAE.sae_utils import SAETrainingConfig, Config
|
20 |
+
|
21 |
+
from types import SimpleNamespace
|
22 |
+
from typing import Optional, List
|
23 |
+
import json
|
24 |
+
|
25 |
+
import tqdm
|
26 |
+
|
27 |
+
def weighted_average(points: torch.Tensor, weights: torch.Tensor):
|
28 |
+
weights = weights / weights.sum()
|
29 |
+
return (points * weights.view(-1, 1)).sum(dim=0)
|
30 |
+
|
31 |
+
|
32 |
+
@torch.no_grad()
|
33 |
+
def geometric_median_objective(
|
34 |
+
median: torch.Tensor, points: torch.Tensor, weights: torch.Tensor
|
35 |
+
) -> torch.Tensor:
|
36 |
+
|
37 |
+
norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
|
38 |
+
|
39 |
+
return (norms * weights).sum()
|
40 |
+
|
41 |
+
|
42 |
+
def compute_geometric_median(
|
43 |
+
points: torch.Tensor,
|
44 |
+
weights: Optional[torch.Tensor] = None,
|
45 |
+
eps: float = 1e-6,
|
46 |
+
maxiter: int = 100,
|
47 |
+
ftol: float = 1e-20,
|
48 |
+
do_log: bool = False,
|
49 |
+
):
|
50 |
+
"""
|
51 |
+
:param points: ``torch.Tensor`` of shape ``(n, d)``
|
52 |
+
:param weights: Optional ``torch.Tensor`` of shape :math:``(n,)``.
|
53 |
+
:param eps: Smallest allowed value of denominator, to avoid divide by zero.
|
54 |
+
Equivalently, this is a smoothing parameter. Default 1e-6.
|
55 |
+
:param maxiter: Maximum number of Weiszfeld iterations. Default 100
|
56 |
+
:param ftol: If objective value does not improve by at least this `ftol` fraction, terminate the algorithm. Default 1e-20.
|
57 |
+
:param do_log: If true will return a log of function values encountered through the course of the algorithm
|
58 |
+
:return: SimpleNamespace object with fields
|
59 |
+
- `median`: estimate of the geometric median, which is a ``torch.Tensor`` object of shape :math:``(d,)``
|
60 |
+
- `termination`: string explaining how the algorithm terminated.
|
61 |
+
- `logs`: function values encountered through the course of the algorithm in a list (None if do_log is false).
|
62 |
+
"""
|
63 |
+
with torch.no_grad():
|
64 |
+
|
65 |
+
if weights is None:
|
66 |
+
weights = torch.ones((points.shape[0],), device=points.device)
|
67 |
+
# initialize median estimate at mean
|
68 |
+
new_weights = weights
|
69 |
+
median = weighted_average(points, weights)
|
70 |
+
objective_value = geometric_median_objective(median, points, weights)
|
71 |
+
if do_log:
|
72 |
+
logs = [objective_value]
|
73 |
+
else:
|
74 |
+
logs = None
|
75 |
+
|
76 |
+
# Weiszfeld iterations
|
77 |
+
early_termination = False
|
78 |
+
pbar = tqdm.tqdm(range(maxiter))
|
79 |
+
for _ in pbar:
|
80 |
+
prev_obj_value = objective_value
|
81 |
+
|
82 |
+
norms = torch.linalg.norm(points - median.view(1, -1), dim=1) # type: ignore
|
83 |
+
new_weights = weights / torch.clamp(norms, min=eps)
|
84 |
+
median = weighted_average(points, new_weights)
|
85 |
+
objective_value = geometric_median_objective(median, points, weights)
|
86 |
+
|
87 |
+
if logs is not None:
|
88 |
+
logs.append(objective_value)
|
89 |
+
if abs(prev_obj_value - objective_value) <= ftol * objective_value:
|
90 |
+
early_termination = True
|
91 |
+
break
|
92 |
+
|
93 |
+
pbar.set_description(f"Objective value: {objective_value:.4f}")
|
94 |
+
|
95 |
+
median = weighted_average(points, new_weights) # allow autodiff to track it
|
96 |
+
return SimpleNamespace(
|
97 |
+
median=median,
|
98 |
+
new_weights=new_weights,
|
99 |
+
termination=(
|
100 |
+
"function value converged within tolerance"
|
101 |
+
if early_termination
|
102 |
+
else "maximum iterations reached"
|
103 |
+
),
|
104 |
+
logs=logs,
|
105 |
+
)
|
106 |
+
|
107 |
+
def maybe_transpose(x):
|
108 |
+
return x.T if not x.is_contiguous() and x.T.is_contiguous() else x
|
109 |
+
|
110 |
+
import wandb
|
111 |
+
|
112 |
+
RANK = 0
|
113 |
+
|
114 |
+
class Logger:
|
115 |
+
def __init__(self, sae_name, **kws):
|
116 |
+
self.vals = {}
|
117 |
+
self.enabled = (RANK == 0) and not kws.pop("dummy", False)
|
118 |
+
self.sae_name = sae_name
|
119 |
+
|
120 |
+
def logkv(self, k, v):
|
121 |
+
if self.enabled:
|
122 |
+
self.vals[f'{self.sae_name}/{k}'] = v.detach() if isinstance(v, torch.Tensor) else v
|
123 |
+
return v
|
124 |
+
|
125 |
+
def dumpkvs(self, step):
|
126 |
+
if self.enabled:
|
127 |
+
wandb.log(self.vals, step=step)
|
128 |
+
self.vals = {}
|
129 |
+
|
130 |
+
|
131 |
+
class FeaturesStats:
|
132 |
+
def __init__(self, dim, logger):
|
133 |
+
self.dim = dim
|
134 |
+
self.logger = logger
|
135 |
+
self.reinit()
|
136 |
+
|
137 |
+
def reinit(self):
|
138 |
+
self.n_activated = torch.zeros(self.dim, dtype=torch.long, device="cuda")
|
139 |
+
self.n = 0
|
140 |
+
|
141 |
+
def update(self, inds):
|
142 |
+
self.n += inds.shape[0]
|
143 |
+
inds = inds.flatten().detach()
|
144 |
+
self.n_activated.scatter_add_(0, inds, torch.ones_like(inds))
|
145 |
+
|
146 |
+
def log(self):
|
147 |
+
self.logger.logkv('activated', (self.n_activated / self.n + 1e-9).log10().cpu().numpy())
|
148 |
+
|
149 |
+
def training_loop_(
|
150 |
+
aes,
|
151 |
+
train_acts_iter,
|
152 |
+
loss_fn,
|
153 |
+
log_interval,
|
154 |
+
save_interval,
|
155 |
+
loggers,
|
156 |
+
sae_cfgs,
|
157 |
+
):
|
158 |
+
sae_packs = []
|
159 |
+
for ae, cfg, logger in zip(aes, sae_cfgs, loggers):
|
160 |
+
pbar = tqdm.tqdm(unit=" steps", desc="Training Loss: ")
|
161 |
+
fstats = FeaturesStats(ae.n_dirs, logger)
|
162 |
+
opt = torch.optim.Adam(ae.parameters(), lr=cfg.lr, eps=cfg.eps, fused=True)
|
163 |
+
sae_packs.append((ae, cfg, logger, pbar, fstats, opt))
|
164 |
+
|
165 |
+
for i, flat_acts_train_batch in enumerate(train_acts_iter):
|
166 |
+
flat_acts_train_batch = flat_acts_train_batch.cuda()
|
167 |
+
|
168 |
+
for ae, cfg, logger, pbar, fstats, opt in sae_packs:
|
169 |
+
recons, info = ae(flat_acts_train_batch)
|
170 |
+
loss = loss_fn(ae, cfg, flat_acts_train_batch, recons, info, logger)
|
171 |
+
|
172 |
+
fstats.update(info['inds'])
|
173 |
+
|
174 |
+
bs = flat_acts_train_batch.shape[0]
|
175 |
+
logger.logkv('not-activated 1e4', (ae.stats_last_nonzero > 1e4 / bs).mean(dtype=float).item())
|
176 |
+
logger.logkv('not-activated 1e6', (ae.stats_last_nonzero > 1e6 / bs).mean(dtype=float).item())
|
177 |
+
logger.logkv('not-activated 1e7', (ae.stats_last_nonzero > 1e7 / bs).mean(dtype=float).item())
|
178 |
+
|
179 |
+
logger.logkv('explained variance', explained_variance(recons, flat_acts_train_batch))
|
180 |
+
logger.logkv('l2_div', (torch.linalg.norm(recons, dim=1) / torch.linalg.norm(flat_acts_train_batch, dim=1)).mean())
|
181 |
+
|
182 |
+
if (i + 1) % log_interval == 0:
|
183 |
+
fstats.log()
|
184 |
+
fstats.reinit()
|
185 |
+
|
186 |
+
if (i + 1) % save_interval == 0:
|
187 |
+
ae.save_to_disk(f"{cfg.save_path}/{i + 1}")
|
188 |
+
|
189 |
+
loss.backward()
|
190 |
+
|
191 |
+
unit_norm_decoder_(ae)
|
192 |
+
unit_norm_decoder_grad_adjustment_(ae)
|
193 |
+
|
194 |
+
opt.step()
|
195 |
+
opt.zero_grad()
|
196 |
+
logger.dumpkvs(i)
|
197 |
+
|
198 |
+
pbar.set_description(f"Training Loss {loss.item():.4f}")
|
199 |
+
pbar.update(1)
|
200 |
+
|
201 |
+
|
202 |
+
for ae, cfg, logger, pbar, fstats, opt in sae_packs:
|
203 |
+
pbar.close()
|
204 |
+
ae.save_to_disk(f"{cfg.save_path}/final")
|
205 |
+
|
206 |
+
|
207 |
+
def init_from_data_(ae, stats_acts_sample):
|
208 |
+
ae.pre_bias.data = (
|
209 |
+
compute_geometric_median(stats_acts_sample[:32768].float().cpu()).median.cuda().float()
|
210 |
+
)
|
211 |
+
|
212 |
+
|
213 |
+
def mse(recons, x):
|
214 |
+
# return ((recons - x) ** 2).sum(dim=-1).mean()
|
215 |
+
return ((recons - x) ** 2).mean()
|
216 |
+
|
217 |
+
def normalized_mse(recon: torch.Tensor, xs: torch.Tensor) -> torch.Tensor:
|
218 |
+
# only used for auxk
|
219 |
+
xs_mu = xs.mean(dim=0)
|
220 |
+
|
221 |
+
loss = mse(recon, xs) / mse(
|
222 |
+
xs_mu[None, :].broadcast_to(xs.shape), xs
|
223 |
+
)
|
224 |
+
|
225 |
+
return loss
|
226 |
+
|
227 |
+
def explained_variance(recons, x):
|
228 |
+
# Compute the variance of the difference
|
229 |
+
diff = x - recons
|
230 |
+
diff_var = torch.var(diff, dim=0, unbiased=False)
|
231 |
+
|
232 |
+
# Compute the variance of the original tensor
|
233 |
+
x_var = torch.var(x, dim=0, unbiased=False)
|
234 |
+
|
235 |
+
# Avoid division by zero
|
236 |
+
explained_var = 1 - diff_var / (x_var + 1e-8)
|
237 |
+
|
238 |
+
return explained_var.mean()
|
239 |
+
|
240 |
+
|
241 |
+
def main():
|
242 |
+
cfg = Config(json.load(open('SAE/config.json')))
|
243 |
+
|
244 |
+
dataloader = ActivationsDataloader(cfg.paths_to_latents, cfg.block_name, cfg.bs)
|
245 |
+
|
246 |
+
acts_iter = dataloader.iterate()
|
247 |
+
stats_acts_sample = torch.cat([
|
248 |
+
next(acts_iter).cpu() for _ in range(10)
|
249 |
+
], dim=0)
|
250 |
+
|
251 |
+
aes = [
|
252 |
+
SparseAutoencoder(
|
253 |
+
n_dirs_local=sae.n_dirs,
|
254 |
+
d_model=sae.d_model,
|
255 |
+
k=sae.k,
|
256 |
+
auxk=sae.auxk,
|
257 |
+
dead_steps_threshold=sae.dead_toks_threshold // cfg.bs,
|
258 |
+
).cuda()
|
259 |
+
for sae in cfg.saes
|
260 |
+
]
|
261 |
+
|
262 |
+
for ae in aes:
|
263 |
+
init_from_data_(ae, stats_acts_sample)
|
264 |
+
|
265 |
+
mse_scale = (
|
266 |
+
1 / ((stats_acts_sample.float().mean(dim=0) - stats_acts_sample.float()) ** 2).mean()
|
267 |
+
)
|
268 |
+
mse_scale = mse_scale.item()
|
269 |
+
del stats_acts_sample
|
270 |
+
|
271 |
+
wandb.init(
|
272 |
+
project=cfg.wandb_project,
|
273 |
+
name=cfg.wandb_name,
|
274 |
+
)
|
275 |
+
|
276 |
+
loggers = [Logger(
|
277 |
+
sae_name=cfg_sae.sae_name,
|
278 |
+
dummy=False,
|
279 |
+
) for cfg_sae in cfg.saes]
|
280 |
+
|
281 |
+
training_loop_(
|
282 |
+
aes,
|
283 |
+
acts_iter,
|
284 |
+
lambda ae, cfg_sae, flat_acts_train_batch, recons, info, logger: (
|
285 |
+
# MSE
|
286 |
+
logger.logkv("train_recons", mse_scale * mse(recons, flat_acts_train_batch))
|
287 |
+
# AuxK
|
288 |
+
+ logger.logkv(
|
289 |
+
"train_maxk_recons",
|
290 |
+
cfg_sae.auxk_coef
|
291 |
+
* normalized_mse(
|
292 |
+
ae.decode_sparse(
|
293 |
+
info["auxk_inds"],
|
294 |
+
info["auxk_vals"],
|
295 |
+
),
|
296 |
+
flat_acts_train_batch - recons.detach() + ae.pre_bias.detach(),
|
297 |
+
).nan_to_num(0),
|
298 |
+
)
|
299 |
+
),
|
300 |
+
sae_cfgs = cfg.saes,
|
301 |
+
loggers=loggers,
|
302 |
+
log_interval=cfg.log_interval,
|
303 |
+
save_interval=cfg.save_interval,
|
304 |
+
)
|
305 |
+
|
306 |
+
|
307 |
+
if __name__ == "__main__":
|
308 |
+
main()
|
utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .hooks import *
|
utils/hooks.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, List, Optional
|
2 |
+
import torch
|
3 |
+
|
4 |
+
class TimedHook:
|
5 |
+
def __init__(self, hook_fn, total_steps, apply_at_steps=None):
|
6 |
+
self.hook_fn = hook_fn
|
7 |
+
self.total_steps = total_steps
|
8 |
+
self.apply_at_steps = apply_at_steps
|
9 |
+
self.current_step = 0
|
10 |
+
|
11 |
+
def identity(self, module, input, output):
|
12 |
+
return output
|
13 |
+
|
14 |
+
def __call__(self, module, input, output):
|
15 |
+
if self.apply_at_steps is not None:
|
16 |
+
if self.current_step in self.apply_at_steps:
|
17 |
+
self.__increment()
|
18 |
+
return self.hook_fn(module, input, output)
|
19 |
+
else:
|
20 |
+
self.__increment()
|
21 |
+
return self.identity(module, input, output)
|
22 |
+
|
23 |
+
return self.identity(module, input, output)
|
24 |
+
|
25 |
+
def __increment(self):
|
26 |
+
if self.current_step < self.total_steps:
|
27 |
+
self.current_step += 1
|
28 |
+
else:
|
29 |
+
self.current_step = 0
|
30 |
+
|
31 |
+
@torch.no_grad()
|
32 |
+
def add_feature(sae, feature_idx, value, module, input, output):
|
33 |
+
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
34 |
+
activated = sae.encode(diff)
|
35 |
+
mask = torch.zeros_like(activated, device=diff.device)
|
36 |
+
mask[..., feature_idx] = value
|
37 |
+
to_add = mask @ sae.decoder.weight.T
|
38 |
+
return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
|
39 |
+
|
40 |
+
@torch.no_grad()
|
41 |
+
def add_feature_on_area_base(sae, feature_idx, activation_map, module, input, output):
|
42 |
+
return add_feature_on_area_base_both(sae, feature_idx, activation_map, module, input, output)
|
43 |
+
|
44 |
+
@torch.no_grad()
|
45 |
+
def add_feature_on_area_base_both(sae, feature_idx, activation_map, module, input, output):
|
46 |
+
# add the feature to cond and subtract from uncond
|
47 |
+
# this assumes diff.shape[0] == 2
|
48 |
+
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
49 |
+
activated = sae.encode(diff)
|
50 |
+
mask = torch.zeros_like(activated, device=diff.device)
|
51 |
+
if len(activation_map) == 2:
|
52 |
+
activation_map = activation_map.unsqueeze(0)
|
53 |
+
mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device)
|
54 |
+
to_add = mask @ sae.decoder.weight.T
|
55 |
+
to_add = to_add.chunk(2)
|
56 |
+
output[0][0] -= to_add[0].permute(0, 3, 1, 2).to(output[0].device)[0]
|
57 |
+
output[0][1] += to_add[1].permute(0, 3, 1, 2).to(output[0].device)[0]
|
58 |
+
return output
|
59 |
+
|
60 |
+
|
61 |
+
@torch.no_grad()
|
62 |
+
def add_feature_on_area_base_cond(sae, feature_idx, activation_map, module, input, output):
|
63 |
+
# add the feature to cond
|
64 |
+
# this assumes diff.shape[0] == 2
|
65 |
+
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
66 |
+
diff_uncond, diff_cond = diff.chunk(2)
|
67 |
+
activated = sae.encode(diff_cond)
|
68 |
+
mask = torch.zeros_like(activated, device=diff_cond.device)
|
69 |
+
if len(activation_map) == 2:
|
70 |
+
activation_map = activation_map.unsqueeze(0)
|
71 |
+
mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device)
|
72 |
+
to_add = mask @ sae.decoder.weight.T
|
73 |
+
output[0][1] += to_add.permute(0, 3, 1, 2).to(output[0].device)[0]
|
74 |
+
return output
|
75 |
+
|
76 |
+
|
77 |
+
@torch.no_grad()
|
78 |
+
def replace_with_feature_base(sae, feature_idx, value, module, input, output):
|
79 |
+
# this assumes diff.shape[0] == 2
|
80 |
+
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
81 |
+
diff_uncond, diff_cond = diff.chunk(2)
|
82 |
+
activated = sae.encode(diff_cond)
|
83 |
+
mask = torch.zeros_like(activated, device=diff_cond.device)
|
84 |
+
mask[..., feature_idx] = value
|
85 |
+
to_add = mask @ sae.decoder.weight.T
|
86 |
+
input[0][1] += to_add.permute(0, 3, 1, 2).to(output[0].device)[0]
|
87 |
+
return input
|
88 |
+
|
89 |
+
|
90 |
+
@torch.no_grad()
|
91 |
+
def add_feature_on_area_turbo(sae, feature_idx, activation_map, module, input, output):
|
92 |
+
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
93 |
+
activated = sae.encode(diff)
|
94 |
+
mask = torch.zeros_like(activated, device=diff.device)
|
95 |
+
if len(activation_map) == 2:
|
96 |
+
activation_map = activation_map.unsqueeze(0)
|
97 |
+
mask[..., feature_idx] = mask[..., feature_idx] = activation_map.to(mask.device)
|
98 |
+
to_add = mask @ sae.decoder.weight.T
|
99 |
+
return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
|
100 |
+
|
101 |
+
@torch.no_grad
|
102 |
+
def add_feature_on_area_flux(
|
103 |
+
sae,
|
104 |
+
feature_idx,
|
105 |
+
activation_map,
|
106 |
+
module,
|
107 |
+
input: torch.Tensor,
|
108 |
+
output: torch.Tensor,
|
109 |
+
):
|
110 |
+
|
111 |
+
diff = (output - input).to(sae.device)
|
112 |
+
activated = sae.encode(diff)
|
113 |
+
|
114 |
+
# TODO: check
|
115 |
+
if len(activation_map) == 2:
|
116 |
+
activation_map = activation_map.unsqueeze(0)
|
117 |
+
mask = torch.zeros_like(activated, device=diff.device)
|
118 |
+
activation_map = activation_map.flatten()
|
119 |
+
mask[..., feature_idx] = activation_map.to(mask.device)
|
120 |
+
to_add = mask @ sae.decoder.weight.T
|
121 |
+
return output + to_add.to(output.device, output.dtype)
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
@torch.no_grad()
|
126 |
+
def replace_with_feature_turbo(sae, feature_idx, value, module, input, output):
|
127 |
+
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
128 |
+
activated = sae.encode(diff)
|
129 |
+
mask = torch.zeros_like(activated, device=diff.device)
|
130 |
+
mask[..., feature_idx] = value
|
131 |
+
to_add = mask @ sae.decoder.weight.T
|
132 |
+
return (input[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
|
133 |
+
|
134 |
+
|
135 |
+
@torch.no_grad()
|
136 |
+
def reconstruct_sae_hook(sae, module, input, output):
|
137 |
+
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
138 |
+
activated = sae.encode(diff)
|
139 |
+
reconstructed = sae.decoder(activated) + sae.pre_bias
|
140 |
+
return (input[0] + reconstructed.permute(0, 3, 1, 2).to(output[0].device),)
|
141 |
+
|
142 |
+
|
143 |
+
@torch.no_grad()
|
144 |
+
def ablate_block(module, input, output):
|
145 |
+
return input
|