import sys
import os
import torch

from PIL import Image
from typing import List
import numpy as np
from utils import (
    tensor_to_pil,
    pil_to_tensor,
    pad_image,
    postprocess_image,
    preprocess_image,
    downloadModels,
    examples,
)

sys.path.append(os.path.dirname("./ComfyUI/"))
from ComfyUI.nodes import (
    CheckpointLoaderSimple,
    VAEDecode,
    VAEEncode,
    KSampler,
    EmptyLatentImage,
    CLIPTextEncode,
)
from ComfyUI.comfy_extras.nodes_compositing import JoinImageWithAlpha
from ComfyUI.comfy_extras.nodes_mask import InvertMask, MaskToImage

from ComfyUI.comfy import samplers

from ComfyUI.custom_nodes.layerdiffuse.layered_diffusion import (
    LayeredDiffusionFG,
    LayeredDiffusionDecode,
    LayeredDiffusionCond,
)
import gradio as gr
from briarmbg import BriaRMBG

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

downloadModels()

with torch.inference_mode():
    ckpt_load_checkpoint = CheckpointLoaderSimple().load_checkpoint
    ckpt = ckpt_load_checkpoint(
        ckpt_name="juggernautXL_version6Rundiffusion.safetensors"
    )

cliptextencode = CLIPTextEncode().encode
emptylatentimage_generate = EmptyLatentImage().generate
ksampler_sample = KSampler().sample
vae_decode = VAEDecode().decode
vae_encode = VAEEncode().encode
ld_fg_apply_layered_diffusion = LayeredDiffusionFG().apply_layered_diffusion
ld_cond_apply_layered_diffusion = LayeredDiffusionCond().apply_layered_diffusion

ld_decode = LayeredDiffusionDecode().decode
mask_to_image = MaskToImage().mask_to_image
invert_mask = InvertMask().invert
join_image_with_alpha = JoinImageWithAlpha().join_image_with_alpha
rmbg_model = BriaRMBG.from_pretrained("briaai/RMBG-1.4").to(device)


def predict(
    prompt: str,
    negative_prompt: str,
    input_image: Image.Image,
    remove_bg: bool,
    cond_mode: str,
    seed: int,
    sampler_name: str,
    scheduler: str,
    steps: int,
    cfg: float,
    denoise: float,
):
    seed = seed if seed != -1 else np.random.randint(0, 2**63 - 1)
    try:
        with torch.inference_mode():
            cliptextencode_prompt = cliptextencode(
                text=prompt,
                clip=ckpt[1],
            )
            cliptextencode_negative_prompt = cliptextencode(
                text=negative_prompt,
                clip=ckpt[1],
            )
            emptylatentimage_sample = emptylatentimage_generate(
                width=1024, height=1024, batch_size=1
            )

            if input_image is not None:
                input_image = pad_image(input_image).resize((1024, 1024))
                if remove_bg:
                    orig_im_size = input_image.size
                    image = preprocess_image(np.array(input_image), [1024, 1024]).to(
                        device
                    )

                    result = rmbg_model(image)
                    # post process
                    result_mask_image = postprocess_image(result[0][0], orig_im_size)

                    # save result
                    pil_mask = Image.fromarray(result_mask_image)
                    no_bg_image = Image.new("RGBA", pil_mask.size, (0, 0, 0, 0))
                    no_bg_image.paste(input_image, mask=pil_mask)
                    input_image = no_bg_image

                img_tensor = pil_to_tensor(input_image)
                img_latent = vae_encode(pixels=img_tensor[0], vae=ckpt[2])
                layereddiffusionapply_sample = ld_cond_apply_layered_diffusion(
                    config=cond_mode,
                    weight=1,
                    model=ckpt[0],
                    cond=cliptextencode_prompt[0],
                    uncond=cliptextencode_negative_prompt[0],
                    latent=img_latent[0],
                )
                ksampler = ksampler_sample(
                    steps=steps,
                    cfg=cfg,
                    sampler_name=sampler_name,
                    scheduler=scheduler,
                    seed=seed,
                    model=layereddiffusionapply_sample[0],
                    positive=layereddiffusionapply_sample[1],
                    negative=layereddiffusionapply_sample[2],
                    latent_image=emptylatentimage_sample[0],
                    denoise=denoise,
                )

                vaedecode_sample = vae_decode(
                    samples=ksampler[0],
                    vae=ckpt[2],
                )
                layereddiffusiondecode_sample = ld_decode(
                    sd_version="SDXL",
                    sub_batch_size=16,
                    samples=ksampler[0],
                    images=vaedecode_sample[0],
                )

                rgb_img = tensor_to_pil(vaedecode_sample[0])
                return (rgb_img[0], rgb_img[0], seed)

            else:
                layereddiffusionapply_sample = ld_fg_apply_layered_diffusion(
                    config="SDXL, Conv Injection", weight=1, model=ckpt[0]
                )
                ksampler = ksampler_sample(
                    steps=steps,
                    cfg=cfg,
                    sampler_name=sampler_name,
                    scheduler=scheduler,
                    seed=seed,
                    model=layereddiffusionapply_sample[0],
                    positive=cliptextencode_prompt[0],
                    negative=cliptextencode_negative_prompt[0],
                    latent_image=emptylatentimage_sample[0],
                    denoise=denoise,
                )

                vaedecode_sample = vae_decode(
                    samples=ksampler[0],
                    vae=ckpt[2],
                )
                layereddiffusiondecode_sample = ld_decode(
                    sd_version="SDXL",
                    sub_batch_size=16,
                    samples=ksampler[0],
                    images=vaedecode_sample[0],
                )
                mask = mask_to_image(mask=layereddiffusiondecode_sample[1])
                ld_image = tensor_to_pil(layereddiffusiondecode_sample[0][0])
                inverted_mask = invert_mask(mask=layereddiffusiondecode_sample[1])
                rgba_img = join_image_with_alpha(
                    image=layereddiffusiondecode_sample[0], alpha=inverted_mask[0]
                )
                rgba_img = tensor_to_pil(rgba_img[0])
                mask = tensor_to_pil(mask[0])
                rgb_img = tensor_to_pil(vaedecode_sample[0])

                return (rgba_img[0], mask[0], seed)
                # return flatten([rgba_img, mask, rgb_img, ld_image])
    except Exception as e:
        raise gr.Error(e)


