|
import os |
|
import glob |
|
import hydra |
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from lib.utils import utils |
|
|
|
|
|
class Dataset(torch.utils.data.Dataset): |
|
def __init__(self, metainfo, split): |
|
root = os.path.join("../data", metainfo.data_dir) |
|
root = hydra.utils.to_absolute_path(root) |
|
|
|
self.start_frame = metainfo.start_frame |
|
self.end_frame = metainfo.end_frame |
|
self.skip_step = 1 |
|
self.images, self.img_sizes = [], [] |
|
self.training_indices = list(range(metainfo.start_frame, metainfo.end_frame, self.skip_step)) |
|
|
|
|
|
img_dir = os.path.join(root, "image") |
|
self.img_paths = sorted(glob.glob(f"{img_dir}/*.png")) |
|
|
|
|
|
self.img_paths = [self.img_paths[i] for i in self.training_indices] |
|
self.img_size = cv2.imread(self.img_paths[0]).shape[:2] |
|
self.n_images = len(self.img_paths) |
|
|
|
|
|
mask_dir = os.path.join(root, "mask") |
|
self.mask_paths = sorted(glob.glob(f"{mask_dir}/*.png")) |
|
self.mask_paths = [self.mask_paths[i] for i in self.training_indices] |
|
|
|
self.shape = np.load(os.path.join(root, "mean_shape.npy")) |
|
self.poses = np.load(os.path.join(root, 'poses.npy'))[self.training_indices] |
|
self.trans = np.load(os.path.join(root, 'normalize_trans.npy'))[self.training_indices] |
|
|
|
camera_dict = np.load(os.path.join(root, "cameras_normalize.npz")) |
|
scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in self.training_indices] |
|
world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in self.training_indices] |
|
|
|
self.scale = 1 / scale_mats[0][0, 0] |
|
|
|
self.intrinsics_all = [] |
|
self.pose_all = [] |
|
for scale_mat, world_mat in zip(scale_mats, world_mats): |
|
P = world_mat @ scale_mat |
|
P = P[:3, :4] |
|
intrinsics, pose = utils.load_K_Rt_from_P(None, P) |
|
self.intrinsics_all.append(torch.from_numpy(intrinsics).float()) |
|
self.pose_all.append(torch.from_numpy(pose).float()) |
|
assert len(self.intrinsics_all) == len(self.pose_all) |
|
|
|
|
|
self.num_sample = split.num_sample |
|
self.sampling_strategy = "weighted" |
|
|
|
def __len__(self): |
|
return self.n_images |
|
|
|
def __getitem__(self, idx): |
|
|
|
img = cv2.imread(self.img_paths[idx]) |
|
|
|
|
|
img = img[:, :, ::-1] / 255 |
|
|
|
mask = cv2.imread(self.mask_paths[idx]) |
|
|
|
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) > 0 |
|
|
|
img_size = self.img_size |
|
|
|
uv = np.mgrid[:img_size[0], :img_size[1]].astype(np.int32) |
|
uv = np.flip(uv, axis=0).copy().transpose(1, 2, 0).astype(np.float32) |
|
|
|
smpl_params = torch.zeros([86]).float() |
|
smpl_params[0] = torch.from_numpy(np.asarray(self.scale)).float() |
|
|
|
smpl_params[1:4] = torch.from_numpy(self.trans[idx]).float() |
|
smpl_params[4:76] = torch.from_numpy(self.poses[idx]).float() |
|
smpl_params[76:] = torch.from_numpy(self.shape).float() |
|
|
|
if self.num_sample > 0: |
|
data = { |
|
"rgb": img, |
|
"uv": uv, |
|
"object_mask": mask, |
|
} |
|
|
|
samples, index_outside = utils.weighted_sampling(data, img_size, self.num_sample) |
|
inputs = { |
|
"uv": samples["uv"].astype(np.float32), |
|
"intrinsics": self.intrinsics_all[idx], |
|
"pose": self.pose_all[idx], |
|
"smpl_params": smpl_params, |
|
'index_outside': index_outside, |
|
"idx": idx |
|
} |
|
images = {"rgb": samples["rgb"].astype(np.float32)} |
|
return inputs, images |
|
else: |
|
inputs = { |
|
"uv": uv.reshape(-1, 2).astype(np.float32), |
|
"intrinsics": self.intrinsics_all[idx], |
|
"pose": self.pose_all[idx], |
|
"smpl_params": smpl_params, |
|
"idx": idx |
|
} |
|
images = { |
|
"rgb": img.reshape(-1, 3).astype(np.float32), |
|
"img_size": self.img_size |
|
} |
|
return inputs, images |
|
|
|
class ValDataset(torch.utils.data.Dataset): |
|
def __init__(self, metainfo, split): |
|
self.dataset = Dataset(metainfo, split) |
|
self.img_size = self.dataset.img_size |
|
|
|
self.total_pixels = np.prod(self.img_size) |
|
self.pixel_per_batch = split.pixel_per_batch |
|
|
|
def __len__(self): |
|
return 1 |
|
|
|
def __getitem__(self, idx): |
|
image_id = int(np.random.choice(len(self.dataset), 1)) |
|
self.data = self.dataset[image_id] |
|
inputs, images = self.data |
|
|
|
inputs = { |
|
"uv": inputs["uv"], |
|
"intrinsics": inputs['intrinsics'], |
|
"pose": inputs['pose'], |
|
"smpl_params": inputs["smpl_params"], |
|
'image_id': image_id, |
|
"idx": inputs['idx'] |
|
} |
|
images = { |
|
"rgb": images["rgb"], |
|
"img_size": images["img_size"], |
|
'pixel_per_batch': self.pixel_per_batch, |
|
'total_pixels': self.total_pixels |
|
} |
|
return inputs, images |
|
|
|
class TestDataset(torch.utils.data.Dataset): |
|
def __init__(self, metainfo, split): |
|
self.dataset = Dataset(metainfo, split) |
|
|
|
self.img_size = self.dataset.img_size |
|
|
|
self.total_pixels = np.prod(self.img_size) |
|
self.pixel_per_batch = split.pixel_per_batch |
|
def __len__(self): |
|
return len(self.dataset) |
|
|
|
def __getitem__(self, idx): |
|
data = self.dataset[idx] |
|
|
|
inputs, images = data |
|
inputs = { |
|
"uv": inputs["uv"], |
|
"intrinsics": inputs['intrinsics'], |
|
"pose": inputs['pose'], |
|
"smpl_params": inputs["smpl_params"], |
|
"idx": inputs['idx'] |
|
} |
|
images = { |
|
"rgb": images["rgb"], |
|
"img_size": images["img_size"] |
|
} |
|
return inputs, images, self.pixel_per_batch, self.total_pixels, idx |
|
|