import os

import torch
import PIL.Image
import numpy as np
import gradio as gr
from yarg import get

from models.stylegan_generator import StyleGANGenerator
from models.stylegan2_generator import StyleGAN2Generator

from utils.constants import VALID_CHOICES, ENABLE_GPU, MODEL_NAME, OUTPUT_LIST, description, title, css, article
from utils.image_manip import tensor_to_pil, concat_images

def get_generator(model_name):
    if model_name == 'stylegan_ffhq':
        generator = StyleGANGenerator(model_name)
    elif model_name == 'stylegan2_ffhq':
        generator = StyleGAN2Generator(model_name)
    else:
        raise ValueError('Model name not recognized')
    if ENABLE_GPU:
        generator = generator.cuda()
    return generator

generator = get_generator(MODEL_NAME)
boundaries = {
    boundary:np.squeeze(np.load(open(os.path.join('boundaries', MODEL_NAME, 'boundary_%s.npy' % boundary), 'rb'))) 
    for boundary in VALID_CHOICES
}

@torch.no_grad()
def inference(seed, coef, nb_images, list_choices):
    global generator, boundaries
    np.random.seed(seed)
    latent_codes = generator.easy_sample(nb_images)
    if ENABLE_GPU:
        latent_codes = latent_codes.cuda()
        generator = generator.cuda()
    generated_images = generator.easy_synthesize(latent_codes)
    generated_images = tensor_to_pil(generated_images)

    new_latent_codes = latent_codes.copy()
    for i, _ in enumerate(generated_images):
        for choice in list_choices:
            new_latent_codes[i, :] +=  boundaries[choice]*coef

    modified_generated_images = generator.easy_synthesize(new_latent_codes)
    modified_generated_images = tensor_to_pil(modified_generated_images)

    concatenated_output = concat_images(generated_images, modified_generated_images)

    return concatenated_output

# https://huggingface.co/spaces/osanseviero/6DRepNet/blob/main/app.py

iface = gr.Interface(
    fn=inference, 
    inputs=[
        gr.inputs.Slider(
            minimum=0,
            maximum=1000,
            step=1,
            default=644,
            label="Random seed to use for the generation"
        ),
        gr.inputs.Slider(
            minimum=-3,
            maximum=3,
            step=0.1,
            default=1,
            label="Modification scale",
        ),
        gr.inputs.Slider(
            minimum=1,
            maximum=8,
            step=1,
            default=2,
            label="Number of images to generate",
        ),
        gr.inputs.CheckboxGroup(
            VALID_CHOICES, 
            default=[], 
            type="value", 
            label="Select attributes to modify", 
            optional=False
        )
    ],
    outputs=OUTPUT_LIST,
    layout="horizontal",
    theme="peach",
    description=description,
    title=title,
    css=css,
    article=article
)
iface.launch()