import gc
import random
from typing import List, Optional

import torch
import numpy as np
from pydantic import BaseModel
from PIL import Image
from diffusers import (
    FluxPipeline,
    FluxImg2ImgPipeline,
    FluxInpaintPipeline,
    FluxControlNetPipeline,
    StableDiffusionXLPipeline,
    StableDiffusionXLImg2ImgPipeline,
    StableDiffusionXLInpaintPipeline,
    StableDiffusionXLControlNetPipeline,
    StableDiffusionXLControlNetImg2ImgPipeline,
    StableDiffusionXLControlNetInpaintPipeline,
    AutoPipelineForText2Image,
    AutoPipelineForImage2Image,
    AutoPipelineForInpainting,
)
from diffusers.schedulers import *
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from controlnet_aux.processor import Processor
from photomaker import (
    PhotoMakerStableDiffusionXLPipeline,
    PhotoMakerStableDiffusionXLControlNetPipeline,
    analyze_faces
)
from sd_embed.embedding_funcs import get_weighted_text_embeddings_sdxl, get_weighted_text_embeddings_flux1

from .init_sys import device, models, refiner, safety_checker, feature_extractor, controlnet_models, face_detector


# Models
class ControlNetReq(BaseModel):
    controlnets: List[str] # ["canny", "tile", "depth"]
    control_images: List[Image.Image]
    controlnet_conditioning_scale: List[float]
    
    class Config:
        arbitrary_types_allowed=True


class SDReq(BaseModel):
    model: str = ""
    prompt: str = ""
    negative_prompt: Optional[str] = "black-forest-labs/FLUX.1-dev"
    fast_generation: Optional[bool] = True
    loras: Optional[list] = []
    embeddings: Optional[list] = []
    resize_mode: Optional[str] = "resize_and_fill" # resize_only, crop_and_resize, resize_and_fill
    scheduler: Optional[str] = "euler_fl"
    height: int = 1024
    width: int = 1024
    num_images_per_prompt: int = 1
    num_inference_steps: int = 8
    guidance_scale: float = 3.5
    seed: Optional[int] = 0
    refiner: bool = False
    vae: bool = True
    controlnet_config: Optional[ControlNetReq] = None
    photomaker_images: Optional[List[Image.Image]] = None
    
    class Config:
        arbitrary_types_allowed=True


class SDImg2ImgReq(SDReq):
    image: Image.Image
    strength: float = 1.0
    
    class Config:
        arbitrary_types_allowed=True


class SDInpaintReq(SDImg2ImgReq):
    mask_image: Image.Image
    
    class Config:
        arbitrary_types_allowed=True


# Helper functions
def get_controlnet(controlnet_config: ControlNetReq):
    control_mode = []
    controlnet = []
    
    for m in controlnet_models:
        for c in controlnet_config.controlnets:
            if c in m["layers"]:
                control_mode.append(m["layers"].index(c))
                controlnet.append(m["controlnet"])
    
    return controlnet, control_mode


def get_pipe(request: SDReq | SDImg2ImgReq | SDInpaintReq):
    for m in models:
        if m["repo_id"] == request.model:
            pipeline = m['pipeline']
            controlnet, control_mode = get_controlnet(request.controlnet_config) if request.controlnet_config else (None, None)
            
            pipe_args = {
                "pipeline": pipeline,
                "control_mode": control_mode,
            }
            if request.controlnet_config:
                pipe_args["controlnet"] = controlnet

            if not request.photomaker_images:
                if isinstance(request, SDReq):
                    pipe_args['pipeline'] = AutoPipelineForText2Image.from_pipe(**pipe_args)
                elif isinstance(request, SDImg2ImgReq):
                    pipe_args['pipeline'] = AutoPipelineForImage2Image.from_pipe(**pipe_args)
                elif isinstance(request, SDInpaintReq):
                    pipe_args['pipeline'] = AutoPipelineForInpainting.from_pipe(**pipe_args)
                else:
                    raise ValueError(f"Unknown request type: {type(request)}")
            elif isinstance(request, any([PhotoMakerStableDiffusionXLPipeline, PhotoMakerStableDiffusionXLControlNetPipeline])):
                if request.controlnet_config:
                    pipe_args['pipeline'] = PhotoMakerStableDiffusionXLControlNetPipeline.from_pipe(**pipe_args)
                else:
                    pipe_args['pipeline'] = PhotoMakerStableDiffusionXLPipeline.from_pipe(**pipe_args)
            else:
                raise ValueError(f"Invalid request type: {type(request)}")
        
    return pipe_args


