import random
import numpy as np
from pathlib import Path
from ResizeRight.resize_right import resize
from einops import rearrange

import torch
import torchvision as thv
from torch.utils.data import Dataset

from utils import util_sisr
from utils import util_image
from utils import util_common

from basicsr.data.realesrgan_dataset import RealESRGANDataset
from .ffhq_degradation_dataset import FFHQDegradationDataset

def get_transforms(transform_type, out_size, sf):
    if transform_type == 'default':
        transform = thv.transforms.Compose([
            util_image.SpatialAug(),
            thv.transforms.ToTensor(),
            thv.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
    elif transform_type == 'face':
        transform = thv.transforms.Compose([
            thv.transforms.ToTensor(),
            thv.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
    elif transform_type == 'bicubic':
        transform = thv.transforms.Compose([
            util_sisr.Bicubic(1/sf),
            thv.transforms.ToTensor(),
            thv.transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
    else:
        raise ValueError(f'Unexpected transform_variant {transform_variant}')
    return transform

def create_dataset(dataset_config):
    if dataset_config['type'] == 'gfpgan':
        dataset = FFHQDegradationDataset(dataset_config['params'])
    elif dataset_config['type'] == 'face':
        dataset = BaseDatasetFace(**dataset_config['params'])
    elif dataset_config['type'] == 'bicubic':
        dataset = DatasetBicubic(**dataset_config['params'])
    elif dataset_config['type'] == 'folder':
        dataset = BaseDataFolder(**dataset_config['params'])
    elif dataset_config['type'] == 'realesrgan':
        dataset = RealESRGANDataset(dataset_config['params'])
    else:
        raise NotImplementedError(dataset_config['type'])

    return dataset

class BaseDatasetFace(Dataset):
    def __init__(self, celeba_txt=None,
                       ffhq_txt=None,
                       out_size=256,
                       transform_type='face',
                       sf=None,
                       length=None):
        super().__init__()
        self.files_names = util_common.readline_txt(celeba_txt) + util_common.readline_txt(ffhq_txt)

        if length is None:
            self.length = len(self.files_names)
        else:
            self.length = length

        self.transform = get_transforms(transform_type, out_size, sf)

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        im_path = self.files_names[index]
        im = util_image.imread(im_path, chn='rgb', dtype='uint8')
        im = self.transform(im)
        return {'image':im,}

class DatasetBicubic(Dataset):
    def __init__(self,
            files_txt=None,
            val_dir=None,
            ext='png',
            sf=None,
            up_back=False,
            need_gt_path=False,
            length=None):
        super().__init__()
        if val_dir is None:
            self.files_names = util_common.readline_txt(files_txt)
        else:
            self.files_names = [str(x) for x in Path(val_dir).glob(f"*.{ext}")]
        self.sf = sf
        self.up_back = up_back
        self.need_gt_path = need_gt_path

        if length is None:
            self.length = len(self.files_names)
        else:
            self.length = length

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        im_path = self.files_names[index]
        im_gt = util_image.imread(im_path, chn='rgb', dtype='float32')
        im_lq = resize(im_gt, scale_factors=1/self.sf)
        if self.up_back:
            im_lq = resize(im_lq, scale_factors=self.sf)

        im_lq = rearrange(im_lq, 'h w c -> c h w')
        im_lq = torch.from_numpy(im_lq).type(torch.float32)

        im_gt = rearrange(im_gt, 'h w c -> c h w')
        im_gt = torch.from_numpy(im_gt).type(torch.float32)

        if self.need_gt_path:
            return {'lq':im_lq, 'gt':im_gt, 'gt_path':im_path}
        else:
            return {'lq':im_lq, 'gt':im_gt}

class BaseDataFolder(Dataset):
    def __init__(
            self,
            dir_path,
            dir_path_gt,
            need_gt_path=True,
            length=None,
            ext=['png', 'jpg', 'jpeg', 'JPEG', 'bmp'],
            mean=0.5,
            std=0.5,
            ):
        super(BaseDataFolder, self).__init__()
        if isinstance(ext, str):
            files_path = [str(x) for x in Path(dir_path).glob(f'*.{ext}')]
        else:
            assert isinstance(ext, list) or isinstance(ext, tuple)
            files_path = []
            for current_ext in ext:
                files_path.extend([str(x) for x in Path(dir_path).glob(f'*.{current_ext}')])
        self.files_path = files_path if length is None else files_path[:length]
        self.dir_path_gt = dir_path_gt
        self.need_gt_path = need_gt_path
        self.mean=mean
        self.std=std

    def __len__(self):
        return len(self.files_path)

    def __getitem__(self, index):
        im_path = self.files_path[index]
        im = util_image.imread(im_path, chn='rgb', dtype='float32')
        im = util_image.normalize_np(im, mean=self.mean, std=self.std, reverse=False)
        im = rearrange(im, 'h w c -> c h w')
        out_dict = {'image':im.astype(np.float32), 'lq':im.astype(np.float32)}

        if self.need_gt_path:
            out_dict['path'] = im_path

        if self.dir_path_gt is not None:
            gt_path = str(Path(self.dir_path_gt) / Path(im_path).name)
            im_gt = util_image.imread(gt_path, chn='rgb', dtype='float32')
            im_gt = util_image.normalize_np(im_gt, mean=self.mean, std=self.std, reverse=False)
            im_gt = rearrange(im_gt, 'h w c -> c h w')
            out_dict['gt'] = im_gt.astype(np.float32)

        return out_dict