File size: 11,079 Bytes
6325697 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
import abc
import torch
from lib.utils import utils
class RaySampler(metaclass=abc.ABCMeta):
def __init__(self,near, far):
self.near = near
self.far = far
@abc.abstractmethod
def get_z_vals(self, ray_dirs, cam_loc, model):
pass
class UniformSampler(RaySampler):
"""Samples uniformly in the range [near, far]
"""
def __init__(self, scene_bounding_sphere, near, N_samples, take_sphere_intersection=False, far=-1):
super().__init__(near, 2.0 * scene_bounding_sphere if far == -1 else far) # default far is 2*R
self.N_samples = N_samples
self.scene_bounding_sphere = scene_bounding_sphere
self.take_sphere_intersection = take_sphere_intersection
def get_z_vals(self, ray_dirs, cam_loc, model):
if not self.take_sphere_intersection:
near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0], 1).cuda()
else:
sphere_intersections = utils.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)
near = self.near * torch.ones(ray_dirs.shape[0], 1).cuda()
far = sphere_intersections[:,1:]
t_vals = torch.linspace(0., 1., steps=self.N_samples).cuda()
z_vals = near * (1. - t_vals) + far * (t_vals)
if model.training:
# get intervals between samples
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
upper = torch.cat([mids, z_vals[..., -1:]], -1)
lower = torch.cat([z_vals[..., :1], mids], -1)
# stratified samples in those intervals
t_rand = torch.rand(z_vals.shape).cuda()
z_vals = lower + (upper - lower) * t_rand
return z_vals
class ErrorBoundSampler(RaySampler):
def __init__(self, scene_bounding_sphere, near, N_samples, N_samples_eval, N_samples_extra,
eps, beta_iters, max_total_iters,
inverse_sphere_bg=False, N_samples_inverse_sphere=0, add_tiny=0.0):
super().__init__(near, 2.0 * scene_bounding_sphere)
self.N_samples = N_samples
self.N_samples_eval = N_samples_eval
self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=inverse_sphere_bg)
self.N_samples_extra = N_samples_extra
self.eps = eps
self.beta_iters = beta_iters
self.max_total_iters = max_total_iters
self.scene_bounding_sphere = scene_bounding_sphere
self.add_tiny = add_tiny
self.inverse_sphere_bg = inverse_sphere_bg
if inverse_sphere_bg:
N_samples_inverse_sphere = 32
self.inverse_sphere_sampler = UniformSampler(1.0, 0.0, N_samples_inverse_sphere, False, far=1.0)
def get_z_vals(self, ray_dirs, cam_loc, model, cond, smpl_tfs, eval_mode, smpl_verts):
beta0 = model.density.get_beta().detach()
# Start with uniform sampling
z_vals = self.uniform_sampler.get_z_vals(ray_dirs, cam_loc, model)
samples, samples_idx = z_vals, None
# Get maximum beta from the upper bound (Lemma 2)
dists = z_vals[:, 1:] - z_vals[:, :-1]
bound = (1.0 / (4.0 * torch.log(torch.tensor(self.eps + 1.0)))) * (dists ** 2.).sum(-1)
beta = torch.sqrt(bound)
total_iters, not_converge = 0, True
# VolSDF Algorithm 1
while not_converge and total_iters < self.max_total_iters:
points = cam_loc.unsqueeze(1) + samples.unsqueeze(2) * ray_dirs.unsqueeze(1)
points_flat = points.reshape(-1, 3)
# Calculating the SDF only for the new sampled points
model.implicit_network.eval()
with torch.no_grad():
samples_sdf = model.sdf_func_with_smpl_deformer(points_flat, cond, smpl_tfs, smpl_verts=smpl_verts)[0]
model.implicit_network.train()
if samples_idx is not None:
sdf_merge = torch.cat([sdf.reshape(-1, z_vals.shape[1] - samples.shape[1]),
samples_sdf.reshape(-1, samples.shape[1])], -1)
sdf = torch.gather(sdf_merge, 1, samples_idx).reshape(-1, 1)
else:
sdf = samples_sdf
# Calculating the bound d* (Theorem 1)
d = sdf.reshape(z_vals.shape)
dists = z_vals[:, 1:] - z_vals[:, :-1]
a, b, c = dists, d[:, :-1].abs(), d[:, 1:].abs()
first_cond = a.pow(2) + b.pow(2) <= c.pow(2)
second_cond = a.pow(2) + c.pow(2) <= b.pow(2)
d_star = torch.zeros(z_vals.shape[0], z_vals.shape[1] - 1).cuda()
d_star[first_cond] = b[first_cond]
d_star[second_cond] = c[second_cond]
s = (a + b + c) / 2.0
area_before_sqrt = s * (s - a) * (s - b) * (s - c)
mask = ~first_cond & ~second_cond & (b + c - a > 0)
d_star[mask] = (2.0 * torch.sqrt(area_before_sqrt[mask])) / (a[mask])
d_star = (d[:, 1:].sign() * d[:, :-1].sign() == 1) * d_star # Fixing the sign
# Updating beta using line search
curr_error = self.get_error_bound(beta0, model, sdf, z_vals, dists, d_star)
beta[curr_error <= self.eps] = beta0
beta_min, beta_max = beta0.unsqueeze(0).repeat(z_vals.shape[0]), beta
for j in range(self.beta_iters):
beta_mid = (beta_min + beta_max) / 2.
curr_error = self.get_error_bound(beta_mid.unsqueeze(-1), model, sdf, z_vals, dists, d_star)
beta_max[curr_error <= self.eps] = beta_mid[curr_error <= self.eps]
beta_min[curr_error > self.eps] = beta_mid[curr_error > self.eps]
beta = beta_max
# Upsample more points
density = model.density(sdf.reshape(z_vals.shape), beta=beta.unsqueeze(-1))
dists = torch.cat([dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(dists.shape[0], 1)], -1)
free_energy = dists * density
shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy[:, :-1]], dim=-1)
alpha = 1 - torch.exp(-free_energy)
transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1))
weights = alpha * transmittance # probability of the ray hits something here
# Check if we are done and this is the last sampling
total_iters += 1
not_converge = beta.max() > beta0
if not_converge and total_iters < self.max_total_iters:
''' Sample more points proportional to the current error bound'''
N = self.N_samples_eval
bins = z_vals
error_per_section = torch.exp(-d_star / beta.unsqueeze(-1)) * (dists[:,:-1] ** 2.) / (4 * beta.unsqueeze(-1) ** 2)
error_integral = torch.cumsum(error_per_section, dim=-1)
bound_opacity = (torch.clamp(torch.exp(error_integral),max=1.e6) - 1.0) * transmittance[:,:-1]
pdf = bound_opacity + self.add_tiny
pdf = pdf / torch.sum(pdf, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
else:
''' Sample the final sample set to be used in the volume rendering integral '''
N = self.N_samples
bins = z_vals
pdf = weights[..., :-1]
pdf = pdf + 1e-5 # prevent nans
pdf = pdf / torch.sum(pdf, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins))
# Invert CDF
if (not_converge and total_iters < self.max_total_iters) or (not model.training):
u = torch.linspace(0., 1., steps=N).cuda().unsqueeze(0).repeat(cdf.shape[0], 1)
else:
u = torch.rand(list(cdf.shape[:-1]) + [N]).cuda()
u = u.contiguous()
inds = torch.searchsorted(cdf, u, right=True)
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = (cdf_g[..., 1] - cdf_g[..., 0])
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
# Adding samples if we not converged
if not_converge and total_iters < self.max_total_iters:
z_vals, samples_idx = torch.sort(torch.cat([z_vals, samples], -1), -1)
z_samples = samples
near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0],1).cuda()
if self.inverse_sphere_bg: # if inverse sphere then need to add the far sphere intersection
far = utils.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)[:,1:]
if self.N_samples_extra > 0:
if model.training:
sampling_idx = torch.randperm(z_vals.shape[1])[:self.N_samples_extra]
else:
sampling_idx = torch.linspace(0, z_vals.shape[1]-1, self.N_samples_extra).long()
z_vals_extra = torch.cat([near, far, z_vals[:,sampling_idx]], -1)
else:
z_vals_extra = torch.cat([near, far], -1)
z_vals, _ = torch.sort(torch.cat([z_samples, z_vals_extra], -1), -1)
# add some of the near surface points
idx = torch.randint(z_vals.shape[-1], (z_vals.shape[0],)).cuda()
z_samples_eik = torch.gather(z_vals, 1, idx.unsqueeze(-1))
if self.inverse_sphere_bg:
z_vals_inverse_sphere = self.inverse_sphere_sampler.get_z_vals(ray_dirs, cam_loc, model)
z_vals_inverse_sphere = z_vals_inverse_sphere * (1./self.scene_bounding_sphere)
z_vals = (z_vals, z_vals_inverse_sphere)
return z_vals, z_samples_eik
def get_error_bound(self, beta, model, sdf, z_vals, dists, d_star):
density = model.density(sdf.reshape(z_vals.shape), beta=beta)
shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), dists * density[:, :-1]], dim=-1)
integral_estimation = torch.cumsum(shifted_free_energy, dim=-1)
error_per_section = torch.exp(-d_star / beta) * (dists ** 2.) / (4 * beta ** 2)
error_integral = torch.cumsum(error_per_section, dim=-1)
bound_opacity = (torch.clamp(torch.exp(error_integral), max=1.e6) - 1.0) * torch.exp(-integral_estimation[:, :-1])
return bound_opacity.max(-1)[0]
|