import json
import random
from typing import List

import spaces
import gradio as gr
from huggingface_hub import ModelCard

from src.tasks.images.sd import gen_img, ControlNetReq, SDReq, SDImg2ImgReq, SDInpaintReq


models = ["black-forest-labs/FLUX.1-dev"]
with open("data/images/loras/flux.json", "r") as f:
    loras = json.load(f)


def flux_tab():
    """
    Create the image tab for Generative Image Generation Models
    
    Args:
    models: list
        A list containing the models repository paths
    gap_iol, gap_la, gap_le, gap_eio, gap_io: Optional[List[dict]]
        A list of dictionaries containing the title and component for the custom gradio component
        Example:
        def gr_comp():
            gr.Label("Hello World")
        
        [
            {
                'title': "Title",
                'component': gr_comp()
            }
        ]
    loras: list
        A list of dictionaries containing the image and title for the Loras Gallery
        Generally a loaded json file from the data folder
    
    """
    def process_gaps(gaps: List[dict]):
        for gap in gaps:
            with gr.Accordion(gap['title']):
                gap['component']
    
    
    with gr.Row():
        with gr.Column():
            with gr.Group() as image_options:
                model = gr.Dropdown(label="Models", choices=models, value=models[0], interactive=True)
                prompt = gr.Textbox(lines=5, label="Prompt")
                negative_prompt = gr.Textbox(label="Negative Prompt")
                fast_generation = gr.Checkbox(label="Fast Generation (Hyper-SD) 🧪")
            
            
            with gr.Accordion("Loras", open=True): # Lora Gallery
                lora_gallery = gr.Gallery(
                    label="Gallery",
                    value=[(lora['image'], lora['title']) for lora in loras],
                    allow_preview=False,
                    columns=[3],
                    type="pil"
                )
                
                with gr.Group():
                    with gr.Column():
                        with gr.Row():
                            custom_lora = gr.Textbox(label="Custom Lora", info="Enter a Huggingface repo path")
                            selected_lora = gr.Textbox(label="Selected Lora", info="Choose from the gallery or enter a custom LoRA")
                        
                        custom_lora_info = gr.HTML(visible=False)
                        add_lora = gr.Button(value="Add LoRA")
                        
                        enabled_loras = gr.State(value=[])
                        with gr.Group():
                            with gr.Row():
                                for i in range(6): # only support max 6 loras due to inference time
                                    with gr.Column():
                                        with gr.Column(scale=2):
                                            globals()[f"lora_slider_{i}"] = gr.Slider(label=f"LoRA {i+1}", minimum=0, maximum=1, step=0.01, value=0.8, visible=False, interactive=True)
                                        with gr.Column():
                                            globals()[f"lora_remove_{i}"] = gr.Button(value="Remove LoRA", visible=False)

            
            with gr.Accordion("Embeddings", open=False): # Embeddings
                gr.Label("To be implemented")
            
            
            with gr.Accordion("Image Options"): # Image Options
                with gr.Tabs():
                    image_options = {
                        "img2img": "Upload Image",
                        "inpaint": "Upload Image",
                        "canny": "Upload Image",
                        "pose": "Upload Image",
                        "depth": "Upload Image",
                    }
                    
                    for image_option, label in image_options.items():
                        with gr.Tab(image_option):
                            if not image_option in ['inpaint', 'scribble']:
                                globals()[f"{image_option}_image"] = gr.Image(label=label, type="pil")
                            elif image_option in ['inpaint', 'scribble']:
                                globals()[f"{image_option}_image"] = gr.ImageEditor(
                                    label=label,
                                    image_mode='RGB',
                                    layers=False,
                                    brush=gr.Brush(colors=["#FFFFFF"], color_mode="fixed") if image_option == 'inpaint' else gr.Brush(),
                                    interactive=True,
                                    type="pil",
                                )
                            
                            # Image Strength (Co-relates to controlnet strength, strength for img2img n inpaint)
                            globals()[f"{image_option}_strength"] = gr.Slider(label="Strength", minimum=0, maximum=1, step=0.01, value=1.0, interactive=True)
                    
                    resize_mode = gr.Radio(
                        label="Resize Mode",
                        choices=["crop and resize", "resize only", "resize and fill"],
                        value="resize and fill",
                        interactive=True
                    )
        
        
        with gr.Column():
            with gr.Group():
                output_images = gr.Gallery(
                        label="Output Images",
                        value=[],
                        allow_preview=True,
                        type="pil",
                        interactive=False,
                    )
                generate_images = gr.Button(value="Generate Images", variant="primary")            
            
            with gr.Accordion("Advance Settings", open=True):
                with gr.Row():
                    scheduler = gr.Dropdown(
                        label="Scheduler",
                        choices = [
                            "fm_euler"
                        ],
                        value="fm_euler",
                        interactive=True
                    )

                with gr.Row():
                    for column in range(2):
                        with gr.Column():
                            options = [
                                ("Height", "image_height", 64, 1024, 64, 1024, True),
                                ("Width", "image_width", 64, 1024, 64, 1024, True),
                                ("Num Images Per Prompt", "image_num_images_per_prompt", 1, 4, 1, 1, True),
                                ("Num Inference Steps", "image_num_inference_steps", 1, 100, 1, 20, True),
                                ("Clip Skip", "image_clip_skip", 0, 2, 1, 2, False),
                                ("Guidance Scale", "image_guidance_scale", 0, 20, 0.5, 3.5, True),
                                ("Seed", "image_seed", 0, 100000, 1, random.randint(0, 100000), True),
                            ]
                            for label, var_name, min_val, max_val, step, value, visible in options[column::2]:
                                globals()[var_name] = gr.Slider(label=label, minimum=min_val, maximum=max_val, step=step, value=value, visible=visible, interactive=True)
                
                with gr.Row():
                    refiner = gr.Checkbox(
                        label="Refiner 🧪",
                        value=False,
                    )
                    vae = gr.Checkbox(
                        label="VAE",
                        value=True,
                    )


    # Events
    # Base Options
    fast_generation.change(update_fast_generation, [model, fast_generation], [image_guidance_scale, image_num_inference_steps]) # Fast Generation # type: ignore
    

    # Lora Gallery
    lora_gallery.select(selected_lora_from_gallery, None, selected_lora)
    custom_lora.change(update_selected_lora, custom_lora, [custom_lora, selected_lora])
    add_lora.click(add_to_enabled_loras, [model, selected_lora, enabled_loras], [selected_lora, custom_lora_info, enabled_loras])
    enabled_loras.change(update_lora_sliders, enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, lora_remove_0, lora_remove_1, lora_remove_2, lora_remove_3, lora_remove_4, lora_remove_5]) # type: ignore

    for i in range(6):
        globals()[f"lora_remove_{i}"].click(
            lambda enabled_loras, index=i: remove_from_enabled_loras(enabled_loras, index),
            [enabled_loras],
            [enabled_loras]
        )
    

    # Generate Image
    generate_images.click(
        generate_image, # type: ignore
        [
            model, prompt, negative_prompt, fast_generation, enabled_loras,
            lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5, # type: ignore
            img2img_image, inpaint_image, canny_image, pose_image, depth_image, # type: ignore
            img2img_strength, inpaint_strength, canny_strength, pose_strength, depth_strength, # type: ignore
            resize_mode,
            scheduler, image_height, image_width, image_num_images_per_prompt, # type: ignore
            image_num_inference_steps, image_guidance_scale, image_seed, # type: ignore
            refiner, vae
        ],
        [output_images]
    )


