import argparse
from typing import Any, Dict

import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

from diffusers import AutoencoderDC


def remap_qkv_(key: str, state_dict: Dict[str, Any]):
    qkv = state_dict.pop(key)
    q, k, v = torch.chunk(qkv, 3, dim=0)
    parent_module, _, _ = key.rpartition(".qkv.conv.weight")
    state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
    state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
    state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()


def remap_proj_conv_(key: str, state_dict: Dict[str, Any]):
    parent_module, _, _ = key.rpartition(".proj.conv.weight")
    state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()


AE_KEYS_RENAME_DICT = {
    # common
    "main.": "",
    "op_list.": "",
    "context_module": "attn",
    "local_module": "conv_out",
    # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
    # If there were more scales, there would be more layers, so a loop would be better to handle this
    "aggreg.0.0": "to_qkv_multiscale.0.proj_in",
    "aggreg.0.1": "to_qkv_multiscale.0.proj_out",
    "depth_conv.conv": "conv_depth",
    "inverted_conv.conv": "conv_inverted",
    "point_conv.conv": "conv_point",
    "point_conv.norm": "norm",
    "conv.conv.": "conv.",
    "conv1.conv": "conv1",
    "conv2.conv": "conv2",
    "conv2.norm": "norm",
    "proj.norm": "norm_out",
    # encoder
    "encoder.project_in.conv": "encoder.conv_in",
    "encoder.project_out.0.conv": "encoder.conv_out",
    "encoder.stages": "encoder.down_blocks",
    # decoder
    "decoder.project_in.conv": "decoder.conv_in",
    "decoder.project_out.0": "decoder.norm_out",
    "decoder.project_out.2.conv": "decoder.conv_out",
    "decoder.stages": "decoder.up_blocks",
}

AE_F32C32_KEYS = {
    # encoder
    "encoder.project_in.conv": "encoder.conv_in.conv",
    # decoder
    "decoder.project_out.2.conv": "decoder.conv_out.conv",
}

AE_F64C128_KEYS = {
    # encoder
    "encoder.project_in.conv": "encoder.conv_in.conv",
    # decoder
    "decoder.project_out.2.conv": "decoder.conv_out.conv",
}

AE_F128C512_KEYS = {
    # encoder
    "encoder.project_in.conv": "encoder.conv_in.conv",
    # decoder
    "decoder.project_out.2.conv": "decoder.conv_out.conv",
}

AE_SPECIAL_KEYS_REMAP = {
    "qkv.conv.weight": remap_qkv_,
    "proj.conv.weight": remap_proj_conv_,
}


def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
    state_dict = saved_dict
    if "model" in saved_dict.keys():
        state_dict = state_dict["model"]
    if "module" in saved_dict.keys():
        state_dict = state_dict["module"]
    if "state_dict" in saved_dict.keys():
        state_dict = state_dict["state_dict"]
    return state_dict


def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
    state_dict[new_key] = state_dict.pop(old_key)


def convert_ae(config_name: str, dtype: torch.dtype):
    config = get_ae_config(config_name)
    hub_id = f"mit-han-lab/{config_name}"
    ckpt_path = hf_hub_download(hub_id, "model.safetensors")
    original_state_dict = get_state_dict(load_file(ckpt_path))

    ae = AutoencoderDC(**config).to(dtype=dtype)

    for key in list(original_state_dict.keys()):
        new_key = key[:]
        for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
            new_key = new_key.replace(replace_key, rename_key)
        update_state_dict_(original_state_dict, key, new_key)

    for key in list(original_state_dict.keys()):
        for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
            if special_key not in key:
                continue
            handler_fn_inplace(key, original_state_dict)

    ae.load_state_dict(original_state_dict, strict=True)
    return ae


