GenStereo / genstereo /dataset /dataset_sg.py
FQiao's picture
Upload 144 files
b4754db verified
import json
import random
import torch
import torchvision.transforms as transforms
# from decord import VideoReader
from PIL import Image
from torch.utils.data import Dataset
from transformers import CLIPImageProcessor
from jaxtyping import Float
import h5py
import os
import json
import torch
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import cv2
from genstereo.dataset.EXRloader import load_exr
# from EXRloader import load_exr
def convert_left_to_right(left_embed, disparity, left_image, random_ratio=None):
# Get the height, width, and channels from the left embedding
_, height, width = left_embed.shape
# Initialize tensors for right_embed, converted_right_image, and mask
# right_embed = torch.full_like(left_embed, 255)
# converted_right_image = torch.full_like(left_image, 255)
right_embed = torch.zeros_like(left_embed)
converted_right_image = torch.zeros_like(left_image)
mask = torch.ones((height, width), dtype=torch.uint8, device=left_embed.device)
# Round the disparity and convert to int
disparity_rounded = torch.round(disparity).squeeze(0).long() # [h, w]
# Loop through the image dimensions and apply the conversion
for y in range(height):
for x in range(width):
new_x = x - disparity_rounded[y, x]
if 0 <= new_x < width:# and disparity_rounded[y, x] > 0:
right_embed[:, y, new_x] = left_embed[:, y, x]
converted_right_image[:, y, new_x] = left_image[:, y, x]
mask[y, new_x] = 0 # Mark as valid in the mask
return right_embed, mask, converted_right_image, disparity
def convert_left_to_right_torch(left_embed, disparity, left_image, random_ratio=None, dataset_name=None):
"""
Convert left features to right features based on disparity values.
Args:
left_embed (torch.Tensor): [c, h, w] tensor representing left feature embeddings.
disparity (torch.Tensor): [1, h, w] tensor of disparity values.
left_image (torch.Tensor): [3, h, w] tensor representing the left image.
Returns:
right_embed (torch.Tensor): [c, h, w] tensor for the right feature embeddings.
mask (torch.Tensor): [h, w] binary mask (1 = invalid, 0 = valid).
converted_right_image (torch.Tensor): [3, h, w] tensor for the right image.
disparity (torch.Tensor): [1, h, w] tensor for the disparity.
"""
# Get the height, width, and channels from the left embedding
_, height, width = left_embed.shape
# Initialize tensors for right_embed, converted_right_image, and mask
right_embed = torch.zeros_like(left_embed)
# converted_right_image = torch.zeros_like(left_image)
converted_right_image = -torch.ones_like(left_image)
mask = torch.ones((height, width), dtype=torch.uint8, device=left_embed.device)
# Round the disparity and convert to int
disparity_rounded = torch.round(disparity).squeeze(0).long() # [h, w]
# Iterate over width and process each column for all rows
for x in range(width):
new_x = x - disparity_rounded[:, x]
if dataset_name == 'InStereo2K' or dataset_name == 'DrivingStereo':
valid_indices = (new_x >= 0) & (new_x < width) & (disparity_rounded[:, x] > 0)
else:
valid_indices = (new_x >= 0) & (new_x < width)
valid_new_x = new_x[valid_indices]
valid_y = torch.arange(height, device=left_embed.device)[valid_indices]
right_embed[:, valid_y, valid_new_x] = left_embed[:, valid_y, x]
converted_right_image[:, valid_y, valid_new_x] = left_image[:, valid_y, x]
mask[valid_y, valid_new_x] = 0 # Mark as valid in the mask
# Apply random masking if random_ratio is set
if random_ratio is not None:
# Create a random mask
random_mask = torch.bernoulli(torch.full((height, width), 1 - random_ratio, device=left_embed.device)).byte()
mask |= random_mask
# Apply the mask to right_embed, converted_right_image, and disparity
right_embed[:, mask == 1] = 0 # Mask out invalid regions in right_embed
converted_right_image[:, mask == 1] = 0 # Mask out invalid regions in converted_right_image
disparity[:, mask == 1] = 0 # Mask out invalid regions in disparity
return right_embed, mask, converted_right_image, disparity
class StereoGenDataset(Dataset):
def __init__(self,
json_files,
img_size,
img_scale=(1.0, 1.0),
img_ratio=(0.9, 1.0),
drop_ratio=0.1,
debug=False,
use_coords=True,
use_wapred=True,):
"""
Args:
json_files (list): Paths to the JSON file.
"""
super().__init__()
self.data = []
for json_file in json_files:
with open(json_file, 'r') as f:
previous_length = len(self.data)
self.data += json.load(f)
added_length = len(self.data) - previous_length
print(f"Loaded {added_length} samples from {json_file}")
# self.data = self.data[:10] # Limit the number of samples to 1M
self.img_size = img_size
self.embedder = self.get_embedder(2)
self.drop_ratio = drop_ratio
self.transform = transforms.Compose([
transforms.ToTensor(), # Convert PIL image to Tensor and scale to [0, 1]
])
self.transform_pixels = transforms.Compose([
transforms.ToTensor(), # Converts image to Tensor
transforms.Normalize([0.5], [0.5]) # Normalize to [-1, 1]
])
self.clip_image_processor = CLIPImageProcessor()
self.debug = debug
self.use_coords = use_coords
self.use_wapred = use_wapred
def __len__(self):
return len(self.data)
def crop(self, img: Image) -> Image:
W, H = img.size
if W < H:
left, right = 0, W
top, bottom = np.ceil((H - W) / 2.), np.floor((H - W) / 2.) + W
else:
left, right = np.ceil((W - H) / 2.), np.floor((W - H) / 2.) + H
top, bottom = 0, H
img = img.crop((left, top, right, bottom))
img = img.resize((self.img_size, self.img_size), Image.BILINEAR)
return img
def crop_and_resize_disp(self, disparity_left):
# Determine the smaller side
h, w = disparity_left.shape[:2]
min_side = min(h, w)
# Calculate the cropping coordinates
start_x = (w - min_side) // 2
start_y = (h - min_side) // 2
# Crop the array to a square
cropped_disparity = disparity_left[start_y:start_y + min_side, start_x:start_x + min_side]
# Resize the cropped array to the desired size
ratio = self.img_size / min_side
resized_disparity = cv2.resize(cropped_disparity, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) * ratio
return resized_disparity
def random_crop_and_resize(self, image_left: Image, image_right: Image, disparity_left: np.ndarray):
"""
Randomly crop and resize stereo image pairs and their disparity maps.
Args:
image_left (Image.Image): Left image (PIL).
image_right (Image.Image): Right image (PIL).
disparity_left (np.ndarray): Left disparity map.
Returns:
tuple: Resized left image, right image, and disparity map.
"""
# Get the dimensions of the image and disparity map
W, H = image_left.size
h_disp, w_disp = disparity_left.shape[:2]
# Ensure the image and disparity map have the same dimensions
assert W == w_disp and H == h_disp, "Image and disparity dimensions must match."
assert isinstance(image_left, Image.Image) and isinstance(image_right, Image.Image), \
"Inputs must be PIL images."
assert isinstance(disparity_left, np.ndarray), "Disparity must be a NumPy array."
# Determine crop size
if min(W, H) > 3 * self.img_size:
crop_size = 3 * self.img_size
elif min(W, H) > 2 * self.img_size:
crop_size = 2 * self.img_size
elif min(W, H) >= self.img_size:
crop_size = self.img_size
else:
crop_size = min(W, H)
# Calculate random crop coordinates
max_x = W - crop_size
max_y = H - crop_size
left = random.randint(0, max(max_x, 0))
top = random.randint(0, max(max_y, 0))
right = left + crop_size
bottom = top + crop_size
# Perform cropping
image_left_cropped = image_left.crop((left, top, right, bottom))
image_right_cropped = image_right.crop((left, top, right, bottom))
disparity_cropped = disparity_left[top:bottom, left:right]
# Resize images and disparity map if necessary
if crop_size != self.img_size:
image_left_resized = image_left_cropped.resize((self.img_size, self.img_size), Image.BILINEAR)
image_right_resized = image_right_cropped.resize((self.img_size, self.img_size), Image.BILINEAR)
ratio = self.img_size / crop_size
disparity_resized = cv2.resize(disparity_cropped, (self.img_size, self.img_size), interpolation=cv2.INTER_LINEAR) * ratio
else:
image_left_resized = image_left_cropped
image_right_resized = image_right_cropped
disparity_resized = disparity_cropped
return image_left_resized, image_right_resized, disparity_resized
class Embedder():
def __init__(self, **kwargs) -> None:
self.kwargs = kwargs
self.create_embedding_fn()
def create_embedding_fn(self) -> None:
embed_fns = []
d = self.kwargs['input_dims']
out_dim = 0
if self.kwargs['include_input']:
embed_fns.append(lambda x : x)
out_dim += d
max_freq = self.kwargs['max_freq_log2']
N_freqs = self.kwargs['num_freqs']
if self.kwargs['log_sampling']:
freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
else:
freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
for freq in freq_bands:
for p_fn in self.kwargs['periodic_fns']:
embed_fns.append(lambda x, p_fn=p_fn, freq=freq : p_fn(x * freq))
out_dim += d
self.embed_fns = embed_fns
self.out_dim = out_dim
def embed(self, inputs) -> Tensor:
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
def get_embedder(self, multires):
embed_kwargs = {
'include_input' : True,
'input_dims' : 2,
'max_freq_log2' : multires-1,
'num_freqs' : multires,
'log_sampling' : True,
'periodic_fns' : [torch.sin, torch.cos],
}
embedder_obj = self.Embedder(**embed_kwargs)
embed = lambda x, eo=embedder_obj : eo.embed(x)
return embed
def getdata(self, idx):
try:
if True:
image_left_path = self.data[idx]['image_left'].replace('/home/f.qiao/Active', '/storage1/jacobsn/Active/user_f.qiao')
image_right_path = self.data[idx]['image_right'].replace('/home/f.qiao/Active', '/storage1/jacobsn/Active/user_f.qiao')
if 'depth_left' in self.data[idx]:
self.data[idx]['depth_left'] = self.data[idx]['depth_left'].replace('/home/f.qiao/Active', '/storage1/jacobsn/Active/user_f.qiao')
elif 'disparity_left' in self.data[idx]:
self.data[idx]['disparity_left'] = self.data[idx]['disparity_left'].replace('/home/f.qiao/Active', '/storage1/jacobsn/Active/user_f.qiao')
else:
image_left_path = self.data[idx]['image_left']
image_right_path = self.data[idx]['image_right']
image_left = Image.open(image_left_path).convert('RGB')
image_right = Image.open(image_right_path).convert('RGB')
disparity_left = None
dataset_name = self.data[idx]["dataset"]
if dataset_name == 'TartanAir':
depth_left_path = self.data[idx]['depth_left']
disparity_left = 80./np.load(depth_left_path)
elif dataset_name == 'IRS':
depth_left_path = self.data[idx]['depth_left']
disparity_left = load_exr(depth_left_path)
elif dataset_name == 'DrivingStereo':
# grpuond truth disparity
disparity_left_path = self.data[idx]['disparity_left']
disparity_left = np.array(Image.open(disparity_left_path), dtype=np.float32) / 256.0
# pseudo disparity
# disparity_left_path = self.data[idx]['disparity_left'].replace('train-disparity-map', 'train-disparity-map-pseudo').replace('.png', '.npy')
# disparity_left = np.load(disparity_left_path)
elif dataset_name == 'VKITTI2':
depth_left_path = self.data[idx]['depth_left']
depth_left = cv2.imread(depth_left_path, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR) / 100.
# invalid = depth_left >= 65535
# print("num_invalid(VKITTI2):", depth_left[invalid].shape[0])
disparity_left = 0.532725 * 725.0087 / (depth_left + 1e-5) # f = 725.0087, b = 0.532725 # meter
elif dataset_name == 'InStereo2K':
disparity_left_path = self.data[idx]['disparity_left']
disparity_left = Image.open(disparity_left_path)
disparity_left = np.array(disparity_left).astype(np.float32)
disparity_left = disparity_left/100
elif dataset_name == 'Sintel':
disparity_left_path = self.data[idx]['disparity_left']
f_in = np.array(Image.open(disparity_left_path))
d_r = f_in[:,:,0].astype('float64')
d_g = f_in[:,:,1].astype('float64')
d_b = f_in[:,:,2].astype('float64')
disparity_left = d_r * 4 + d_g / (2**6) + d_b / (2**14)
elif dataset_name == 'crestereo':
disparity_left_path = self.data[idx]['disparity_left']
disparity_left = cv2.imread(disparity_left_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 32
elif dataset_name == 'Spring':
disparity_left_path = self.data[idx]['disparity_left']
with h5py.File(disparity_left_path, "r") as f:
disparity_left = np.array(f["disparity"][()]).astype(np.float32)
disparity_left = np.ascontiguousarray(disparity_left, dtype=np.float32)[::2, ::2]
elif dataset_name == 'Falling_Things':
depth_left_path = self.data[idx]['depth_left']
depth_left = np.array(Image.open(depth_left_path), dtype=np.float32)
disparity_left = 460896 / depth_left # 6cm * 768.1605834960938px * 100 = 460896
elif dataset_name == 'SimStereo':
depth_left_path = self.data[idx]['disparity_left'].replace('left', 'right')
disparity_left = np.load(depth_left_path)
elif dataset_name == 'PLT-D3':
depth_left_path = self.data[idx]['depth_left'].replace('left', 'right')
disparity_left = 0.12 * 800 / np.load(depth_left_path)['arr_0'] # 0.12m * 800 / depth
elif dataset_name == 'DynamicReplica':
disparity_left_path = self.data[idx]['disparity_left']
disparity_left = np.load(disparity_left_path)
elif dataset_name == 'InfinigenSV':
disparity_left_path = self.data[idx]['disparity_left']
disparity_left = np.load(disparity_left_path)
elif dataset_name == 'UnrealStereo4K':
disparity_left_path = self.data[idx]['disparity_left']
disparity_left = np.load(disparity_left_path, mmap_mode='c')
elif dataset_name == 'skdataset':
disparity_left_path = self.data[idx]['disparity_left']
disparity_left = np.load(disparity_left_path)
elif dataset_name == 'DIML':
disparity_left_path = self.data[idx]['disparity_left']
disparity_left = np.load(disparity_left_path)
else:
print(f"Dataset {self.data[idx]['dataset']} is not supported.")
return image_left, image_right, disparity_left, dataset_name
except Exception as e:
bad_file_path = self.data[idx]['image_left'] # Capture the bad file path
print(f"Error loading data from {bad_file_path}: {e}")
return None, None, None, None
def __getitem__(self, idx):
# def getitem(self, idx):
# 1.Load images and depth maps
image_left, image_right, disparity_left, dataset_name = self.getdata(idx)
# Retry or skip sample if None is returned
if image_left is None or image_right is None or disparity_left is None:
print(f"Data at index {idx} is invalid. Skipping.")
return self.__getitem__((idx + 1) % len(self.data)) # Try next index
# 2. Crop and resize
image_left, image_right, disparity_left = self.random_crop_and_resize(image_left, image_right, disparity_left)
# 3. Generate coords
grid: Float[Tensor, 'H W C'] = torch.stack(torch.meshgrid(
torch.arange(self.img_size), torch.arange(self.img_size), indexing='xy'), dim=-1
) # torch.Size([512, 512, 2])
# 4. Coordinates embedding.
coords = torch.stack((grid[..., 0]/self.img_size, grid[..., 1]/self.img_size), dim=-1)
embed = self.embedder(coords)
embed = embed.permute(2, 0, 1) # h w c -> c h w torch.Size([10, 512, 512])
# 5. Convert to PyTorch tensors
image_left_tensor = self.transform_pixels(image_left)
image_right_tensor = self.transform_pixels(image_right)
# image_left_tensor = torch.tensor(np.array(image_left), dtype=torch.float32).permute(2, 0, 1)
# image_right_tensor = torch.tensor(np.array(image_right), dtype=torch.float32).permute(2, 0, 1)
disparity_left_tensor = torch.tensor(disparity_left, dtype=torch.float32).unsqueeze(0) # Add a channel dimension
# 6. Warp left to right
random_mask = random.random()
rando_ratio = random.random() if random_mask < self.drop_ratio else None
warped_embed, mask, converted_right, disparity_left_tensor = convert_left_to_right_torch(embed, disparity_left_tensor, image_left_tensor, rando_ratio, dataset_name=dataset_name)
if self.debug:
save_folder = "./check_dataset/"
os.makedirs(save_folder, exist_ok=True)
# cv2.imwrite(f"{save_folder}/{dataset_name}_{idx}_left.png", (image_left_tensor.permute(1, 2, 0).numpy()[:, :, ::-1]/2+0.5)*255)
cv2.imwrite(f"{save_folder}/{dataset_name}_{idx}_mask.png", mask.numpy()*255)
cv2.imwrite(f"{save_folder}/{dataset_name}_{idx}_right.png", (image_right_tensor.permute(1, 2, 0).numpy()[:, :, ::-1]/2+0.5)*255)
cv2.imwrite(f"{save_folder}/{dataset_name}_{idx}_converted_right.png", (converted_right.permute(1, 2, 0).numpy()[:, :, ::-1]/2+0.5)*255)
# import IPython; IPython.embed()
# print("embed.shape:", embed.shape, mask.unsqueeze(0).shape, image_left_tensor.shape, converted_right.shape)
# 7. Add mask to the embeddings
if self.use_coords and self.use_wapred:
src_coords_embed = torch.cat(
[embed, torch.zeros_like(mask.unsqueeze(0), device=mask.device), image_left_tensor], dim=0)
trg_coords_embed = torch.cat([warped_embed, mask.unsqueeze(0), converted_right], dim=0)
elif self.use_coords and not self.use_wapred:
src_coords_embed = torch.cat([embed, torch.zeros_like(mask.unsqueeze(0), device=mask.device)], dim=0)
trg_coords_embed = torch.cat([warped_embed, mask.unsqueeze(0)], dim=0)
else:
src_coords_embed = torch.cat([image_left_tensor, torch.zeros_like(mask.unsqueeze(0), device=mask.device)], dim=0)
trg_coords_embed = torch.cat([converted_right, mask.unsqueeze(0)], dim=0)
# 8. Get clip image
clip_image = self.clip_image_processor(
images=image_left, return_tensors="pt"
).pixel_values[0]
sample = {
'source': image_left_tensor,
'correspondence': disparity_left_tensor,
'target': image_right_tensor,
'src_coords_embed': src_coords_embed,
'trg_coords_embed': trg_coords_embed,
'clip_images':clip_image,
'converted_right': converted_right,
'mask': mask.unsqueeze(0),
}
return sample
if __name__ == "__main__":
# Load the dataset from JSON file
json_file = [
# "./data/tartanair/TartanAir_dataset_paths.json", \
# "./data/IRS/IRS_dataset_paths.json", \
# "./data/DrivingStereo/DrivingStereo_dataset_paths.json",\
# "./data/VKITTI2/VKITTI2_dataset_paths_2.json", \
# "./data/InStereo2K/InStereo2K_dataset_paths_20.json", \
# "./data/Sintel/Sintel_dataset_paths_20.json", \
# "./data/crestereo/crestereo_dataset_paths.json", \
# "./data/Spring/Spring_dataset_paths_10.json", \
# "./data/Falling_Things/Falling_Things_dataset_paths.json", \
# "./data/SimStereo/SimStereo_dataset_paths.json", \
# "./data/DynamicReplica/DynamicReplica_dataset_paths.json", \
# "./data/PLT-D3/PLT-D3_dataset_paths_10.json", \
# "./data/InfinigenSV/InfinigenSV_dataset_paths_2.json", \
"./data/skdataset/sk_dataset_paths.json", \
# "./data/DIML_Outdoor/DIML_Outdoor_dataset_paths.json", \
# "./data/UnrealStereo4K/UnrealStereo4K_dataset_paths_5.json", \
]
dataset = StereoGenDataset(json_file, img_size=512)
print(f"Number of entries in the dataset: {len(dataset)}")
# Sample 20 random entries
sampled_indices = random.sample(range(len(dataset)), 20)
# Loop through the sampled indices and access dataset entries
for idx in sampled_indices:
data_entry = dataset[idx]
# You can now do something with data_entry, e.g., printing or processing it
print(f"Processing dataset entry at index {idx}")
# dataset.__getitem__(1000)