feat(colab): require less memory
Browse files- tools/inference/inference_pipeline.ipynb +862 -478
tools/inference/inference_pipeline.ipynb
CHANGED
@@ -1,481 +1,865 @@
|
|
1 |
{
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
"VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
|
84 |
-
"\n",
|
85 |
-
"# CLIP model\n",
|
86 |
-
"CLIP_REPO = \"openai/clip-vit-large-patch14\"\n",
|
87 |
-
"CLIP_COMMIT_ID = None"
|
88 |
-
]
|
89 |
-
},
|
90 |
-
{
|
91 |
-
"cell_type": "markdown",
|
92 |
-
"source": [
|
93 |
-
"If your hardware can handle it, you can use dalle-mega instead of dalle-mini.\n",
|
94 |
-
"\n",
|
95 |
-
"**Note: on free Colab, you will most likely not be able to load dalle-mega so don't run the below cell to use dalle-mini instead.**"
|
96 |
-
],
|
97 |
-
"metadata": {
|
98 |
-
"id": "Jy01V3OltG5q"
|
99 |
-
}
|
100 |
-
},
|
101 |
-
{
|
102 |
-
"cell_type": "code",
|
103 |
-
"source": [
|
104 |
-
"# dalle-mega\n",
|
105 |
-
"DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\""
|
106 |
-
],
|
107 |
-
"metadata": {
|
108 |
-
"id": "VRRX_g9vtMZG"
|
109 |
-
},
|
110 |
-
"execution_count": null,
|
111 |
-
"outputs": []
|
112 |
-
},
|
113 |
-
{
|
114 |
-
"cell_type": "code",
|
115 |
-
"execution_count": null,
|
116 |
-
"metadata": {
|
117 |
-
"id": "Yv-aR3t4Oe5v"
|
118 |
-
},
|
119 |
-
"outputs": [],
|
120 |
-
"source": [
|
121 |
-
"import jax\n",
|
122 |
-
"import jax.numpy as jnp\n",
|
123 |
-
"\n",
|
124 |
-
"# check how many devices are available\n",
|
125 |
-
"jax.local_device_count()"
|
126 |
-
]
|
127 |
-
},
|
128 |
-
{
|
129 |
-
"cell_type": "code",
|
130 |
-
"execution_count": null,
|
131 |
-
"metadata": {
|
132 |
-
"id": "92zYmvsQ38vL"
|
133 |
-
},
|
134 |
-
"outputs": [],
|
135 |
-
"source": [
|
136 |
-
"# Load models & tokenizer\n",
|
137 |
-
"from dalle_mini import DalleBart, DalleBartProcessor\n",
|
138 |
-
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
139 |
-
"from transformers import CLIPProcessor, FlaxCLIPModel\n",
|
140 |
-
"\n",
|
141 |
-
"# Load dalle-mini\n",
|
142 |
-
"model, params = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False)\n",
|
143 |
-
"\n",
|
144 |
-
"# Load VQGAN\n",
|
145 |
-
"vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
|
146 |
-
"\n",
|
147 |
-
"# Load CLIP\n",
|
148 |
-
"clip = FlaxCLIPModel.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
|
149 |
-
"clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)"
|
150 |
-
]
|
151 |
-
},
|
152 |
-
{
|
153 |
-
"cell_type": "markdown",
|
154 |
-
"metadata": {
|
155 |
-
"id": "o_vH2X1tDtzA"
|
156 |
-
},
|
157 |
-
"source": [
|
158 |
-
"Model parameters are replicated on each device for faster inference."
|
159 |
-
]
|
160 |
-
},
|
161 |
-
{
|
162 |
-
"cell_type": "code",
|
163 |
-
"execution_count": null,
|
164 |
-
"metadata": {
|
165 |
-
"id": "wtvLoM48EeVw"
|
166 |
-
},
|
167 |
-
"outputs": [],
|
168 |
-
"source": [
|
169 |
-
"from flax.jax_utils import replicate\n",
|
170 |
-
"\n",
|
171 |
-
"params = replicate(params)\n",
|
172 |
-
"vqgan._params = replicate(vqgan.params)\n",
|
173 |
-
"clip._params = replicate(clip.params)"
|
174 |
-
]
|
175 |
-
},
|
176 |
-
{
|
177 |
-
"cell_type": "markdown",
|
178 |
-
"metadata": {
|
179 |
-
"id": "0A9AHQIgZ_qw"
|
180 |
-
},
|
181 |
-
"source": [
|
182 |
-
"Model functions are compiled and parallelized to take advantage of multiple devices."
|
183 |
-
]
|
184 |
-
},
|
185 |
-
{
|
186 |
-
"cell_type": "code",
|
187 |
-
"execution_count": null,
|
188 |
-
"metadata": {
|
189 |
-
"id": "sOtoOmYsSYPz"
|
190 |
-
},
|
191 |
-
"outputs": [],
|
192 |
-
"source": [
|
193 |
-
"from functools import partial\n",
|
194 |
-
"\n",
|
195 |
-
"# model inference\n",
|
196 |
-
"@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
|
197 |
-
"def p_generate(\n",
|
198 |
-
" tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
|
199 |
-
"):\n",
|
200 |
-
" return model.generate(\n",
|
201 |
-
" **tokenized_prompt,\n",
|
202 |
-
" prng_key=key,\n",
|
203 |
-
" params=params,\n",
|
204 |
-
" top_k=top_k,\n",
|
205 |
-
" top_p=top_p,\n",
|
206 |
-
" temperature=temperature,\n",
|
207 |
-
" condition_scale=condition_scale,\n",
|
208 |
-
" )\n",
|
209 |
-
"\n",
|
210 |
-
"\n",
|
211 |
-
"# decode image\n",
|
212 |
-
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
213 |
-
"def p_decode(indices, params):\n",
|
214 |
-
" return vqgan.decode_code(indices, params=params)\n",
|
215 |
-
"\n",
|
216 |
-
"\n",
|
217 |
-
"# score images\n",
|
218 |
-
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
219 |
-
"def p_clip(inputs, params):\n",
|
220 |
-
" logits = clip(params=params, **inputs).logits_per_image\n",
|
221 |
-
" return logits"
|
222 |
-
]
|
223 |
-
},
|
224 |
-
{
|
225 |
-
"cell_type": "markdown",
|
226 |
-
"metadata": {
|
227 |
-
"id": "HmVN6IBwapBA"
|
228 |
-
},
|
229 |
-
"source": [
|
230 |
-
"Keys are passed to the model on each device to generate unique inference per device."
|
231 |
-
]
|
232 |
-
},
|
233 |
-
{
|
234 |
-
"cell_type": "code",
|
235 |
-
"execution_count": null,
|
236 |
-
"metadata": {
|
237 |
-
"id": "4CTXmlUkThhX"
|
238 |
-
},
|
239 |
-
"outputs": [],
|
240 |
-
"source": [
|
241 |
-
"import random\n",
|
242 |
-
"\n",
|
243 |
-
"# create a random key\n",
|
244 |
-
"seed = random.randint(0, 2**32 - 1)\n",
|
245 |
-
"key = jax.random.PRNGKey(seed)"
|
246 |
-
]
|
247 |
-
},
|
248 |
-
{
|
249 |
-
"cell_type": "markdown",
|
250 |
-
"metadata": {
|
251 |
-
"id": "BrnVyCo81pij"
|
252 |
-
},
|
253 |
-
"source": [
|
254 |
-
"## 🖍 Text Prompt"
|
255 |
-
]
|
256 |
-
},
|
257 |
-
{
|
258 |
-
"cell_type": "markdown",
|
259 |
-
"metadata": {
|
260 |
-
"id": "rsmj0Aj5OQox"
|
261 |
-
},
|
262 |
-
"source": [
|
263 |
-
"Our model requires processing prompts."
|
264 |
-
]
|
265 |
-
},
|
266 |
-
{
|
267 |
-
"cell_type": "code",
|
268 |
-
"execution_count": null,
|
269 |
-
"metadata": {
|
270 |
-
"id": "YjjhUychOVxm"
|
271 |
-
},
|
272 |
-
"outputs": [],
|
273 |
-
"source": [
|
274 |
-
"from dalle_mini import DalleBartProcessor\n",
|
275 |
-
"\n",
|
276 |
-
"processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
|
277 |
-
]
|
278 |
-
},
|
279 |
-
{
|
280 |
-
"cell_type": "markdown",
|
281 |
-
"metadata": {
|
282 |
-
"id": "BQ7fymSPyvF_"
|
283 |
-
},
|
284 |
-
"source": [
|
285 |
-
"Let's define a text prompt."
|
286 |
-
]
|
287 |
-
},
|
288 |
-
{
|
289 |
-
"cell_type": "code",
|
290 |
-
"execution_count": null,
|
291 |
-
"metadata": {
|
292 |
-
"id": "x_0vI9ge1oKr"
|
293 |
-
},
|
294 |
-
"outputs": [],
|
295 |
-
"source": [
|
296 |
-
"prompt = \"sunset over a lake in the mountains\""
|
297 |
-
]
|
298 |
-
},
|
299 |
-
{
|
300 |
-
"cell_type": "code",
|
301 |
-
"execution_count": null,
|
302 |
-
"metadata": {
|
303 |
-
"id": "VKjEZGjtO49k"
|
304 |
-
},
|
305 |
-
"outputs": [],
|
306 |
-
"source": [
|
307 |
-
"tokenized_prompt = processor([prompt])"
|
308 |
-
]
|
309 |
-
},
|
310 |
-
{
|
311 |
-
"cell_type": "markdown",
|
312 |
-
"metadata": {
|
313 |
-
"id": "-CEJBnuJOe5z"
|
314 |
-
},
|
315 |
-
"source": [
|
316 |
-
"Finally we replicate it onto each device."
|
317 |
-
]
|
318 |
-
},
|
319 |
-
{
|
320 |
-
"cell_type": "code",
|
321 |
-
"execution_count": null,
|
322 |
-
"metadata": {
|
323 |
-
"id": "lQePgju5Oe5z"
|
324 |
-
},
|
325 |
-
"outputs": [],
|
326 |
-
"source": [
|
327 |
-
"tokenized_prompt = replicate(tokenized_prompt)"
|
328 |
-
]
|
329 |
-
},
|
330 |
-
{
|
331 |
-
"cell_type": "markdown",
|
332 |
-
"metadata": {
|
333 |
-
"id": "phQ9bhjRkgAZ"
|
334 |
-
},
|
335 |
-
"source": [
|
336 |
-
"## 🎨 Generate images\n",
|
337 |
-
"\n",
|
338 |
-
"We generate images using dalle-mini model and decode them with the VQGAN."
|
339 |
-
]
|
340 |
-
},
|
341 |
-
{
|
342 |
-
"cell_type": "code",
|
343 |
-
"execution_count": null,
|
344 |
-
"metadata": {
|
345 |
-
"id": "d0wVkXpKqnHA"
|
346 |
-
},
|
347 |
-
"outputs": [],
|
348 |
-
"source": [
|
349 |
-
"# number of predictions\n",
|
350 |
-
"n_predictions = 16\n",
|
351 |
-
"\n",
|
352 |
-
"# We can customize top_k/top_p used for generating samples\n",
|
353 |
-
"gen_top_k = None\n",
|
354 |
-
"gen_top_p = None\n",
|
355 |
-
"temperature = None\n",
|
356 |
-
"cond_scale = 3.0"
|
357 |
-
]
|
358 |
-
},
|
359 |
-
{
|
360 |
-
"cell_type": "code",
|
361 |
-
"execution_count": null,
|
362 |
-
"metadata": {
|
363 |
-
"id": "SDjEx9JxR3v8"
|
364 |
-
},
|
365 |
-
"outputs": [],
|
366 |
-
"source": [
|
367 |
-
"from flax.training.common_utils import shard_prng_key\n",
|
368 |
-
"import numpy as np\n",
|
369 |
-
"from PIL import Image\n",
|
370 |
-
"from tqdm.notebook import trange\n",
|
371 |
-
"\n",
|
372 |
-
"# generate images\n",
|
373 |
-
"images = []\n",
|
374 |
-
"for i in trange(max(n_predictions // jax.device_count(), 1)):\n",
|
375 |
-
" # get a new key\n",
|
376 |
-
" key, subkey = jax.random.split(key)\n",
|
377 |
-
" # generate images\n",
|
378 |
-
" encoded_images = p_generate(\n",
|
379 |
-
" tokenized_prompt,\n",
|
380 |
-
" shard_prng_key(subkey),\n",
|
381 |
-
" params,\n",
|
382 |
-
" gen_top_k,\n",
|
383 |
-
" gen_top_p,\n",
|
384 |
-
" temperature,\n",
|
385 |
-
" cond_scale,\n",
|
386 |
-
" )\n",
|
387 |
-
" # remove BOS\n",
|
388 |
-
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
389 |
-
" # decode images\n",
|
390 |
-
" decoded_images = p_decode(encoded_images, vqgan.params)\n",
|
391 |
-
" decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
|
392 |
-
" for img in decoded_images:\n",
|
393 |
-
" images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
|
394 |
-
]
|
395 |
-
},
|
396 |
-
{
|
397 |
-
"cell_type": "markdown",
|
398 |
-
"metadata": {
|
399 |
-
"id": "tw02wG9zGmyB"
|
400 |
-
},
|
401 |
-
"source": [
|
402 |
-
"Let's calculate their score with CLIP."
|
403 |
-
]
|
404 |
-
},
|
405 |
-
{
|
406 |
-
"cell_type": "code",
|
407 |
-
"execution_count": null,
|
408 |
-
"metadata": {
|
409 |
-
"id": "FoLXpjCmGpju"
|
410 |
-
},
|
411 |
-
"outputs": [],
|
412 |
-
"source": [
|
413 |
-
"from flax.training.common_utils import shard\n",
|
414 |
-
"\n",
|
415 |
-
"# get clip scores\n",
|
416 |
-
"clip_inputs = clip_processor(\n",
|
417 |
-
" text=[prompt] * jax.device_count(),\n",
|
418 |
-
" images=images,\n",
|
419 |
-
" return_tensors=\"np\",\n",
|
420 |
-
" padding=\"max_length\",\n",
|
421 |
-
" max_length=77,\n",
|
422 |
-
" truncation=True,\n",
|
423 |
-
").data\n",
|
424 |
-
"logits = p_clip(shard(clip_inputs), clip.params)\n",
|
425 |
-
"logits = logits.squeeze().flatten()"
|
426 |
-
]
|
427 |
-
},
|
428 |
-
{
|
429 |
-
"cell_type": "markdown",
|
430 |
-
"metadata": {
|
431 |
-
"id": "4AAWRm70LgED"
|
432 |
-
},
|
433 |
-
"source": [
|
434 |
-
"Let's display images ranked by CLIP score."
|
435 |
-
]
|
436 |
-
},
|
437 |
-
{
|
438 |
-
"cell_type": "code",
|
439 |
-
"execution_count": null,
|
440 |
-
"metadata": {
|
441 |
-
"id": "zsgxxubLLkIu"
|
442 |
-
},
|
443 |
-
"outputs": [],
|
444 |
-
"source": [
|
445 |
-
"print(f\"Prompt: {prompt}\\n\")\n",
|
446 |
-
"for idx in logits.argsort()[::-1]:\n",
|
447 |
-
" display(images[idx])\n",
|
448 |
-
" print(f\"Score: {logits[idx]:.2f}\\n\")"
|
449 |
-
]
|
450 |
-
}
|
451 |
-
],
|
452 |
-
"metadata": {
|
453 |
-
"accelerator": "GPU",
|
454 |
"colab": {
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
-
|
465 |
-
|
466 |
-
"
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
478 |
},
|
479 |
-
|
480 |
-
|
481 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "118UKH5bWCGa"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"# DALL·E mini - Inference pipeline\n",
|
10 |
+
"\n",
|
11 |
+
"*Generate images from a text prompt*\n",
|
12 |
+
"\n",
|
13 |
+
"<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
|
14 |
+
"\n",
|
15 |
+
"This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
|
16 |
+
"\n",
|
17 |
+
"Just want to play? Use directly [DALL·E mini app](https://huggingface.co/spaces/dalle-mini/dalle-mini).\n",
|
18 |
+
"\n",
|
19 |
+
"For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "markdown",
|
24 |
+
"metadata": {
|
25 |
+
"id": "dS8LbaonYm3a"
|
26 |
+
},
|
27 |
+
"source": [
|
28 |
+
"## 🛠️ Installation and set-up"
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "code",
|
33 |
+
"execution_count": null,
|
34 |
+
"metadata": {
|
35 |
+
"id": "uzjAM2GBYpZX"
|
36 |
+
},
|
37 |
+
"outputs": [],
|
38 |
+
"source": [
|
39 |
+
"# Install required libraries\n",
|
40 |
+
"!pip install -q git+https://github.com/huggingface/transformers.git\n",
|
41 |
+
"!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
|
42 |
+
"!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
|
43 |
+
]
|
44 |
+
},
|
45 |
+
{
|
46 |
+
"cell_type": "markdown",
|
47 |
+
"metadata": {
|
48 |
+
"id": "ozHzTkyv8cqU"
|
49 |
+
},
|
50 |
+
"source": [
|
51 |
+
"We load required models:\n",
|
52 |
+
"* DALL·E mini for text to encoded images\n",
|
53 |
+
"* VQGAN for decoding images\n",
|
54 |
+
"* CLIP for scoring predictions"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "code",
|
59 |
+
"execution_count": null,
|
60 |
+
"metadata": {
|
61 |
+
"id": "K6CxW2o42f-w"
|
62 |
+
},
|
63 |
+
"outputs": [],
|
64 |
+
"source": [
|
65 |
+
"# Model references\n",
|
66 |
+
"\n",
|
67 |
+
"# dalle-mega\n",
|
68 |
+
"DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
|
69 |
+
"DALLE_COMMIT_ID = None\n",
|
70 |
+
"\n",
|
71 |
+
"# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line\n",
|
72 |
+
"# DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"\n",
|
73 |
+
"\n",
|
74 |
+
"# VQGAN model\n",
|
75 |
+
"VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
|
76 |
+
"VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\""
|
77 |
+
]
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"cell_type": "code",
|
81 |
+
"execution_count": null,
|
82 |
+
"metadata": {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
"colab": {
|
84 |
+
"base_uri": "https://localhost:8080/"
|
85 |
+
},
|
86 |
+
"id": "Yv-aR3t4Oe5v",
|
87 |
+
"outputId": "3097b2c7-5dac-475f-edde-898799dd7294"
|
88 |
+
},
|
89 |
+
"outputs": [],
|
90 |
+
"source": [
|
91 |
+
"import jax\n",
|
92 |
+
"import jax.numpy as jnp\n",
|
93 |
+
"\n",
|
94 |
+
"# check how many devices are available\n",
|
95 |
+
"jax.local_device_count()"
|
96 |
+
]
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"cell_type": "code",
|
100 |
+
"execution_count": null,
|
101 |
+
"metadata": {
|
102 |
+
"colab": {
|
103 |
+
"base_uri": "https://localhost:8080/"
|
104 |
+
},
|
105 |
+
"id": "92zYmvsQ38vL",
|
106 |
+
"outputId": "d897dfdb-dae7-4026-da36-8b23dce066e8"
|
107 |
+
},
|
108 |
+
"outputs": [],
|
109 |
+
"source": [
|
110 |
+
"# Load models & tokenizer\n",
|
111 |
+
"from dalle_mini import DalleBart, DalleBartProcessor\n",
|
112 |
+
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
|
113 |
+
"from transformers import CLIPProcessor, FlaxCLIPModel\n",
|
114 |
+
"\n",
|
115 |
+
"# Load dalle-mini\n",
|
116 |
+
"model, params = DalleBart.from_pretrained(\n",
|
117 |
+
" DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
|
118 |
+
")\n",
|
119 |
+
"\n",
|
120 |
+
"# Load VQGAN\n",
|
121 |
+
"vqgan, vqgan_params = VQModel.from_pretrained(\n",
|
122 |
+
" VQGAN_REPO, revision=VQGAN_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
|
123 |
+
")"
|
124 |
+
]
|
125 |
+
},
|
126 |
+
{
|
127 |
+
"cell_type": "markdown",
|
128 |
+
"metadata": {
|
129 |
+
"id": "o_vH2X1tDtzA"
|
130 |
+
},
|
131 |
+
"source": [
|
132 |
+
"Model parameters are replicated on each device for faster inference."
|
133 |
+
]
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"cell_type": "code",
|
137 |
+
"execution_count": null,
|
138 |
+
"metadata": {
|
139 |
+
"id": "wtvLoM48EeVw"
|
140 |
+
},
|
141 |
+
"outputs": [],
|
142 |
+
"source": [
|
143 |
+
"from flax.jax_utils import replicate\n",
|
144 |
+
"\n",
|
145 |
+
"params = replicate(params)\n",
|
146 |
+
"vqgan_params = replicate(vqgan_params)"
|
147 |
+
]
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"cell_type": "markdown",
|
151 |
+
"metadata": {
|
152 |
+
"id": "0A9AHQIgZ_qw"
|
153 |
+
},
|
154 |
+
"source": [
|
155 |
+
"Model functions are compiled and parallelized to take advantage of multiple devices."
|
156 |
+
]
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"cell_type": "code",
|
160 |
+
"execution_count": null,
|
161 |
+
"metadata": {
|
162 |
+
"id": "sOtoOmYsSYPz"
|
163 |
+
},
|
164 |
+
"outputs": [],
|
165 |
+
"source": [
|
166 |
+
"from functools import partial\n",
|
167 |
+
"\n",
|
168 |
+
"# model inference\n",
|
169 |
+
"@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
|
170 |
+
"def p_generate(\n",
|
171 |
+
" tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
|
172 |
+
"):\n",
|
173 |
+
" return model.generate(\n",
|
174 |
+
" **tokenized_prompt,\n",
|
175 |
+
" prng_key=key,\n",
|
176 |
+
" params=params,\n",
|
177 |
+
" top_k=top_k,\n",
|
178 |
+
" top_p=top_p,\n",
|
179 |
+
" temperature=temperature,\n",
|
180 |
+
" condition_scale=condition_scale,\n",
|
181 |
+
" )\n",
|
182 |
+
"\n",
|
183 |
+
"\n",
|
184 |
+
"# decode image\n",
|
185 |
+
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
186 |
+
"def p_decode(indices, params):\n",
|
187 |
+
" return vqgan.decode_code(indices, params=params)"
|
188 |
+
]
|
189 |
+
},
|
190 |
+
{
|
191 |
+
"cell_type": "markdown",
|
192 |
+
"metadata": {
|
193 |
+
"id": "HmVN6IBwapBA"
|
194 |
+
},
|
195 |
+
"source": [
|
196 |
+
"Keys are passed to the model on each device to generate unique inference per device."
|
197 |
+
]
|
198 |
+
},
|
199 |
+
{
|
200 |
+
"cell_type": "code",
|
201 |
+
"execution_count": null,
|
202 |
+
"metadata": {
|
203 |
+
"id": "4CTXmlUkThhX"
|
204 |
+
},
|
205 |
+
"outputs": [],
|
206 |
+
"source": [
|
207 |
+
"import random\n",
|
208 |
+
"\n",
|
209 |
+
"# create a random key\n",
|
210 |
+
"seed = random.randint(0, 2**32 - 1)\n",
|
211 |
+
"key = jax.random.PRNGKey(seed)"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"cell_type": "markdown",
|
216 |
+
"metadata": {
|
217 |
+
"id": "BrnVyCo81pij"
|
218 |
+
},
|
219 |
+
"source": [
|
220 |
+
"## 🖍 Text Prompt"
|
221 |
+
]
|
222 |
+
},
|
223 |
+
{
|
224 |
+
"cell_type": "markdown",
|
225 |
+
"metadata": {
|
226 |
+
"id": "rsmj0Aj5OQox"
|
227 |
+
},
|
228 |
+
"source": [
|
229 |
+
"Our model requires processing prompts."
|
230 |
+
]
|
231 |
+
},
|
232 |
+
{
|
233 |
+
"cell_type": "code",
|
234 |
+
"execution_count": null,
|
235 |
+
"metadata": {
|
236 |
+
"colab": {
|
237 |
+
"base_uri": "https://localhost:8080/"
|
238 |
+
},
|
239 |
+
"id": "YjjhUychOVxm",
|
240 |
+
"outputId": "a286f17a-a388-4754-ec4d-0464c0666c90"
|
241 |
+
},
|
242 |
+
"outputs": [],
|
243 |
+
"source": [
|
244 |
+
"from dalle_mini import DalleBartProcessor\n",
|
245 |
+
"\n",
|
246 |
+
"processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
|
247 |
+
]
|
248 |
+
},
|
249 |
+
{
|
250 |
+
"cell_type": "markdown",
|
251 |
+
"metadata": {
|
252 |
+
"id": "BQ7fymSPyvF_"
|
253 |
+
},
|
254 |
+
"source": [
|
255 |
+
"Let's define a text prompt."
|
256 |
+
]
|
257 |
},
|
258 |
+
{
|
259 |
+
"cell_type": "code",
|
260 |
+
"execution_count": null,
|
261 |
+
"metadata": {
|
262 |
+
"id": "x_0vI9ge1oKr"
|
263 |
+
},
|
264 |
+
"outputs": [],
|
265 |
+
"source": [
|
266 |
+
"prompt = \"sunset over a lake in the mountains\""
|
267 |
+
]
|
268 |
+
},
|
269 |
+
{
|
270 |
+
"cell_type": "code",
|
271 |
+
"execution_count": null,
|
272 |
+
"metadata": {
|
273 |
+
"id": "VKjEZGjtO49k"
|
274 |
+
},
|
275 |
+
"outputs": [],
|
276 |
+
"source": [
|
277 |
+
"tokenized_prompt = processor([prompt])"
|
278 |
+
]
|
279 |
+
},
|
280 |
+
{
|
281 |
+
"cell_type": "markdown",
|
282 |
+
"metadata": {
|
283 |
+
"id": "-CEJBnuJOe5z"
|
284 |
+
},
|
285 |
+
"source": [
|
286 |
+
"Finally we replicate it onto each device."
|
287 |
+
]
|
288 |
+
},
|
289 |
+
{
|
290 |
+
"cell_type": "code",
|
291 |
+
"execution_count": null,
|
292 |
+
"metadata": {
|
293 |
+
"id": "lQePgju5Oe5z"
|
294 |
+
},
|
295 |
+
"outputs": [],
|
296 |
+
"source": [
|
297 |
+
"tokenized_prompt = replicate(tokenized_prompt)"
|
298 |
+
]
|
299 |
+
},
|
300 |
+
{
|
301 |
+
"cell_type": "markdown",
|
302 |
+
"metadata": {
|
303 |
+
"id": "phQ9bhjRkgAZ"
|
304 |
+
},
|
305 |
+
"source": [
|
306 |
+
"## 🎨 Generate images\n",
|
307 |
+
"\n",
|
308 |
+
"We generate images using dalle-mini model and decode them with the VQGAN."
|
309 |
+
]
|
310 |
+
},
|
311 |
+
{
|
312 |
+
"cell_type": "code",
|
313 |
+
"execution_count": null,
|
314 |
+
"metadata": {
|
315 |
+
"id": "d0wVkXpKqnHA"
|
316 |
+
},
|
317 |
+
"outputs": [],
|
318 |
+
"source": [
|
319 |
+
"# number of predictions\n",
|
320 |
+
"n_predictions = 8\n",
|
321 |
+
"\n",
|
322 |
+
"# We can customize generation parameters\n",
|
323 |
+
"gen_top_k = None\n",
|
324 |
+
"gen_top_p = None\n",
|
325 |
+
"temperature = None\n",
|
326 |
+
"cond_scale = 3.0"
|
327 |
+
]
|
328 |
+
},
|
329 |
+
{
|
330 |
+
"cell_type": "code",
|
331 |
+
"execution_count": null,
|
332 |
+
"metadata": {
|
333 |
+
"colab": {
|
334 |
+
"base_uri": "https://localhost:8080/",
|
335 |
+
"height": 1000,
|
336 |
+
"referenced_widgets": [
|
337 |
+
"cef76449b8d74217ae36c56be3990eec",
|
338 |
+
"7be07ba7cfe642a596509c756dcefddc",
|
339 |
+
"2a02378499fc414299f17a2d5dcac867",
|
340 |
+
"427d47d9423441d286ae80a637ae35a0",
|
341 |
+
"cb157fd4e37041d1beae29eaa729c8ff",
|
342 |
+
"73413668398b45dfa8484a2c2be778ec",
|
343 |
+
"e7d108a4b168442fb2048f58ddeb0a18",
|
344 |
+
"5e81a141422f432395055f5cafb07016",
|
345 |
+
"5f476a929da84fa985b2e980459da7b9",
|
346 |
+
"f3b643a0ca2444fd959fff9b45d79d27",
|
347 |
+
"82b87345233549d699ce3fd8080fa988"
|
348 |
+
]
|
349 |
+
},
|
350 |
+
"id": "SDjEx9JxR3v8",
|
351 |
+
"outputId": "8f4287a7-aff9-41ef-a026-02265de0c205"
|
352 |
+
},
|
353 |
+
"outputs": [],
|
354 |
+
"source": [
|
355 |
+
"from flax.training.common_utils import shard_prng_key\n",
|
356 |
+
"import numpy as np\n",
|
357 |
+
"from PIL import Image\n",
|
358 |
+
"from tqdm.notebook import trange\n",
|
359 |
+
"\n",
|
360 |
+
"print(f\"Prompt: {prompt}\\n\")\n",
|
361 |
+
"# generate images\n",
|
362 |
+
"images = []\n",
|
363 |
+
"for i in trange(max(n_predictions // jax.device_count(), 1)):\n",
|
364 |
+
" # get a new key\n",
|
365 |
+
" key, subkey = jax.random.split(key)\n",
|
366 |
+
" # generate images\n",
|
367 |
+
" encoded_images = p_generate(\n",
|
368 |
+
" tokenized_prompt,\n",
|
369 |
+
" shard_prng_key(subkey),\n",
|
370 |
+
" params,\n",
|
371 |
+
" gen_top_k,\n",
|
372 |
+
" gen_top_p,\n",
|
373 |
+
" temperature,\n",
|
374 |
+
" cond_scale,\n",
|
375 |
+
" )\n",
|
376 |
+
" # remove BOS\n",
|
377 |
+
" encoded_images = encoded_images.sequences[..., 1:]\n",
|
378 |
+
" # decode images\n",
|
379 |
+
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
|
380 |
+
" decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
|
381 |
+
" for decoded_img in decoded_images:\n",
|
382 |
+
" img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))\n",
|
383 |
+
" images.append(img)\n",
|
384 |
+
" display(img)"
|
385 |
+
]
|
386 |
+
},
|
387 |
+
{
|
388 |
+
"cell_type": "markdown",
|
389 |
+
"metadata": {
|
390 |
+
"id": "tw02wG9zGmyB"
|
391 |
+
},
|
392 |
+
"source": [
|
393 |
+
"## 🏅 Optional: Rank images by CLIP score\n",
|
394 |
+
"\n",
|
395 |
+
"We can rank images according to CLIP.\n",
|
396 |
+
"\n",
|
397 |
+
"**Note: your session may crash if you don't have a subscription to Colab Pro.**"
|
398 |
+
]
|
399 |
+
},
|
400 |
+
{
|
401 |
+
"cell_type": "code",
|
402 |
+
"execution_count": null,
|
403 |
+
"metadata": {
|
404 |
+
"id": "RGjlIW_f6GA0"
|
405 |
+
},
|
406 |
+
"outputs": [],
|
407 |
+
"source": [
|
408 |
+
"# CLIP model\n",
|
409 |
+
"CLIP_REPO = \"openai/clip-vit-base-patch32\"\n",
|
410 |
+
"CLIP_COMMIT_ID = None\n",
|
411 |
+
"\n",
|
412 |
+
"# Load CLIP\n",
|
413 |
+
"clip, clip_params = FlaxCLIPModel.from_pretrained(\n",
|
414 |
+
" CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
|
415 |
+
")\n",
|
416 |
+
"clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
|
417 |
+
"clip_params = replicate(clip_params)\n",
|
418 |
+
"\n",
|
419 |
+
"# score images\n",
|
420 |
+
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
421 |
+
"def p_clip(inputs, params):\n",
|
422 |
+
" logits = clip(params=params, **inputs).logits_per_image\n",
|
423 |
+
" return logits"
|
424 |
+
]
|
425 |
+
},
|
426 |
+
{
|
427 |
+
"cell_type": "code",
|
428 |
+
"execution_count": null,
|
429 |
+
"metadata": {
|
430 |
+
"id": "FoLXpjCmGpju"
|
431 |
+
},
|
432 |
+
"outputs": [],
|
433 |
+
"source": [
|
434 |
+
"from flax.training.common_utils import shard\n",
|
435 |
+
"\n",
|
436 |
+
"# CLIP model\n",
|
437 |
+
"CLIP_REPO = \"openai/clip-vit-base-patch32\"\n",
|
438 |
+
"CLIP_COMMIT_ID = None\n",
|
439 |
+
"\n",
|
440 |
+
"# Load CLIP\n",
|
441 |
+
"clip, clip_params = FlaxCLIPModel.from_pretrained(\n",
|
442 |
+
" CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
|
443 |
+
")\n",
|
444 |
+
"clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
|
445 |
+
"clip_params = replicate(clip_params)\n",
|
446 |
+
"\n",
|
447 |
+
"# score images\n",
|
448 |
+
"@partial(jax.pmap, axis_name=\"batch\")\n",
|
449 |
+
"def p_clip(inputs, params):\n",
|
450 |
+
" logits = clip(params=params, **inputs).logits_per_image\n",
|
451 |
+
" return logits\n",
|
452 |
+
"\n",
|
453 |
+
"\n",
|
454 |
+
"# get clip scores\n",
|
455 |
+
"clip_inputs = clip_processor(\n",
|
456 |
+
" text=[prompt] * jax.device_count(),\n",
|
457 |
+
" images=images,\n",
|
458 |
+
" return_tensors=\"np\",\n",
|
459 |
+
" padding=\"max_length\",\n",
|
460 |
+
" max_length=77,\n",
|
461 |
+
" truncation=True,\n",
|
462 |
+
").data\n",
|
463 |
+
"logits = p_clip(shard(clip_inputs), clip.params)\n",
|
464 |
+
"logits = logits.squeeze().flatten()"
|
465 |
+
]
|
466 |
+
},
|
467 |
+
{
|
468 |
+
"cell_type": "markdown",
|
469 |
+
"metadata": {
|
470 |
+
"id": "4AAWRm70LgED"
|
471 |
+
},
|
472 |
+
"source": [
|
473 |
+
"Let's now display images ranked by CLIP score."
|
474 |
+
]
|
475 |
+
},
|
476 |
+
{
|
477 |
+
"cell_type": "code",
|
478 |
+
"execution_count": null,
|
479 |
+
"metadata": {
|
480 |
+
"id": "zsgxxubLLkIu"
|
481 |
+
},
|
482 |
+
"outputs": [],
|
483 |
+
"source": [
|
484 |
+
"print(f\"Prompt: {prompt}\\n\")\n",
|
485 |
+
"for idx in logits.argsort()[::-1]:\n",
|
486 |
+
" display(images[idx])\n",
|
487 |
+
" print(f\"Score: {logits[idx]:.2f}\\n\")"
|
488 |
+
]
|
489 |
+
}
|
490 |
+
],
|
491 |
+
"metadata": {
|
492 |
+
"accelerator": "GPU",
|
493 |
+
"colab": {
|
494 |
+
"collapsed_sections": [],
|
495 |
+
"machine_shape": "hm",
|
496 |
+
"name": "DALL·E mini - Inference pipeline.ipynb",
|
497 |
+
"provenance": []
|
498 |
+
},
|
499 |
+
"kernelspec": {
|
500 |
+
"display_name": "Python 3 (ipykernel)",
|
501 |
+
"language": "python",
|
502 |
+
"name": "python3"
|
503 |
+
},
|
504 |
+
"language_info": {
|
505 |
+
"codemirror_mode": {
|
506 |
+
"name": "ipython",
|
507 |
+
"version": 3
|
508 |
+
},
|
509 |
+
"file_extension": ".py",
|
510 |
+
"mimetype": "text/x-python",
|
511 |
+
"name": "python",
|
512 |
+
"nbconvert_exporter": "python",
|
513 |
+
"pygments_lexer": "ipython3",
|
514 |
+
"version": "3.9.7"
|
515 |
+
},
|
516 |
+
"widgets": {
|
517 |
+
"application/vnd.jupyter.widget-state+json": {
|
518 |
+
"2a02378499fc414299f17a2d5dcac867": {
|
519 |
+
"model_module": "@jupyter-widgets/controls",
|
520 |
+
"model_module_version": "1.5.0",
|
521 |
+
"model_name": "FloatProgressModel",
|
522 |
+
"state": {
|
523 |
+
"_dom_classes": [],
|
524 |
+
"_model_module": "@jupyter-widgets/controls",
|
525 |
+
"_model_module_version": "1.5.0",
|
526 |
+
"_model_name": "FloatProgressModel",
|
527 |
+
"_view_count": null,
|
528 |
+
"_view_module": "@jupyter-widgets/controls",
|
529 |
+
"_view_module_version": "1.5.0",
|
530 |
+
"_view_name": "ProgressView",
|
531 |
+
"bar_style": "",
|
532 |
+
"description": "",
|
533 |
+
"description_tooltip": null,
|
534 |
+
"layout": "IPY_MODEL_5e81a141422f432395055f5cafb07016",
|
535 |
+
"max": 8,
|
536 |
+
"min": 0,
|
537 |
+
"orientation": "horizontal",
|
538 |
+
"style": "IPY_MODEL_5f476a929da84fa985b2e980459da7b9",
|
539 |
+
"value": 5
|
540 |
+
}
|
541 |
+
},
|
542 |
+
"427d47d9423441d286ae80a637ae35a0": {
|
543 |
+
"model_module": "@jupyter-widgets/controls",
|
544 |
+
"model_module_version": "1.5.0",
|
545 |
+
"model_name": "HTMLModel",
|
546 |
+
"state": {
|
547 |
+
"_dom_classes": [],
|
548 |
+
"_model_module": "@jupyter-widgets/controls",
|
549 |
+
"_model_module_version": "1.5.0",
|
550 |
+
"_model_name": "HTMLModel",
|
551 |
+
"_view_count": null,
|
552 |
+
"_view_module": "@jupyter-widgets/controls",
|
553 |
+
"_view_module_version": "1.5.0",
|
554 |
+
"_view_name": "HTMLView",
|
555 |
+
"description": "",
|
556 |
+
"description_tooltip": null,
|
557 |
+
"layout": "IPY_MODEL_f3b643a0ca2444fd959fff9b45d79d27",
|
558 |
+
"placeholder": "",
|
559 |
+
"style": "IPY_MODEL_82b87345233549d699ce3fd8080fa988",
|
560 |
+
"value": " 5/8 [04:25<02:39, 53.09s/it]"
|
561 |
+
}
|
562 |
+
},
|
563 |
+
"5e81a141422f432395055f5cafb07016": {
|
564 |
+
"model_module": "@jupyter-widgets/base",
|
565 |
+
"model_module_version": "1.2.0",
|
566 |
+
"model_name": "LayoutModel",
|
567 |
+
"state": {
|
568 |
+
"_model_module": "@jupyter-widgets/base",
|
569 |
+
"_model_module_version": "1.2.0",
|
570 |
+
"_model_name": "LayoutModel",
|
571 |
+
"_view_count": null,
|
572 |
+
"_view_module": "@jupyter-widgets/base",
|
573 |
+
"_view_module_version": "1.2.0",
|
574 |
+
"_view_name": "LayoutView",
|
575 |
+
"align_content": null,
|
576 |
+
"align_items": null,
|
577 |
+
"align_self": null,
|
578 |
+
"border": null,
|
579 |
+
"bottom": null,
|
580 |
+
"display": null,
|
581 |
+
"flex": null,
|
582 |
+
"flex_flow": null,
|
583 |
+
"grid_area": null,
|
584 |
+
"grid_auto_columns": null,
|
585 |
+
"grid_auto_flow": null,
|
586 |
+
"grid_auto_rows": null,
|
587 |
+
"grid_column": null,
|
588 |
+
"grid_gap": null,
|
589 |
+
"grid_row": null,
|
590 |
+
"grid_template_areas": null,
|
591 |
+
"grid_template_columns": null,
|
592 |
+
"grid_template_rows": null,
|
593 |
+
"height": null,
|
594 |
+
"justify_content": null,
|
595 |
+
"justify_items": null,
|
596 |
+
"left": null,
|
597 |
+
"margin": null,
|
598 |
+
"max_height": null,
|
599 |
+
"max_width": null,
|
600 |
+
"min_height": null,
|
601 |
+
"min_width": null,
|
602 |
+
"object_fit": null,
|
603 |
+
"object_position": null,
|
604 |
+
"order": null,
|
605 |
+
"overflow": null,
|
606 |
+
"overflow_x": null,
|
607 |
+
"overflow_y": null,
|
608 |
+
"padding": null,
|
609 |
+
"right": null,
|
610 |
+
"top": null,
|
611 |
+
"visibility": null,
|
612 |
+
"width": null
|
613 |
+
}
|
614 |
+
},
|
615 |
+
"5f476a929da84fa985b2e980459da7b9": {
|
616 |
+
"model_module": "@jupyter-widgets/controls",
|
617 |
+
"model_module_version": "1.5.0",
|
618 |
+
"model_name": "ProgressStyleModel",
|
619 |
+
"state": {
|
620 |
+
"_model_module": "@jupyter-widgets/controls",
|
621 |
+
"_model_module_version": "1.5.0",
|
622 |
+
"_model_name": "ProgressStyleModel",
|
623 |
+
"_view_count": null,
|
624 |
+
"_view_module": "@jupyter-widgets/base",
|
625 |
+
"_view_module_version": "1.2.0",
|
626 |
+
"_view_name": "StyleView",
|
627 |
+
"bar_color": null,
|
628 |
+
"description_width": ""
|
629 |
+
}
|
630 |
+
},
|
631 |
+
"73413668398b45dfa8484a2c2be778ec": {
|
632 |
+
"model_module": "@jupyter-widgets/base",
|
633 |
+
"model_module_version": "1.2.0",
|
634 |
+
"model_name": "LayoutModel",
|
635 |
+
"state": {
|
636 |
+
"_model_module": "@jupyter-widgets/base",
|
637 |
+
"_model_module_version": "1.2.0",
|
638 |
+
"_model_name": "LayoutModel",
|
639 |
+
"_view_count": null,
|
640 |
+
"_view_module": "@jupyter-widgets/base",
|
641 |
+
"_view_module_version": "1.2.0",
|
642 |
+
"_view_name": "LayoutView",
|
643 |
+
"align_content": null,
|
644 |
+
"align_items": null,
|
645 |
+
"align_self": null,
|
646 |
+
"border": null,
|
647 |
+
"bottom": null,
|
648 |
+
"display": null,
|
649 |
+
"flex": null,
|
650 |
+
"flex_flow": null,
|
651 |
+
"grid_area": null,
|
652 |
+
"grid_auto_columns": null,
|
653 |
+
"grid_auto_flow": null,
|
654 |
+
"grid_auto_rows": null,
|
655 |
+
"grid_column": null,
|
656 |
+
"grid_gap": null,
|
657 |
+
"grid_row": null,
|
658 |
+
"grid_template_areas": null,
|
659 |
+
"grid_template_columns": null,
|
660 |
+
"grid_template_rows": null,
|
661 |
+
"height": null,
|
662 |
+
"justify_content": null,
|
663 |
+
"justify_items": null,
|
664 |
+
"left": null,
|
665 |
+
"margin": null,
|
666 |
+
"max_height": null,
|
667 |
+
"max_width": null,
|
668 |
+
"min_height": null,
|
669 |
+
"min_width": null,
|
670 |
+
"object_fit": null,
|
671 |
+
"object_position": null,
|
672 |
+
"order": null,
|
673 |
+
"overflow": null,
|
674 |
+
"overflow_x": null,
|
675 |
+
"overflow_y": null,
|
676 |
+
"padding": null,
|
677 |
+
"right": null,
|
678 |
+
"top": null,
|
679 |
+
"visibility": null,
|
680 |
+
"width": null
|
681 |
+
}
|
682 |
+
},
|
683 |
+
"7be07ba7cfe642a596509c756dcefddc": {
|
684 |
+
"model_module": "@jupyter-widgets/controls",
|
685 |
+
"model_module_version": "1.5.0",
|
686 |
+
"model_name": "HTMLModel",
|
687 |
+
"state": {
|
688 |
+
"_dom_classes": [],
|
689 |
+
"_model_module": "@jupyter-widgets/controls",
|
690 |
+
"_model_module_version": "1.5.0",
|
691 |
+
"_model_name": "HTMLModel",
|
692 |
+
"_view_count": null,
|
693 |
+
"_view_module": "@jupyter-widgets/controls",
|
694 |
+
"_view_module_version": "1.5.0",
|
695 |
+
"_view_name": "HTMLView",
|
696 |
+
"description": "",
|
697 |
+
"description_tooltip": null,
|
698 |
+
"layout": "IPY_MODEL_73413668398b45dfa8484a2c2be778ec",
|
699 |
+
"placeholder": "",
|
700 |
+
"style": "IPY_MODEL_e7d108a4b168442fb2048f58ddeb0a18",
|
701 |
+
"value": " 62%"
|
702 |
+
}
|
703 |
+
},
|
704 |
+
"82b87345233549d699ce3fd8080fa988": {
|
705 |
+
"model_module": "@jupyter-widgets/controls",
|
706 |
+
"model_module_version": "1.5.0",
|
707 |
+
"model_name": "DescriptionStyleModel",
|
708 |
+
"state": {
|
709 |
+
"_model_module": "@jupyter-widgets/controls",
|
710 |
+
"_model_module_version": "1.5.0",
|
711 |
+
"_model_name": "DescriptionStyleModel",
|
712 |
+
"_view_count": null,
|
713 |
+
"_view_module": "@jupyter-widgets/base",
|
714 |
+
"_view_module_version": "1.2.0",
|
715 |
+
"_view_name": "StyleView",
|
716 |
+
"description_width": ""
|
717 |
+
}
|
718 |
+
},
|
719 |
+
"cb157fd4e37041d1beae29eaa729c8ff": {
|
720 |
+
"model_module": "@jupyter-widgets/base",
|
721 |
+
"model_module_version": "1.2.0",
|
722 |
+
"model_name": "LayoutModel",
|
723 |
+
"state": {
|
724 |
+
"_model_module": "@jupyter-widgets/base",
|
725 |
+
"_model_module_version": "1.2.0",
|
726 |
+
"_model_name": "LayoutModel",
|
727 |
+
"_view_count": null,
|
728 |
+
"_view_module": "@jupyter-widgets/base",
|
729 |
+
"_view_module_version": "1.2.0",
|
730 |
+
"_view_name": "LayoutView",
|
731 |
+
"align_content": null,
|
732 |
+
"align_items": null,
|
733 |
+
"align_self": null,
|
734 |
+
"border": null,
|
735 |
+
"bottom": null,
|
736 |
+
"display": null,
|
737 |
+
"flex": null,
|
738 |
+
"flex_flow": null,
|
739 |
+
"grid_area": null,
|
740 |
+
"grid_auto_columns": null,
|
741 |
+
"grid_auto_flow": null,
|
742 |
+
"grid_auto_rows": null,
|
743 |
+
"grid_column": null,
|
744 |
+
"grid_gap": null,
|
745 |
+
"grid_row": null,
|
746 |
+
"grid_template_areas": null,
|
747 |
+
"grid_template_columns": null,
|
748 |
+
"grid_template_rows": null,
|
749 |
+
"height": null,
|
750 |
+
"justify_content": null,
|
751 |
+
"justify_items": null,
|
752 |
+
"left": null,
|
753 |
+
"margin": null,
|
754 |
+
"max_height": null,
|
755 |
+
"max_width": null,
|
756 |
+
"min_height": null,
|
757 |
+
"min_width": null,
|
758 |
+
"object_fit": null,
|
759 |
+
"object_position": null,
|
760 |
+
"order": null,
|
761 |
+
"overflow": null,
|
762 |
+
"overflow_x": null,
|
763 |
+
"overflow_y": null,
|
764 |
+
"padding": null,
|
765 |
+
"right": null,
|
766 |
+
"top": null,
|
767 |
+
"visibility": null,
|
768 |
+
"width": null
|
769 |
+
}
|
770 |
+
},
|
771 |
+
"cef76449b8d74217ae36c56be3990eec": {
|
772 |
+
"model_module": "@jupyter-widgets/controls",
|
773 |
+
"model_module_version": "1.5.0",
|
774 |
+
"model_name": "HBoxModel",
|
775 |
+
"state": {
|
776 |
+
"_dom_classes": [],
|
777 |
+
"_model_module": "@jupyter-widgets/controls",
|
778 |
+
"_model_module_version": "1.5.0",
|
779 |
+
"_model_name": "HBoxModel",
|
780 |
+
"_view_count": null,
|
781 |
+
"_view_module": "@jupyter-widgets/controls",
|
782 |
+
"_view_module_version": "1.5.0",
|
783 |
+
"_view_name": "HBoxView",
|
784 |
+
"box_style": "",
|
785 |
+
"children": [
|
786 |
+
"IPY_MODEL_7be07ba7cfe642a596509c756dcefddc",
|
787 |
+
"IPY_MODEL_2a02378499fc414299f17a2d5dcac867",
|
788 |
+
"IPY_MODEL_427d47d9423441d286ae80a637ae35a0"
|
789 |
+
],
|
790 |
+
"layout": "IPY_MODEL_cb157fd4e37041d1beae29eaa729c8ff"
|
791 |
+
}
|
792 |
+
},
|
793 |
+
"e7d108a4b168442fb2048f58ddeb0a18": {
|
794 |
+
"model_module": "@jupyter-widgets/controls",
|
795 |
+
"model_module_version": "1.5.0",
|
796 |
+
"model_name": "DescriptionStyleModel",
|
797 |
+
"state": {
|
798 |
+
"_model_module": "@jupyter-widgets/controls",
|
799 |
+
"_model_module_version": "1.5.0",
|
800 |
+
"_model_name": "DescriptionStyleModel",
|
801 |
+
"_view_count": null,
|
802 |
+
"_view_module": "@jupyter-widgets/base",
|
803 |
+
"_view_module_version": "1.2.0",
|
804 |
+
"_view_name": "StyleView",
|
805 |
+
"description_width": ""
|
806 |
+
}
|
807 |
+
},
|
808 |
+
"f3b643a0ca2444fd959fff9b45d79d27": {
|
809 |
+
"model_module": "@jupyter-widgets/base",
|
810 |
+
"model_module_version": "1.2.0",
|
811 |
+
"model_name": "LayoutModel",
|
812 |
+
"state": {
|
813 |
+
"_model_module": "@jupyter-widgets/base",
|
814 |
+
"_model_module_version": "1.2.0",
|
815 |
+
"_model_name": "LayoutModel",
|
816 |
+
"_view_count": null,
|
817 |
+
"_view_module": "@jupyter-widgets/base",
|
818 |
+
"_view_module_version": "1.2.0",
|
819 |
+
"_view_name": "LayoutView",
|
820 |
+
"align_content": null,
|
821 |
+
"align_items": null,
|
822 |
+
"align_self": null,
|
823 |
+
"border": null,
|
824 |
+
"bottom": null,
|
825 |
+
"display": null,
|
826 |
+
"flex": null,
|
827 |
+
"flex_flow": null,
|
828 |
+
"grid_area": null,
|
829 |
+
"grid_auto_columns": null,
|
830 |
+
"grid_auto_flow": null,
|
831 |
+
"grid_auto_rows": null,
|
832 |
+
"grid_column": null,
|
833 |
+
"grid_gap": null,
|
834 |
+
"grid_row": null,
|
835 |
+
"grid_template_areas": null,
|
836 |
+
"grid_template_columns": null,
|
837 |
+
"grid_template_rows": null,
|
838 |
+
"height": null,
|
839 |
+
"justify_content": null,
|
840 |
+
"justify_items": null,
|
841 |
+
"left": null,
|
842 |
+
"margin": null,
|
843 |
+
"max_height": null,
|
844 |
+
"max_width": null,
|
845 |
+
"min_height": null,
|
846 |
+
"min_width": null,
|
847 |
+
"object_fit": null,
|
848 |
+
"object_position": null,
|
849 |
+
"order": null,
|
850 |
+
"overflow": null,
|
851 |
+
"overflow_x": null,
|
852 |
+
"overflow_y": null,
|
853 |
+
"padding": null,
|
854 |
+
"right": null,
|
855 |
+
"top": null,
|
856 |
+
"visibility": null,
|
857 |
+
"width": null
|
858 |
+
}
|
859 |
+
}
|
860 |
+
}
|
861 |
+
}
|
862 |
+
},
|
863 |
+
"nbformat": 4,
|
864 |
+
"nbformat_minor": 0
|
865 |
+
}
|