from diffusers import DDPMPipeline
import gradio as gr
from ui import title, description, examples


RES = None

models = [
    {'type': 'pokemon', 'res': 64, 'id': 'mrm8488/ddpm-ema-pokemon-64'},
    {'type': 'flowers', 'res': 64, 'id': 'mrm8488/ddpm-ema-flower-64'},
    {'type': 'anime_faces', 'res': 128, 'id': 'mrm8488/ddpm-ema-anime-v2-128'},
    {'type': 'butterflies', 'res': 128, 'id': 'mrm8488/ddpm-ema-butterflies-128'},
    #{'type': 'human_faces', 'res': 256, 'id': 'fusing/ddpm-celeba-hq'}
]
for model in models:
    print(model)
    pipeline = DDPMPipeline.from_pretrained(model['id'])
    pipeline.save_pretrained('.')
    model['pipeline'] = pipeline


def predict(type):
    pipeline = None
    for model in models:
        if model['type'] == type:
            pipeline = model['pipeline']
            RES = model['res']
            break
    # run pipeline in inference
    image = pipeline()["sample"]

    return image[0]


gr.Interface(
    predict,
    inputs=[gr.components.Dropdown(choices=[model['type'] for model in models], label='Choose a model')
            ],
    outputs=[gr.Image(shape=(64,64), type="pil",
                      elem_id="generated_image")],
    title=title,
    description=description
).launch()