boris commited on
Commit
024f8f5
·
1 Parent(s): c3e93df

feat(colab): handle dalle-mega

Browse files
Files changed (1) hide show
  1. tools/inference/inference_pipeline.ipynb +481 -456
tools/inference/inference_pipeline.ipynb CHANGED
@@ -1,458 +1,483 @@
1
  {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {
6
- "colab_type": "text",
7
- "id": "view-in-github"
8
- },
9
- "source": [
10
- "<a href=\"https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  },
13
- {
14
- "cell_type": "markdown",
15
- "metadata": {
16
- "id": "118UKH5bWCGa"
17
- },
18
- "source": [
19
- "# DALL·E mini - Inference pipeline\n",
20
- "\n",
21
- "*Generate images from a text prompt*\n",
22
- "\n",
23
- "<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
24
- "\n",
25
- "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
26
- "\n",
27
- "Just want to play? Use [the demo](https://huggingface.co/spaces/flax-community/dalle-mini).\n",
28
- "\n",
29
- "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
30
- ]
31
- },
32
- {
33
- "cell_type": "markdown",
34
- "metadata": {
35
- "id": "dS8LbaonYm3a"
36
- },
37
- "source": [
38
- "## 🛠️ Installation and set-up"
39
- ]
40
- },
41
- {
42
- "cell_type": "code",
43
- "execution_count": null,
44
- "metadata": {
45
- "id": "uzjAM2GBYpZX"
46
- },
47
- "outputs": [],
48
- "source": [
49
- "# Install required libraries\n",
50
- "!pip install -q git+https://github.com/huggingface/transformers.git\n",
51
- "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
52
- "!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
53
- ]
54
- },
55
- {
56
- "cell_type": "markdown",
57
- "metadata": {
58
- "id": "ozHzTkyv8cqU"
59
- },
60
- "source": [
61
- "We load required models:\n",
62
- "* dalle·mini for text to encoded images\n",
63
- "* VQGAN for decoding images\n",
64
- "* CLIP for scoring predictions"
65
- ]
66
- },
67
- {
68
- "cell_type": "code",
69
- "execution_count": null,
70
- "metadata": {
71
- "id": "K6CxW2o42f-w"
72
- },
73
- "outputs": [],
74
- "source": [
75
- "# Model references\n",
76
- "\n",
77
- "# dalle-mini\n",
78
- "DALLE_MODEL = \"dalle-mini/dalle-mini/wzoooa1c:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
79
- "DALLE_COMMIT_ID = None\n",
80
- "\n",
81
- "# VQGAN model\n",
82
- "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
83
- "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
84
- "\n",
85
- "# CLIP model\n",
86
- "CLIP_REPO = \"openai/clip-vit-large-patch14\"\n",
87
- "CLIP_COMMIT_ID = None"
88
- ]
89
- },
90
- {
91
- "cell_type": "code",
92
- "execution_count": null,
93
- "metadata": {
94
- "id": "Yv-aR3t4Oe5v"
95
- },
96
- "outputs": [],
97
- "source": [
98
- "import jax\n",
99
- "import jax.numpy as jnp\n",
100
- "\n",
101
- "# check how many devices are available\n",
102
- "jax.local_device_count()"
103
- ]
104
- },
105
- {
106
- "cell_type": "code",
107
- "execution_count": null,
108
- "metadata": {
109
- "id": "92zYmvsQ38vL"
110
- },
111
- "outputs": [],
112
- "source": [
113
- "# Load models & tokenizer\n",
114
- "from dalle_mini import DalleBart, DalleBartProcessor\n",
115
- "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
116
- "from transformers import CLIPProcessor, FlaxCLIPModel\n",
117
- "\n",
118
- "# Load dalle-mini\n",
119
- "model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n",
120
- "\n",
121
- "# Load VQGAN\n",
122
- "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
123
- "\n",
124
- "# Load CLIP\n",
125
- "clip = FlaxCLIPModel.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
126
- "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)"
127
- ]
128
- },
129
- {
130
- "cell_type": "markdown",
131
- "metadata": {
132
- "id": "o_vH2X1tDtzA"
133
- },
134
- "source": [
135
- "Model parameters are replicated on each device for faster inference."
136
- ]
137
- },
138
- {
139
- "cell_type": "code",
140
- "execution_count": null,
141
- "metadata": {
142
- "id": "wtvLoM48EeVw"
143
- },
144
- "outputs": [],
145
- "source": [
146
- "from flax.jax_utils import replicate\n",
147
- "\n",
148
- "model._params = replicate(model.params)\n",
149
- "vqgan._params = replicate(vqgan.params)\n",
150
- "clip._params = replicate(clip.params)"
151
- ]
152
- },
153
- {
154
- "cell_type": "markdown",
155
- "metadata": {
156
- "id": "0A9AHQIgZ_qw"
157
- },
158
- "source": [
159
- "Model functions are compiled and parallelized to take advantage of multiple devices."
160
- ]
161
- },
162
- {
163
- "cell_type": "code",
164
- "execution_count": null,
165
- "metadata": {
166
- "id": "sOtoOmYsSYPz"
167
- },
168
- "outputs": [],
169
- "source": [
170
- "from functools import partial\n",
171
- "\n",
172
- "# model inference\n",
173
- "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
174
- "def p_generate(\n",
175
- " tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
176
- "):\n",
177
- " return model.generate(\n",
178
- " **tokenized_prompt,\n",
179
- " prng_key=key,\n",
180
- " params=params,\n",
181
- " top_k=top_k,\n",
182
- " top_p=top_p,\n",
183
- " temperature=temperature,\n",
184
- " condition_scale=condition_scale,\n",
185
- " )\n",
186
- "\n",
187
- "\n",
188
- "# decode image\n",
189
- "@partial(jax.pmap, axis_name=\"batch\")\n",
190
- "def p_decode(indices, params):\n",
191
- " return vqgan.decode_code(indices, params=params)\n",
192
- "\n",
193
- "\n",
194
- "# score images\n",
195
- "@partial(jax.pmap, axis_name=\"batch\")\n",
196
- "def p_clip(inputs, params):\n",
197
- " logits = clip(params=params, **inputs).logits_per_image\n",
198
- " return logits"
199
- ]
200
- },
201
- {
202
- "cell_type": "markdown",
203
- "metadata": {
204
- "id": "HmVN6IBwapBA"
205
- },
206
- "source": [
207
- "Keys are passed to the model on each device to generate unique inference per device."
208
- ]
209
- },
210
- {
211
- "cell_type": "code",
212
- "execution_count": null,
213
- "metadata": {
214
- "id": "4CTXmlUkThhX"
215
- },
216
- "outputs": [],
217
- "source": [
218
- "import random\n",
219
- "\n",
220
- "# create a random key\n",
221
- "seed = random.randint(0, 2**32 - 1)\n",
222
- "key = jax.random.PRNGKey(seed)"
223
- ]
224
- },
225
- {
226
- "cell_type": "markdown",
227
- "metadata": {
228
- "id": "BrnVyCo81pij"
229
- },
230
- "source": [
231
- "## 🖍 Text Prompt"
232
- ]
233
- },
234
- {
235
- "cell_type": "markdown",
236
- "metadata": {
237
- "id": "rsmj0Aj5OQox"
238
- },
239
- "source": [
240
- "Our model requires processing prompts."
241
- ]
242
- },
243
- {
244
- "cell_type": "code",
245
- "execution_count": null,
246
- "metadata": {
247
- "id": "YjjhUychOVxm"
248
- },
249
- "outputs": [],
250
- "source": [
251
- "from dalle_mini import DalleBartProcessor\n",
252
- "\n",
253
- "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
254
- ]
255
- },
256
- {
257
- "cell_type": "markdown",
258
- "metadata": {
259
- "id": "BQ7fymSPyvF_"
260
- },
261
- "source": [
262
- "Let's define a text prompt."
263
- ]
264
- },
265
- {
266
- "cell_type": "code",
267
- "execution_count": null,
268
- "metadata": {
269
- "id": "x_0vI9ge1oKr"
270
- },
271
- "outputs": [],
272
- "source": [
273
- "prompt = \"sunset over a lake in the mountains\""
274
- ]
275
- },
276
- {
277
- "cell_type": "code",
278
- "execution_count": null,
279
- "metadata": {
280
- "id": "VKjEZGjtO49k"
281
- },
282
- "outputs": [],
283
- "source": [
284
- "tokenized_prompt = processor([prompt])"
285
- ]
286
- },
287
- {
288
- "cell_type": "markdown",
289
- "metadata": {
290
- "id": "-CEJBnuJOe5z"
291
- },
292
- "source": [
293
- "Finally we replicate it onto each device."
294
- ]
295
- },
296
- {
297
- "cell_type": "code",
298
- "execution_count": null,
299
- "metadata": {
300
- "id": "lQePgju5Oe5z"
301
- },
302
- "outputs": [],
303
- "source": [
304
- "tokenized_prompt = replicate(tokenized_prompt)"
305
- ]
306
- },
307
- {
308
- "cell_type": "markdown",
309
- "metadata": {
310
- "id": "phQ9bhjRkgAZ"
311
- },
312
- "source": [
313
- "## 🎨 Generate images\n",
314
- "\n",
315
- "We generate images using dalle-mini model and decode them with the VQGAN."
316
- ]
317
- },
318
- {
319
- "cell_type": "code",
320
- "execution_count": null,
321
- "metadata": {
322
- "id": "d0wVkXpKqnHA"
323
- },
324
- "outputs": [],
325
- "source": [
326
- "# number of predictions\n",
327
- "n_predictions = 16\n",
328
- "\n",
329
- "# We can customize top_k/top_p used for generating samples\n",
330
- "gen_top_k = None\n",
331
- "gen_top_p = None\n",
332
- "temperature = 0.85\n",
333
- "cond_scale = 3.0"
334
- ]
335
- },
336
- {
337
- "cell_type": "code",
338
- "execution_count": null,
339
- "metadata": {
340
- "id": "SDjEx9JxR3v8"
341
- },
342
- "outputs": [],
343
- "source": [
344
- "from flax.training.common_utils import shard_prng_key\n",
345
- "import numpy as np\n",
346
- "from PIL import Image\n",
347
- "from tqdm.notebook import trange\n",
348
- "\n",
349
- "# generate images\n",
350
- "images = []\n",
351
- "for i in trange(max(n_predictions // jax.device_count(), 1)):\n",
352
- " # get a new key\n",
353
- " key, subkey = jax.random.split(key)\n",
354
- " # generate images\n",
355
- " encoded_images = p_generate(\n",
356
- " tokenized_prompt,\n",
357
- " shard_prng_key(subkey),\n",
358
- " model.params,\n",
359
- " gen_top_k,\n",
360
- " gen_top_p,\n",
361
- " temperature,\n",
362
- " cond_scale,\n",
363
- " )\n",
364
- " # remove BOS\n",
365
- " encoded_images = encoded_images.sequences[..., 1:]\n",
366
- " # decode images\n",
367
- " decoded_images = p_decode(encoded_images, vqgan.params)\n",
368
- " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
369
- " for img in decoded_images:\n",
370
- " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
371
- ]
372
- },
373
- {
374
- "cell_type": "markdown",
375
- "metadata": {
376
- "id": "tw02wG9zGmyB"
377
- },
378
- "source": [
379
- "Let's calculate their score with CLIP."
380
- ]
381
- },
382
- {
383
- "cell_type": "code",
384
- "execution_count": null,
385
- "metadata": {
386
- "id": "FoLXpjCmGpju"
387
- },
388
- "outputs": [],
389
- "source": [
390
- "from flax.training.common_utils import shard\n",
391
- "\n",
392
- "# get clip scores\n",
393
- "clip_inputs = clip_processor(\n",
394
- " text=[prompt] * jax.device_count(),\n",
395
- " images=images,\n",
396
- " return_tensors=\"np\",\n",
397
- " padding=\"max_length\",\n",
398
- " max_length=77,\n",
399
- " truncation=True,\n",
400
- ").data\n",
401
- "logits = p_clip(shard(clip_inputs), clip.params)\n",
402
- "logits = logits.squeeze().flatten()"
403
- ]
404
- },
405
- {
406
- "cell_type": "markdown",
407
- "metadata": {
408
- "id": "4AAWRm70LgED"
409
- },
410
- "source": [
411
- "Let's display images ranked by CLIP score."
412
- ]
413
- },
414
- {
415
- "cell_type": "code",
416
- "execution_count": null,
417
- "metadata": {
418
- "id": "zsgxxubLLkIu"
419
- },
420
- "outputs": [],
421
- "source": [
422
- "print(f\"Prompt: {prompt}\\n\")\n",
423
- "for idx in logits.argsort()[::-1]:\n",
424
- " display(images[idx])\n",
425
- " print(f\"Score: {logits[idx]:.2f}\\n\")"
426
- ]
427
- }
428
- ],
429
- "metadata": {
430
- "accelerator": "GPU",
431
- "colab": {
432
- "collapsed_sections": [],
433
- "include_colab_link": true,
434
- "machine_shape": "hm",
435
- "name": "DALL·E mini - Inference pipeline.ipynb",
436
- "provenance": []
437
- },
438
- "kernelspec": {
439
- "display_name": "Python 3 (ipykernel)",
440
- "language": "python",
441
- "name": "python3"
442
- },
443
- "language_info": {
444
- "codemirror_mode": {
445
- "name": "ipython",
446
- "version": 3
447
- },
448
- "file_extension": ".py",
449
- "mimetype": "text/x-python",
450
- "name": "python",
451
- "nbconvert_exporter": "python",
452
- "pygments_lexer": "ipython3",
453
- "version": "3.9.7"
454
- }
455
- },
456
- "nbformat": 4,
457
- "nbformat_minor": 0
458
- }
 