def load_scheduler(pipeline, scheduler):
    schedulers = {
        "dpmpp_2m": (DPMSolverMultistepScheduler, {}),
        "dpmpp_2m_k": (DPMSolverMultistepScheduler, {"use_karras_sigmas": True}),
        "dpmpp_2m_sde": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++"}),
        "dpmpp_2m_sde_k": (DPMSolverMultistepScheduler, {"algorithm_type": "sde-dpmsolver++", "use_karras_sigmas": True}),
        "dpmpp_sde": (DPMSolverSinglestepScheduler, {}),
        "dpmpp_sde_k": (DPMSolverSinglestepScheduler, {"use_karras_sigmas": True}),
        "dpm2": (KDPM2DiscreteScheduler, {}),
        "dpm2_k": (KDPM2DiscreteScheduler, {"use_karras_sigmas": True}),
        "dpm2_a": (KDPM2AncestralDiscreteScheduler, {}),
        "dpm2_a_k": (KDPM2AncestralDiscreteScheduler, {"use_karras_sigmas": True}),
        "euler": (EulerDiscreteScheduler, {}),
        "euler_a": (EulerAncestralDiscreteScheduler, {}),
        "heun": (HeunDiscreteScheduler, {}),
        "lms": (LMSDiscreteScheduler, {}),
        "lms_k": (LMSDiscreteScheduler, {"use_karras_sigmas": True}),
        "deis": (DEISMultistepScheduler, {}),
        "unipc": (UniPCMultistepScheduler, {}),
        "fm_euler": (FlowMatchEulerDiscreteScheduler, {}),
    }
    scheduler_class, kwargs = schedulers.get(scheduler, (None, {}))
    
    if scheduler_class is not None:
        scheduler = scheduler_class.from_config(pipeline.scheduler.config, **kwargs)
    else:
        raise ValueError(f"Unknown scheduler: {scheduler}")
    
    return scheduler


def load_loras(pipeline, loras, fast_generation):
    for i, lora in enumerate(loras):
        pipeline.load_lora_weights(lora['repo_id'], adapter_name=f"lora_{i}")
    adapter_names = [f"lora_{i}" for i in range(len(loras))]
    adapter_weights = [lora['weight'] for lora in loras]
    
    if fast_generation:
        hyper_lora = hf_hub_download(
            "ByteDance/Hyper-SD",
            "Hyper-FLUX.1-dev-8steps-lora.safetensors" if isinstance(pipeline, FluxPipeline) else "Hyper-SDXL-2steps-lora.safetensors"
        )
        hyper_weight = 0.125 if isinstance(pipeline, FluxPipeline) else 1.0
        pipeline.load_lora_weights(hyper_lora, adapter_name="hyper_lora")
        adapter_names.append("hyper_lora")
        adapter_weights.append(hyper_weight)
    
    pipeline.set_adapters(adapter_names, adapter_weights)


