|
|
|
|
|
|
|
from PIL import Image |
|
import gradio as gr |
|
|
|
def compose_predictions(images, caption=None): |
|
increased_h = 0 if caption is None else 48 |
|
w, h = images[0].size[0], images[0].size[1] |
|
img = Image.new("RGB", (len(images)*w, h + increased_h)) |
|
for i, img_ in enumerate(images): |
|
img.paste(img_, (i*w, increased_h)) |
|
|
|
if caption is not None: |
|
draw = ImageDraw.Draw(img) |
|
font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40) |
|
draw.text((20, 3), caption, (255,255,255), font=font) |
|
return img |
|
|
|
def compose_predictions_grid(images): |
|
cols = 4 |
|
rows = len(images) // cols |
|
w, h = images[0].size[0], images[0].size[1] |
|
img = Image.new("RGB", (w * cols, h * rows)) |
|
for i, img_ in enumerate(images): |
|
row = i // cols |
|
col = i % cols |
|
img.paste(img_, (w * col, h * row)) |
|
return img |
|
|
|
def top_k_predictions_real(prompt, num_candidates=32, k=8): |
|
images = hallucinate(prompt, num_images=num_candidates) |
|
images = clip_top_k(prompt, images, k=num_preds) |
|
return images |
|
|
|
def top_k_predictions(prompt, num_candidates=32, k=8): |
|
images = [] |
|
for i in range(k): |
|
image = Image.open(f"sample_images/image_{i}.jpg") |
|
images.append(image) |
|
return images |
|
|
|
def run_inference(prompt, num_images=32, num_preds=8): |
|
images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds) |
|
predictions = compose_predictions(images) |
|
output_title = f""" |
|
<p style="font-size:22px; font-style:bold">Best predictions</p> |
|
<p>We asked our model to generate 32 candidates for your prompt:</p> |
|
|
|
<pre> |
|
|
|
<b>{prompt}</b> |
|
</pre> |
|
<p>We then used a pre-trained <a href="https://huggingface.co/openai/clip-vit-base-patch32">CLIP model</a> to score them according to the |
|
similarity of the text and the image representations.</p> |
|
|
|
<p>This is the result:</p> |
|
""" |
|
output_description = """ |
|
<p>Read more about the process <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA">in our report</a>.<p> |
|
<p style='text-align: center'>Created with <a href="https://github.com/borisdayma/dalle-mini">DALLE路mini</a></p> |
|
""" |
|
return (output_title, predictions, output_description) |
|
|
|
outputs = [ |
|
gr.outputs.HTML(label=""), |
|
gr.outputs.Image(label=''), |
|
gr.outputs.HTML(label=""), |
|
] |
|
|
|
description = """ |
|
Welcome to our demo of DALL路E-mini. This project was created on TPU v3-8s during the 馃 Flax / JAX Community Week. |
|
It reproduces the essential characteristics of OpenAI's DALL路E, at a fraction of the size. |
|
|
|
Please, write what you would like the model to generate, or select one of the examples below. |
|
""" |
|
gr.Interface(run_inference, |
|
inputs=[gr.inputs.Textbox(label='Prompt')], |
|
outputs=outputs, |
|
title='DALL路E mini', |
|
description=description, |
|
article="<p style='text-align: center'> DALLE路mini by Boris Dayma et al. | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a></p>", |
|
layout='vertical', |
|
theme='huggingface', |
|
examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']], |
|
allow_flagging=False, |
|
live=False, |
|
server_port=8999 |
|
).launch( |
|
share=True |
|
) |
|
|