from typing import List
from PIL import Image
import numpy as np
import math
import random
import cv2
from typing import List

import torch
import einops
from pytorch_lightning import seed_everything
from transparent_background import Remover

from dataset.opencv_transforms.functional import to_tensor, center_crop
from vtdm.model import create_model
from vtdm.util import tensor2vid

remover = Remover(jit=False)


def pil_to_cv2(pil_image: Image.Image) -> np.ndarray:
    cv_image = np.array(pil_image)
    cv_image = cv2.cvtColor(cv_image, cv2.COLOR_RGB2BGR)
    return cv_image

def prepare_white_image(input_image: Image.Image) -> Image.Image:
    # remove bg
    output = remover.process(input_image, type='rgba')

    # expand image
    width, height = output.size
    max_side = max(width, height)
    white_image = Image.new('RGBA', (max_side, max_side), (0, 0, 0, 0))
    x_offset = (max_side - width) // 2
    y_offset = (max_side - height) // 2
    white_image.paste(output, (x_offset, y_offset))

    return white_image


class MultiViewGenerator:
    def __init__(self, checkpoint_path, config_path="inference.yaml"):
        self.models = {}
        denoising_model = create_model(config_path).cpu()
        denoising_model.init_from_ckpt(checkpoint_path)
        denoising_model = denoising_model.cuda().half()
        self.models["denoising_model"] = denoising_model

    def denoising(self, frames, args):
        with torch.no_grad():
            C, T, H, W = frames.shape
            batch = {"video": frames.unsqueeze(0)}
            batch["elevation"] = (
                torch.Tensor([args["elevation"]]).to(torch.int64).to(frames.device)
            )
            batch["fps_id"] = torch.Tensor([7]).to(torch.int64).to(frames.device)
            batch["motion_bucket_id"] = (
                torch.Tensor([127]).to(torch.int64).to(frames.device)
            )
            batch = self.models["denoising_model"].add_custom_cond(batch, infer=True)

            with torch.autocast(device_type="cuda", dtype=torch.float16):
                c, uc = self.models[
                    "denoising_model"
                ].conditioner.get_unconditional_conditioning(
                    batch,
                    force_uc_zero_embeddings=["cond_frames", "cond_frames_without_noise"],
                )

            additional_model_inputs = {
                "image_only_indicator": torch.zeros(2, T).to(
                    self.models["denoising_model"].device
                ),
                "num_video_frames": batch["num_video_frames"],
            }

            def denoiser(input, sigma, c):
                return self.models["denoising_model"].denoiser(
                    self.models["denoising_model"].model,
                    input,
                    sigma,
                    c,
                    **additional_model_inputs
                )

            with torch.autocast(device_type="cuda", dtype=torch.float16):
                randn = torch.randn(
                    [T, 4, H // 8, W // 8], device=self.models["denoising_model"].device
                )
                samples = self.models["denoising_model"].sampler(denoiser, randn, cond=c, uc=uc)

            samples = self.models["denoising_model"].decode_first_stage(samples.half())
            samples = einops.rearrange(samples, "(b t) c h w -> b c t h w", t=T)

        return tensor2vid(samples)

    def video_pipeline(self, frames, args) -> List[Image.Image]:
        num_iter = args["num_iter"]
        out_list = []

        for _ in range(num_iter):
            with torch.no_grad():
                results = self.denoising(frames, args)

            if len(out_list) == 0:
                out_list = out_list + results
            else:
                out_list = out_list + results[1:]

            img = out_list[-1]
            img = to_tensor(img)
            img = (img - 0.5) * 2.0
            frames[:, 0] = img

        result = []

        for i, frame in enumerate(out_list):
            input_image = Image.fromarray(frame)
            output_image = remover.process(input_image, type='rgba')
            result.append(output_image)

        return result

    def process(self, white_image: Image.Image, args) -> List[Image.Image]:
        img = pil_to_cv2(white_image)
        frame_list = [img] * args["clip_size"]

        h, w = frame_list[0].shape[0:2]
        rate = max(
            args["input_resolution"][0] * 1.0 / h, args["input_resolution"][1] * 1.0 / w
        )
        frame_list = [
            cv2.resize(f, [math.ceil(w * rate), math.ceil(h * rate)]) for f in frame_list
        ]
        frame_list = [
            center_crop(f, [args["input_resolution"][0], args["input_resolution"][1]])
            for f in frame_list
        ]
        frame_list = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frame_list]

        frame_list = [to_tensor(f) for f in frame_list]
        frame_list = [(f - 0.5) * 2.0 for f in frame_list]
        frames = torch.stack(frame_list, 1)
        frames = frames.cuda()

        self.models["denoising_model"].num_samples = args["clip_size"]
        self.models["denoising_model"].image_size = args["input_resolution"]

        return self.video_pipeline(frames, args)

    def infer(self, white_image: Image.Image) -> List[Image.Image]:
        seed = random.randint(0, 65535)
        seed_everything(seed)

        params = {
            "clip_size": 25,
            "input_resolution": [512, 512],
            "num_iter": 1,
            "aes": 6.0,
            "mv": [0.0, 0.0, 0.0, 10.0],
            "elevation": 0,
        }

        return self.process(white_image, params)