def get_ae_config(name: str):
    if name in ["dc-ae-f32c32-sana-1.0"]:
        config = {
            "latent_channels": 32,
            "encoder_block_types": (
                "ResBlock",
                "ResBlock",
                "ResBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
            ),
            "decoder_block_types": (
                "ResBlock",
                "ResBlock",
                "ResBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
            ),
            "encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
            "decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
            "encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
            "decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
            "encoder_layers_per_block": (2, 2, 2, 3, 3, 3),
            "decoder_layers_per_block": [3, 3, 3, 3, 3, 3],
            "downsample_block_type": "conv",
            "upsample_block_type": "interpolate",
            "decoder_norm_types": "rms_norm",
            "decoder_act_fns": "silu",
            "scaling_factor": 0.41407,
        }
    elif name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
        AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS)
        config = {
            "latent_channels": 32,
            "encoder_block_types": [
                "ResBlock",
                "ResBlock",
                "ResBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
            ],
            "decoder_block_types": [
                "ResBlock",
                "ResBlock",
                "ResBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
            ],
            "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
            "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
            "encoder_layers_per_block": [0, 4, 8, 2, 2, 2],
            "decoder_layers_per_block": [0, 5, 10, 2, 2, 2],
            "encoder_qkv_multiscales": ((), (), (), (), (), ()),
            "decoder_qkv_multiscales": ((), (), (), (), (), ()),
            "decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"],
            "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"],
        }
        if name == "dc-ae-f32c32-in-1.0":
            config["scaling_factor"] = 0.3189
        elif name == "dc-ae-f32c32-mix-1.0":
            config["scaling_factor"] = 0.4552
    elif name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
        AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS)
        config = {
            "latent_channels": 128,
            "encoder_block_types": [
                "ResBlock",
                "ResBlock",
                "ResBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
            ],
            "decoder_block_types": [
                "ResBlock",
                "ResBlock",
                "ResBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
            ],
            "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
            "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
            "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2],
            "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2],
            "encoder_qkv_multiscales": ((), (), (), (), (), (), ()),
            "decoder_qkv_multiscales": ((), (), (), (), (), (), ()),
            "decoder_norm_types": [
                "batch_norm",
                "batch_norm",
                "batch_norm",
                "rms_norm",
                "rms_norm",
                "rms_norm",
                "rms_norm",
            ],
            "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"],
        }
        if name == "dc-ae-f64c128-in-1.0":
            config["scaling_factor"] = 0.2889
        elif name == "dc-ae-f64c128-mix-1.0":
            config["scaling_factor"] = 0.4538
    elif name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
        AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS)
        config = {
            "latent_channels": 512,
            "encoder_block_types": [
                "ResBlock",
                "ResBlock",
                "ResBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
            ],
            "decoder_block_types": [
                "ResBlock",
                "ResBlock",
                "ResBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
                "EfficientViTBlock",
            ],
            "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
            "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
            "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2],
            "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2],
            "encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
            "decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
            "decoder_norm_types": [
                "batch_norm",
                "batch_norm",
                "batch_norm",
                "rms_norm",
                "rms_norm",
                "rms_norm",
                "rms_norm",
                "rms_norm",
            ],
            "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"],
        }
        if name == "dc-ae-f128c512-in-1.0":
            config["scaling_factor"] = 0.4883
        elif name == "dc-ae-f128c512-mix-1.0":
            config["scaling_factor"] = 0.3620
    else:
        raise ValueError("Invalid config name provided.")

    return config


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_name",
        type=str,
        default="dc-ae-f32c32-sana-1.0",
        choices=[
            "dc-ae-f32c32-sana-1.0",
            "dc-ae-f32c32-in-1.0",
            "dc-ae-f32c32-mix-1.0",
            "dc-ae-f64c128-in-1.0",
            "dc-ae-f64c128-mix-1.0",
            "dc-ae-f128c512-in-1.0",
            "dc-ae-f128c512-mix-1.0",
        ],
        help="The DCAE checkpoint to convert",
    )
    parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
    parser.add_argument("--dtype", default="fp32", help="Torch dtype to save the model in.")
    return parser.parse_args()


DTYPE_MAPPING = {
    "fp32": torch.float32,
    "fp16": torch.float16,
    "bf16": torch.bfloat16,
}

VARIANT_MAPPING = {
    "fp32": None,
    "fp16": "fp16",
    "bf16": "bf16",
}


if __name__ == "__main__":
    args = get_args()

    dtype = DTYPE_MAPPING[args.dtype]
    variant = VARIANT_MAPPING[args.dtype]

    ae = convert_ae(args.config_name, dtype)
    ae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", variant=variant)