from typing import Literal, Union, Optional

import torch, gc, os
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection, T5TokenizerFast
from transformers import (
    AutoModel,
    CLIPModel,
    CLIPProcessor,
)
from huggingface_hub import hf_hub_download
from diffusers import (
    UNet2DConditionModel,
    SchedulerMixin,
    StableDiffusionPipeline,
    StableDiffusionXLPipeline,
    FluxPipeline,
    AutoencoderKL,
    FluxTransformer2DModel,
)
import copy
from diffusers.schedulers import (
    DDIMScheduler,
    DDPMScheduler,
    LMSDiscreteScheduler,
    EulerAncestralDiscreteScheduler,
    FlowMatchEulerDiscreteScheduler,
)
from diffusers import LCMScheduler, AutoencoderTiny
import sys
sys.path.append('.')
from .flux_utils import *

TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4"
TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1"

AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"]

SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]

DIFFUSERS_CACHE_DIR = None  # if you want to change the cache dir, change this


def load_diffusers_model(
    pretrained_model_name_or_path: str,
    v2: bool = False,
    clip_skip: Optional[int] = None,
    weight_dtype: torch.dtype = torch.float32,
) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
    # VAE はいらない

    if v2:
        tokenizer = CLIPTokenizer.from_pretrained(
            TOKENIZER_V2_MODEL_NAME,
            subfolder="tokenizer",
            torch_dtype=weight_dtype,
            cache_dir=DIFFUSERS_CACHE_DIR,
        )
        text_encoder = CLIPTextModel.from_pretrained(
            pretrained_model_name_or_path,
            subfolder="text_encoder",
            # default is clip skip 2
            num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23,
            torch_dtype=weight_dtype,
            cache_dir=DIFFUSERS_CACHE_DIR,
        )
    else:
        tokenizer = CLIPTokenizer.from_pretrained(
            TOKENIZER_V1_MODEL_NAME,
            subfolder="tokenizer",
            torch_dtype=weight_dtype,
            cache_dir=DIFFUSERS_CACHE_DIR,
        )
        text_encoder = CLIPTextModel.from_pretrained(
            pretrained_model_name_or_path,
            subfolder="text_encoder",
            num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12,
            torch_dtype=weight_dtype,
            cache_dir=DIFFUSERS_CACHE_DIR,
        )

    unet = UNet2DConditionModel.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="unet",
        torch_dtype=weight_dtype,
        cache_dir=DIFFUSERS_CACHE_DIR,
    )

    return tokenizer, text_encoder, unet


def load_checkpoint_model(
    checkpoint_path: str,
    v2: bool = False,
    clip_skip: Optional[int] = None,
    weight_dtype: torch.dtype = torch.float32,
) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
    pipe = StableDiffusionPipeline.from_ckpt(
        checkpoint_path,
        upcast_attention=True if v2 else False,
        torch_dtype=weight_dtype,
        cache_dir=DIFFUSERS_CACHE_DIR,
    )

    unet = pipe.unet
    tokenizer = pipe.tokenizer
    text_encoder = pipe.text_encoder
    if clip_skip is not None:
        if v2:
            text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1)
        else:
            text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1)

    del pipe

    return tokenizer, text_encoder, unet


def load_models(
    pretrained_model_name_or_path: str,
    scheduler_name: AVAILABLE_SCHEDULERS,
    v2: bool = False,
    v_pred: bool = False,
    weight_dtype: torch.dtype = torch.float32,
) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]:
    if pretrained_model_name_or_path.endswith(
        ".ckpt"
    ) or pretrained_model_name_or_path.endswith(".safetensors"):
        tokenizer, text_encoder, unet = load_checkpoint_model(
            pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
        )
    else:  # diffusers
        tokenizer, text_encoder, unet = load_diffusers_model(
            pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
        )

    # VAE はいらない

    scheduler = create_noise_scheduler(
        scheduler_name,
        prediction_type="v_prediction" if v_pred else "epsilon",
    )

    return tokenizer, text_encoder, unet, scheduler


