from __future__ import annotations

import os
import gc
from tqdm import tqdm
import wandb

import torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import LinearLR, SequentialLR, ConstantLR

from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
from diffrhythm.dataset.custom_dataset_align2f5 import LanceDiffusionDataset

from torch.utils.data import DataLoader, DistributedSampler

from ema_pytorch import EMA

from diffrhythm.model import CFM
from diffrhythm.model.utils import exists, default

import time

# from apex.optimizers.fused_adam import FusedAdam

# trainer


class Trainer:
    def __init__(
        self,
        model: CFM,
        args,
        epochs,
        learning_rate,
        #dataloader,
        num_warmup_updates=20000,
        save_per_updates=1000,
        checkpoint_path=None,
        batch_size=32,
        batch_size_type: str = "sample",
        max_samples=32,
        grad_accumulation_steps=1,
        max_grad_norm=1.0,
        noise_scheduler: str | None = None,
        duration_predictor: torch.nn.Module | None = None,
        wandb_project="test_e2-tts",
        wandb_run_name="test_run",
        wandb_resume_id: str = None,
        last_per_steps=None,
        accelerate_kwargs: dict = dict(),
        ema_kwargs: dict = dict(),
        bnb_optimizer: bool = False,
        reset_lr: bool = False,
        use_style_prompt: bool = False,
        grad_ckpt: bool = False
    ):
        self.args = args

        ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False, )

        logger = "wandb" if wandb.api.api_key else None
        #logger = None
        print(f"Using logger: {logger}")
        # print("-----------1-------------")
        import tbe.common
        # print("-----------2-------------")
        self.accelerator = Accelerator(
            log_with=logger,
            kwargs_handlers=[ddp_kwargs],
            gradient_accumulation_steps=grad_accumulation_steps,
            **accelerate_kwargs,
        )
        # print("-----------3-------------")

        if logger == "wandb":
            if exists(wandb_resume_id):
                init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name, "id": wandb_resume_id}}
            else:
                init_kwargs = {"wandb": {"resume": "allow", "name": wandb_run_name}}
            self.accelerator.init_trackers(
                project_name=wandb_project,
                init_kwargs=init_kwargs,
                config={
                    "epochs": epochs,
                    "learning_rate": learning_rate,
                    "num_warmup_updates": num_warmup_updates,
                    "batch_size": batch_size,
                    "batch_size_type": batch_size_type,
                    "max_samples": max_samples,
                    "grad_accumulation_steps": grad_accumulation_steps,
                    "max_grad_norm": max_grad_norm,
                    "gpus": self.accelerator.num_processes,
                    "noise_scheduler": noise_scheduler,
                },
            )

        self.precision = self.accelerator.state.mixed_precision
        self.precision = self.precision.replace("no", "fp32")
        print("!!!!!!!!!!!!!!!!!", self.precision)

        self.model = model
        #self.model = torch.compile(model)

        #self.dataloader = dataloader

        if self.is_main:
            self.ema_model = EMA(model, include_online_model=False, **ema_kwargs)

            self.ema_model.to(self.accelerator.device)
            if self.accelerator.state.distributed_type in ["DEEPSPEED", "FSDP"]:
                self.ema_model.half()

        self.epochs = epochs
        self.num_warmup_updates = num_warmup_updates
        self.save_per_updates = save_per_updates
        self.last_per_steps = default(last_per_steps, save_per_updates * grad_accumulation_steps)
        self.checkpoint_path = default(checkpoint_path, "ckpts/test_e2-tts")

        self.max_samples = max_samples
        self.grad_accumulation_steps = grad_accumulation_steps
        self.max_grad_norm = max_grad_norm

        self.noise_scheduler = noise_scheduler

        self.duration_predictor = duration_predictor

        self.reset_lr = reset_lr

        self.use_style_prompt = use_style_prompt
        
        self.grad_ckpt = grad_ckpt

        if bnb_optimizer:
            import bitsandbytes as bnb

            self.optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=learning_rate)
        else:
            self.optimizer = AdamW(model.parameters(), lr=learning_rate)
        #self.optimizer = FusedAdam(model.parameters(), lr=learning_rate)

        #self.model = torch.compile(self.model)
        if self.accelerator.state.distributed_type == "DEEPSPEED":
            self.accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = batch_size
        
        self.get_dataloader()
        self.get_scheduler()
        # self.get_constant_scheduler()

        self.model, self.optimizer, self.scheduler, self.train_dataloader = self.accelerator.prepare(self.model, self.optimizer, self.scheduler, self.train_dataloader)

    def get_scheduler(self):
        warmup_steps = (
            self.num_warmup_updates * self.accelerator.num_processes
        )  # consider a fixed warmup steps while using accelerate multi-gpu ddp
        total_steps = len(self.train_dataloader) * self.epochs / self.grad_accumulation_steps
        decay_steps = total_steps - warmup_steps
        warmup_scheduler = LinearLR(self.optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
        decay_scheduler = LinearLR(self.optimizer, start_factor=1.0, end_factor=1e-8, total_iters=decay_steps)
        # constant_scheduler = ConstantLR(self.optimizer, factor=1, total_iters=decay_steps)
        self.scheduler = SequentialLR(
            self.optimizer, schedulers=[warmup_scheduler, decay_scheduler], milestones=[warmup_steps]
        )

    def get_constant_scheduler(self):
        total_steps = len(self.train_dataloader) * self.epochs / self.grad_accumulation_steps
        self.scheduler = ConstantLR(self.optimizer, factor=1, total_iters=total_steps)

    def get_dataloader(self):
        prompt_path = self.args.prompt_path.split('|')
        lrc_path = self.args.lrc_path.split('|')
        latent_path = self.args.latent_path.split('|')
        ldd = LanceDiffusionDataset(*LanceDiffusionDataset.init_data(self.args.dataset_path), \
                                        max_frames=self.args.max_frames, min_frames=self.args.min_frames, \
                                        align_lyrics=self.args.align_lyrics, lyrics_slice=self.args.lyrics_slice, \
                                        use_style_prompt=self.args.use_style_prompt, parse_lyrics=self.args.parse_lyrics,
                                        lyrics_shift=self.args.lyrics_shift, downsample_rate=self.args.downsample_rate, \
                                        skip_empty_lyrics=self.args.skip_empty_lyrics, tokenizer_type=self.args.tokenizer_type, precision=self.precision, \
                                        start_time=time.time(), pure_prob=self.args.pure_prob)
        
        # start_time = time.time()
        self.train_dataloader = DataLoader(
            dataset=ldd,
            batch_size=self.args.batch_size,      # 每个批次的样本数
            shuffle=True,      # 是否随机打乱数据
            num_workers=4,     # 用于加载数据的子进程数
            pin_memory=True,   # 加速GPU训练
            collate_fn=ldd.custom_collate_fn,
            persistent_workers=True
        )


    @property
    def is_main(self):
        return self.accelerator.is_main_process

    def save_checkpoint(self, step, last=False):
        self.accelerator.wait_for_everyone()
        if self.is_main:
            checkpoint = dict(
                model_state_dict=self.accelerator.unwrap_model(self.model).state_dict(),
                optimizer_state_dict=self.accelerator.unwrap_model(self.optimizer).state_dict(),
                ema_model_state_dict=self.ema_model.state_dict(),
                scheduler_state_dict=self.scheduler.state_dict(),
                step=step,
            )
            if not os.path.exists(self.checkpoint_path):
                os.makedirs(self.checkpoint_path)
            if last:
                self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_last.pt")
                print(f"Saved last checkpoint at step {step}")
            else:
                self.accelerator.save(checkpoint, f"{self.checkpoint_path}/model_{step}.pt")

    def load_checkpoint(self):
        if (
            not exists(self.checkpoint_path)
            or not os.path.exists(self.checkpoint_path)
            or not os.listdir(self.checkpoint_path)
        ):
            return 0

        self.accelerator.wait_for_everyone()
        if "model_last.pt" in os.listdir(self.checkpoint_path):
            latest_checkpoint = "model_last.pt"
        else:
            latest_checkpoint = sorted(
                [f for f in os.listdir(self.checkpoint_path) if f.endswith(".pt")],
                key=lambda x: int("".join(filter(str.isdigit, x))),
            )[-1]
        
        checkpoint = torch.load(f"{self.checkpoint_path}/{latest_checkpoint}", map_location="cpu")

        ### **1. 过滤 `ema_model` 的不匹配参数**
        if self.is_main:
            ema_dict = self.ema_model.state_dict()
            ema_checkpoint_dict = checkpoint["ema_model_state_dict"]

            filtered_ema_dict = {
                k: v for k, v in ema_checkpoint_dict.items()
                if k in ema_dict and ema_dict[k].shape == v.shape  # 仅加载 shape 匹配的参数
            }

            print(f"Loading {len(filtered_ema_dict)} / {len(ema_checkpoint_dict)} ema_model params")
            self.ema_model.load_state_dict(filtered_ema_dict, strict=False)

        ### **2. 过滤 `model` 的不匹配参数**
        model_dict = self.accelerator.unwrap_model(self.model).state_dict()
        checkpoint_model_dict = checkpoint["model_state_dict"]

        filtered_model_dict = {
            k: v for k, v in checkpoint_model_dict.items()
            if k in model_dict and model_dict[k].shape == v.shape  # 仅加载 shape 匹配的参数
        }

        print(f"Loading {len(filtered_model_dict)} / {len(checkpoint_model_dict)} model params")
        self.accelerator.unwrap_model(self.model).load_state_dict(filtered_model_dict, strict=False)

        ### **3. 加载优化器、调度器和步数**
        if "step" in checkpoint:
            if self.scheduler and not self.reset_lr:
                self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
            step = checkpoint["step"]
        else:
            step = 0

        del checkpoint
        gc.collect()
        print("Checkpoint loaded at step", step)
        return step

    def train(self, resumable_with_seed: int = None):
        train_dataloader = self.train_dataloader

        start_step = self.load_checkpoint()
        global_step = start_step

        if resumable_with_seed > 0:
            orig_epoch_step = len(train_dataloader)
            skipped_epoch = int(start_step // orig_epoch_step)
            skipped_batch = start_step % orig_epoch_step
            skipped_dataloader = self.accelerator.skip_first_batches(train_dataloader, num_batches=skipped_batch)
        else:
            skipped_epoch = 0

        for epoch in range(skipped_epoch, self.epochs):
            self.model.train()
            if resumable_with_seed > 0 and epoch == skipped_epoch:
                progress_bar = tqdm(
                    skipped_dataloader,
                    desc=f"Epoch {epoch+1}/{self.epochs}",
                    unit="step",
                    disable=not self.accelerator.is_local_main_process,
                    initial=skipped_batch,
                    total=orig_epoch_step,
                    smoothing=0.15
                )
            else:
                progress_bar = tqdm(
                    train_dataloader,
                    desc=f"Epoch {epoch+1}/{self.epochs}",
                    unit="step",
                    disable=not self.accelerator.is_local_main_process,
                    smoothing=0.15
                )

            for batch in progress_bar:
                with self.accelerator.accumulate(self.model):
                    text_inputs = batch["lrc"]
                    mel_spec = batch["latent"].permute(0, 2, 1)
                    mel_lengths = batch["latent_lengths"]
                    style_prompt = batch["prompt"]
                    style_prompt_lens = batch["prompt_lengths"]
                    start_time = batch["start_time"]

                    loss, cond, pred = self.model(
                        mel_spec, text=text_inputs, lens=mel_lengths, noise_scheduler=self.noise_scheduler,
                        style_prompt=style_prompt if self.use_style_prompt else None,
                        style_prompt_lens=style_prompt_lens if self.use_style_prompt else None,
                        grad_ckpt=self.grad_ckpt, start_time=start_time
                    )
                    self.accelerator.backward(loss)

                    if self.max_grad_norm > 0 and self.accelerator.sync_gradients:
                        self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)

                    self.optimizer.step()
                    self.scheduler.step()
                    self.optimizer.zero_grad()

                if self.is_main:
                    self.ema_model.update()

                global_step += 1

                if self.accelerator.is_local_main_process:
                    self.accelerator.log({"loss": loss.item(), "lr": self.scheduler.get_last_lr()[0]}, step=global_step)

                progress_bar.set_postfix(step=str(global_step), loss=loss.item())

                if global_step % (self.save_per_updates * self.grad_accumulation_steps) == 0:
                    self.save_checkpoint(global_step)

                if global_step % self.last_per_steps == 0:
                    self.save_checkpoint(global_step, last=True)

        self.save_checkpoint(global_step, last=True)

        self.accelerator.end_training()