from typing import Callable, Dict, List, Optional, Union
import gc

import numpy as np
import torch
import torch.nn.functional as F

from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import (
    _resize_with_antialiasing,
    StableVideoDiffusionPipeline,
    retrieve_timesteps,
)
from diffusers.utils import logging
from kornia.utils import create_meshgrid
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name

@torch.no_grad()
def normalize_point_map(point_map, valid_mask):
    # T,H,W,3 T,H,W
    norm_factor = (point_map[..., 2] * valid_mask.float()).mean() / (valid_mask.float().mean() + 1e-8)
    norm_factor = norm_factor.clip(min=1e-3)
    return point_map / norm_factor

def point_map_xy2intrinsic_map(point_map_xy):
    # *,h,w,2
    height, width = point_map_xy.shape[-3], point_map_xy.shape[-2]
    assert height % 2 == 0
    assert width % 2 == 0
    mesh_grid = create_meshgrid(
        height=height,
        width=width,
        normalized_coordinates=True,
        device=point_map_xy.device,
        dtype=point_map_xy.dtype
    )[0] # h,w,2
    assert mesh_grid.abs().min() > 1e-4
    # *,h,w,2
    mesh_grid = mesh_grid.expand_as(point_map_xy)
    nc = point_map_xy.mean(dim=-2).mean(dim=-2) # *, 2
    nc_map = nc[..., None, None, :].expand_as(point_map_xy)
    nf = ((point_map_xy - nc_map) / mesh_grid).mean(dim=-2).mean(dim=-2)
    nf_map = nf[..., None, None, :].expand_as(point_map_xy)
    # print((mesh_grid * nf_map + nc_map - point_map_xy).abs().max())

    return torch.cat([nc_map, nf_map], dim=-1)

def robust_min_max(tensor, quantile=0.99):
    T, H, W = tensor.shape
    min_vals = []
    max_vals = []
    for i in range(T):
        min_vals.append(torch.quantile(tensor[i], q=1-quantile, interpolation='nearest').item())
        max_vals.append(torch.quantile(tensor[i], q=quantile, interpolation='nearest').item())
    return min(min_vals), max(max_vals) 