def load_diffusers_model_xl(
    pretrained_model_name_or_path: str,
    weight_dtype: torch.dtype = torch.float32,
) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
    # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet

    tokenizers = [
        CLIPTokenizer.from_pretrained(
            pretrained_model_name_or_path,
            subfolder="tokenizer",
            torch_dtype=weight_dtype,
            cache_dir=DIFFUSERS_CACHE_DIR,
        ),
        CLIPTokenizer.from_pretrained(
            pretrained_model_name_or_path,
            subfolder="tokenizer_2",
            torch_dtype=weight_dtype,
            cache_dir=DIFFUSERS_CACHE_DIR,
            pad_token_id=0,  # same as open clip
        ),
    ]

    text_encoders = [
        CLIPTextModel.from_pretrained(
            pretrained_model_name_or_path,
            subfolder="text_encoder",
            torch_dtype=weight_dtype,
            cache_dir=DIFFUSERS_CACHE_DIR,
        ),
        CLIPTextModelWithProjection.from_pretrained(
            pretrained_model_name_or_path,
            subfolder="text_encoder_2",
            torch_dtype=weight_dtype,
            cache_dir=DIFFUSERS_CACHE_DIR,
        ),
    ]

    unet = UNet2DConditionModel.from_pretrained(
        pretrained_model_name_or_path,
        subfolder="unet",
        torch_dtype=weight_dtype,
        cache_dir=DIFFUSERS_CACHE_DIR,
    )

    return tokenizers, text_encoders, unet


def load_checkpoint_model_xl(
    checkpoint_path: str,
    weight_dtype: torch.dtype = torch.float32,
) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
    pipe = StableDiffusionXLPipeline.from_single_file(
        checkpoint_path,
        torch_dtype=weight_dtype,
        cache_dir=DIFFUSERS_CACHE_DIR,
    )

    unet = pipe.unet
    tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
    text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
    if len(text_encoders) == 2:
        text_encoders[1].pad_token_id = 0

    del pipe

    return tokenizers, text_encoders, unet


def load_models_xl_(
    pretrained_model_name_or_path: str,
    scheduler_name: AVAILABLE_SCHEDULERS,
    weight_dtype: torch.dtype = torch.float32,
) -> tuple[
    list[CLIPTokenizer],
    list[SDXL_TEXT_ENCODER_TYPE],
    UNet2DConditionModel,
    SchedulerMixin,
]:
    if pretrained_model_name_or_path.endswith(
        ".ckpt"
    ) or pretrained_model_name_or_path.endswith(".safetensors"):
        (
            tokenizers,
            text_encoders,
            unet,
        ) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype)
    else:  # diffusers
        (
            tokenizers,
            text_encoders,
            unet,
        ) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype)

    scheduler = create_noise_scheduler(scheduler_name)
        
    return tokenizers, text_encoders, unet, scheduler


def create_noise_scheduler(
    scheduler_name: AVAILABLE_SCHEDULERS = "ddpm",
    prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
) -> SchedulerMixin:
    # 正直、どれがいいのかわからない。元の実装だとDDIMとDDPMとLMSを選べたのだけど、どれがいいのかわからぬ。

    name = scheduler_name.lower().replace(" ", "_")
    if name == "ddim":
        # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim
        scheduler = DDIMScheduler(
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            num_train_timesteps=1000,
            clip_sample=False,
            prediction_type=prediction_type,  # これでいいの?
        )
    elif name == "ddpm":
        # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm
        scheduler = DDPMScheduler(
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            num_train_timesteps=1000,
            clip_sample=False,
            prediction_type=prediction_type,
        )
    elif name == "lms":
        # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete
        scheduler = LMSDiscreteScheduler(
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            num_train_timesteps=1000,
            prediction_type=prediction_type,
        )
    elif name == "euler_a":
        # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
        scheduler = EulerAncestralDiscreteScheduler(
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            num_train_timesteps=1000,
            # clip_sample=False,
            prediction_type=prediction_type,
        )
    else:
        raise ValueError(f"Unknown scheduler name: {name}")

    return scheduler