1
  {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "view-in-github",
7
+ "colab_type": "text"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "markdown",
15
+ "metadata": {
16
+ "id": "118UKH5bWCGa"
17
+ },
18
+ "source": [
19
+ "# DALL·E mini - Inference pipeline\n",
20
+ "\n",
21
+ "*Generate images from a text prompt*\n",
22
+ "\n",
23
+ "<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
24
+ "\n",
25
+ "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
26
+ "\n",
27
+ "Just want to play? Use [the demo](https://huggingface.co/spaces/flax-community/dalle-mini).\n",
28
+ "\n",
29
+ "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "metadata": {
35
+ "id": "dS8LbaonYm3a"
36
+ },
37
+ "source": [
38
+ "## 🛠️ Installation and set-up"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {
45
+ "id": "uzjAM2GBYpZX"
46
+ },
47
+ "outputs": [],
48
+ "source": [
49
+ "# Install required libraries\n",
50
+ "!pip install -q git+https://github.com/huggingface/transformers.git\n",
51
+ "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
52
+ "!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "markdown",
57
+ "metadata": {
58
+ "id": "ozHzTkyv8cqU"
59
+ },
60
+ "source": [
61
+ "We load required models:\n",
62
+ "* dalle·mini for text to encoded images\n",
63
+ "* VQGAN for decoding images\n",
64
+ "* CLIP for scoring predictions"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": null,
70
+ "metadata": {
71
+ "id": "K6CxW2o42f-w"
72
+ },
73
+ "outputs": [],
74
+ "source": [
75
+ "# Model references\n",
76
+ "\n",
77
+ "# dalle-mini\n",
78
+ "DALLE_MODEL = \"dalle-mini/dalle-mini/wzoooa1c:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
79
+ "DALLE_COMMIT_ID = None\n",
80
+ "\n",
81
+ "# dalle-mega - comment this section if your hardware runs out of memory to use dalle-mini instead\n",
82
+ "DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1:latest\"\n",
83
+ "\n",
84
+ "# VQGAN model\n",
85
+ "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
86
+ "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
87
+ "\n",
88
+ "# CLIP model\n",
89
+ "CLIP_REPO = \"openai/clip-vit-large-patch14\"\n",
90
+ "CLIP_COMMIT_ID = None"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "metadata": {
97
+ "id": "Yv-aR3t4Oe5v"
98
+ },
99
+ "outputs": [],
100
+ "source": [
101
+ "import jax\n",
102
+ "import jax.numpy as jnp\n",
103
+ "\n",
104
+ "# check how many devices are available\n",
105
+ "jax.local_device_count()"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "metadata": {
112
+ "id": "92zYmvsQ38vL"
113
+ },
114
+ "outputs": [],
115
+ "source": [
116
+ "# Load models & tokenizer\n",
117
+ "from dalle_mini import DalleBart, DalleBartProcessor\n",
118
+ "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
119
+ "from transformers import CLIPProcessor, FlaxCLIPModel\n",
120
+ "\n",
121
+ "# Load dalle-mini\n",
122
+ "model, params = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, _do_init=False)\n",
123
+ "\n",
124
+ "# Load VQGAN\n",
125
+ "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
126
+ "\n",
127
+ "# Load CLIP\n",
128
+ "clip = FlaxCLIPModel.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
129
+ "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "markdown",
134
+ "source": [
135
+ "Reduce memory usage."
136
+ ],
137
+ "metadata": {
138
+ "id": "Q29VMSoXqUdj"
139
+ }
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "source": [
144
+ "params = model.to_fp16(params)\n",
145
+ "#vqgan._params = vqgan.to_bf16(vqgan._params)\n",
146
+ "#clip._params = clip.to_bf16(clip._params)"
147
+ ],
148
+ "metadata": {
149
+ "id": "2r-zWmsPqTjh"
150
+ },
151
+ "execution_count": null,
152
+ "outputs": []
153
+ },
154
+ {
155
+ "cell_type": "markdown",
156
+ "metadata": {
157
+ "id": "o_vH2X1tDtzA"
158
+ },
159
+ "source": [
160
+ "Model parameters are replicated on each device for faster inference."
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": null,
166
+ "metadata": {
167
+ "id": "wtvLoM48EeVw"
168
+ },
169
+ "outputs": [],
170
+ "source": [
171
+ "from flax.jax_utils import replicate\n",
172
+ "\n",
173
+ "params = replicate(params)\n",
174
+ "vqgan._params = replicate(vqgan.params)\n",
175
+ "clip._params = replicate(clip.params)"
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "markdown",
180
+ "metadata": {
181
+ "id": "0A9AHQIgZ_qw"
182
+ },
183
+ "source": [
184
+ "Model functions are compiled and parallelized to take advantage of multiple devices."
185
+ ]
186
+ },
187
+ {
188
+ "cell_type": "code",
189
+ "execution_count": null,
190
+ "metadata": {
191
+ "id": "sOtoOmYsSYPz"
192
+ },
193
+ "outputs": [],
194
+ "source": [
195
+ "from functools import partial\n",
196
+ "\n",
197
+ "# model inference\n",
198
+ "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
199
+ "def p_generate(\n",
200
+ " tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
201
+ "):\n",
202
+ " return model.generate(\n",
203
+ " **tokenized_prompt,\n",
204
+ " prng_key=key,\n",
205
+ " params=params,\n",
206
+ " top_k=top_k,\n",
207
+ " top_p=top_p,\n",
208
+ " temperature=temperature,\n",
209
+ " condition_scale=condition_scale,\n",
210
+ " )\n",
211
+ "\n",
212
+ "\n",
213
+ "# decode image\n",
214
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
215
+ "def p_decode(indices, params):\n",
216
+ " return vqgan.decode_code(indices, params=params)\n",
217
+ "\n",
218
+ "\n",
219
+ "# score images\n",
220
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
221
+ "def p_clip(inputs, params):\n",
222
+ " logits = clip(params=params, **inputs).logits_per_image\n",
223
+ " return logits"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "markdown",
228
+ "metadata": {
229
+ "id": "HmVN6IBwapBA"
230
+ },
231
+ "source": [
232
+ "Keys are passed to the model on each device to generate unique inference per device."
233
+ ]
234
+ },
235
+ {
236
+ "cell_type": "code",
237
+ "execution_count": null,
238
+ "metadata": {
239
+ "id": "4CTXmlUkThhX"
240
+ },
241
+ "outputs": [],
242
+ "source": [
243
+ "import random\n",
244
+ "\n",
245
+ "# create a random key\n",
246
+ "seed = random.randint(0, 2**32 - 1)\n",
247
+ "key = jax.random.PRNGKey(seed)"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "markdown",
252
+ "metadata": {
253
+ "id": "BrnVyCo81pij"
254
+ },
255
+ "source": [
256
+ "## 🖍 Text Prompt"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "markdown",
261
+ "metadata": {
262
+ "id": "rsmj0Aj5OQox"
263
+ },
264
+ "source": [
265
+ "Our model requires processing prompts."
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "execution_count": null,
271
+ "metadata": {
272
+ "id": "YjjhUychOVxm"
273
+ },
274
+ "outputs": [],
275
+ "source": [
276
+ "from dalle_mini import DalleBartProcessor\n",
277
+ "\n",
278
+ "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
279
+ ]
280
+ },
281
+ {
282
+ "cell_type": "markdown",
283
+ "metadata": {
284
+ "id": "BQ7fymSPyvF_"
285
+ },
286
+ "source": [
287
+ "Let's define a text prompt."
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": null,
293
+ "metadata": {
294
+ "id": "x_0vI9ge1oKr"
295
+ },
296
+ "outputs": [],
297
+ "source": [
298
+ "prompt = \"sunset over a lake in the mountains\""
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "execution_count": null,
304
+ "metadata": {
305
+ "id": "VKjEZGjtO49k"
306
+ },
307
+ "outputs": [],
308
+ "source": [
309
+ "tokenized_prompt = processor([prompt])"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "markdown",
314
+ "metadata": {
315
+ "id": "-CEJBnuJOe5z"
316
+ },
317
+ "source": [
318
+ "Finally we replicate it onto each device."
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": null,
324
+ "metadata": {
325
+ "id": "lQePgju5Oe5z"
326
+ },
327
+ "outputs": [],
328
+ "source": [
329
+ "tokenized_prompt = replicate(tokenized_prompt)"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "markdown",
334
+ "metadata": {
335
+ "id": "phQ9bhjRkgAZ"
336
+ },
337
+ "source": [
338
+ "## 🎨 Generate images\n",
339
+ "\n",
340
+ "We generate images using dalle-mini model and decode them with the VQGAN."
341
+ ]
342
+ },
343
+ {
344
+ "cell_type": "code",
345
+ "execution_count": null,
346
+ "metadata": {
347
+ "id": "d0wVkXpKqnHA"
348
+ },
349
+ "outputs": [],
350
+ "source": [
351
+ "# number of predictions\n",
352
+ "n_predictions = 16\n",
353
+ "\n",
354
+ "# We can customize top_k/top_p used for generating samples\n",
355
+ "gen_top_k = None\n",
356
+ "gen_top_p = None\n",
357
+ "temperature = 0.85\n",
358
+ "cond_scale = 3.0"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "code",
363
+ "execution_count": null,
364
+ "metadata": {
365
+ "id": "SDjEx9JxR3v8"
366
+ },
367
+ "outputs": [],
368
+ "source": [
369
+ "from flax.training.common_utils import shard_prng_key\n",
370
+ "import numpy as np\n",
371
+ "from PIL import Image\n",
372
+ "from tqdm.notebook import trange\n",
373
+ "\n",
374
+ "# generate images\n",
375
+ "images = []\n",
376
+ "for i in trange(max(n_predictions // jax.device_count(), 1)):\n",
377
+ " # get a new key\n",
378
+ " key, subkey = jax.random.split(key)\n",
379
+ " # generate images\n",
380
+ " encoded_images = p_generate(\n",
381
+ " tokenized_prompt,\n",
382
+ " shard_prng_key(subkey),\n",
383
+ " params,\n",
384
+ " gen_top_k,\n",
385
+ " gen_top_p,\n",
386
+ " temperature,\n",
387
+ " cond_scale,\n",
388
+ " )\n",
389
+ " # remove BOS\n",
390
+ " encoded_images = encoded_images.sequences[..., 1:]\n",
391
+ " # decode images\n",
392
+ " decoded_images = p_decode(encoded_images, vqgan.params)\n",
393
+ " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
394
+ " for img in decoded_images:\n",
395
+ " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
396
+ ]
397
+ },
398
+ {
399
+ "cell_type": "markdown",
400
+ "metadata": {
401
+ "id": "tw02wG9zGmyB"
402
+ },
403
+ "source": [
404
+ "Let's calculate their score with CLIP."
405
+ ]
406
+ },
407
+ {
408
+ "cell_type": "code",
409
+ "execution_count": null,
410
+ "metadata": {
411
+ "id": "FoLXpjCmGpju"
412
+ },
413
+ "outputs": [],
414
+ "source": [
415
+ "from flax.training.common_utils import shard\n",
416
+ "\n",
417
+ "# get clip scores\n",
418
+ "clip_inputs = clip_processor(\n",
419
+ " text=[prompt] * jax.device_count(),\n",
420
+ " images=images,\n",
421
+ " return_tensors=\"np\",\n",
422
+ " padding=\"max_length\",\n",
423
+ " max_length=77,\n",
424
+ " truncation=True,\n",
425
+ ").data\n",
426
+ "logits = p_clip(shard(clip_inputs), clip.params)\n",
427
+ "logits = logits.squeeze().flatten()"
428
+ ]
429
+ },
430
+ {
431
+ "cell_type": "markdown",
432
+ "metadata": {
433
+ "id": "4AAWRm70LgED"
434
+ },
435
+ "source": [
436
+ "Let's display images ranked by CLIP score."
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": null,
442
+ "metadata": {
443
+ "id": "zsgxxubLLkIu"
444
+ },
445
+ "outputs": [],
446
+ "source": [
447
+ "print(f\"Prompt: {prompt}\\n\")\n",
448
+ "for idx in logits.argsort()[::-1]:\n",
449
+ " display(images[idx])\n",
450
+ " print(f\"Score: {logits[idx]:.2f}\\n\")"
451
+ ]
452
+ }
453
+ ],
454
+ "metadata": {
455
+ "accelerator": "GPU",
456
+ "colab": {
457
+ "collapsed_sections": [],
458
+ "machine_shape": "hm",
459
+ "name": "DALL·E mini - Inference pipeline.ipynb",
460
+ "provenance": [],
461
+ "include_colab_link": true
462
+ },
463
+ "kernelspec": {
464
+ "display_name": "Python 3 (ipykernel)",
465
+ "language": "python",
466
+ "name": "python3"
467
+ },
468
+ "language_info": {
469
+ "codemirror_mode": {
470
+ "name": "ipython",
471
+ "version": 3
472
+ },
473
+ "file_extension": ".py",
474
+ "mimetype": "text/x-python",
475
+ "name": "python",
476
+ "nbconvert_exporter": "python",
477
+ "pygments_lexer": "ipython3",
478
+ "version": "3.9.7"
479
+ }
480
  },
481
+ "nbformat": 4,
482
+ "nbformat_minor": 0
483
+ }