# Imports
import gradio as gr
import threading
import requests
import random
import spaces
import torch
import uuid
import json
import os
import numpy as np

from huggingface_hub import hf_hub_download
from diffusers import DiffusionPipeline
from transformers import pipeline
from PIL import Image

# Pre-Initialize
DEVICE = "auto"
if DEVICE == "auto":
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[SYSTEM] | Using {DEVICE} type compute device.")

# Variables
HF_TOKEN = os.environ.get("HF_TOKEN")

MAX_SEED = 9007199254740991
DEFAULT_INPUT = ""
DEFAULT_NEGATIVE_INPUT = "(bad, ugly, amputation, abstract, blur, deformed, distorted, disfigured, disconnected, mutation, mutated, low quality, lowres), unfinished, text, signature, watermark, (limbs, legs, feet, arms, hands), (porn, nude, naked, nsfw)"
DEFAULT_MODEL = "Default"
DEFAULT_HEIGHT = 1024
DEFAULT_WIDTH = 1024

headers = {"Content-Type": "application/json", "Authorization": f"Bearer {HF_TOKEN}" }

css = '''
.gradio-container{max-width: 560px !important}
h1{text-align:center}
footer {
    visibility: hidden
}
'''

repo_nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection")

repo_default = DiffusionPipeline.from_pretrained("fluently/Fluently-XL-Final", torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False)
repo_default.load_lora_weights("ehristoforu/dalle-3-xl-v2", adapter_name="default_base")
repo_default.load_lora_weights("artificialguybr/PixelArtRedmond", adapter_name="pixel_base")
repo_default.load_lora_weights("nerijs/pixel-art-xl", adapter_name="pixel_base_2")

repo_pro = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, use_safetensors=True)
repo_pro.load_lora_weights(hf_hub_download("alimama-creative/FLUX.1-Turbo-Alpha", "diffusion_pytorch_model.safetensors"))

repo_customs = {
    "Default": repo_default,
    "Realistic": DiffusionPipeline.from_pretrained("ehristoforu/Visionix-alpha", torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False),
    "Anime": DiffusionPipeline.from_pretrained("cagliostrolab/animagine-xl-3.1", torch_dtype=torch.float16, use_safetensors=True, add_watermarker=False),
    "Pixel": repo_default,
    "Pro": repo_pro,
}

# Functions
def save_image(img, seed):
    name = f"{seed}-{uuid.uuid4()}.png"
    img.save(name)
    return name
    
def get_seed(seed):
    seed = seed.strip()
    if seed.isdigit():
        return int(seed)
    else:
        return random.randint(0, MAX_SEED)

@spaces.GPU(duration=30)
def generate(input=DEFAULT_INPUT, filter_input="", negative_input=DEFAULT_NEGATIVE_INPUT, model=DEFAULT_MODEL, height=DEFAULT_HEIGHT, width=DEFAULT_WIDTH, steps=1, guidance=0, number=1, seed=None, height_buffer=DEFAULT_HEIGHT, width_buffer=DEFAULT_WIDTH):

    repo = repo_customs[model or "Default"]
    filter_input = filter_input or ""
    negative_input = negative_input or DEFAULT_NEGATIVE_INPUT
    steps_set = steps
    guidance_set = guidance
    seed = get_seed(seed)

    print(input, filter_input, negative_input, model, height, width, steps, guidance, number, seed)
    
    if model == "Realistic":   
        steps_set = 25
        guidance_set = 7
    elif model == "Anime":   
        steps_set = 25
        guidance_set = 7
    elif model == "Pixel":   
        steps_set = 10
        guidance_set = 1.5
        repo.set_adapters(["pixel_base", "pixel_base_2"], adapter_weights=[1, 1])
    elif model == "Pro":   
        steps_set = 8
        guidance_set = 3.5
    else:
        steps_set = 25
        guidance_set = 7
        repo.set_adapters(["default_base"], adapter_weights=[0.7])

    if not steps:
        steps = steps_set
    if not guidance:
        guidance = guidance_set
    
    print(steps, guidance)
    
    repo.to(DEVICE)
    
    parameters = {
        "prompt": input,
        "height": height,
        "width": width,
        "num_inference_steps": steps,
        "guidance_scale": guidance,
        "num_images_per_prompt": number,
        "generator": torch.Generator().manual_seed(seed),
        "output_type":"pil",
    }

    if model != "Pro":
        parameters["negative_prompt"] = filter_input + negative_input

    images = repo(**parameters).images
    image_paths = [save_image(img, seed) for img in images]

    print(image_paths)
    
    nsfw_prediction = repo_nsfw_classifier(image_paths[0])

    print(nsfw_prediction)

    buffer_image = images[0].convert("RGBA").resize((width_buffer, height_buffer))
    
    image_array = np.array(buffer_image)
    pixel_data = image_array.flatten().tolist()
    
    buffer_json = json.dumps(pixel_data)

    return image_paths, {item['label']: round(item['score'], 3) for item in nsfw_prediction}, buffer_json

def cloud():
    print("[CLOUD] | Space maintained.")

@spaces.GPU(duration=0.1)
def gpu():
    print("[GPU] | Fetched GPU token.")
    
# Initialize
with gr.Blocks(css=css) as main:
    with gr.Column():
        gr.Markdown("🪄 Generate high quality images in all styles.")
        
    with gr.Column():
        input = gr.Textbox(lines=1, value=DEFAULT_INPUT, label="Input")
        filter_input = gr.Textbox(lines=1, value="", label="Input Filter")
        negative_input = gr.Textbox(lines=1, value=DEFAULT_NEGATIVE_INPUT, label="Input Negative")
        model = gr.Dropdown(choices=repo_customs.keys(), value="Default", label="Model")
        height = gr.Slider(minimum=8, maximum=2160, step=1, value=DEFAULT_HEIGHT, label="Height")
        width = gr.Slider(minimum=8, maximum=2160, step=1, value=DEFAULT_WIDTH, label="Width")
        steps = gr.Slider(minimum=1, maximum=100, step=1, value=25, label="Steps")
        guidance = gr.Slider(minimum=0, maximum=100, step=0.1, value=5, label = "Guidance")
        number = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="Number")
        seed = gr.Textbox(lines=1, value="", label="Seed (Blank for random)")
        height_buffer = gr.Slider(minimum=1, maximum=2160, step=1, value=DEFAULT_HEIGHT, label="Buffer Height")
        width_buffer = gr.Slider(minimum=1, maximum=2160, step=1, value=DEFAULT_WIDTH, label="Buffer Width")
        submit = gr.Button("▶")
        maintain = gr.Button("☁️")
        get_gpu = gr.Button("💻")

    with gr.Column():
        output = gr.Gallery(columns=1, label="Image")
        output_2 = gr.Label()
        output_3 = gr.Textbox(lines=1, value="", label="Buffer")
            
    submit.click(generate, inputs=[input, filter_input, negative_input, model, height, width, steps, guidance, number, seed, height_buffer, width_buffer], outputs=[output, output_2, output_3], queue=False)
    maintain.click(cloud, inputs=[], outputs=[], queue=False)
    get_gpu.click(gpu, inputs=[], outputs=[], queue=False)

main.launch(show_api=True)