# Functions
def update_fast_generation(model, fast_generation):
    if fast_generation:
        return (
            gr.update(
                value=3.5
            ),
            gr.update(
                value=8
            )
        )


def selected_lora_from_gallery(evt: gr.SelectData):
    return (
        gr.update(
            value=evt.index
        )
    )


def update_selected_lora(custom_lora):
    link = custom_lora.split("/")
    
    if len(link) == 2:
        model_card = ModelCard.load(custom_lora)
        trigger_word = model_card.data.get("instance_prompt", "")
        image_url = f"""https://huggingface.co/{custom_lora}/resolve/main/{model_card.data.get("widget", [{}])[0].get("output", {}).get("url", None)}"""
        
        custom_lora_info_css = """
        <style>
            .custom-lora-info {
                font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', sans-serif;
                background: linear-gradient(135deg, #4a90e2, #7b61ff);
                color: white;
                padding: 16px;
                border-radius: 8px;
                box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
                margin: 16px 0;
            }
            .custom-lora-header {
                font-size: 18px;
                font-weight: 600;
                margin-bottom: 12px;
            }
            .custom-lora-content {
                display: flex;
                align-items: center;
                background-color: rgba(255, 255, 255, 0.1);
                border-radius: 6px;
                padding: 12px;
            }
            .custom-lora-image {
                width: 80px;
                height: 80px;
                object-fit: cover;
                border-radius: 6px;
                margin-right: 16px;
            }
            .custom-lora-text h3 {
                margin: 0 0 8px 0;
                font-size: 16px;
                font-weight: 600;
            }
            .custom-lora-text small {
                font-size: 14px;
                opacity: 0.9;
            }
            .custom-trigger-word {
                background-color: rgba(255, 255, 255, 0.2);
                padding: 2px 6px;
                border-radius: 4px;
                font-weight: 600;
            }
        </style>
        """

        custom_lora_info_html = f"""
        <div class="custom-lora-info">
            <div class="custom-lora-header">Custom LoRA: {custom_lora}</div>
            <div class="custom-lora-content">
                <img class="custom-lora-image" src="{image_url}" alt="LoRA preview">
                <div class="custom-lora-text">
                    <h3>{link[1].replace("-", " ").replace("_", " ")}</h3>
                    <small>{"Using: <span class='custom-trigger-word'>"+trigger_word+"</span> as the trigger word" if trigger_word else "No trigger word found. If there's a trigger word, include it in your prompt"}</small>
                </div>
            </div>
        </div>
        """

        custom_lora_info_html = f"{custom_lora_info_css}{custom_lora_info_html}"

        return (
            gr.update( # selected_lora
                value=custom_lora,
            ),
            gr.update( # custom_lora_info
                value=custom_lora_info_html,
                visible=True
            )
        )

    else:
        return (
            gr.update( # selected_lora
                value=custom_lora,
            ),
            gr.update( # custom_lora_info
                value=custom_lora_info_html if len(link) == 0 else "",
                visible=False
            )
        )


