# -*- coding: utf-8 -*-
"""
Created on Tue Apr 26 21:02:31 2022

@author: pc
"""

import pickle
import numpy as np
import torch
import gradio as gr 
import sys
import subprocess
import os
from typing import Tuple
import PIL.Image
from huggingface_hub import hf_hub_download

os.system("git clone https://github.com/NVlabs/stylegan3")

sys.path.append("stylegan3")



DESCRIPTION = f'''This model generates healthy MR Brain Images.


[Example]("https://huggingface.co/spaces/SerdarHelli/Brain-MR-Image-Generation-GAN/blob/main/ex.png")
'''
network_pkl="brainmrigan.pkl"


with open(network_pkl, 'rb') as f:
    G = pickle.load(f)['G_ema'] 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
G.eval()
G.to(device)

def predict(Seed,noise_mode,truncation_psi):

  # Generate images.
    z = torch.from_numpy(np.random.RandomState(Seed).randn(1, G.z_dim)).to(device)
    label = torch.zeros([1, G.c_dim], device=device)
    # Construct an inverse rotation/translation matrix and pass to the generator.  The
    # generator expects this matrix as an inverse to avoid potentially failing numerical
    # operations in the network.
    


    img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
    img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
      
    return (PIL.Image.fromarray(img[0].cpu().numpy()[:,:,0])).resize((512,512))



noises=['const', 'random', 'none']
interface=gr.Interface(fn=predict, title="Brain MR Image Generation with StyleGAN-2",
                       description = DESCRIPTION,
                       article = "Author: S.Serdar Helli and Burhan Arat",
                       inputs=[gr.inputs.Slider( minimum=0, maximum=2**16,label='Seed'),gr.inputs.Radio( choices=noises,  default='const',label='Noise Mods'),
                                           gr.inputs.Slider(0, 2, step=0.05, default=1, label='Truncation psi')],
                       outputs=gr.outputs.Image( type="numpy", label="Output"))


interface.launch(debug=True)