Mariam-Elz commited on
Commit
f6981d4
·
verified ·
1 Parent(s): b1e3fd8

Upload imagedream/ldm/modules/encoders/modules.py with huggingface_hub

Browse files
imagedream/ldm/modules/encoders/modules.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+
5
+ from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
6
+
7
+ import numpy as np
8
+ import open_clip
9
+ from PIL import Image
10
+ from ...util import default, count_params
11
+
12
+
13
+ class AbstractEncoder(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def encode(self, *args, **kwargs):
18
+ raise NotImplementedError
19
+
20
+
21
+ class IdentityEncoder(AbstractEncoder):
22
+ def encode(self, x):
23
+ return x
24
+
25
+
26
+ class ClassEmbedder(nn.Module):
27
+ def __init__(self, embed_dim, n_classes=1000, key="class", ucg_rate=0.1):
28
+ super().__init__()
29
+ self.key = key
30
+ self.embedding = nn.Embedding(n_classes, embed_dim)
31
+ self.n_classes = n_classes
32
+ self.ucg_rate = ucg_rate
33
+
34
+ def forward(self, batch, key=None, disable_dropout=False):
35
+ if key is None:
36
+ key = self.key
37
+ # this is for use in crossattn
38
+ c = batch[key][:, None]
39
+ if self.ucg_rate > 0.0 and not disable_dropout:
40
+ mask = 1.0 - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
41
+ c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
42
+ c = c.long()
43
+ c = self.embedding(c)
44
+ return c
45
+
46
+ def get_unconditional_conditioning(self, bs, device="cuda"):
47
+ uc_class = (
48
+ self.n_classes - 1
49
+ ) # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
50
+ uc = torch.ones((bs,), device=device) * uc_class
51
+ uc = {self.key: uc}
52
+ return uc
53
+
54
+
55
+ def disabled_train(self, mode=True):
56
+ """Overwrite model.train with this function to make sure train/eval mode
57
+ does not change anymore."""
58
+ return self
59
+
60
+
61
+ class FrozenT5Embedder(AbstractEncoder):
62
+ """Uses the T5 transformer encoder for text"""
63
+
64
+ def __init__(
65
+ self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True
66
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
67
+ super().__init__()
68
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
69
+ self.transformer = T5EncoderModel.from_pretrained(version)
70
+ self.device = device
71
+ self.max_length = max_length # TODO: typical value?
72
+ if freeze:
73
+ self.freeze()
74
+
75
+ def freeze(self):
76
+ self.transformer = self.transformer.eval()
77
+ # self.train = disabled_train
78
+ for param in self.parameters():
79
+ param.requires_grad = False
80
+
81
+ def forward(self, text):
82
+ batch_encoding = self.tokenizer(
83
+ text,
84
+ truncation=True,
85
+ max_length=self.max_length,
86
+ return_length=True,
87
+ return_overflowing_tokens=False,
88
+ padding="max_length",
89
+ return_tensors="pt",
90
+ )
91
+ tokens = batch_encoding["input_ids"].to(self.device)
92
+ outputs = self.transformer(input_ids=tokens)
93
+
94
+ z = outputs.last_hidden_state
95
+ return z
96
+
97
+ def encode(self, text):
98
+ return self(text)
99
+
100
+
101
+ class FrozenCLIPEmbedder(AbstractEncoder):
102
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
103
+
104
+ LAYERS = ["last", "pooled", "hidden"]
105
+
106
+ def __init__(
107
+ self,
108
+ version="openai/clip-vit-large-patch14",
109
+ device="cuda",
110
+ max_length=77,
111
+ freeze=True,
112
+ layer="last",
113
+ layer_idx=None,
114
+ ): # clip-vit-base-patch32
115
+ super().__init__()
116
+ assert layer in self.LAYERS
117
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
118
+ self.transformer = CLIPTextModel.from_pretrained(version)
119
+ self.device = device
120
+ self.max_length = max_length
121
+ if freeze:
122
+ self.freeze()
123
+ self.layer = layer
124
+ self.layer_idx = layer_idx
125
+ if layer == "hidden":
126
+ assert layer_idx is not None
127
+ assert 0 <= abs(layer_idx) <= 12
128
+
129
+ def freeze(self):
130
+ self.transformer = self.transformer.eval()
131
+ # self.train = disabled_train
132
+ for param in self.parameters():
133
+ param.requires_grad = False
134
+
135
+ def forward(self, text):
136
+ batch_encoding = self.tokenizer(
137
+ text,
138
+ truncation=True,
139
+ max_length=self.max_length,
140
+ return_length=True,
141
+ return_overflowing_tokens=False,
142
+ padding="max_length",
143
+ return_tensors="pt",
144
+ )
145
+ tokens = batch_encoding["input_ids"].to(self.device)
146
+ outputs = self.transformer(
147
+ input_ids=tokens, output_hidden_states=self.layer == "hidden"
148
+ )
149
+ if self.layer == "last":
150
+ z = outputs.last_hidden_state
151
+ elif self.layer == "pooled":
152
+ z = outputs.pooler_output[:, None, :]
153
+ else:
154
+ z = outputs.hidden_states[self.layer_idx]
155
+ return z
156
+
157
+ def encode(self, text):
158
+ return self(text)
159
+
160
+
161
+ class FrozenOpenCLIPEmbedder(AbstractEncoder, nn.Module):
162
+ """
163
+ Uses the OpenCLIP transformer encoder for text
164
+ """
165
+
166
+ LAYERS = [
167
+ # "pooled",
168
+ "last",
169
+ "penultimate",
170
+ ]
171
+
172
+ def __init__(
173
+ self,
174
+ arch="ViT-H-14",
175
+ version="laion2b_s32b_b79k",
176
+ device="cuda",
177
+ max_length=77,
178
+ freeze=True,
179
+ layer="last",
180
+ ip_mode=None
181
+ ):
182
+ """_summary_
183
+
184
+ Args:
185
+ ip_mode (str, optional): what is the image promcessing mode. Defaults to None.
186
+
187
+ """
188
+ super().__init__()
189
+ assert layer in self.LAYERS
190
+ model, _, preprocess = open_clip.create_model_and_transforms(
191
+ arch, device=torch.device("cpu"), pretrained=version
192
+ )
193
+ if ip_mode is None:
194
+ del model.visual
195
+
196
+ self.model = model
197
+ self.preprocess = preprocess
198
+ self.device = device
199
+ self.max_length = max_length
200
+ self.ip_mode = ip_mode
201
+ if freeze:
202
+ self.freeze()
203
+ self.layer = layer
204
+ if self.layer == "last":
205
+ self.layer_idx = 0
206
+ elif self.layer == "penultimate":
207
+ self.layer_idx = 1
208
+ else:
209
+ raise NotImplementedError()
210
+
211
+ def freeze(self):
212
+ self.model = self.model.eval()
213
+ for param in self.parameters():
214
+ param.requires_grad = False
215
+
216
+ def forward(self, text):
217
+ tokens = open_clip.tokenize(text)
218
+ z = self.encode_with_transformer(tokens.to(self.device))
219
+ return z
220
+
221
+ def forward_image(self, pil_image):
222
+ if isinstance(pil_image, Image.Image):
223
+ pil_image = [pil_image]
224
+ if isinstance(pil_image, torch.Tensor):
225
+ pil_image = pil_image.cpu().numpy()
226
+ if isinstance(pil_image, np.ndarray):
227
+ if pil_image.ndim == 3:
228
+ pil_image = pil_image[None, :, :, :]
229
+ pil_image = [Image.fromarray(x) for x in pil_image]
230
+
231
+ images = []
232
+ for image in pil_image:
233
+ images.append(self.preprocess(image).to(self.device))
234
+
235
+ image = torch.stack(images, 0) # to [b, 3, h, w]
236
+ if self.ip_mode == "global":
237
+ image_features = self.model.encode_image(image)
238
+ image_features /= image_features.norm(dim=-1, keepdim=True)
239
+ elif "local" in self.ip_mode:
240
+ image_features = self.encode_image_with_transformer(image)
241
+
242
+ return image_features # b, l
243
+
244
+ def encode_image_with_transformer(self, x):
245
+ visual = self.model.visual
246
+ x = visual.conv1(x) # shape = [*, width, grid, grid]
247
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
248
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
249
+
250
+ # class embeddings and positional embeddings
251
+ x = torch.cat(
252
+ [visual.class_embedding.to(x.dtype) + \
253
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
254
+ x], dim=1) # shape = [*, grid ** 2 + 1, width]
255
+ x = x + visual.positional_embedding.to(x.dtype)
256
+
257
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
258
+ # x = visual.patch_dropout(x)
259
+ x = visual.ln_pre(x)
260
+
261
+ x = x.permute(1, 0, 2) # NLD -> LND
262
+ hidden = self.image_transformer_forward(x)
263
+ x = hidden[-2].permute(1, 0, 2) # LND -> NLD
264
+ return x
265
+
266
+ def image_transformer_forward(self, x):
267
+ encoder_states = ()
268
+ trans = self.model.visual.transformer
269
+ for r in trans.resblocks:
270
+ if trans.grad_checkpointing and not torch.jit.is_scripting():
271
+ # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372
272
+ x = checkpoint(r, x, None, None, None)
273
+ else:
274
+ x = r(x, attn_mask=None)
275
+ encoder_states = encoder_states + (x, )
276
+ return encoder_states
277
+
278
+ def encode_with_transformer(self, text):
279
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
280
+ x = x + self.model.positional_embedding
281
+ x = x.permute(1, 0, 2) # NLD -> LND
282
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
283
+ x = x.permute(1, 0, 2) # LND -> NLD
284
+ x = self.model.ln_final(x)
285
+ return x
286
+
287
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
288
+ for i, r in enumerate(self.model.transformer.resblocks):
289
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
290
+ break
291
+ if (
292
+ self.model.transformer.grad_checkpointing
293
+ and not torch.jit.is_scripting()
294
+ ):
295
+ x = checkpoint(r, x, attn_mask)
296
+ else:
297
+ x = r(x, attn_mask=attn_mask)
298
+ return x
299
+
300
+ def encode(self, text):
301
+ return self(text)
302
+
303
+
304
+ class FrozenCLIPT5Encoder(AbstractEncoder):
305
+ def __init__(
306
+ self,
307
+ clip_version="openai/clip-vit-large-patch14",
308
+ t5_version="google/t5-v1_1-xl",
309
+ device="cuda",
310
+ clip_max_length=77,
311
+ t5_max_length=77,
312
+ ):
313
+ super().__init__()
314
+ self.clip_encoder = FrozenCLIPEmbedder(
315
+ clip_version, device, max_length=clip_max_length
316
+ )
317
+ self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
318
+ print(
319
+ f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
320
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params."
321
+ )
322
+
323
+ def encode(self, text):
324
+ return self(text)
325
+
326
+ def forward(self, text):
327
+ clip_z = self.clip_encoder.encode(text)
328
+ t5_z = self.t5_encoder.encode(text)
329
+ return [clip_z, t5_z]