IF3D / code /lib /model /ray_sampler.py
leobcc's picture
vid2avatar baseline
6325697
raw
history blame
11.1 kB
import abc
import torch
from lib.utils import utils
class RaySampler(metaclass=abc.ABCMeta):
def __init__(self,near, far):
self.near = near
self.far = far
@abc.abstractmethod
def get_z_vals(self, ray_dirs, cam_loc, model):
pass
class UniformSampler(RaySampler):
"""Samples uniformly in the range [near, far]
"""
def __init__(self, scene_bounding_sphere, near, N_samples, take_sphere_intersection=False, far=-1):
super().__init__(near, 2.0 * scene_bounding_sphere if far == -1 else far) # default far is 2*R
self.N_samples = N_samples
self.scene_bounding_sphere = scene_bounding_sphere
self.take_sphere_intersection = take_sphere_intersection
def get_z_vals(self, ray_dirs, cam_loc, model):
if not self.take_sphere_intersection:
near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0], 1).cuda()
else:
sphere_intersections = utils.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)
near = self.near * torch.ones(ray_dirs.shape[0], 1).cuda()
far = sphere_intersections[:,1:]
t_vals = torch.linspace(0., 1., steps=self.N_samples).cuda()
z_vals = near * (1. - t_vals) + far * (t_vals)
if model.training:
# get intervals between samples
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
upper = torch.cat([mids, z_vals[..., -1:]], -1)
lower = torch.cat([z_vals[..., :1], mids], -1)
# stratified samples in those intervals
t_rand = torch.rand(z_vals.shape).cuda()
z_vals = lower + (upper - lower) * t_rand
return z_vals
class ErrorBoundSampler(RaySampler):
def __init__(self, scene_bounding_sphere, near, N_samples, N_samples_eval, N_samples_extra,
eps, beta_iters, max_total_iters,
inverse_sphere_bg=False, N_samples_inverse_sphere=0, add_tiny=0.0):
super().__init__(near, 2.0 * scene_bounding_sphere)
self.N_samples = N_samples
self.N_samples_eval = N_samples_eval
self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=inverse_sphere_bg)
self.N_samples_extra = N_samples_extra
self.eps = eps
self.beta_iters = beta_iters
self.max_total_iters = max_total_iters
self.scene_bounding_sphere = scene_bounding_sphere
self.add_tiny = add_tiny
self.inverse_sphere_bg = inverse_sphere_bg
if inverse_sphere_bg:
N_samples_inverse_sphere = 32
self.inverse_sphere_sampler = UniformSampler(1.0, 0.0, N_samples_inverse_sphere, False, far=1.0)
def get_z_vals(self, ray_dirs, cam_loc, model, cond, smpl_tfs, eval_mode, smpl_verts):
beta0 = model.density.get_beta().detach()
# Start with uniform sampling
z_vals = self.uniform_sampler.get_z_vals(ray_dirs, cam_loc, model)
samples, samples_idx = z_vals, None
# Get maximum beta from the upper bound (Lemma 2)
dists = z_vals[:, 1:] - z_vals[:, :-1]
bound = (1.0 / (4.0 * torch.log(torch.tensor(self.eps + 1.0)))) * (dists ** 2.).sum(-1)
beta = torch.sqrt(bound)
total_iters, not_converge = 0, True
# VolSDF Algorithm 1
while not_converge and total_iters < self.max_total_iters:
points = cam_loc.unsqueeze(1) + samples.unsqueeze(2) * ray_dirs.unsqueeze(1)
points_flat = points.reshape(-1, 3)
# Calculating the SDF only for the new sampled points
model.implicit_network.eval()
with torch.no_grad():
samples_sdf = model.sdf_func_with_smpl_deformer(points_flat, cond, smpl_tfs, smpl_verts=smpl_verts)[0]
model.implicit_network.train()
if samples_idx is not None:
sdf_merge = torch.cat([sdf.reshape(-1, z_vals.shape[1] - samples.shape[1]),
samples_sdf.reshape(-1, samples.shape[1])], -1)
sdf = torch.gather(sdf_merge, 1, samples_idx).reshape(-1, 1)
else:
sdf = samples_sdf
# Calculating the bound d* (Theorem 1)
d = sdf.reshape(z_vals.shape)
dists = z_vals[:, 1:] - z_vals[:, :-1]
a, b, c = dists, d[:, :-1].abs(), d[:, 1:].abs()
first_cond = a.pow(2) + b.pow(2) <= c.pow(2)
second_cond = a.pow(2) + c.pow(2) <= b.pow(2)
d_star = torch.zeros(z_vals.shape[0], z_vals.shape[1] - 1).cuda()
d_star[first_cond] = b[first_cond]
d_star[second_cond] = c[second_cond]
s = (a + b + c) / 2.0
area_before_sqrt = s * (s - a) * (s - b) * (s - c)
mask = ~first_cond & ~second_cond & (b + c - a > 0)
d_star[mask] = (2.0 * torch.sqrt(area_before_sqrt[mask])) / (a[mask])
d_star = (d[:, 1:].sign() * d[:, :-1].sign() == 1) * d_star # Fixing the sign
# Updating beta using line search
curr_error = self.get_error_bound(beta0, model, sdf, z_vals, dists, d_star)
beta[curr_error <= self.eps] = beta0
beta_min, beta_max = beta0.unsqueeze(0).repeat(z_vals.shape[0]), beta
for j in range(self.beta_iters):
beta_mid = (beta_min + beta_max) / 2.
curr_error = self.get_error_bound(beta_mid.unsqueeze(-1), model, sdf, z_vals, dists, d_star)
beta_max[curr_error <= self.eps] = beta_mid[curr_error <= self.eps]
beta_min[curr_error > self.eps] = beta_mid[curr_error > self.eps]
beta = beta_max
# Upsample more points
density = model.density(sdf.reshape(z_vals.shape), beta=beta.unsqueeze(-1))
dists = torch.cat([dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(dists.shape[0], 1)], -1)
free_energy = dists * density
shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy[:, :-1]], dim=-1)
alpha = 1 - torch.exp(-free_energy)
transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1))
weights = alpha * transmittance # probability of the ray hits something here
# Check if we are done and this is the last sampling
total_iters += 1
not_converge = beta.max() > beta0
if not_converge and total_iters < self.max_total_iters:
''' Sample more points proportional to the current error bound'''
N = self.N_samples_eval
bins = z_vals
error_per_section = torch.exp(-d_star / beta.unsqueeze(-1)) * (dists[:,:-1] ** 2.) / (4 * beta.unsqueeze(-1) ** 2)
error_integral = torch.cumsum(error_per_section, dim=-1)
bound_opacity = (torch.clamp(torch.exp(error_integral),max=1.e6) - 1.0) * transmittance[:,:-1]
pdf = bound_opacity + self.add_tiny
pdf = pdf / torch.sum(pdf, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
else:
''' Sample the final sample set to be used in the volume rendering integral '''
N = self.N_samples
bins = z_vals
pdf = weights[..., :-1]
pdf = pdf + 1e-5 # prevent nans
pdf = pdf / torch.sum(pdf, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins))
# Invert CDF
if (not_converge and total_iters < self.max_total_iters) or (not model.training):
u = torch.linspace(0., 1., steps=N).cuda().unsqueeze(0).repeat(cdf.shape[0], 1)
else:
u = torch.rand(list(cdf.shape[:-1]) + [N]).cuda()
u = u.contiguous()
inds = torch.searchsorted(cdf, u, right=True)
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = (cdf_g[..., 1] - cdf_g[..., 0])
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
# Adding samples if we not converged
if not_converge and total_iters < self.max_total_iters:
z_vals, samples_idx = torch.sort(torch.cat([z_vals, samples], -1), -1)
z_samples = samples
near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0],1).cuda()
if self.inverse_sphere_bg: # if inverse sphere then need to add the far sphere intersection
far = utils.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)[:,1:]
if self.N_samples_extra > 0:
if model.training:
sampling_idx = torch.randperm(z_vals.shape[1])[:self.N_samples_extra]
else:
sampling_idx = torch.linspace(0, z_vals.shape[1]-1, self.N_samples_extra).long()
z_vals_extra = torch.cat([near, far, z_vals[:,sampling_idx]], -1)
else:
z_vals_extra = torch.cat([near, far], -1)
z_vals, _ = torch.sort(torch.cat([z_samples, z_vals_extra], -1), -1)
# add some of the near surface points
idx = torch.randint(z_vals.shape[-1], (z_vals.shape[0],)).cuda()
z_samples_eik = torch.gather(z_vals, 1, idx.unsqueeze(-1))
if self.inverse_sphere_bg:
z_vals_inverse_sphere = self.inverse_sphere_sampler.get_z_vals(ray_dirs, cam_loc, model)
z_vals_inverse_sphere = z_vals_inverse_sphere * (1./self.scene_bounding_sphere)
z_vals = (z_vals, z_vals_inverse_sphere)
return z_vals, z_samples_eik
def get_error_bound(self, beta, model, sdf, z_vals, dists, d_star):
density = model.density(sdf.reshape(z_vals.shape), beta=beta)
shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), dists * density[:, :-1]], dim=-1)
integral_estimation = torch.cumsum(shifted_free_energy, dim=-1)
error_per_section = torch.exp(-d_star / beta) * (dists ** 2.) / (4 * beta ** 2)
error_integral = torch.cumsum(error_per_section, dim=-1)
bound_opacity = (torch.clamp(torch.exp(error_integral), max=1.e6) - 1.0) * torch.exp(-integral_estimation[:, :-1])
return bound_opacity.max(-1)[0]