def add_to_enabled_loras(model, selected_lora, enabled_loras):
    lora_data = loras
    try:
        selected_lora = int(selected_lora)
        
        if 0 <= selected_lora: # is the index of the lora in the gallery
            lora_info = lora_data[selected_lora]
            enabled_loras.append({
                "repo_id": lora_info["repo"],
                "trigger_word": lora_info["trigger_word"]
            })
    except ValueError:
        link = selected_lora.split("/")
        if len(link) == 2:
            model_card = ModelCard.load(selected_lora)
            trigger_word = model_card.data.get("instance_prompt", "")
            enabled_loras.append({
                "repo_id": selected_lora,
                "trigger_word": trigger_word
            })
    
    return (
        gr.update( # selected_lora
            value=""
        ),
        gr.update( # custom_lora_info
            value="",
            visible=False
        ),
        gr.update( # enabled_loras
            value=enabled_loras
        )
    )


def update_lora_sliders(enabled_loras):
    sliders = []
    remove_buttons = []
    
    for lora in enabled_loras:
        sliders.append(
            gr.update(
                label=lora.get("repo_id", ""),
                info=f"Trigger Word: {lora.get('trigger_word', '')}",
                visible=True,
                interactive=True
            )
        )
        remove_buttons.append(
            gr.update(
                visible=True,
                interactive=True
            )
        )
    
    if len(sliders) < 6:
        for i in range(len(sliders), 6):
            sliders.append(
                gr.update(
                    visible=False
                )
            )
            remove_buttons.append(
                gr.update(
                    visible=False
                )
            )
    
    return *sliders, *remove_buttons


