boris commited on
Commit
5f16fb0
·
1 Parent(s): 89bc9d4

feat(colab): require less memory

Browse files
Files changed (1) hide show
  1. tools/inference/inference_pipeline.ipynb +862 -478
tools/inference/inference_pipeline.ipynb CHANGED
@@ -1,481 +1,865 @@
1
  {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {
6
- "id": "view-in-github",
7
- "colab_type": "text"
8
- },
9
- "source": [
10
- "<a href=\"https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
- ]
12
- },
13
- {
14
- "cell_type": "markdown",
15
- "metadata": {
16
- "id": "118UKH5bWCGa"
17
- },
18
- "source": [
19
- "# DALL·E mini - Inference pipeline\n",
20
- "\n",
21
- "*Generate images from a text prompt*\n",
22
- "\n",
23
- "<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
24
- "\n",
25
- "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
26
- "\n",
27
- "Just want to play? Use [the demo](https://huggingface.co/spaces/flax-community/dalle-mini).\n",
28
- "\n",
29
- "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
30
- ]
31
- },
32
- {
33
- "cell_type": "markdown",
34
- "metadata": {
35
- "id": "dS8LbaonYm3a"
36
- },
37
- "source": [
38
- "## 🛠️ Installation and set-up"
39
- ]
40
- },
41
- {
42
- "cell_type": "code",
43
- "execution_count": null,
44
- "metadata": {
45
- "id": "uzjAM2GBYpZX"
46
- },
47
- "outputs": [],
48
- "source": [
49
- "# Install required libraries\n",
50
- "!pip install -q git+https://github.com/huggingface/transformers.git\n",
51
- "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
52
- "!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
53
- ]
54
- },
55
- {
56
- "cell_type": "markdown",
57
- "metadata": {
58
- "id": "ozHzTkyv8cqU"
59
- },
60
- "source": [
61
- "We load required models:\n",
62
- "* dalle·mini for text to encoded images\n",
63
- "* VQGAN for decoding images\n",
64
- "* CLIP for scoring predictions"
65
- ]
66
- },
67
- {
68
- "cell_type": "code",
69
- "execution_count": null,
70
- "metadata": {
71
- "id": "K6CxW2o42f-w"
72
- },
73
- "outputs": [],
74
- "source": [
75
- "# Model references\n",
76
- "\n",
77
- "# dalle-mini\n",
78
- "DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
79
- "DALLE_COMMIT_ID = None\n",
80
- "\n",
81
- "# VQGAN model\n",
82
- "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
83
- "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
84
- "\n",
85
- "# CLIP model\n",
86
- "CLIP_REPO = \"openai/clip-vit-large-patch14\"\n",
87
- "CLIP_COMMIT_ID = None"
88
- ]
89
- },
90
- {
91
- "cell_type": "markdown",
92
- "source": [
93
- "If your hardware can handle it, you can use dalle-mega instead of dalle-mini.\n",
94
- "\n",
95
- "**Note: on free Colab, you will most likely not be able to load dalle-mega so don't run the below cell to use dalle-mini instead.**"
96
- ],
97
- "metadata": {
98
- "id": "Jy01V3OltG5q"
99
- }
100
- },
101
- {
102
- "cell_type": "code",
103
- "source": [
104
- "# dalle-mega\n",
105
- "DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\""
106
- ],
107
- "metadata": {
108
- "id": "VRRX_g9vtMZG"
109
- },
110
- "execution_count": null,
111
- "outputs": []
112
- },
113
- {
114
- "cell_type": "code",
115
- "execution_count": null,
116
- "metadata": {
117
- "id": "Yv-aR3t4Oe5v"
118
- },
119
- "outputs": [],
120
- "source": [
121
- "import jax\n",
122
- "import jax.numpy as jnp\n",
123
- "\n",
124
- "# check how many devices are available\n",
125
- "jax.local_device_count()"
126
- ]
127
- },
128
- {
129
- "cell_type": "code",
130
- "execution_count": null,
131
- "metadata": {
132
- "id": "92zYmvsQ38vL"
133
- },
134
- "outputs": [],
135
- "source": [
136
- "# Load models & tokenizer\n",
137
- "from dalle_mini import DalleBart, DalleBartProcessor\n",
138
- "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
139
- "from transformers import CLIPProcessor, FlaxCLIPModel\n",
140
- "\n",
141
- "# Load dalle-mini\n",
142
- "model, params = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False)\n",
143
- "\n",
144
- "# Load VQGAN\n",
145
- "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
146
- "\n",
147
- "# Load CLIP\n",
148
- "clip = FlaxCLIPModel.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
149
- "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)"
150
- ]
151
- },
152
- {
153
- "cell_type": "markdown",
154
- "metadata": {
155
- "id": "o_vH2X1tDtzA"
156
- },
157
- "source": [
158
- "Model parameters are replicated on each device for faster inference."
159
- ]
160
- },
161
- {
162
- "cell_type": "code",
163
- "execution_count": null,
164
- "metadata": {
165
- "id": "wtvLoM48EeVw"
166
- },
167
- "outputs": [],
168
- "source": [
169
- "from flax.jax_utils import replicate\n",
170
- "\n",
171
- "params = replicate(params)\n",
172
- "vqgan._params = replicate(vqgan.params)\n",
173
- "clip._params = replicate(clip.params)"
174
- ]
175
- },
176
- {
177
- "cell_type": "markdown",
178
- "metadata": {
179
- "id": "0A9AHQIgZ_qw"
180
- },
181
- "source": [
182
- "Model functions are compiled and parallelized to take advantage of multiple devices."
183
- ]
184
- },
185
- {
186
- "cell_type": "code",
187
- "execution_count": null,
188
- "metadata": {
189
- "id": "sOtoOmYsSYPz"
190
- },
191
- "outputs": [],
192
- "source": [
193
- "from functools import partial\n",
194
- "\n",
195
- "# model inference\n",
196
- "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
197
- "def p_generate(\n",
198
- " tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
199
- "):\n",
200
- " return model.generate(\n",
201
- " **tokenized_prompt,\n",
202
- " prng_key=key,\n",
203
- " params=params,\n",
204
- " top_k=top_k,\n",
205
- " top_p=top_p,\n",
206
- " temperature=temperature,\n",
207
- " condition_scale=condition_scale,\n",
208
- " )\n",
209
- "\n",
210
- "\n",
211
- "# decode image\n",
212
- "@partial(jax.pmap, axis_name=\"batch\")\n",
213
- "def p_decode(indices, params):\n",
214
- " return vqgan.decode_code(indices, params=params)\n",
215
- "\n",
216
- "\n",
217
- "# score images\n",
218
- "@partial(jax.pmap, axis_name=\"batch\")\n",
219
- "def p_clip(inputs, params):\n",
220
- " logits = clip(params=params, **inputs).logits_per_image\n",
221
- " return logits"
222
- ]
223
- },
224
- {
225
- "cell_type": "markdown",
226
- "metadata": {
227
- "id": "HmVN6IBwapBA"
228
- },
229
- "source": [
230
- "Keys are passed to the model on each device to generate unique inference per device."
231
- ]
232
- },
233
- {
234
- "cell_type": "code",
235
- "execution_count": null,
236
- "metadata": {
237
- "id": "4CTXmlUkThhX"
238
- },
239
- "outputs": [],
240
- "source": [
241
- "import random\n",
242
- "\n",
243
- "# create a random key\n",
244
- "seed = random.randint(0, 2**32 - 1)\n",
245
- "key = jax.random.PRNGKey(seed)"
246
- ]
247
- },
248
- {
249
- "cell_type": "markdown",
250
- "metadata": {
251
- "id": "BrnVyCo81pij"
252
- },
253
- "source": [
254
- "## 🖍 Text Prompt"
255
- ]
256
- },
257
- {
258
- "cell_type": "markdown",
259
- "metadata": {
260
- "id": "rsmj0Aj5OQox"
261
- },
262
- "source": [
263
- "Our model requires processing prompts."
264
- ]
265
- },
266
- {
267
- "cell_type": "code",
268
- "execution_count": null,
269
- "metadata": {
270
- "id": "YjjhUychOVxm"
271
- },
272
- "outputs": [],
273
- "source": [
274
- "from dalle_mini import DalleBartProcessor\n",
275
- "\n",
276
- "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
277
- ]
278
- },
279
- {
280
- "cell_type": "markdown",
281
- "metadata": {
282
- "id": "BQ7fymSPyvF_"
283
- },
284
- "source": [
285
- "Let's define a text prompt."
286
- ]
287
- },
288
- {
289
- "cell_type": "code",
290
- "execution_count": null,
291
- "metadata": {
292
- "id": "x_0vI9ge1oKr"
293
- },
294
- "outputs": [],
295
- "source": [
296
- "prompt = \"sunset over a lake in the mountains\""
297
- ]
298
- },
299
- {
300
- "cell_type": "code",
301
- "execution_count": null,
302
- "metadata": {
303
- "id": "VKjEZGjtO49k"
304
- },
305
- "outputs": [],
306
- "source": [
307
- "tokenized_prompt = processor([prompt])"
308
- ]
309
- },
310
- {
311
- "cell_type": "markdown",
312
- "metadata": {
313
- "id": "-CEJBnuJOe5z"
314
- },
315
- "source": [
316
- "Finally we replicate it onto each device."
317
- ]
318
- },
319
- {
320
- "cell_type": "code",
321
- "execution_count": null,
322
- "metadata": {
323
- "id": "lQePgju5Oe5z"
324
- },
325
- "outputs": [],
326
- "source": [
327
- "tokenized_prompt = replicate(tokenized_prompt)"
328
- ]
329
- },
330
- {
331
- "cell_type": "markdown",
332
- "metadata": {
333
- "id": "phQ9bhjRkgAZ"
334
- },
335
- "source": [
336
- "## 🎨 Generate images\n",
337
- "\n",
338
- "We generate images using dalle-mini model and decode them with the VQGAN."
339
- ]
340
- },
341
- {
342
- "cell_type": "code",
343
- "execution_count": null,
344
- "metadata": {
345
- "id": "d0wVkXpKqnHA"
346
- },
347
- "outputs": [],
348
- "source": [
349
- "# number of predictions\n",
350
- "n_predictions = 16\n",
351
- "\n",
352
- "# We can customize top_k/top_p used for generating samples\n",
353
- "gen_top_k = None\n",
354
- "gen_top_p = None\n",
355
- "temperature = None\n",
356
- "cond_scale = 3.0"
357
- ]
358
- },
359
- {
360
- "cell_type": "code",
361
- "execution_count": null,
362
- "metadata": {
363
- "id": "SDjEx9JxR3v8"
364
- },
365
- "outputs": [],
366
- "source": [
367
- "from flax.training.common_utils import shard_prng_key\n",
368
- "import numpy as np\n",
369
- "from PIL import Image\n",
370
- "from tqdm.notebook import trange\n",
371
- "\n",
372
- "# generate images\n",
373
- "images = []\n",
374
- "for i in trange(max(n_predictions // jax.device_count(), 1)):\n",
375
- " # get a new key\n",
376
- " key, subkey = jax.random.split(key)\n",
377
- " # generate images\n",
378
- " encoded_images = p_generate(\n",
379
- " tokenized_prompt,\n",
380
- " shard_prng_key(subkey),\n",
381
- " params,\n",
382
- " gen_top_k,\n",
383
- " gen_top_p,\n",
384
- " temperature,\n",
385
- " cond_scale,\n",
386
- " )\n",
387
- " # remove BOS\n",
388
- " encoded_images = encoded_images.sequences[..., 1:]\n",
389
- " # decode images\n",
390
- " decoded_images = p_decode(encoded_images, vqgan.params)\n",
391
- " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
392
- " for img in decoded_images:\n",
393
- " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
394
- ]
395
- },
396
- {
397
- "cell_type": "markdown",
398
- "metadata": {
399
- "id": "tw02wG9zGmyB"
400
- },
401
- "source": [
402
- "Let's calculate their score with CLIP."
403
- ]
404
- },
405
- {
406
- "cell_type": "code",
407
- "execution_count": null,
408
- "metadata": {
409
- "id": "FoLXpjCmGpju"
410
- },
411
- "outputs": [],
412
- "source": [
413
- "from flax.training.common_utils import shard\n",
414
- "\n",
415
- "# get clip scores\n",
416
- "clip_inputs = clip_processor(\n",
417
- " text=[prompt] * jax.device_count(),\n",
418
- " images=images,\n",
419
- " return_tensors=\"np\",\n",
420
- " padding=\"max_length\",\n",
421
- " max_length=77,\n",
422
- " truncation=True,\n",
423
- ").data\n",
424
- "logits = p_clip(shard(clip_inputs), clip.params)\n",
425
- "logits = logits.squeeze().flatten()"
426
- ]
427
- },
428
- {
429
- "cell_type": "markdown",
430
- "metadata": {
431
- "id": "4AAWRm70LgED"
432
- },
433
- "source": [
434
- "Let's display images ranked by CLIP score."
435
- ]
436
- },
437
- {
438
- "cell_type": "code",
439
- "execution_count": null,
440
- "metadata": {
441
- "id": "zsgxxubLLkIu"
442
- },
443
- "outputs": [],
444
- "source": [
445
- "print(f\"Prompt: {prompt}\\n\")\n",
446
- "for idx in logits.argsort()[::-1]:\n",
447
- " display(images[idx])\n",
448
- " print(f\"Score: {logits[idx]:.2f}\\n\")"
449
- ]
450
- }
451
- ],
452
- "metadata": {
453
- "accelerator": "GPU",
454
  "colab": {
455
- "collapsed_sections": [],
456
- "machine_shape": "hm",
457
- "name": "DALL·E mini - Inference pipeline.ipynb",
458
- "provenance": [],
459
- "include_colab_link": true
460
- },
461
- "kernelspec": {
462
- "display_name": "Python 3 (ipykernel)",
463
- "language": "python",
464
- "name": "python3"
465
- },
466
- "language_info": {
467
- "codemirror_mode": {
468
- "name": "ipython",
469
- "version": 3
470
- },
471
- "file_extension": ".py",
472
- "mimetype": "text/x-python",
473
- "name": "python",
474
- "nbconvert_exporter": "python",
475
- "pygments_lexer": "ipython3",
476
- "version": "3.9.7"
477
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
  },
479
- "nbformat": 4,
480
- "nbformat_minor": 0
481
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "118UKH5bWCGa"
7
+ },
8
+ "source": [
9
+ "# DALL·E mini - Inference pipeline\n",
10
+ "\n",
11
+ "*Generate images from a text prompt*\n",
12
+ "\n",
13
+ "<img src=\"https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true\" width=\"200\">\n",
14
+ "\n",
15
+ "This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
16
+ "\n",
17
+ "Just want to play? Use directly [DALL·E mini app](https://huggingface.co/spaces/dalle-mini/dalle-mini).\n",
18
+ "\n",
19
+ "For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "markdown",
24
+ "metadata": {
25
+ "id": "dS8LbaonYm3a"
26
+ },
27
+ "source": [
28
+ "## 🛠️ Installation and set-up"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "metadata": {
35
+ "id": "uzjAM2GBYpZX"
36
+ },
37
+ "outputs": [],
38
+ "source": [
39
+ "# Install required libraries\n",
40
+ "!pip install -q git+https://github.com/huggingface/transformers.git\n",
41
+ "!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git\n",
42
+ "!pip install -q git+https://github.com/borisdayma/dalle-mini.git"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "markdown",
47
+ "metadata": {
48
+ "id": "ozHzTkyv8cqU"
49
+ },
50
+ "source": [
51
+ "We load required models:\n",
52
+ "* DALL·E mini for text to encoded images\n",
53
+ "* VQGAN for decoding images\n",
54
+ "* CLIP for scoring predictions"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "metadata": {
61
+ "id": "K6CxW2o42f-w"
62
+ },
63
+ "outputs": [],
64
+ "source": [
65
+ "# Model references\n",
66
+ "\n",
67
+ "# dalle-mega\n",
68
+ "DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
69
+ "DALLE_COMMIT_ID = None\n",
70
+ "\n",
71
+ "# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line\n",
72
+ "# DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"\n",
73
+ "\n",
74
+ "# VQGAN model\n",
75
+ "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
76
+ "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\""
77
+ ]
78
+ },
79
+ {
80
+ "cell_type": "code",
81
+ "execution_count": null,
82
+ "metadata": {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  "colab": {
84
+ "base_uri": "https://localhost:8080/"
85
+ },
86
+ "id": "Yv-aR3t4Oe5v",
87
+ "outputId": "3097b2c7-5dac-475f-edde-898799dd7294"
88
+ },
89
+ "outputs": [],
90
+ "source": [
91
+ "import jax\n",
92
+ "import jax.numpy as jnp\n",
93
+ "\n",
94
+ "# check how many devices are available\n",
95
+ "jax.local_device_count()"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "metadata": {
102
+ "colab": {
103
+ "base_uri": "https://localhost:8080/"
104
+ },
105
+ "id": "92zYmvsQ38vL",
106
+ "outputId": "d897dfdb-dae7-4026-da36-8b23dce066e8"
107
+ },
108
+ "outputs": [],
109
+ "source": [
110
+ "# Load models & tokenizer\n",
111
+ "from dalle_mini import DalleBart, DalleBartProcessor\n",
112
+ "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
113
+ "from transformers import CLIPProcessor, FlaxCLIPModel\n",
114
+ "\n",
115
+ "# Load dalle-mini\n",
116
+ "model, params = DalleBart.from_pretrained(\n",
117
+ " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
118
+ ")\n",
119
+ "\n",
120
+ "# Load VQGAN\n",
121
+ "vqgan, vqgan_params = VQModel.from_pretrained(\n",
122
+ " VQGAN_REPO, revision=VQGAN_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
123
+ ")"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "markdown",
128
+ "metadata": {
129
+ "id": "o_vH2X1tDtzA"
130
+ },
131
+ "source": [
132
+ "Model parameters are replicated on each device for faster inference."
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": null,
138
+ "metadata": {
139
+ "id": "wtvLoM48EeVw"
140
+ },
141
+ "outputs": [],
142
+ "source": [
143
+ "from flax.jax_utils import replicate\n",
144
+ "\n",
145
+ "params = replicate(params)\n",
146
+ "vqgan_params = replicate(vqgan_params)"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "markdown",
151
+ "metadata": {
152
+ "id": "0A9AHQIgZ_qw"
153
+ },
154
+ "source": [
155
+ "Model functions are compiled and parallelized to take advantage of multiple devices."
156
+ ]
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "execution_count": null,
161
+ "metadata": {
162
+ "id": "sOtoOmYsSYPz"
163
+ },
164
+ "outputs": [],
165
+ "source": [
166
+ "from functools import partial\n",
167
+ "\n",
168
+ "# model inference\n",
169
+ "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
170
+ "def p_generate(\n",
171
+ " tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
172
+ "):\n",
173
+ " return model.generate(\n",
174
+ " **tokenized_prompt,\n",
175
+ " prng_key=key,\n",
176
+ " params=params,\n",
177
+ " top_k=top_k,\n",
178
+ " top_p=top_p,\n",
179
+ " temperature=temperature,\n",
180
+ " condition_scale=condition_scale,\n",
181
+ " )\n",
182
+ "\n",
183
+ "\n",
184
+ "# decode image\n",
185
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
186
+ "def p_decode(indices, params):\n",
187
+ " return vqgan.decode_code(indices, params=params)"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "markdown",
192
+ "metadata": {
193
+ "id": "HmVN6IBwapBA"
194
+ },
195
+ "source": [
196
+ "Keys are passed to the model on each device to generate unique inference per device."
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": null,
202
+ "metadata": {
203
+ "id": "4CTXmlUkThhX"
204
+ },
205
+ "outputs": [],
206
+ "source": [
207
+ "import random\n",
208
+ "\n",
209
+ "# create a random key\n",
210
+ "seed = random.randint(0, 2**32 - 1)\n",
211
+ "key = jax.random.PRNGKey(seed)"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "markdown",
216
+ "metadata": {
217
+ "id": "BrnVyCo81pij"
218
+ },
219
+ "source": [
220
+ "## 🖍 Text Prompt"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "markdown",
225
+ "metadata": {
226
+ "id": "rsmj0Aj5OQox"
227
+ },
228
+ "source": [
229
+ "Our model requires processing prompts."
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": null,
235
+ "metadata": {
236
+ "colab": {
237
+ "base_uri": "https://localhost:8080/"
238
+ },
239
+ "id": "YjjhUychOVxm",
240
+ "outputId": "a286f17a-a388-4754-ec4d-0464c0666c90"
241
+ },
242
+ "outputs": [],
243
+ "source": [
244
+ "from dalle_mini import DalleBartProcessor\n",
245
+ "\n",
246
+ "processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "markdown",
251
+ "metadata": {
252
+ "id": "BQ7fymSPyvF_"
253
+ },
254
+ "source": [
255
+ "Let's define a text prompt."
256
+ ]
257
  },
258
+ {
259
+ "cell_type": "code",
260
+ "execution_count": null,
261
+ "metadata": {
262
+ "id": "x_0vI9ge1oKr"
263
+ },
264
+ "outputs": [],
265
+ "source": [
266
+ "prompt = \"sunset over a lake in the mountains\""
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "code",
271
+ "execution_count": null,
272
+ "metadata": {
273
+ "id": "VKjEZGjtO49k"
274
+ },
275
+ "outputs": [],
276
+ "source": [
277
+ "tokenized_prompt = processor([prompt])"
278
+ ]
279
+ },
280
+ {
281
+ "cell_type": "markdown",
282
+ "metadata": {
283
+ "id": "-CEJBnuJOe5z"
284
+ },
285
+ "source": [
286
+ "Finally we replicate it onto each device."
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "code",
291
+ "execution_count": null,
292
+ "metadata": {
293
+ "id": "lQePgju5Oe5z"
294
+ },
295
+ "outputs": [],
296
+ "source": [
297
+ "tokenized_prompt = replicate(tokenized_prompt)"
298
+ ]
299
+ },
300
+ {
301
+ "cell_type": "markdown",
302
+ "metadata": {
303
+ "id": "phQ9bhjRkgAZ"
304
+ },
305
+ "source": [
306
+ "## 🎨 Generate images\n",
307
+ "\n",
308
+ "We generate images using dalle-mini model and decode them with the VQGAN."
309
+ ]
310
+ },
311
+ {
312
+ "cell_type": "code",
313
+ "execution_count": null,
314
+ "metadata": {
315
+ "id": "d0wVkXpKqnHA"
316
+ },
317
+ "outputs": [],
318
+ "source": [
319
+ "# number of predictions\n",
320
+ "n_predictions = 8\n",
321
+ "\n",
322
+ "# We can customize generation parameters\n",
323
+ "gen_top_k = None\n",
324
+ "gen_top_p = None\n",
325
+ "temperature = None\n",
326
+ "cond_scale = 3.0"
327
+ ]
328
+ },
329
+ {
330
+ "cell_type": "code",
331
+ "execution_count": null,
332
+ "metadata": {
333
+ "colab": {
334
+ "base_uri": "https://localhost:8080/",
335
+ "height": 1000,
336
+ "referenced_widgets": [
337
+ "cef76449b8d74217ae36c56be3990eec",
338
+ "7be07ba7cfe642a596509c756dcefddc",
339
+ "2a02378499fc414299f17a2d5dcac867",
340
+ "427d47d9423441d286ae80a637ae35a0",
341
+ "cb157fd4e37041d1beae29eaa729c8ff",
342
+ "73413668398b45dfa8484a2c2be778ec",
343
+ "e7d108a4b168442fb2048f58ddeb0a18",
344
+ "5e81a141422f432395055f5cafb07016",
345
+ "5f476a929da84fa985b2e980459da7b9",
346
+ "f3b643a0ca2444fd959fff9b45d79d27",
347
+ "82b87345233549d699ce3fd8080fa988"
348
+ ]
349
+ },
350
+ "id": "SDjEx9JxR3v8",
351
+ "outputId": "8f4287a7-aff9-41ef-a026-02265de0c205"
352
+ },
353
+ "outputs": [],
354
+ "source": [
355
+ "from flax.training.common_utils import shard_prng_key\n",
356
+ "import numpy as np\n",
357
+ "from PIL import Image\n",
358
+ "from tqdm.notebook import trange\n",
359
+ "\n",
360
+ "print(f\"Prompt: {prompt}\\n\")\n",
361
+ "# generate images\n",
362
+ "images = []\n",
363
+ "for i in trange(max(n_predictions // jax.device_count(), 1)):\n",
364
+ " # get a new key\n",
365
+ " key, subkey = jax.random.split(key)\n",
366
+ " # generate images\n",
367
+ " encoded_images = p_generate(\n",
368
+ " tokenized_prompt,\n",
369
+ " shard_prng_key(subkey),\n",
370
+ " params,\n",
371
+ " gen_top_k,\n",
372
+ " gen_top_p,\n",
373
+ " temperature,\n",
374
+ " cond_scale,\n",
375
+ " )\n",
376
+ " # remove BOS\n",
377
+ " encoded_images = encoded_images.sequences[..., 1:]\n",
378
+ " # decode images\n",
379
+ " decoded_images = p_decode(encoded_images, vqgan_params)\n",
380
+ " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
381
+ " for decoded_img in decoded_images:\n",
382
+ " img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))\n",
383
+ " images.append(img)\n",
384
+ " display(img)"
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "markdown",
389
+ "metadata": {
390
+ "id": "tw02wG9zGmyB"
391
+ },
392
+ "source": [
393
+ "## 🏅 Optional: Rank images by CLIP score\n",
394
+ "\n",
395
+ "We can rank images according to CLIP.\n",
396
+ "\n",
397
+ "**Note: your session may crash if you don't have a subscription to Colab Pro.**"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": null,
403
+ "metadata": {
404
+ "id": "RGjlIW_f6GA0"
405
+ },
406
+ "outputs": [],
407
+ "source": [
408
+ "# CLIP model\n",
409
+ "CLIP_REPO = \"openai/clip-vit-base-patch32\"\n",
410
+ "CLIP_COMMIT_ID = None\n",
411
+ "\n",
412
+ "# Load CLIP\n",
413
+ "clip, clip_params = FlaxCLIPModel.from_pretrained(\n",
414
+ " CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
415
+ ")\n",
416
+ "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
417
+ "clip_params = replicate(clip_params)\n",
418
+ "\n",
419
+ "# score images\n",
420
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
421
+ "def p_clip(inputs, params):\n",
422
+ " logits = clip(params=params, **inputs).logits_per_image\n",
423
+ " return logits"
424
+ ]
425
+ },
426
+ {
427
+ "cell_type": "code",
428
+ "execution_count": null,
429
+ "metadata": {
430
+ "id": "FoLXpjCmGpju"
431
+ },
432
+ "outputs": [],
433
+ "source": [
434
+ "from flax.training.common_utils import shard\n",
435
+ "\n",
436
+ "# CLIP model\n",
437
+ "CLIP_REPO = \"openai/clip-vit-base-patch32\"\n",
438
+ "CLIP_COMMIT_ID = None\n",
439
+ "\n",
440
+ "# Load CLIP\n",
441
+ "clip, clip_params = FlaxCLIPModel.from_pretrained(\n",
442
+ " CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
443
+ ")\n",
444
+ "clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
445
+ "clip_params = replicate(clip_params)\n",
446
+ "\n",
447
+ "# score images\n",
448
+ "@partial(jax.pmap, axis_name=\"batch\")\n",
449
+ "def p_clip(inputs, params):\n",
450
+ " logits = clip(params=params, **inputs).logits_per_image\n",
451
+ " return logits\n",
452
+ "\n",
453
+ "\n",
454
+ "# get clip scores\n",
455
+ "clip_inputs = clip_processor(\n",
456
+ " text=[prompt] * jax.device_count(),\n",
457
+ " images=images,\n",
458
+ " return_tensors=\"np\",\n",
459
+ " padding=\"max_length\",\n",
460
+ " max_length=77,\n",
461
+ " truncation=True,\n",
462
+ ").data\n",
463
+ "logits = p_clip(shard(clip_inputs), clip.params)\n",
464
+ "logits = logits.squeeze().flatten()"
465
+ ]
466
+ },
467
+ {
468
+ "cell_type": "markdown",
469
+ "metadata": {
470
+ "id": "4AAWRm70LgED"
471
+ },
472
+ "source": [
473
+ "Let's now display images ranked by CLIP score."
474
+ ]
475
+ },
476
+ {
477
+ "cell_type": "code",
478
+ "execution_count": null,
479
+ "metadata": {
480
+ "id": "zsgxxubLLkIu"
481
+ },
482
+ "outputs": [],
483
+ "source": [
484
+ "print(f\"Prompt: {prompt}\\n\")\n",
485
+ "for idx in logits.argsort()[::-1]:\n",
486
+ " display(images[idx])\n",
487
+ " print(f\"Score: {logits[idx]:.2f}\\n\")"
488
+ ]
489
+ }
490
+ ],
491
+ "metadata": {
492
+ "accelerator": "GPU",
493
+ "colab": {
494
+ "collapsed_sections": [],
495
+ "machine_shape": "hm",
496
+ "name": "DALL·E mini - Inference pipeline.ipynb",
497
+ "provenance": []
498
+ },
499
+ "kernelspec": {
500
+ "display_name": "Python 3 (ipykernel)",
501
+ "language": "python",
502
+ "name": "python3"
503
+ },
504
+ "language_info": {
505
+ "codemirror_mode": {
506
+ "name": "ipython",
507
+ "version": 3
508
+ },
509
+ "file_extension": ".py",
510
+ "mimetype": "text/x-python",
511
+ "name": "python",
512
+ "nbconvert_exporter": "python",
513
+ "pygments_lexer": "ipython3",
514
+ "version": "3.9.7"
515
+ },
516
+ "widgets": {
517
+ "application/vnd.jupyter.widget-state+json": {
518
+ "2a02378499fc414299f17a2d5dcac867": {
519
+ "model_module": "@jupyter-widgets/controls",
520
+ "model_module_version": "1.5.0",
521
+ "model_name": "FloatProgressModel",
522
+ "state": {
523
+ "_dom_classes": [],
524
+ "_model_module": "@jupyter-widgets/controls",
525
+ "_model_module_version": "1.5.0",
526
+ "_model_name": "FloatProgressModel",
527
+ "_view_count": null,
528
+ "_view_module": "@jupyter-widgets/controls",
529
+ "_view_module_version": "1.5.0",
530
+ "_view_name": "ProgressView",
531
+ "bar_style": "",
532
+ "description": "",
533
+ "description_tooltip": null,
534
+ "layout": "IPY_MODEL_5e81a141422f432395055f5cafb07016",
535
+ "max": 8,
536
+ "min": 0,
537
+ "orientation": "horizontal",
538
+ "style": "IPY_MODEL_5f476a929da84fa985b2e980459da7b9",
539
+ "value": 5
540
+ }
541
+ },
542
+ "427d47d9423441d286ae80a637ae35a0": {
543
+ "model_module": "@jupyter-widgets/controls",
544
+ "model_module_version": "1.5.0",
545
+ "model_name": "HTMLModel",
546
+ "state": {
547
+ "_dom_classes": [],
548
+ "_model_module": "@jupyter-widgets/controls",
549
+ "_model_module_version": "1.5.0",
550
+ "_model_name": "HTMLModel",
551
+ "_view_count": null,
552
+ "_view_module": "@jupyter-widgets/controls",
553
+ "_view_module_version": "1.5.0",
554
+ "_view_name": "HTMLView",
555
+ "description": "",
556
+ "description_tooltip": null,
557
+ "layout": "IPY_MODEL_f3b643a0ca2444fd959fff9b45d79d27",
558
+ "placeholder": "​",
559
+ "style": "IPY_MODEL_82b87345233549d699ce3fd8080fa988",
560
+ "value": " 5/8 [04:25&lt;02:39, 53.09s/it]"
561
+ }
562
+ },
563
+ "5e81a141422f432395055f5cafb07016": {
564
+ "model_module": "@jupyter-widgets/base",
565
+ "model_module_version": "1.2.0",
566
+ "model_name": "LayoutModel",
567
+ "state": {
568
+ "_model_module": "@jupyter-widgets/base",
569
+ "_model_module_version": "1.2.0",
570
+ "_model_name": "LayoutModel",
571
+ "_view_count": null,
572
+ "_view_module": "@jupyter-widgets/base",
573
+ "_view_module_version": "1.2.0",
574
+ "_view_name": "LayoutView",
575
+ "align_content": null,
576
+ "align_items": null,
577
+ "align_self": null,
578
+ "border": null,
579
+ "bottom": null,
580
+ "display": null,
581
+ "flex": null,
582
+ "flex_flow": null,
583
+ "grid_area": null,
584
+ "grid_auto_columns": null,
585
+ "grid_auto_flow": null,
586
+ "grid_auto_rows": null,
587
+ "grid_column": null,
588
+ "grid_gap": null,
589
+ "grid_row": null,
590
+ "grid_template_areas": null,
591
+ "grid_template_columns": null,
592
+ "grid_template_rows": null,
593
+ "height": null,
594
+ "justify_content": null,
595
+ "justify_items": null,
596
+ "left": null,
597
+ "margin": null,
598
+ "max_height": null,
599
+ "max_width": null,
600
+ "min_height": null,
601
+ "min_width": null,
602
+ "object_fit": null,
603
+ "object_position": null,
604
+ "order": null,
605
+ "overflow": null,
606
+ "overflow_x": null,
607
+ "overflow_y": null,
608
+ "padding": null,
609
+ "right": null,
610
+ "top": null,
611
+ "visibility": null,
612
+ "width": null
613
+ }
614
+ },
615
+ "5f476a929da84fa985b2e980459da7b9": {
616
+ "model_module": "@jupyter-widgets/controls",
617
+ "model_module_version": "1.5.0",
618
+ "model_name": "ProgressStyleModel",
619
+ "state": {
620
+ "_model_module": "@jupyter-widgets/controls",
621
+ "_model_module_version": "1.5.0",
622
+ "_model_name": "ProgressStyleModel",
623
+ "_view_count": null,
624
+ "_view_module": "@jupyter-widgets/base",
625
+ "_view_module_version": "1.2.0",
626
+ "_view_name": "StyleView",
627
+ "bar_color": null,
628
+ "description_width": ""
629
+ }
630
+ },
631
+ "73413668398b45dfa8484a2c2be778ec": {
632
+ "model_module": "@jupyter-widgets/base",
633
+ "model_module_version": "1.2.0",
634
+ "model_name": "LayoutModel",
635
+ "state": {
636
+ "_model_module": "@jupyter-widgets/base",
637
+ "_model_module_version": "1.2.0",
638
+ "_model_name": "LayoutModel",
639
+ "_view_count": null,
640
+ "_view_module": "@jupyter-widgets/base",
641
+ "_view_module_version": "1.2.0",
642
+ "_view_name": "LayoutView",
643
+ "align_content": null,
644
+ "align_items": null,
645
+ "align_self": null,
646
+ "border": null,
647
+ "bottom": null,
648
+ "display": null,
649
+ "flex": null,
650
+ "flex_flow": null,
651
+ "grid_area": null,
652
+ "grid_auto_columns": null,
653
+ "grid_auto_flow": null,
654
+ "grid_auto_rows": null,
655
+ "grid_column": null,
656
+ "grid_gap": null,
657
+ "grid_row": null,
658
+ "grid_template_areas": null,
659
+ "grid_template_columns": null,
660
+ "grid_template_rows": null,
661
+ "height": null,
662
+ "justify_content": null,
663
+ "justify_items": null,
664
+ "left": null,
665
+ "margin": null,
666
+ "max_height": null,
667
+ "max_width": null,
668
+ "min_height": null,
669
+ "min_width": null,
670
+ "object_fit": null,
671
+ "object_position": null,
672
+ "order": null,
673
+ "overflow": null,
674
+ "overflow_x": null,
675
+ "overflow_y": null,
676
+ "padding": null,
677
+ "right": null,
678
+ "top": null,
679
+ "visibility": null,
680
+ "width": null
681
+ }
682
+ },
683
+ "7be07ba7cfe642a596509c756dcefddc": {
684
+ "model_module": "@jupyter-widgets/controls",
685
+ "model_module_version": "1.5.0",
686
+ "model_name": "HTMLModel",
687
+ "state": {
688
+ "_dom_classes": [],
689
+ "_model_module": "@jupyter-widgets/controls",
690
+ "_model_module_version": "1.5.0",
691
+ "_model_name": "HTMLModel",
692
+ "_view_count": null,
693
+ "_view_module": "@jupyter-widgets/controls",
694
+ "_view_module_version": "1.5.0",
695
+ "_view_name": "HTMLView",
696
+ "description": "",
697
+ "description_tooltip": null,
698
+ "layout": "IPY_MODEL_73413668398b45dfa8484a2c2be778ec",
699
+ "placeholder": "​",
700
+ "style": "IPY_MODEL_e7d108a4b168442fb2048f58ddeb0a18",
701
+ "value": " 62%"
702
+ }
703
+ },
704
+ "82b87345233549d699ce3fd8080fa988": {
705
+ "model_module": "@jupyter-widgets/controls",
706
+ "model_module_version": "1.5.0",
707
+ "model_name": "DescriptionStyleModel",
708
+ "state": {
709
+ "_model_module": "@jupyter-widgets/controls",
710
+ "_model_module_version": "1.5.0",
711
+ "_model_name": "DescriptionStyleModel",
712
+ "_view_count": null,
713
+ "_view_module": "@jupyter-widgets/base",
714
+ "_view_module_version": "1.2.0",
715
+ "_view_name": "StyleView",
716
+ "description_width": ""
717
+ }
718
+ },
719
+ "cb157fd4e37041d1beae29eaa729c8ff": {
720
+ "model_module": "@jupyter-widgets/base",
721
+ "model_module_version": "1.2.0",
722
+ "model_name": "LayoutModel",
723
+ "state": {
724
+ "_model_module": "@jupyter-widgets/base",
725
+ "_model_module_version": "1.2.0",
726
+ "_model_name": "LayoutModel",
727
+ "_view_count": null,
728
+ "_view_module": "@jupyter-widgets/base",
729
+ "_view_module_version": "1.2.0",
730
+ "_view_name": "LayoutView",
731
+ "align_content": null,
732
+ "align_items": null,
733
+ "align_self": null,
734
+ "border": null,
735
+ "bottom": null,
736
+ "display": null,
737
+ "flex": null,
738
+ "flex_flow": null,
739
+ "grid_area": null,
740
+ "grid_auto_columns": null,
741
+ "grid_auto_flow": null,
742
+ "grid_auto_rows": null,
743
+ "grid_column": null,
744
+ "grid_gap": null,
745
+ "grid_row": null,
746
+ "grid_template_areas": null,
747
+ "grid_template_columns": null,
748
+ "grid_template_rows": null,
749
+ "height": null,
750
+ "justify_content": null,
751
+ "justify_items": null,
752
+ "left": null,
753
+ "margin": null,
754
+ "max_height": null,
755
+ "max_width": null,
756
+ "min_height": null,
757
+ "min_width": null,
758
+ "object_fit": null,
759
+ "object_position": null,
760
+ "order": null,
761
+ "overflow": null,
762
+ "overflow_x": null,
763
+ "overflow_y": null,
764
+ "padding": null,
765
+ "right": null,
766
+ "top": null,
767
+ "visibility": null,
768
+ "width": null
769
+ }
770
+ },
771
+ "cef76449b8d74217ae36c56be3990eec": {
772
+ "model_module": "@jupyter-widgets/controls",
773
+ "model_module_version": "1.5.0",
774
+ "model_name": "HBoxModel",
775
+ "state": {
776
+ "_dom_classes": [],
777
+ "_model_module": "@jupyter-widgets/controls",
778
+ "_model_module_version": "1.5.0",
779
+ "_model_name": "HBoxModel",
780
+ "_view_count": null,
781
+ "_view_module": "@jupyter-widgets/controls",
782
+ "_view_module_version": "1.5.0",
783
+ "_view_name": "HBoxView",
784
+ "box_style": "",
785
+ "children": [
786
+ "IPY_MODEL_7be07ba7cfe642a596509c756dcefddc",
787
+ "IPY_MODEL_2a02378499fc414299f17a2d5dcac867",
788
+ "IPY_MODEL_427d47d9423441d286ae80a637ae35a0"
789
+ ],
790
+ "layout": "IPY_MODEL_cb157fd4e37041d1beae29eaa729c8ff"
791
+ }
792
+ },
793
+ "e7d108a4b168442fb2048f58ddeb0a18": {
794
+ "model_module": "@jupyter-widgets/controls",
795
+ "model_module_version": "1.5.0",
796
+ "model_name": "DescriptionStyleModel",
797
+ "state": {
798
+ "_model_module": "@jupyter-widgets/controls",
799
+ "_model_module_version": "1.5.0",
800
+ "_model_name": "DescriptionStyleModel",
801
+ "_view_count": null,
802
+ "_view_module": "@jupyter-widgets/base",
803
+ "_view_module_version": "1.2.0",
804
+ "_view_name": "StyleView",
805
+ "description_width": ""
806
+ }
807
+ },
808
+ "f3b643a0ca2444fd959fff9b45d79d27": {
809
+ "model_module": "@jupyter-widgets/base",
810
+ "model_module_version": "1.2.0",
811
+ "model_name": "LayoutModel",
812
+ "state": {
813
+ "_model_module": "@jupyter-widgets/base",
814
+ "_model_module_version": "1.2.0",
815
+ "_model_name": "LayoutModel",
816
+ "_view_count": null,
817
+ "_view_module": "@jupyter-widgets/base",
818
+ "_view_module_version": "1.2.0",
819
+ "_view_name": "LayoutView",
820
+ "align_content": null,
821
+ "align_items": null,
822
+ "align_self": null,
823
+ "border": null,
824
+ "bottom": null,
825
+ "display": null,
826
+ "flex": null,
827
+ "flex_flow": null,
828
+ "grid_area": null,
829
+ "grid_auto_columns": null,
830
+ "grid_auto_flow": null,
831
+ "grid_auto_rows": null,
832
+ "grid_column": null,
833
+ "grid_gap": null,
834
+ "grid_row": null,
835
+ "grid_template_areas": null,
836
+ "grid_template_columns": null,
837
+ "grid_template_rows": null,
838
+ "height": null,
839
+ "justify_content": null,
840
+ "justify_items": null,
841
+ "left": null,
842
+ "margin": null,
843
+ "max_height": null,
844
+ "max_width": null,
845
+ "min_height": null,
846
+ "min_width": null,
847
+ "object_fit": null,
848
+ "object_position": null,
849
+ "order": null,
850
+ "overflow": null,
851
+ "overflow_x": null,
852
+ "overflow_y": null,
853
+ "padding": null,
854
+ "right": null,
855
+ "top": null,
856
+ "visibility": null,
857
+ "width": null
858
+ }
859
+ }
860
+ }
861
+ }
862
+ },
863
+ "nbformat": 4,
864
+ "nbformat_minor": 0
865
+ }