from .. import WarpCore
from ..utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
from abc import abstractmethod
from dataclasses import dataclass
import torch
from torch import nn
from torch.utils.data import DataLoader
from gdf import GDF
import numpy as np
from tqdm import tqdm
import wandb

import webdataset as wds
from webdataset.handlers import warn_and_continue
from torch.distributed import barrier
from enum import Enum

class TargetReparametrization(Enum):
    EPSILON = 'epsilon'
    X0 = 'x0'

class DiffusionCore(WarpCore):
    @dataclass(frozen=True)
    class Config(WarpCore.Config):
        # TRAINING PARAMS
        lr: float = EXPECTED_TRAIN
        grad_accum_steps: int = EXPECTED_TRAIN
        batch_size: int = EXPECTED_TRAIN
        updates: int = EXPECTED_TRAIN
        warmup_updates: int = EXPECTED_TRAIN
        save_every: int = 500
        backup_every: int = 20000
        use_fsdp: bool = True

        # EMA UPDATE
        ema_start_iters: int = None
        ema_iters: int = None
        ema_beta: float = None

        # GDF setting
        gdf_target_reparametrization: TargetReparametrization = None # epsilon or x0
    
    @dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED
    class Info(WarpCore.Info):
        ema_loss: float = None

    @dataclass(frozen=True)
    class Models(WarpCore.Models):
        generator : nn.Module = EXPECTED
        generator_ema : nn.Module = None # optional

    @dataclass(frozen=True)
    class Optimizers(WarpCore.Optimizers):
        generator : any = EXPECTED

    @dataclass(frozen=True)
    class Schedulers(WarpCore.Schedulers):
        generator: any = None

    @dataclass(frozen=True)
    class Extras(WarpCore.Extras):
        gdf: GDF = EXPECTED
        sampling_configs: dict = EXPECTED

    # --------------------------------------------
    info: Info
    config: Config

    @abstractmethod
    def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
        raise NotImplementedError("This method needs to be overriden")

    @abstractmethod
    def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
        raise NotImplementedError("This method needs to be overriden")

    @abstractmethod
    def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False):
        raise NotImplementedError("This method needs to be overriden")

    @abstractmethod
    def webdataset_path(self, extras: Extras):
        raise NotImplementedError("This method needs to be overriden")

    @abstractmethod
    def webdataset_filters(self, extras: Extras):
        raise NotImplementedError("This method needs to be overriden")
    
    @abstractmethod
    def webdataset_preprocessors(self, extras: Extras):
        raise NotImplementedError("This method needs to be overriden")

    @abstractmethod
    def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
        raise NotImplementedError("This method needs to be overriden")
    # -------------

    def setup_data(self, extras: Extras) -> WarpCore.Data:
        # SETUP DATASET
        dataset_path = self.webdataset_path(extras)
        preprocessors = self.webdataset_preprocessors(extras)
        filters = self.webdataset_filters(extras)

        handler = warn_and_continue # None
        # handler = None
        dataset = wds.WebDataset(
            dataset_path, resampled=True, handler=handler
        ).select(filters).shuffle(690, handler=handler).decode(
            "pilrgb", handler=handler
        ).to_tuple(
            *[p[0] for p in preprocessors], handler=handler
        ).map_tuple(
            *[p[1] for p in preprocessors], handler=handler
        ).map(lambda x: {p[2]:x[i] for i, p in enumerate(preprocessors)})

        # SETUP DATALOADER
        real_batch_size = self.config.batch_size//(self.world_size*self.config.grad_accum_steps)
        dataloader = DataLoader(
            dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True
        )

        return self.Data(dataset=dataset, dataloader=dataloader, iterator=iter(dataloader))

    def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
        batch = next(data.iterator)

        with torch.no_grad():
            conditions = self.get_conditions(batch, models, extras)
            latents = self.encode_latents(batch, models, extras)
            noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)

        # FORWARD PASS
        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            pred = models.generator(noised, noise_cond, **conditions)
            if self.config.gdf_target_reparametrization == TargetReparametrization.EPSILON:
                pred = extras.gdf.undiffuse(noised, logSNR, pred)[1] # transform whatever prediction to epsilon to use in the loss
                target = noise
            elif self.config.gdf_target_reparametrization == TargetReparametrization.X0:
                pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] # transform whatever prediction to x0 to use in the loss
                target = latents
            loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3])
            loss_adjusted = (loss * loss_weight).mean() / self.config.grad_accum_steps

        return loss, loss_adjusted
    
    def train(self, data: WarpCore.Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
        start_iter = self.info.iter+1
        max_iters = self.config.updates * self.config.grad_accum_steps
        if self.is_main_node:
            print(f"STARTING AT STEP: {start_iter}/{max_iters}")

        pbar = tqdm(range(start_iter, max_iters+1)) if self.is_main_node else range(start_iter, max_iters+1) # <--- DDP
        models.generator.train()
        for i in pbar:
            # FORWARD PASS
            loss, loss_adjusted = self.forward_pass(data, extras, models)

            # BACKWARD PASS
            if i % self.config.grad_accum_steps == 0 or i == max_iters:
                loss_adjusted.backward()
                grad_norm = nn.utils.clip_grad_norm_(models.generator.parameters(), 1.0)
                optimizers_dict = optimizers.to_dict()
                for k in optimizers_dict:
                    optimizers_dict[k].step()
                schedulers_dict = schedulers.to_dict()
                for k in schedulers_dict:
                    schedulers_dict[k].step()
                models.generator.zero_grad(set_to_none=True)
                self.info.total_steps += 1
            else:
                with models.generator.no_sync():
                    loss_adjusted.backward()
            self.info.iter = i

            # UPDATE EMA
            if models.generator_ema is not None and i % self.config.ema_iters == 0:
                update_weights_ema(
                    models.generator_ema, models.generator,
                    beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0)
                )

            # UPDATE LOSS METRICS
            self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01

            if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
                wandb.alert(
                    title=f"NaN value encountered in training run {self.info.wandb_run_id}", 
                    text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}",
                    wait_duration=60*30
                )

            if self.is_main_node:
                logs = {
                    'loss': self.info.ema_loss, 
                    'raw_loss': loss.mean().item(),
                    'grad_norm': grad_norm.item(),
                    'lr': optimizers.generator.param_groups[0]['lr'],
                    'total_steps': self.info.total_steps,
                }

                pbar.set_postfix(logs)
                if self.config.wandb_project is not None:
                    wandb.log(logs)

            if i == 1 or i % (self.config.save_every*self.config.grad_accum_steps) == 0 or i == max_iters:
                # SAVE AND CHECKPOINT STUFF
                if np.isnan(loss.mean().item()):
                    if self.is_main_node and self.config.wandb_project is not None:
                        tqdm.write("Skipping sampling & checkpoint because the loss is NaN")
                        wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.run_id}", text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN")
                else:
                    self.save_checkpoints(models, optimizers)
                    if self.is_main_node:
                        create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
                    self.sample(models, data, extras)

    def models_to_save(self):
        return ['generator', 'generator_ema']

    def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None):
        barrier()
        suffix = '' if suffix is None else suffix
        self.save_info(self.info, suffix=suffix)
        models_dict = models.to_dict()
        optimizers_dict = optimizers.to_dict()
        for key in self.models_to_save():
            model = models_dict[key]
            if model is not None:
                self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp)
        for key in optimizers_dict:
            optimizer = optimizers_dict[key]
            if optimizer is not None:
                self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models.generator if self.config.use_fsdp else None)
        if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0:
            self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps//1000}k")
        torch.cuda.empty_cache()