feat(app): migrate demo to Gradio by @AK391 (#179)
Browse files- README.md +3 -2
- app/gradio/app.py +53 -0
- app/gradio/app_gradio.py +0 -179
- app/gradio/backend.py +33 -0
- app/gradio/requirements.txt +0 -4
- app/streamlit/app.py +0 -2
README.md
CHANGED
@@ -3,8 +3,9 @@ title: DALL路E mini
|
|
3 |
emoji: 馃
|
4 |
colorFrom: yellow
|
5 |
colorTo: green
|
6 |
-
sdk:
|
7 |
-
|
|
|
8 |
pinned: True
|
9 |
license: apache-2.0
|
10 |
---
|
|
|
3 |
emoji: 馃
|
4 |
colorFrom: yellow
|
5 |
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.0b6
|
8 |
+
app_file: app/gradio/app.py
|
9 |
pinned: True
|
10 |
license: apache-2.0
|
11 |
---
|
app/gradio/app.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding: utf-8
|
3 |
+
import os
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
from backend import get_images_from_backend
|
7 |
+
|
8 |
+
block = gr.Blocks(css=".container { max-width: 800px; margin: auto; }")
|
9 |
+
backend_url = os.environ["BACKEND_SERVER"] + "/generate"
|
10 |
+
|
11 |
+
|
12 |
+
def infer(prompt):
|
13 |
+
response = get_images_from_backend(prompt, backend_url)
|
14 |
+
return response["images"]
|
15 |
+
|
16 |
+
|
17 |
+
with block:
|
18 |
+
gr.Markdown("<h1><center>DALL路E mini</center></h1>")
|
19 |
+
gr.Markdown(
|
20 |
+
"DALL路E mini is an AI model that generates images from any prompt you give!"
|
21 |
+
)
|
22 |
+
with gr.Group():
|
23 |
+
with gr.Box():
|
24 |
+
with gr.Row().style(mobile_collapse=False, equal_height=True):
|
25 |
+
|
26 |
+
text = gr.Textbox(
|
27 |
+
label="Enter your prompt", show_label=False, max_lines=1
|
28 |
+
).style(
|
29 |
+
border=(True, False, True, True),
|
30 |
+
margin=False,
|
31 |
+
rounded=(True, False, False, True),
|
32 |
+
container=False,
|
33 |
+
)
|
34 |
+
btn = gr.Button("Run").style(
|
35 |
+
margin=False,
|
36 |
+
rounded=(False, True, True, False),
|
37 |
+
)
|
38 |
+
gallery = gr.Gallery(label="Generated images", show_label=False).style(
|
39 |
+
grid=[3], height="auto"
|
40 |
+
)
|
41 |
+
btn.click(infer, inputs=text, outputs=gallery)
|
42 |
+
|
43 |
+
gr.Markdown(
|
44 |
+
"""___
|
45 |
+
<p style='text-align: center'>
|
46 |
+
Created by Boris Dayma et al. 2021-2022
|
47 |
+
<br/>
|
48 |
+
<a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA" target="_blank">Project Report</a>
|
49 |
+
</p>"""
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
block.launch()
|
app/gradio/app_gradio.py
DELETED
@@ -1,179 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python
|
2 |
-
# coding: utf-8
|
3 |
-
|
4 |
-
# Uncomment to run on cpu
|
5 |
-
# import os
|
6 |
-
# os.environ["JAX_PLATFORM_NAME"] = "cpu"
|
7 |
-
|
8 |
-
import random
|
9 |
-
|
10 |
-
import gradio as gr
|
11 |
-
import jax
|
12 |
-
import numpy as np
|
13 |
-
from flax.jax_utils import replicate
|
14 |
-
from flax.training.common_utils import shard
|
15 |
-
from PIL import Image, ImageDraw, ImageFont
|
16 |
-
|
17 |
-
# ## CLIP Scoring
|
18 |
-
from transformers import BartTokenizer, CLIPProcessor, FlaxCLIPModel
|
19 |
-
from vqgan_jax.modeling_flax_vqgan import VQModel
|
20 |
-
|
21 |
-
from dalle_mini.model import CustomFlaxBartForConditionalGeneration
|
22 |
-
|
23 |
-
DALLE_REPO = "flax-community/dalle-mini"
|
24 |
-
DALLE_COMMIT_ID = "4d34126d0df8bc4a692ae933e3b902a1fa8b6114"
|
25 |
-
|
26 |
-
VQGAN_REPO = "flax-community/vqgan_f16_16384"
|
27 |
-
VQGAN_COMMIT_ID = "90cc46addd2dd8f5be21586a9a23e1b95aa506a9"
|
28 |
-
|
29 |
-
tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
|
30 |
-
model = CustomFlaxBartForConditionalGeneration.from_pretrained(
|
31 |
-
DALLE_REPO, revision=DALLE_COMMIT_ID
|
32 |
-
)
|
33 |
-
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
|
34 |
-
|
35 |
-
|
36 |
-
def captioned_strip(images, caption=None, rows=1):
|
37 |
-
increased_h = 0 if caption is None else 48
|
38 |
-
w, h = images[0].size[0], images[0].size[1]
|
39 |
-
img = Image.new("RGB", (len(images) * w // rows, h * rows + increased_h))
|
40 |
-
for i, img_ in enumerate(images):
|
41 |
-
img.paste(img_, (i // rows * w, increased_h + (i % rows) * h))
|
42 |
-
|
43 |
-
if caption is not None:
|
44 |
-
draw = ImageDraw.Draw(img)
|
45 |
-
font = ImageFont.truetype(
|
46 |
-
"/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40
|
47 |
-
)
|
48 |
-
draw.text((20, 3), caption, (255, 255, 255), font=font)
|
49 |
-
return img
|
50 |
-
|
51 |
-
|
52 |
-
def custom_to_pil(x):
|
53 |
-
x = np.clip(x, 0.0, 1.0)
|
54 |
-
x = (255 * x).astype(np.uint8)
|
55 |
-
x = Image.fromarray(x)
|
56 |
-
if not x.mode == "RGB":
|
57 |
-
x = x.convert("RGB")
|
58 |
-
return x
|
59 |
-
|
60 |
-
|
61 |
-
def generate(input, rng, params):
|
62 |
-
return model.generate(
|
63 |
-
**input,
|
64 |
-
max_length=257,
|
65 |
-
num_beams=1,
|
66 |
-
do_sample=True,
|
67 |
-
prng_key=rng,
|
68 |
-
eos_token_id=50000,
|
69 |
-
pad_token_id=50000,
|
70 |
-
params=params,
|
71 |
-
)
|
72 |
-
|
73 |
-
|
74 |
-
def get_images(indices, params):
|
75 |
-
return vqgan.decode_code(indices, params=params)
|
76 |
-
|
77 |
-
|
78 |
-
p_generate = jax.pmap(generate, "batch")
|
79 |
-
p_get_images = jax.pmap(get_images, "batch")
|
80 |
-
|
81 |
-
bart_params = replicate(model.params)
|
82 |
-
vqgan_params = replicate(vqgan.params)
|
83 |
-
|
84 |
-
clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
85 |
-
print("Initialize FlaxCLIPModel")
|
86 |
-
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
87 |
-
print("Initialize CLIPProcessor")
|
88 |
-
|
89 |
-
|
90 |
-
def hallucinate(prompt, num_images=64):
|
91 |
-
prompt = [prompt] * jax.device_count()
|
92 |
-
inputs = tokenizer(
|
93 |
-
prompt,
|
94 |
-
return_tensors="jax",
|
95 |
-
padding="max_length",
|
96 |
-
truncation=True,
|
97 |
-
max_length=128,
|
98 |
-
).data
|
99 |
-
inputs = shard(inputs)
|
100 |
-
|
101 |
-
all_images = []
|
102 |
-
for i in range(num_images // jax.device_count()):
|
103 |
-
key = random.randint(0, 1e7)
|
104 |
-
rng = jax.random.PRNGKey(key)
|
105 |
-
rngs = jax.random.split(rng, jax.local_device_count())
|
106 |
-
indices = p_generate(inputs, rngs, bart_params).sequences
|
107 |
-
indices = indices[:, :, 1:]
|
108 |
-
|
109 |
-
images = p_get_images(indices, vqgan_params)
|
110 |
-
images = np.squeeze(np.asarray(images), 1)
|
111 |
-
for image in images:
|
112 |
-
all_images.append(custom_to_pil(image))
|
113 |
-
return all_images
|
114 |
-
|
115 |
-
|
116 |
-
def clip_top_k(prompt, images, k=8):
|
117 |
-
inputs = processor(text=prompt, images=images, return_tensors="np", padding=True)
|
118 |
-
outputs = clip(**inputs)
|
119 |
-
logits = outputs.logits_per_text
|
120 |
-
scores = np.array(logits[0]).argsort()[-k:][::-1]
|
121 |
-
return [images[score] for score in scores]
|
122 |
-
|
123 |
-
|
124 |
-
def compose_predictions(images, caption=None):
|
125 |
-
increased_h = 0 if caption is None else 48
|
126 |
-
w, h = images[0].size[0], images[0].size[1]
|
127 |
-
img = Image.new("RGB", (len(images) * w, h + increased_h))
|
128 |
-
for i, img_ in enumerate(images):
|
129 |
-
img.paste(img_, (i * w, increased_h))
|
130 |
-
|
131 |
-
if caption is not None:
|
132 |
-
draw = ImageDraw.Draw(img)
|
133 |
-
font = ImageFont.truetype(
|
134 |
-
"/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40
|
135 |
-
)
|
136 |
-
draw.text((20, 3), caption, (255, 255, 255), font=font)
|
137 |
-
return img
|
138 |
-
|
139 |
-
|
140 |
-
def top_k_predictions(prompt, num_candidates=32, k=8):
|
141 |
-
images = hallucinate(prompt, num_images=num_candidates)
|
142 |
-
images = clip_top_k(prompt, images, k=k)
|
143 |
-
return images
|
144 |
-
|
145 |
-
|
146 |
-
def run_inference(prompt, num_images=32, num_preds=8):
|
147 |
-
images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
|
148 |
-
predictions = captioned_strip(images)
|
149 |
-
output_title = f"""
|
150 |
-
<b>{prompt}</b>
|
151 |
-
"""
|
152 |
-
return (output_title, predictions)
|
153 |
-
|
154 |
-
|
155 |
-
outputs = [
|
156 |
-
gr.outputs.HTML(label=""), # To be used as title
|
157 |
-
gr.outputs.Image(label=""),
|
158 |
-
]
|
159 |
-
|
160 |
-
description = """
|
161 |
-
DALL路E-mini is an AI model that generates images from any prompt you give! Generate images from text:
|
162 |
-
"""
|
163 |
-
gr.Interface(
|
164 |
-
run_inference,
|
165 |
-
inputs=[gr.inputs.Textbox(label="What do you want to see?")],
|
166 |
-
outputs=outputs,
|
167 |
-
title="DALL路E mini",
|
168 |
-
description=description,
|
169 |
-
article="<p style='text-align: center'> Created by Boris Dayma et al. 2021 | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a> | <a href='https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA'>Report</a></p>",
|
170 |
-
layout="vertical",
|
171 |
-
theme="huggingface",
|
172 |
-
examples=[
|
173 |
-
["an armchair in the shape of an avocado"],
|
174 |
-
["snowy mountains by the sea"],
|
175 |
-
],
|
176 |
-
allow_flagging=False,
|
177 |
-
live=False,
|
178 |
-
# server_port=8999
|
179 |
-
).launch(share=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/gradio/backend.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Client requests to Dalle-Mini Backend server
|
2 |
+
|
3 |
+
import base64
|
4 |
+
from io import BytesIO
|
5 |
+
|
6 |
+
import requests
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
|
10 |
+
class ServiceError(Exception):
|
11 |
+
def __init__(self, status_code):
|
12 |
+
self.status_code = status_code
|
13 |
+
|
14 |
+
|
15 |
+
def get_images_from_backend(prompt, backend_url):
|
16 |
+
r = requests.post(backend_url, json={"prompt": prompt})
|
17 |
+
if r.status_code == 200:
|
18 |
+
json = r.json()
|
19 |
+
images = json["images"]
|
20 |
+
images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
|
21 |
+
version = json.get("version", "unknown")
|
22 |
+
return {"images": images, "version": version}
|
23 |
+
else:
|
24 |
+
raise ServiceError(r.status_code)
|
25 |
+
|
26 |
+
|
27 |
+
def get_model_version(url):
|
28 |
+
r = requests.get(url)
|
29 |
+
if r.status_code == 200:
|
30 |
+
version = r.json()["version"]
|
31 |
+
return version
|
32 |
+
else:
|
33 |
+
raise ServiceError(r.status_code)
|
app/gradio/requirements.txt
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
# Requirements for huggingface spaces
|
2 |
-
gradio>=2.2.3
|
3 |
-
flax
|
4 |
-
transformers
|
|
|
|
|
|
|
|
|
|
app/streamlit/app.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
#!/usr/bin/env python
|
2 |
# coding: utf-8
|
3 |
|
4 |
-
from datetime import datetime
|
5 |
-
|
6 |
import streamlit as st
|
7 |
from backend import ServiceError, get_images_from_backend
|
8 |
|
|
|
1 |
#!/usr/bin/env python
|
2 |
# coding: utf-8
|
3 |
|
|
|
|
|
4 |
import streamlit as st
|
5 |
from backend import ServiceError, get_images_from_backend
|
6 |
|