|
import abc |
|
import torch |
|
from lib.utils import utils |
|
|
|
class RaySampler(metaclass=abc.ABCMeta): |
|
def __init__(self,near, far): |
|
self.near = near |
|
self.far = far |
|
|
|
@abc.abstractmethod |
|
def get_z_vals(self, ray_dirs, cam_loc, model): |
|
pass |
|
|
|
class UniformSampler(RaySampler): |
|
"""Samples uniformly in the range [near, far] |
|
""" |
|
def __init__(self, scene_bounding_sphere, near, N_samples, take_sphere_intersection=False, far=-1): |
|
super().__init__(near, 2.0 * scene_bounding_sphere if far == -1 else far) |
|
self.N_samples = N_samples |
|
self.scene_bounding_sphere = scene_bounding_sphere |
|
self.take_sphere_intersection = take_sphere_intersection |
|
|
|
def get_z_vals(self, ray_dirs, cam_loc, model): |
|
if not self.take_sphere_intersection: |
|
near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0], 1).cuda() |
|
else: |
|
sphere_intersections = utils.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere) |
|
near = self.near * torch.ones(ray_dirs.shape[0], 1).cuda() |
|
far = sphere_intersections[:,1:] |
|
|
|
t_vals = torch.linspace(0., 1., steps=self.N_samples).cuda() |
|
z_vals = near * (1. - t_vals) + far * (t_vals) |
|
|
|
if model.training: |
|
|
|
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1]) |
|
upper = torch.cat([mids, z_vals[..., -1:]], -1) |
|
lower = torch.cat([z_vals[..., :1], mids], -1) |
|
|
|
t_rand = torch.rand(z_vals.shape).cuda() |
|
|
|
z_vals = lower + (upper - lower) * t_rand |
|
|
|
return z_vals |
|
|
|
class ErrorBoundSampler(RaySampler): |
|
def __init__(self, scene_bounding_sphere, near, N_samples, N_samples_eval, N_samples_extra, |
|
eps, beta_iters, max_total_iters, |
|
inverse_sphere_bg=False, N_samples_inverse_sphere=0, add_tiny=0.0): |
|
super().__init__(near, 2.0 * scene_bounding_sphere) |
|
self.N_samples = N_samples |
|
self.N_samples_eval = N_samples_eval |
|
self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=inverse_sphere_bg) |
|
|
|
self.N_samples_extra = N_samples_extra |
|
|
|
self.eps = eps |
|
self.beta_iters = beta_iters |
|
self.max_total_iters = max_total_iters |
|
self.scene_bounding_sphere = scene_bounding_sphere |
|
self.add_tiny = add_tiny |
|
|
|
self.inverse_sphere_bg = inverse_sphere_bg |
|
if inverse_sphere_bg: |
|
N_samples_inverse_sphere = 32 |
|
self.inverse_sphere_sampler = UniformSampler(1.0, 0.0, N_samples_inverse_sphere, False, far=1.0) |
|
|
|
def get_z_vals(self, ray_dirs, cam_loc, model, cond, smpl_tfs, eval_mode, smpl_verts): |
|
beta0 = model.density.get_beta().detach() |
|
|
|
|
|
z_vals = self.uniform_sampler.get_z_vals(ray_dirs, cam_loc, model) |
|
samples, samples_idx = z_vals, None |
|
|
|
|
|
dists = z_vals[:, 1:] - z_vals[:, :-1] |
|
bound = (1.0 / (4.0 * torch.log(torch.tensor(self.eps + 1.0)))) * (dists ** 2.).sum(-1) |
|
beta = torch.sqrt(bound) |
|
|
|
total_iters, not_converge = 0, True |
|
|
|
|
|
while not_converge and total_iters < self.max_total_iters: |
|
points = cam_loc.unsqueeze(1) + samples.unsqueeze(2) * ray_dirs.unsqueeze(1) |
|
points_flat = points.reshape(-1, 3) |
|
|
|
model.implicit_network.eval() |
|
with torch.no_grad(): |
|
samples_sdf = model.sdf_func_with_smpl_deformer(points_flat, cond, smpl_tfs, smpl_verts=smpl_verts)[0] |
|
model.implicit_network.train() |
|
if samples_idx is not None: |
|
sdf_merge = torch.cat([sdf.reshape(-1, z_vals.shape[1] - samples.shape[1]), |
|
samples_sdf.reshape(-1, samples.shape[1])], -1) |
|
sdf = torch.gather(sdf_merge, 1, samples_idx).reshape(-1, 1) |
|
else: |
|
sdf = samples_sdf |
|
|
|
|
|
|
|
d = sdf.reshape(z_vals.shape) |
|
dists = z_vals[:, 1:] - z_vals[:, :-1] |
|
a, b, c = dists, d[:, :-1].abs(), d[:, 1:].abs() |
|
first_cond = a.pow(2) + b.pow(2) <= c.pow(2) |
|
second_cond = a.pow(2) + c.pow(2) <= b.pow(2) |
|
d_star = torch.zeros(z_vals.shape[0], z_vals.shape[1] - 1).cuda() |
|
d_star[first_cond] = b[first_cond] |
|
d_star[second_cond] = c[second_cond] |
|
s = (a + b + c) / 2.0 |
|
area_before_sqrt = s * (s - a) * (s - b) * (s - c) |
|
mask = ~first_cond & ~second_cond & (b + c - a > 0) |
|
d_star[mask] = (2.0 * torch.sqrt(area_before_sqrt[mask])) / (a[mask]) |
|
d_star = (d[:, 1:].sign() * d[:, :-1].sign() == 1) * d_star |
|
|
|
|
|
|
|
curr_error = self.get_error_bound(beta0, model, sdf, z_vals, dists, d_star) |
|
beta[curr_error <= self.eps] = beta0 |
|
beta_min, beta_max = beta0.unsqueeze(0).repeat(z_vals.shape[0]), beta |
|
for j in range(self.beta_iters): |
|
beta_mid = (beta_min + beta_max) / 2. |
|
curr_error = self.get_error_bound(beta_mid.unsqueeze(-1), model, sdf, z_vals, dists, d_star) |
|
beta_max[curr_error <= self.eps] = beta_mid[curr_error <= self.eps] |
|
beta_min[curr_error > self.eps] = beta_mid[curr_error > self.eps] |
|
beta = beta_max |
|
|
|
|
|
|
|
density = model.density(sdf.reshape(z_vals.shape), beta=beta.unsqueeze(-1)) |
|
|
|
dists = torch.cat([dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(dists.shape[0], 1)], -1) |
|
free_energy = dists * density |
|
shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy[:, :-1]], dim=-1) |
|
alpha = 1 - torch.exp(-free_energy) |
|
transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1)) |
|
weights = alpha * transmittance |
|
|
|
|
|
total_iters += 1 |
|
not_converge = beta.max() > beta0 |
|
|
|
if not_converge and total_iters < self.max_total_iters: |
|
''' Sample more points proportional to the current error bound''' |
|
|
|
N = self.N_samples_eval |
|
|
|
bins = z_vals |
|
error_per_section = torch.exp(-d_star / beta.unsqueeze(-1)) * (dists[:,:-1] ** 2.) / (4 * beta.unsqueeze(-1) ** 2) |
|
error_integral = torch.cumsum(error_per_section, dim=-1) |
|
bound_opacity = (torch.clamp(torch.exp(error_integral),max=1.e6) - 1.0) * transmittance[:,:-1] |
|
|
|
pdf = bound_opacity + self.add_tiny |
|
pdf = pdf / torch.sum(pdf, -1, keepdim=True) |
|
cdf = torch.cumsum(pdf, -1) |
|
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) |
|
|
|
else: |
|
''' Sample the final sample set to be used in the volume rendering integral ''' |
|
|
|
N = self.N_samples |
|
|
|
bins = z_vals |
|
pdf = weights[..., :-1] |
|
pdf = pdf + 1e-5 |
|
pdf = pdf / torch.sum(pdf, -1, keepdim=True) |
|
cdf = torch.cumsum(pdf, -1) |
|
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) |
|
|
|
|
|
|
|
if (not_converge and total_iters < self.max_total_iters) or (not model.training): |
|
u = torch.linspace(0., 1., steps=N).cuda().unsqueeze(0).repeat(cdf.shape[0], 1) |
|
else: |
|
u = torch.rand(list(cdf.shape[:-1]) + [N]).cuda() |
|
u = u.contiguous() |
|
|
|
inds = torch.searchsorted(cdf, u, right=True) |
|
below = torch.max(torch.zeros_like(inds - 1), inds - 1) |
|
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) |
|
inds_g = torch.stack([below, above], -1) |
|
|
|
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] |
|
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) |
|
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) |
|
|
|
denom = (cdf_g[..., 1] - cdf_g[..., 0]) |
|
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) |
|
t = (u - cdf_g[..., 0]) / denom |
|
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) |
|
|
|
|
|
|
|
if not_converge and total_iters < self.max_total_iters: |
|
z_vals, samples_idx = torch.sort(torch.cat([z_vals, samples], -1), -1) |
|
|
|
|
|
z_samples = samples |
|
|
|
near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0],1).cuda() |
|
if self.inverse_sphere_bg: |
|
far = utils.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)[:,1:] |
|
|
|
if self.N_samples_extra > 0: |
|
if model.training: |
|
sampling_idx = torch.randperm(z_vals.shape[1])[:self.N_samples_extra] |
|
else: |
|
sampling_idx = torch.linspace(0, z_vals.shape[1]-1, self.N_samples_extra).long() |
|
z_vals_extra = torch.cat([near, far, z_vals[:,sampling_idx]], -1) |
|
else: |
|
z_vals_extra = torch.cat([near, far], -1) |
|
|
|
z_vals, _ = torch.sort(torch.cat([z_samples, z_vals_extra], -1), -1) |
|
|
|
|
|
idx = torch.randint(z_vals.shape[-1], (z_vals.shape[0],)).cuda() |
|
z_samples_eik = torch.gather(z_vals, 1, idx.unsqueeze(-1)) |
|
|
|
if self.inverse_sphere_bg: |
|
z_vals_inverse_sphere = self.inverse_sphere_sampler.get_z_vals(ray_dirs, cam_loc, model) |
|
z_vals_inverse_sphere = z_vals_inverse_sphere * (1./self.scene_bounding_sphere) |
|
z_vals = (z_vals, z_vals_inverse_sphere) |
|
|
|
return z_vals, z_samples_eik |
|
|
|
def get_error_bound(self, beta, model, sdf, z_vals, dists, d_star): |
|
density = model.density(sdf.reshape(z_vals.shape), beta=beta) |
|
shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), dists * density[:, :-1]], dim=-1) |
|
integral_estimation = torch.cumsum(shifted_free_energy, dim=-1) |
|
error_per_section = torch.exp(-d_star / beta) * (dists ** 2.) / (4 * beta ** 2) |
|
error_integral = torch.cumsum(error_per_section, dim=-1) |
|
bound_opacity = (torch.clamp(torch.exp(error_integral), max=1.e6) - 1.0) * torch.exp(-integral_estimation[:, :-1]) |
|
|
|
return bound_opacity.max(-1)[0] |
|
|
|
|
|
|