def remove_from_enabled_loras(enabled_loras, index):
    enabled_loras.pop(index)
    return (
        gr.update(
            value=enabled_loras
        )
    )


@spaces.GPU
def generate_image(
        model, prompt, negative_prompt, fast_generation, enabled_loras,
        lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5,
        img2img_image, inpaint_image, canny_image, pose_image, depth_image,
        img2img_strength, inpaint_strength, canny_strength, pose_strength, depth_strength,
        resize_mode,
        scheduler, image_height, image_width, image_num_images_per_prompt,
        image_num_inference_steps, image_guidance_scale, image_seed,
        refiner, vae
    ):
        base_args = {
            "model": model,
            "prompt": prompt,
            "negative_prompt": negative_prompt,
            "fast_generation": fast_generation,
            "loras": None,
            "resize_mode": resize_mode,
            "scheduler": scheduler,
            "height": int(image_height),
            "width": int(image_width),
            "num_images_per_prompt": float(image_num_images_per_prompt),
            "num_inference_steps": float(image_num_inference_steps),
            "guidance_scale": float(image_guidance_scale),
            "seed": int(image_seed),
            "refiner": refiner,
            "vae": vae,
            "controlnet_config": None,
        }
        base_args = SDReq(**base_args)

        if len(enabled_loras) > 0:
            base_args.loras = []
            for enabled_lora, lora_slider in zip(enabled_loras, [lora_slider_0, lora_slider_1, lora_slider_2, lora_slider_3, lora_slider_4, lora_slider_5]):
                if enabled_lora.get("repo_id", None):
                    base_args.loras.append(
                        {
                            "repo_id": enabled_lora["repo_id"],
                            "weight": lora_slider
                        }
                    )
        
        image = None
        mask_image = None
        strength = None
        
        if img2img_image:
            image = img2img_image
            strength = float(img2img_strength)
            
            base_args = SDImg2ImgReq(
                **base_args.__dict__,
                image=image,
                strength=strength
            )
        elif inpaint_image:
            image = inpaint_image['background'] if not all(pixel == (0, 0, 0) for pixel in list(inpaint_image['background'].getdata())) else None
            mask_image = inpaint_image['layers'][0] if image else None
            strength = float(inpaint_strength)
            
            base_args = SDInpaintReq(
                **base_args.__dict__,
                image=image,
                mask_image=mask_image,
                strength=strength
            )
        elif any([canny_image, pose_image, depth_image]):
            base_args.controlnet_config = ControlNetReq(
                controlnets=[],
                control_images=[],
                controlnet_conditioning_scale=[]
            )
            
            if canny_image:
                base_args.controlnet_config.controlnets.append("canny_fl")
                base_args.controlnet_config.control_images.append(canny_image)
                base_args.controlnet_config.controlnet_conditioning_scale.append(float(canny_strength))
            if pose_image:
                base_args.controlnet_config.controlnets.append("pose_fl")
                base_args.controlnet_config.control_images.append(pose_image)
                base_args.controlnet_config.controlnet_conditioning_scale.append(float(pose_strength))
            if depth_image:
                base_args.controlnet_config.controlnets.append("depth_fl")
                base_args.controlnet_config.control_images.append(depth_image)
                base_args.controlnet_config.controlnet_conditioning_scale.append(float(depth_strength))
        else:
            base_args = SDReq(**base_args.__dict__)

        images = gen_img(base_args)
        
        return (
            gr.update(
                value=images,
                interactive=True
            )
        )