import random

import gradio as gr
import torch
from diffusers import (
    AutoPipelineForText2Image,
    AutoPipelineForImage2Image,
    AutoPipelineForInpainting,
)
from huggingface_hub import hf_hub_download
from diffusers.schedulers import *
from huggingface_hub import hf_hub_download
from sd_embed.embedding_funcs import get_weighted_text_embeddings_flux1

from .common_helpers import ControlNetReq, BaseReq, BaseImg2ImgReq, BaseInpaintReq, cleanup, get_controlnet_images, resize_images
from modules.pipelines.flux_pipelines import device, models, flux_vae, controlnet
from modules.pipelines.common_pipelines import refiner


def get_control_mode(controlnet_config: ControlNetReq):
    control_mode = []
    layers = ["canny", "tile", "depth", "blur", "pose", "gray", "low_quality"]

    for c in controlnet_config.controlnets:
        if c in layers:
            control_mode.append(layers.index(c))

    return control_mode


def get_pipe(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
    for m in models:
        if m['repo_id'] == request.model:
            pipe_args = {
                "pipeline": m['pipeline'],
            }

            # Set ControlNet config
            if request.controlnet_config:
                pipe_args["control_mode"] = get_control_mode(request.controlnet_config)
                pipe_args["controlnet"] = [controlnet]

            # Choose Pipeline Mode
            if isinstance(request, BaseInpaintReq):
                pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args)
            elif isinstance(request, BaseImg2ImgReq):
                pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args)
            elif isinstance(request, BaseReq):
                pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)

            # Enable or Disable Vae
            if request.vae:
                pipe_args["pipeline"].vae = flux_vae
            elif not request.vae:
                pipe_args["pipeline"].vae = None

            # Set Scheduler
            pipe_args["pipeline"].scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe_args["pipeline"].scheduler.config)

            # Set Loras
            if request.loras:
                for i, lora in enumerate(request.loras):
                    pipe_args["pipeline"].load_lora_weights(lora['repo_id'], adapter_name=f"lora_{i}")
                adapter_names = [f"lora_{i}" for i in range(len(request.loras))]
                adapter_weights = [lora['weight'] for lora in request.loras]

                if request.fast_generation:
                    hyper_lora = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")
                    hyper_weight = 0.125
                    pipe_args["pipeline"].load_lora_weights(hyper_lora, adapter_name="hyper_lora")
                    adapter_names.append("hyper_lora")
                    adapter_weights.append(hyper_weight)

                pipe_args["pipeline"].set_adapters(adapter_names, adapter_weights)

            return pipe_args


def get_prompt_attention(pipeline, prompt):
    return get_weighted_text_embeddings_flux1(pipeline, prompt)


# Gen Function
def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
    pipe_args = get_pipe(request)
    pipeline = pipe_args["pipeline"]
    try:
        positive_prompt_embeds, positive_prompt_pooled = get_prompt_attention(pipeline, request.prompt)

        # Common Args
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        args = {
            'prompt_embeds': positive_prompt_embeds,
            'pooled_prompt_embeds': positive_prompt_pooled,
            'height': request.height,
            'width': request.width,
            'num_images_per_prompt': request.num_images_per_prompt,
            'num_inference_steps': request.num_inference_steps,
            'guidance_scale': request.guidance_scale,
            'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)],
        }

        if request.controlnet_config:
            args['control_mode'] = get_control_mode(request.controlnet_config)
            args['control_images'] = get_controlnet_images(request.controlnet_config, request.height, request.width, request.resize_mode)
            args['controlnet_conditioning_scale'] = request.controlnet_config.controlnet_conditioning_scale

        if isinstance(request, (BaseImg2ImgReq, BaseInpaintReq)):
            args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)[0]
            args['strength'] = request.strength

        if isinstance(request, BaseInpaintReq):
            args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)[0]

        # Generate
        images = pipeline(**args).images

        # Refiner
        if request.refiner:
            images = refiner(image=images, prompt=request.prompt, num_inference_steps=40, denoising_start=0.7).images

        return images
    except Exception as e:
        cleanup(pipeline, request.loras)
        raise gr.Error(f"Error: {e}")
    finally:
        cleanup(pipeline, request.loras)