def load_models_xl(params):
    """
    Load all required models for training
    
    Args:
        params: Dictionary containing model parameters and configurations
        
    Returns:
        dict: Dictionary containing all loaded models and tokenizers
    """
    device = params['device']
    weight_dtype = params['weight_dtype']
    
    # Load SDXL components (UNet, text encoders, tokenizers)
    scheduler_name = 'ddim'
    tokenizers, text_encoders, unet, noise_scheduler = load_models_xl_(
        params['pretrained_model_name_or_path'],
        scheduler_name=scheduler_name,
    )
    
    # Move text encoders to device and set to eval mode
    for text_encoder in text_encoders:
        text_encoder.to(device, dtype=weight_dtype)
        text_encoder.requires_grad_(False)
        text_encoder.eval()
    
    # Set up UNet
    unet.to(device, dtype=weight_dtype)
    unet.requires_grad_(False)
    unet.eval()
    
    # Load tiny VAE for efficiency
    vae = AutoencoderTiny.from_pretrained(
        "madebyollin/taesdxl", 
        torch_dtype=weight_dtype
    )
    vae = vae.to(device, dtype=weight_dtype)
    vae.requires_grad_(False)
    
    # Load appropriate encoder (CLIP or DinoV2)
    if params['encoder'] == 'dinov2-small':
        clip_model = AutoModel.from_pretrained(
            'facebook/dinov2-small', 
            torch_dtype=weight_dtype
        )
        clip_processor= None
    else:
        clip_model = CLIPModel.from_pretrained(
            "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M", 
            torch_dtype=weight_dtype
        )
        clip_processor = CLIPProcessor.from_pretrained("wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M")
    clip_model = clip_model.to(device, dtype=weight_dtype)
    clip_model.requires_grad_(False)

    
    
    # If using DMD checkpoint, load it
    if params['distilled'] != 'None':
        if '.safetensors' in params['distilled']:
            unet.load_state_dict(load_file(params['distilled'], device=device))
        elif 'dmd2' in params['distilled']:
            repo_name = "tianweiy/DMD2"
            ckpt_name = "dmd2_sdxl_4step_unet_fp16.bin"
            unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name)))
        else:
            unet.load_state_dict(torch.load(params['distilled']))

        
        # Set up LCM scheduler for DMD
        noise_scheduler = LCMScheduler(
            beta_start=0.00085,
            beta_end=0.012,
            beta_schedule="scaled_linear",
            num_train_timesteps=1000,
            prediction_type="epsilon",
            original_inference_steps=1000
        )

    noise_scheduler.set_timesteps(params['max_denoising_steps'])
    pipe = StableDiffusionXLPipeline(vae = vae,
            text_encoder = text_encoders[0],
            text_encoder_2 = text_encoders[1],
            tokenizer = tokenizers[0],
            tokenizer_2 = tokenizers[1],
            unet = unet,
            scheduler = noise_scheduler)
    pipe.set_progress_bar_config(disable=True)
    return {
        'unet': unet,
        'vae': vae,
        'clip_model': clip_model,
        'clip_processor': clip_processor,
        'tokenizers': tokenizers,
        'text_encoders': text_encoders,
        'noise_scheduler': noise_scheduler
    }, pipe