def load_xl_embeddings(pipeline, embeddings):
    for embedding in embeddings:
        state_dict = load_file(hf_hub_download(embedding['repo_id']))
        pipeline.load_textual_inversion(state_dict['clip_g'], token=embedding['token'], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
        pipeline.load_textual_inversion(state_dict["clip_l"], token=embedding['token'], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)


def resize_images(images: List[Image.Image], height: int, width: int, resize_mode: str):
    for image in images:
        if resize_mode == "resize_only":
            image = image.resize((width, height))
        elif resize_mode == "crop_and_resize":
            image = image.crop((0, 0, width, height))
        elif resize_mode == "resize_and_fill":
            image = image.resize((width, height), Image.Resampling.LANCZOS)

    return images


def get_controlnet_images(controlnets: List[str], control_images: List[Image.Image], height: int, width: int, resize_mode: str):
    response_images = []
    control_images = resize_images(control_images, height, width, resize_mode)
    for controlnet, image in zip(controlnets, control_images):
        if controlnet == "canny" or controlnet == "canny_xs" or controlnet == "canny_fl":
            processor = Processor('canny')
        elif controlnet == "depth" or controlnet == "depth_xs" or controlnet == "depth_fl":
            processor = Processor('depth_midas')
        elif controlnet == "pose" or controlnet == "pose_fl":
            processor = Processor('openpose_full')
        elif controlnet == "scribble":
            processor = Processor('scribble')
        else:
            raise ValueError(f"Invalid Controlnet: {controlnet}")
    
        response_images.append(processor(image, to_pil=True))
    
    return response_images


def check_image_safety(images: List[Image.Image]):
    safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
    has_nsfw_concepts = safety_checker(
        images=[images],
        clip_input=safety_checker_input.pixel_values.to("cuda"),
    )
    
    return has_nsfw_concepts[1]


def get_prompt_attention(pipeline, prompt, negative_prompt):
    if isinstance(pipeline, (FluxPipeline, FluxImg2ImgPipeline, FluxInpaintPipeline, FluxControlNetPipeline)):
        prompt_embeds, pooled_prompt_embeds = get_weighted_text_embeddings_flux1(pipeline, prompt)
        return prompt_embeds, None, pooled_prompt_embeds, None
    elif isinstance(pipeline, StableDiffusionXLPipeline):
        prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = get_weighted_text_embeddings_sdxl(pipeline, prompt, negative_prompt)
        return prompt_embeds, prompt_neg_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
    else:
        raise ValueError(f"Invalid pipeline type: {type(pipeline)}")


def get_photomaker_images(photomaker_images: List[Image.Image], height: int, width: int, resize_mode: str):
    image_input_ids = []
    image_id_embeds = []
    photomaker_images = resize_images(photomaker_images, height, width, resize_mode)
    
    for image in photomaker_images:
        image_input_ids.append(img)
        img = np.array(image)[:, :, ::-1]
        faces = analyze_faces(face_detector, image)
        if len(faces) > 0:
            image_id_embeds.append(torch.from_numpy(faces[0]['embeddings']))
        else:
            raise ValueError("No face detected in the image")
    
    return image_input_ids, image_id_embeds


def cleanup(pipeline, loras = None, embeddings = None):
    if loras:
        pipeline.disable_lora()
        pipeline.unload_lora_weights()
    if embeddings:
        pipeline.unload_textual_inversion()
    gc.collect()
    torch.cuda.empty_cache()


# Gen function
def gen_img(
    request: SDReq | SDImg2ImgReq | SDInpaintReq
):
    pipeline_args = get_pipe(request)
    pipeline = pipeline_args['pipeline']
    try:
        pipeline.scheduler = load_scheduler(pipeline, request.scheduler)
        
        load_loras(pipeline, request.loras, request.fast_generation)
        load_xl_embeddings(pipeline, request.embeddings)
        
        control_images = get_controlnet_images(request.controlnet_config.controlnets, request.controlnet_config.control_images, request.height, request.width, request.resize_mode) if request.controlnet_config else None
        photomaker_images, photomaker_id_embeds = get_photomaker_images(request.photomaker_images, request.height, request.width) if request.photomaker_images else (None, None)
        
        positive_prompt_embeds, negative_prompt_embeds, positive_prompt_pooled, negative_prompt_pooled = get_prompt_attention(pipeline, request.prompt, request.negative_prompt)
        
        # Common args
        args = {
            'prompt_embeds': positive_prompt_embeds,
            'pooled_prompt_embeds': positive_prompt_pooled,
            'height': request.height,
            'width': request.width,
            'num_images_per_prompt': request.num_images_per_prompt,
            'num_inference_steps': request.num_inference_steps,
            'guidance_scale': request.guidance_scale,
            'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)],
        }
        
        if isinstance(pipeline, any([StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline,
                                     StableDiffusionXLControlNetPipeline, StableDiffusionXLControlNetImg2ImgPipeline, StableDiffusionXLControlNetInpaintPipeline])):
            args['clip_skip'] = request.clip_skip
            args['negative_prompt_embeds'] = negative_prompt_embeds
            args['negative_pooled_prompt_embeds'] = negative_prompt_pooled
        
        if isinstance(pipeline, FluxControlNetPipeline) and request.controlnet_config:
            args['control_mode'] = pipeline_args['control_mode']
            args['control_image'] = control_images
            args['controlnet_conditioning_scale'] = request.controlnet_conditioning_scale
        
        if not isinstance(pipeline, FluxControlNetPipeline) and request.controlnet_config:
            args['controlnet_conditioning_scale'] = request.controlnet_conditioning_scale
        
            if isinstance(request, SDReq):
                args['image'] = control_images
            elif isinstance(request, (SDImg2ImgReq, SDInpaintReq)):
                args['control_image'] = control_images
        
        if request.photomaker_images and isinstance(pipeline, any([PhotoMakerStableDiffusionXLPipeline, PhotoMakerStableDiffusionXLControlNetPipeline])):
            args['input_id_images'] = photomaker_images
            args['input_id_embeds'] = photomaker_id_embeds
            args['start_merge_step'] = 10
        
        if isinstance(request, SDImg2ImgReq):
            args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)
            args['strength'] = request.strength
        elif isinstance(request, SDInpaintReq):
            args['image'] = resize_images([request.image], request.height, request.width, request.resize_mode)
            args['mask_image'] = resize_images([request.mask_image], request.height, request.width, request.resize_mode)
            args['strength'] = request.strength
        
        images = pipeline(**args).images
        
        if request.refiner:
            images = refiner(
                prompt=request.prompt,
                num_inference_steps=40,
                denoising_start=0.7,
                image=images.images
            ).images
        
        cleanup(pipeline, request.loras, request.embeddings)
        
        return images
    except Exception as e:
        cleanup(pipeline, request.loras, request.embeddings)
        raise ValueError(f"Error generating image: {e}") from e