ddpm / app.py
lanzhiwang's picture
debug 12
33d25b3
raw
history blame
1.82 kB
import gradio as gr
from diffusers import DiffusionPipeline
import torch
from diffusers import DDPMScheduler, UNet2DModel
from PIL import Image
import numpy as np
def erzeuge(prompt):
return pipeline(prompt).images # [0]
# def erzeuge_komplex(prompt):
# scheduler = DDPMScheduler.from_pretrained("google/ddpm-cat-256")
# model = UNet2DModel.from_pretrained("google/ddpm-cat-256").to("cuda")
# scheduler.set_timesteps(50)
# sample_size = model.config.sample_size
# noise = torch.randn((1, 3, sample_size, sample_size)).to("cuda")
# input = noise
# for t in scheduler.timesteps:
# with torch.no_grad():
# noisy_residual = model(input, t).sample
# prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample
# input = prev_noisy_sample
# image = (input / 2 + 0.5).clamp(0, 1)
# image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
# image = Image.fromarray((image * 255).round().astype("uint8"))
# return image
# pipeline = DiffusionPipeline.from_pretrained("google/ddpm-cat-256")
pipeline = DiffusionPipeline.from_pretrained("google/ddpm-celebahq-256")
# pipeline.to("cuda")
with gr.Blocks() as demo:
with gr.Column(variant="panel"):
with gr.Row(variant="compact"):
text = gr.Textbox(
label="Deine Beschreibung:",
show_label=False,
max_lines=1,
placeholder="Bildbeschreibung",
)
btn = gr.Button("erzeuge Bild")
gallery = gr.Gallery(
label="Erzeugtes Bild", show_label=False, elem_id="gallery"
)
btn.click(erzeuge, inputs=[text], outputs=[gallery])
text.submit(erzeuge, inputs=[text], outputs=[gallery])
if __name__ == "__main__":
demo.launch()