class GeometryCrafterDiffPipeline(StableVideoDiffusionPipeline):

    @torch.inference_mode()
    def encode_video(
        self,
        video: torch.Tensor,
        chunk_size: int = 14,
    ) -> torch.Tensor:
        """
        :param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
        :param chunk_size: the chunk size to encode video
        :return: image_embeddings in shape of [b, 1024]
        """

        video_224 = _resize_with_antialiasing(video.float(), (224, 224))
        video_224 = (video_224 + 1.0) / 2.0  # [-1, 1] -> [0, 1]
        embeddings = []
        for i in range(0, video_224.shape[0], chunk_size):
            emb = self.feature_extractor(
                images=video_224[i : i + chunk_size],
                do_normalize=True,
                do_center_crop=False,
                do_resize=False,
                do_rescale=False,
                return_tensors="pt",
            ).pixel_values.to(video.device, dtype=video.dtype)
            embeddings.append(self.image_encoder(emb).image_embeds)  # [b, 1024]

        embeddings = torch.cat(embeddings, dim=0)  # [t, 1024]
        return embeddings

    @torch.inference_mode()
    def encode_vae_video(
        self,
        video: torch.Tensor,
        chunk_size: int = 14,
    ):
        """
        :param video: [b, c, h, w] in range [-1, 1], the b may contain multiple videos or frames
        :param chunk_size: the chunk size to encode video
        :return: vae latents in shape of [b, c, h, w]
        """
        video_latents = []
        for i in range(0, video.shape[0], chunk_size):
            video_latents.append(
                self.vae.encode(video[i : i + chunk_size]).latent_dist.mode()
            )
        video_latents = torch.cat(video_latents, dim=0)
        return video_latents
    
    @torch.inference_mode()
    def produce_priors(self, prior_model, frame, chunk_size=8):
        T, _, H, W = frame.shape 
        # frame = (frame + 1) / 2
        pred_point_maps = []
        pred_masks = []
        for i in range(0, len(frame), chunk_size):
            pred_p, pred_m = prior_model.forward_image(frame[i:i+chunk_size])
            pred_point_maps.append(pred_p)
            pred_masks.append(pred_m)
        pred_point_maps = torch.cat(pred_point_maps, dim=0)
        pred_masks = torch.cat(pred_masks, dim=0)
        
        pred_masks = pred_masks.float() * 2 - 1
        
        # T,H,W,3 T,H,W
        pred_point_maps = normalize_point_map(pred_point_maps, pred_masks > 0)

        pred_disps = 1.0 / pred_point_maps[..., 2].clamp_min(1e-3)
        pred_disps = pred_disps * (pred_masks > 0)
        min_disparity, max_disparity = robust_min_max(pred_disps)
        pred_disps = ((pred_disps - min_disparity) / (max_disparity - min_disparity+1e-4)).clamp(0, 1)
        pred_disps = pred_disps * 2 - 1

        pred_point_maps[..., :2] = pred_point_maps[..., :2] / (pred_point_maps[..., 2:3] + 1e-7)
        pred_point_maps[..., 2] = torch.log(pred_point_maps[..., 2] + 1e-7) * (pred_masks > 0) # [x/z, y/z, log(z)]

        pred_intr_maps = point_map_xy2intrinsic_map(pred_point_maps[..., :2]).permute(0,3,1,2) # T,H,W,2      
        pred_point_maps = pred_point_maps.permute(0,3,1,2)
        
        return pred_disps, pred_masks, pred_point_maps, pred_intr_maps
    
    @torch.inference_mode()
    def encode_point_map(self, point_map_vae, disparity, valid_mask, point_map, intrinsic_map, chunk_size=8):
        T, _, H, W = point_map.shape
        latents = []

        psedo_image = disparity[:, None].repeat(1,3,1,1)
        intrinsic_map = torch.norm(intrinsic_map[:, 2:4], p=2, dim=1, keepdim=False)

        for i in range(0, T, chunk_size):
            latent_dist = self.vae.encode(psedo_image[i : i + chunk_size].to(self.vae.dtype)).latent_dist
            latent_dist = point_map_vae.encode(                
                torch.cat([
                    intrinsic_map[i:i+chunk_size, None],
                    point_map[i:i+chunk_size, 2:3], 
                    disparity[i:i+chunk_size, None], 
                    valid_mask[i:i+chunk_size, None]], dim=1),
                latent_dist
            )
            if isinstance(latent_dist, DiagonalGaussianDistribution):
                latent = latent_dist.mode()
            else:
                latent = latent_dist
            
            assert isinstance(latent, torch.Tensor)    
            latents.append(latent)
        latents = torch.cat(latents, dim=0)
        latents = latents * self.vae.config.scaling_factor
        return latents

    @torch.no_grad()
    def decode_point_map(self, point_map_vae, latents, chunk_size=8, force_projection=True, force_fixed_focal=True, use_extract_interp=False, need_resize=False, height=None, width=None):
        T = latents.shape[0]
        rec_intrinsic_maps = []
        rec_depth_maps = []
        rec_valid_masks = []
        for i in range(0, T, chunk_size):
            lat = latents[i:i+chunk_size] 
            rec_imap, rec_dmap, rec_vmask = point_map_vae.decode(  
                lat,           
                num_frames=lat.shape[0],
            )
            rec_intrinsic_maps.append(rec_imap)
            rec_depth_maps.append(rec_dmap)
            rec_valid_masks.append(rec_vmask)
        
        rec_intrinsic_maps = torch.cat(rec_intrinsic_maps, dim=0)
        rec_depth_maps = torch.cat(rec_depth_maps, dim=0)
        rec_valid_masks = torch.cat(rec_valid_masks, dim=0)
        
        if need_resize:
            rec_depth_maps = F.interpolate(rec_depth_maps, (height, width), mode='nearest-exact') if use_extract_interp else F.interpolate(rec_depth_maps, (height, width), mode='bilinear', align_corners=False)
            rec_valid_masks = F.interpolate(rec_valid_masks, (height, width), mode='nearest-exact') if use_extract_interp else F.interpolate(rec_valid_masks, (height, width), mode='bilinear', align_corners=False)
            rec_intrinsic_maps = F.interpolate(rec_intrinsic_maps, (height, width), mode='bilinear', align_corners=False)

        H, W = rec_intrinsic_maps.shape[-2], rec_intrinsic_maps.shape[-1]
        mesh_grid = create_meshgrid(
            H, W, 
            normalized_coordinates=True
        ).to(rec_intrinsic_maps.device, rec_intrinsic_maps.dtype, non_blocking=True)
        # 1,h,w,2
        rec_intrinsic_maps = torch.cat([rec_intrinsic_maps * W / np.sqrt(W**2+H**2), rec_intrinsic_maps * H / np.sqrt(W**2+H**2)], dim=1) # t,2,h,w
        mesh_grid = mesh_grid.permute(0,3,1,2)
        rec_valid_masks = rec_valid_masks.squeeze(1) > 0

        if force_projection:
            if force_fixed_focal:
                nfx = (rec_intrinsic_maps[:, 0, :, :] * rec_valid_masks.float()).mean() / (rec_valid_masks.float().mean() + 1e-4) 
                nfy = (rec_intrinsic_maps[:, 1, :, :] * rec_valid_masks.float()).mean() / (rec_valid_masks.float().mean() + 1e-4) 
                rec_intrinsic_maps = torch.tensor([nfx, nfy], device=rec_intrinsic_maps.device)[None, :, None, None].repeat(T, 1, 1, 1)    
            else:
                nfx = (rec_intrinsic_maps[:, 0, :, :] * rec_valid_masks.float()).mean(dim=[-1, -2]) / (rec_valid_masks.float().mean(dim=[-1, -2]) + 1e-4) 
                nfy = (rec_intrinsic_maps[:, 1, :, :] * rec_valid_masks.float()).mean(dim=[-1, -2]) / (rec_valid_masks.float().mean(dim=[-1, -2]) + 1e-4) 
                rec_intrinsic_maps = torch.stack([nfx, nfy], dim=-1)[:, :, None, None]
                # t,2,1,1

        rec_point_maps = torch.cat([rec_intrinsic_maps * mesh_grid, rec_depth_maps], dim=1).permute(0,2,3,1)
        xy, z = rec_point_maps.split([2, 1], dim=-1)
        z = torch.clamp_max(z, 10) # for numerical stability
        z = torch.exp(z)
        rec_point_maps = torch.cat([xy * z, z], dim=-1)

        return rec_point_maps, rec_valid_masks


    @torch.no_grad()
    def __call__(
        self,
        video: Union[np.ndarray, torch.Tensor],
        point_map_vae,
        prior_model,
        height: int = 320,
        width: int = 640,
        num_inference_steps: int = 5,
        guidance_scale: float = 1.0,
        window_size: Optional[int] = 14,
        noise_aug_strength: float = 0.02,
        decode_chunk_size: Optional[int] = None,
        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
        latents: Optional[torch.FloatTensor] = None,
        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
        overlap: int = 4,
        force_projection: bool = True,
        force_fixed_focal: bool = True,
        use_extract_interp: bool = False,
        track_time: bool = False,
    ):
        
        # video: in shape [t, h, w, c] if np.ndarray or [t, c, h, w] if torch.Tensor, in range [0, 1]
        
        # 0. Default height and width to unet
        if isinstance(video, np.ndarray):
            video = torch.from_numpy(video.transpose(0, 3, 1, 2))
        else:
            assert isinstance(video, torch.Tensor)
        height = height or video.shape[-2]
        width = width or video.shape[-1]
        original_height = video.shape[-2]
        original_width = video.shape[-1]
        num_frames = video.shape[0]
        decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else 8
        if num_frames <= window_size:
            window_size = num_frames
            overlap = 0
        stride = window_size - overlap

        # 1. Check inputs. Raise error if not correct
        assert height % 64 == 0 and width % 64 == 0
        if original_height != height or original_width != width:
            need_resize = True
        else:
            need_resize = False

        # 2. Define call parameters
        batch_size = 1
        device = self._execution_device
        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
        # corresponds to doing no classifier free guidance.
        self._guidance_scale = guidance_scale

        if track_time:
            start_event = torch.cuda.Event(enable_timing=True)
            prior_event = torch.cuda.Event(enable_timing=True)
            encode_event = torch.cuda.Event(enable_timing=True)
            denoise_event = torch.cuda.Event(enable_timing=True)
            decode_event = torch.cuda.Event(enable_timing=True)
            start_event.record()

        # 3. Encode input video
        pred_disparity, pred_valid_mask, pred_point_map, pred_intrinsic_map = self.produce_priors(
            prior_model, 
            video.to(device=device, dtype=torch.float32),
            chunk_size=decode_chunk_size
        ) # T,H,W T,H,W T,3,H,W T,2,H,W

        if need_resize:
            pred_disparity = F.interpolate(pred_disparity.unsqueeze(1), (height, width), mode='bilinear', align_corners=False).squeeze(1)
            pred_valid_mask = F.interpolate(pred_valid_mask.unsqueeze(1), (height, width), mode='bilinear', align_corners=False).squeeze(1)
            pred_point_map = F.interpolate(pred_point_map, (height, width), mode='bilinear', align_corners=False)
            pred_intrinsic_map = F.interpolate(pred_intrinsic_map, (height, width), mode='bilinear', align_corners=False)


        if track_time:
            prior_event.record()
            torch.cuda.synchronize()
            elapsed_time_ms = start_event.elapsed_time(prior_event)
            print(f"Elapsed time for computing per-frame prior: {elapsed_time_ms} ms")
        else:
            gc.collect()
            torch.cuda.empty_cache()


        # 3. Encode input video
        if need_resize:
            video = F.interpolate(video, (height, width), mode="bicubic", align_corners=False, antialias=True).clamp(0, 1)
        video = video.to(device=device, dtype=self.dtype)
        video = video * 2.0 - 1.0  # [0,1] -> [-1,1], in [t, c, h, w]

        video_embeddings = self.encode_video(video, chunk_size=decode_chunk_size).unsqueeze(0)
        prior_latents = self.encode_point_map(
            point_map_vae,
            pred_disparity, 
            pred_valid_mask, 
            pred_point_map, 
            pred_intrinsic_map, 
            chunk_size=decode_chunk_size
        ).unsqueeze(0).to(video_embeddings.dtype) # 1,T,C,H,W

        # 4. Encode input image using VAE

        # pdb.set_trace()
        needs_upcasting = (
            self.vae.dtype == torch.float16 and self.vae.config.force_upcast
        )
        if needs_upcasting:
            self.vae.to(dtype=torch.float32)

        video_latents = self.encode_vae_video(
            video.to(self.vae.dtype),
            chunk_size=decode_chunk_size,
        ).unsqueeze(0).to(video_embeddings.dtype)  # [1, t, c, h, w]

        torch.cuda.empty_cache()

        if track_time:
            encode_event.record()
            torch.cuda.synchronize()
            elapsed_time_ms = prior_event.elapsed_time(encode_event)
            print(f"Elapsed time for encode prior and frames: {elapsed_time_ms} ms")
        else:
            gc.collect()
            torch.cuda.empty_cache()

        # cast back to fp16 if needed
        if needs_upcasting:
            self.vae.to(dtype=torch.float16)

        # 5. Get Added Time IDs
        added_time_ids = self._get_add_time_ids(
            7,
            127,
            noise_aug_strength,
            video_embeddings.dtype,
            batch_size,
            1,
            False,
        )  # [1 or 2, 3]
        added_time_ids = added_time_ids.to(device)

        # 6. Prepare timesteps
        timesteps, num_inference_steps = retrieve_timesteps(
            self.scheduler, num_inference_steps, device, None, None
        )
        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
        self._num_timesteps = len(timesteps)

        # 7. Prepare latent variables
        # num_channels_latents = self.unet.config.in_channels - prior_latents.shape[1]
        num_channels_latents = 8
        latents_init = self.prepare_latents(
            batch_size,
            window_size,
            num_channels_latents,
            height,
            width,
            video_embeddings.dtype,
            device,
            generator,
            latents,
        )  # [1, t, c, h, w]
        latents_all = None

        idx_start = 0
        if overlap > 0:
            weights = torch.linspace(0, 1, overlap, device=device)
            weights = weights.view(1, overlap, 1, 1, 1)
        else:
            weights = None

        while idx_start < num_frames - overlap:
            idx_end = min(idx_start + window_size, num_frames)
            self.scheduler.set_timesteps(num_inference_steps, device=device)
            # 9. Denoising loop
            # latents_init = latents_init.flip(1)
            latents = latents_init[:, : idx_end - idx_start].clone()
            latents_init = torch.cat(
                [latents_init[:, -overlap:], latents_init[:, :stride]], dim=1
            )

            video_latents_current = video_latents[:, idx_start:idx_end]
            prior_latents_current = prior_latents[:, idx_start:idx_end]
            video_embeddings_current = video_embeddings[:, idx_start:idx_end]

            with self.progress_bar(total=num_inference_steps) as progress_bar:
                for i, t in enumerate(timesteps):
                    if latents_all is not None and i == 0:
                        latents[:, :overlap] = (
                            latents_all[:, -overlap:]
                            + latents[:, :overlap]
                            / self.scheduler.init_noise_sigma
                            * self.scheduler.sigmas[i]
                        )

                    latent_model_input = latents

                    latent_model_input = self.scheduler.scale_model_input(
                        latent_model_input, t
                    )  # [1 or 2, t, c, h, w]
                    latent_model_input = torch.cat(
                        [latent_model_input, video_latents_current, prior_latents_current], dim=2
                    )
                    noise_pred = self.unet(
                        latent_model_input,
                        t,
                        encoder_hidden_states=video_embeddings_current,
                        added_time_ids=added_time_ids,
                        return_dict=False,
                    )[0]
                    # pdb.set_trace()
                    # perform guidance
                    if self.do_classifier_free_guidance:
                        latent_model_input = latents
                        latent_model_input = self.scheduler.scale_model_input(
                            latent_model_input, t
                        )
                        latent_model_input = torch.cat(
                            [latent_model_input, torch.zeros_like(latent_model_input), torch.zeros_like(latent_model_input)],
                            dim=2,
                        )
                        noise_pred_uncond = self.unet(
                            latent_model_input,
                            t,
                            encoder_hidden_states=torch.zeros_like(
                                video_embeddings_current
                            ),
                            added_time_ids=added_time_ids,
                            return_dict=False,
                        )[0]
                        noise_pred = noise_pred_uncond + self.guidance_scale * (
                            noise_pred - noise_pred_uncond
                        )
                    latents = self.scheduler.step(noise_pred, t, latents).prev_sample

                    if callback_on_step_end is not None:
                        callback_kwargs = {}
                        for k in callback_on_step_end_tensor_inputs:
                            callback_kwargs[k] = locals()[k]
                        callback_outputs = callback_on_step_end(
                            self, i, t, callback_kwargs
                        )

                        latents = callback_outputs.pop("latents", latents)

                    if i == len(timesteps) - 1 or (
                        (i + 1) > num_warmup_steps
                        and (i + 1) % self.scheduler.order == 0
                    ):
                        progress_bar.update()

            if latents_all is None:
                latents_all = latents.clone()
            else:
                if overlap > 0:
                    latents_all[:, -overlap:] = latents[
                        :, :overlap
                    ] * weights + latents_all[:, -overlap:] * (1 - weights)
                latents_all = torch.cat([latents_all, latents[:, overlap:]], dim=1)

            idx_start += stride

        latents_all = 1 / self.vae.config.scaling_factor * latents_all.squeeze(0).to(torch.float32)

        if track_time:
            denoise_event.record()
            torch.cuda.synchronize()
            elapsed_time_ms = encode_event.elapsed_time(denoise_event)
            print(f"Elapsed time for denoise latent: {elapsed_time_ms} ms")
        else:
            gc.collect()
            torch.cuda.empty_cache()

        point_map, valid_mask = self.decode_point_map(
            point_map_vae, 
            latents_all, 
            chunk_size=decode_chunk_size, 
            force_projection=force_projection,
            force_fixed_focal=force_fixed_focal,
            use_extract_interp=use_extract_interp, 
            need_resize=need_resize, 
            height=original_height, 
            width=original_width)
        

        if track_time:
            decode_event.record()
            torch.cuda.synchronize()
            elapsed_time_ms = denoise_event.elapsed_time(decode_event)
            print(f"Elapsed time for decode latent: {elapsed_time_ms} ms")
        else:
            gc.collect()
            torch.cuda.empty_cache()

        self.maybe_free_model_hooks()
        # t,h,w,3   t,h,w
        return point_map, valid_mask