import os
import json
import torch
import gc
import numpy as np
from PIL import Image
from diffusers import StableDiffusionXLPipeline
import open_clip
from huggingface_hub import hf_hub_download
from IP_Adapter.ip_adapter import IPAdapterXL
from perform_swap import compute_dataset_embeds_svd, get_modified_images_embeds_composition
from create_grids import create_grids
import argparse

def save_images(output_dir, image_list):
    os.makedirs(output_dir, exist_ok=True)
    for i, img in enumerate(image_list):
        img.save(os.path.join(output_dir, f"sample_{i + 1}.png"))

def get_image_embeds(pil_image, model, preprocess, device):
    image = preprocess(pil_image)[np.newaxis, :, :, :]
    with torch.no_grad():
        embeds = model.encode_image(image.to(device))
    return embeds.cpu().detach().numpy()

def process_combo(
    image_embeds_base,
    image_names_base,
    concept_embeds,
    concept_names,
    projection_matrices,
    ip_model,
    output_base_dir,
    num_samples=4,
    seed=420,
    prompt=None,
    scale=1.0
):
    for base_embed, base_name in zip(image_embeds_base, image_names_base):
        # Generate all combinations of concept embeddings
        for combo_indices in np.ndindex(*(len(embeds) for embeds in concept_embeds)):
            concept_combo_names = [concept_names[c][idx] for c, idx in enumerate(combo_indices)]
            combo_dir = os.path.join(
                output_base_dir,
                f"{base_name}_to_" + "_".join(concept_combo_names)
            )
            if os.path.exists(combo_dir):
                print(f"Directory {combo_dir} already exists. Skipping...")
                continue

            projections_data = [
                {
                    "embed": concept_embeds[c][idx],
                    "projection_matrix": projection_matrices[c]
                }
                for c, idx in enumerate(combo_indices)
            ]

            modified_images = get_modified_images_embeds_composition(
                base_embed, projections_data, ip_model, prompt=prompt, scale=scale, num_samples=num_samples, seed=seed
            )
            save_images(combo_dir, modified_images)
            del modified_images
            torch.cuda.empty_cache()
            gc.collect()

def main(config_path, should_create_grids):
    with open(config_path, 'r') as f:
        config = json.load(f)

    if "prompt" not in config:
        config["prompt"] = None
    
    if "scale" not in config:
        config["scale"] = 1.0 if config["prompt"] is None else 0.6

    if "seed" not in config:
        config["seed"] = 420

    if "num_samples" not in config:
        config["num_samples"] = 4


    base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"

    pipe = StableDiffusionXLPipeline.from_pretrained(
        base_model_path,
        torch_dtype=torch.float16,
        add_watermarker=False,
    )

    image_encoder_repo = 'h94/IP-Adapter'
    image_encoder_subfolder = 'models/image_encoder'

    ip_ckpt = hf_hub_download('h94/IP-Adapter', subfolder="sdxl_models", filename='ip-adapter_sdxl_vit-h.bin')
    device = "cuda"

    ip_model = IPAdapterXL(pipe, image_encoder_repo, image_encoder_subfolder, ip_ckpt, device)

    device = 'cuda:0'
    model, _, preprocess = open_clip.create_model_and_transforms('hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K')
    model.to(device)

    # Get base image embeddings
    image_files_base = [os.path.join(config["input_dir_base"], f) for f in os.listdir(config["input_dir_base"]) if f.lower().endswith(('png', 'jpg', 'jpeg'))]
    image_embeds_base = []
    image_names_base = []
    for path in image_files_base:
        img_name = os.path.basename(path)
        image_names_base.append(img_name)
        image_embeds_base.append(get_image_embeds(Image.open(path).convert("RGB"), model, preprocess, device))

    # Handle n concepts
    concept_dirs = config["input_dirs_concepts"]
    concept_embeds = []
    concept_names = []
    projection_matrices = []

    for concept_dir, embeds_path, rank in zip(concept_dirs, config["all_embeds_paths"], config["ranks"]):
        image_files = [os.path.join(concept_dir, f) for f in os.listdir(concept_dir) if f.lower().endswith(('png', 'jpg', 'jpeg'))]
        embeds = []
        names = []
        for path in image_files:
            img_name = os.path.basename(path)
            names.append(img_name)
            embeds.append(get_image_embeds(Image.open(path).convert("RGB"), model, preprocess, device))
        concept_embeds.append(embeds)
        concept_names.append(names)

        with open(embeds_path, "rb") as f:
            all_embeds_in = np.load(f)
        projection_matrix = compute_dataset_embeds_svd(all_embeds_in, rank)
        projection_matrices.append(projection_matrix)


    # Process combinations
    process_combo(
        image_embeds_base,
        image_names_base,
        concept_embeds,
        concept_names,
        projection_matrices,
        ip_model,
        config["output_base_dir"],
        config["num_samples"],
        config["seed"],
        config["prompt"],
        config["scale"]
    )

    # generate grids
    if should_create_grids:
        create_grids(config)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process images using embeddings and configurations.")
    parser.add_argument("--config", type=str, required=True, help="Path to the configuration JSON file.")
    parser.add_argument("--create_grids", action="store_true", help="Enable grid creation")
    args = parser.parse_args()
    main(args.config, args.create_grids)