def load_models_flux(params):
    # Load the tokenizers
    tokenizer_one = CLIPTokenizer.from_pretrained(
        params['pretrained_model_name_or_path'],
        subfolder="tokenizer",
        torch_dtype=params['weight_dtype'], device_map=params['device']
    )
    tokenizer_two = T5TokenizerFast.from_pretrained(
        params['pretrained_model_name_or_path'],
        subfolder="tokenizer_2",
        torch_dtype=params['weight_dtype'], device_map=params['device']
    )
    
    # Load scheduler and models
    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
            params['pretrained_model_name_or_path'], 
            subfolder="scheduler",
            torch_dtype=params['weight_dtype'], device=params['device']
    )
    noise_scheduler_copy = copy.deepcopy(noise_scheduler)
    
    
    
    # import correct text encoder classes
    text_encoder_cls_one = import_model_class_from_model_name_or_path(
        params['pretrained_model_name_or_path'],
    )
    text_encoder_cls_two = import_model_class_from_model_name_or_path(
       params['pretrained_model_name_or_path'], subfolder="text_encoder_2"
    )
    # Load the text encoders
    text_encoder_one, text_encoder_two = load_text_encoders(params['pretrained_model_name_or_path'], text_encoder_cls_one, text_encoder_cls_two, params['weight_dtype'])
    
    # Load VAE
    vae = AutoencoderKL.from_pretrained(
        params['pretrained_model_name_or_path'],
        subfolder="vae",
        torch_dtype=params['weight_dtype'], device_map='auto'
    )
    transformer = FluxTransformer2DModel.from_pretrained(
        params['pretrained_model_name_or_path'], 
        subfolder="transformer", 
        torch_dtype=params['weight_dtype']
    )
    
    # We only train the additional adapter LoRA layers
    transformer.requires_grad_(False)
    vae.requires_grad_(False)
    text_encoder_one.requires_grad_(False)
    text_encoder_two.requires_grad_(False)
    
    vae.to(params['device'])
    transformer.to(params['device'])
    text_encoder_one.to(params['device'])
    text_encoder_two.to(params['device'])

    # Load appropriate encoder (CLIP or DinoV2)
    if params['encoder'] == 'dinov2-small':
        clip_model = AutoModel.from_pretrained(
            'facebook/dinov2-small', 
            torch_dtype=params['weight_dtype']
        )
        clip_processor= None
    else:
        clip_model = CLIPModel.from_pretrained(
            "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M", 
            torch_dtype=params['weight_dtype']
        )
        clip_processor = CLIPProcessor.from_pretrained("wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M")
    clip_model = clip_model.to(params['device'], dtype=params['weight_dtype'])
    clip_model.requires_grad_(False)


    pipe = FluxPipeline(noise_scheduler,
                    vae,
                    text_encoder_one,
                    tokenizer_one,
                    text_encoder_two,
                    tokenizer_two,
                    transformer,
                   )
    pipe.set_progress_bar_config(disable=True)

    return {
        'transformer': transformer,
        'vae': vae,
        'clip_model': clip_model,
        'clip_processor': clip_processor,
        'tokenizers': [tokenizer_one, tokenizer_two],
        'text_encoders': [text_encoder_one,text_encoder_two],
        'noise_scheduler': noise_scheduler
    }, pipe

def save_checkpoint(networks, save_path, weight_dtype):
    """
    Save network weights and perform cleanup
    
    Args:
        networks: Dictionary of LoRA networks to save
        save_path: Path to save the checkpoints
        weight_dtype: Data type for the weights
    """
    print("Saving checkpoint...")
    
    try:
        # Create save directory if it doesn't exist
        os.makedirs(save_path, exist_ok=True)
        
        # Save each network's weights
        for net_idx, network in networks.items():
            save_name = f"{save_path}/slider_{net_idx}.pt"
            try:
                network.save_weights(
                    save_name,
                    dtype=weight_dtype,
                )
            except Exception as e:
                print(f"Error saving network {net_idx}: {str(e)}")
                continue
                
        # Cleanup
        torch.cuda.empty_cache()
        gc.collect()
        
        print("Checkpoint saved successfully.")
        
    except Exception as e:
        print(f"Error during checkpoint saving: {str(e)}")
        
    finally:
        # Ensure memory is cleaned up even if save fails
        torch.cuda.empty_cache()
        gc.collect()