import os import glob import hydra import cv2 import numpy as np import torch from lib.utils import utils class Dataset(torch.utils.data.Dataset): def __init__(self, metainfo, split): root = os.path.join("../data", metainfo.data_dir) root = hydra.utils.to_absolute_path(root) self.start_frame = metainfo.start_frame self.end_frame = metainfo.end_frame self.skip_step = 1 self.images, self.img_sizes = [], [] self.training_indices = list(range(metainfo.start_frame, metainfo.end_frame, self.skip_step)) # images img_dir = os.path.join(root, "image") self.img_paths = sorted(glob.glob(f"{img_dir}/*.png")) # only store the image paths to avoid OOM self.img_paths = [self.img_paths[i] for i in self.training_indices] self.img_size = cv2.imread(self.img_paths[0]).shape[:2] self.n_images = len(self.img_paths) # coarse projected SMPL masks, only for sampling mask_dir = os.path.join(root, "mask") self.mask_paths = sorted(glob.glob(f"{mask_dir}/*.png")) self.mask_paths = [self.mask_paths[i] for i in self.training_indices] self.shape = np.load(os.path.join(root, "mean_shape.npy")) self.poses = np.load(os.path.join(root, 'poses.npy'))[self.training_indices] self.trans = np.load(os.path.join(root, 'normalize_trans.npy'))[self.training_indices] # cameras camera_dict = np.load(os.path.join(root, "cameras_normalize.npz")) scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in self.training_indices] world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in self.training_indices] self.scale = 1 / scale_mats[0][0, 0] self.intrinsics_all = [] self.pose_all = [] for scale_mat, world_mat in zip(scale_mats, world_mats): P = world_mat @ scale_mat P = P[:3, :4] intrinsics, pose = utils.load_K_Rt_from_P(None, P) self.intrinsics_all.append(torch.from_numpy(intrinsics).float()) self.pose_all.append(torch.from_numpy(pose).float()) assert len(self.intrinsics_all) == len(self.pose_all) # other properties self.num_sample = split.num_sample self.sampling_strategy = "weighted" def __len__(self): return self.n_images def __getitem__(self, idx): # normalize RGB img = cv2.imread(self.img_paths[idx]) # preprocess: BGR -> RGB -> Normalize img = img[:, :, ::-1] / 255 mask = cv2.imread(self.mask_paths[idx]) # preprocess: BGR -> Gray -> Mask mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) > 0 img_size = self.img_size uv = np.mgrid[:img_size[0], :img_size[1]].astype(np.int32) uv = np.flip(uv, axis=0).copy().transpose(1, 2, 0).astype(np.float32) smpl_params = torch.zeros([86]).float() smpl_params[0] = torch.from_numpy(np.asarray(self.scale)).float() smpl_params[1:4] = torch.from_numpy(self.trans[idx]).float() smpl_params[4:76] = torch.from_numpy(self.poses[idx]).float() smpl_params[76:] = torch.from_numpy(self.shape).float() if self.num_sample > 0: data = { "rgb": img, "uv": uv, "object_mask": mask, } samples, index_outside = utils.weighted_sampling(data, img_size, self.num_sample) inputs = { "uv": samples["uv"].astype(np.float32), "intrinsics": self.intrinsics_all[idx], "pose": self.pose_all[idx], "smpl_params": smpl_params, 'index_outside': index_outside, "idx": idx } images = {"rgb": samples["rgb"].astype(np.float32)} return inputs, images else: inputs = { "uv": uv.reshape(-1, 2).astype(np.float32), "intrinsics": self.intrinsics_all[idx], "pose": self.pose_all[idx], "smpl_params": smpl_params, "idx": idx } images = { "rgb": img.reshape(-1, 3).astype(np.float32), "img_size": self.img_size } return inputs, images class ValDataset(torch.utils.data.Dataset): def __init__(self, metainfo, split): self.dataset = Dataset(metainfo, split) self.img_size = self.dataset.img_size self.total_pixels = np.prod(self.img_size) self.pixel_per_batch = split.pixel_per_batch def __len__(self): return 1 def __getitem__(self, idx): image_id = int(np.random.choice(len(self.dataset), 1)) self.data = self.dataset[image_id] inputs, images = self.data inputs = { "uv": inputs["uv"], "intrinsics": inputs['intrinsics'], "pose": inputs['pose'], "smpl_params": inputs["smpl_params"], 'image_id': image_id, "idx": inputs['idx'] } images = { "rgb": images["rgb"], "img_size": images["img_size"], 'pixel_per_batch': self.pixel_per_batch, 'total_pixels': self.total_pixels } return inputs, images class TestDataset(torch.utils.data.Dataset): def __init__(self, metainfo, split): self.dataset = Dataset(metainfo, split) self.img_size = self.dataset.img_size self.total_pixels = np.prod(self.img_size) self.pixel_per_batch = split.pixel_per_batch def __len__(self): return len(self.dataset) def __getitem__(self, idx): data = self.dataset[idx] inputs, images = data inputs = { "uv": inputs["uv"], "intrinsics": inputs['intrinsics'], "pose": inputs['pose'], "smpl_params": inputs["smpl_params"], "idx": inputs['idx'] } images = { "rgb": images["rgb"], "img_size": images["img_size"] } return inputs, images, self.pixel_per_batch, self.total_pixels, idx