def flatten(l: List[List[any]]) -> List[any]:
    return [item for sublist in l for item in sublist]


def predict_examples(
    prompt,
    negative_prompt,
    input_image=None,
    remove_bg=False,
    cond_mode=None,
    seed=-1,
    cfg=10,
):
    return predict(
        prompt,
        negative_prompt,
        input_image,
        remove_bg,
        cond_mode,
        seed,
        "dpmpp_2m_sde_gpu",
        "karras",
        30,
        cfg,
        1.0,
    )


css = """
.gradio-container { max-width: 68rem !important; }
"""
with gr.Blocks(css=css) as blocks:
    gr.Markdown("""# LayerDiffuse (unofficial)
    Using ComfyUI building blocks with custom node by [huchenlei](https://github.com/huchenlei/ComfyUI-layerdiffuse)    
    Models: [LayerDiffusion/layerdiffusion-v1](https://huggingface.co/LayerDiffusion/layerdiffusion-v1/tree/main)  
    Paper: [Transparent Image Layer Diffusion using Latent Transparency](https://huggingface.co/papers/2402.17113)
""")

    with gr.Row():
        with gr.Column():
            prompt = gr.Text(label="Prompt")
            negative_prompt = gr.Text(label="Negative Prompt")
            button = gr.Button("Generate")
            with gr.Accordion(open=False, label="Input Images (Optional)"):
                with gr.Group():
                    cond_mode = gr.Radio(
                        value="SDXL, Foreground",
                        choices=["SDXL, Foreground", "SDXL, Background"],
                        info="Whether to use input image as foreground or background",
                    )
                    remove_bg = gr.Checkbox(
                        info="Remove background using BriaRMBG",
                        label="Remove Background",
                        value=False,
                    )
                    input_image = gr.Image(
                        label="Input Image",
                        type="pil",
                    )
            with gr.Accordion(open=False, label="Advanced Options"):
                with gr.Group():
                    with gr.Row():
                        seed = gr.Slider(
                            label="Seed",
                            value=-1,
                            minimum=-1,
                            maximum=0xFFFFFFFFFFFFFFFF,
                            step=1,
                        )
                        curr_seed = gr.Number(
                            value=-1, interactive=False, scale=0, label=" "
                        )
                sampler_name = gr.Dropdown(
                    choices=samplers.KSampler.SAMPLERS,
                    label="Sampler Name",
                    value="dpmpp_2m_sde_gpu",
                )
                scheduler = gr.Dropdown(
                    choices=samplers.KSampler.SCHEDULERS,
                    label="Scheduler",
                    value="karras",
                )
                steps = gr.Slider(
                    label="Steps", value=20, minimum=1, maximum=50, step=1
                )
                cfg = gr.Number(
                    label="CFG", value=5.0, minimum=0.0, maximum=100.0, step=0.1
                )
                denoise = gr.Number(
                    label="Denoise", value=1.0, minimum=0.0, maximum=1.0, step=0.01
                )

        with gr.Column():
            image = gr.Image()
            with gr.Accordion(label="Mask", open=False):
                mask = gr.Image()

    inputs = [
        prompt,
        negative_prompt,
        input_image,
        remove_bg,
        cond_mode,
        seed,
        sampler_name,
        scheduler,
        steps,
        cfg,
        denoise,
    ]
    outputs = [image, mask, curr_seed]
    button.click(fn=predict, inputs=inputs, outputs=outputs)

    gr.Examples(
        fn=predict_examples,
        examples=examples,
        inputs=[
            prompt,
            negative_prompt,
            input_image,
            remove_bg,
            cond_mode,
            seed,
        ],
        outputs=outputs,
        cache_examples=True,
    )


if __name__ == "__main__":
    blocks.queue(api_open=False)
    blocks.launch(show_api=False)