File size: 1,057 Bytes
6325697 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import torch
class PointInSpace:
def __init__(self, global_sigma=0.5, local_sigma=0.01):
self.global_sigma = global_sigma
self.local_sigma = local_sigma
def get_points(self, pc_input=None, local_sigma=None, global_ratio=0.125):
"""Sample one point near each of the given point + 1/8 uniformly.
Args:
pc_input (tensor): sampling centers. shape: [B, N, D]
Returns:
samples (tensor): sampled points. shape: [B, N + N / 8, D]
"""
batch_size, sample_size, dim = pc_input.shape
if local_sigma is None:
sample_local = pc_input + (torch.randn_like(pc_input) * self.local_sigma)
else:
sample_local = pc_input + (torch.randn_like(pc_input) * local_sigma)
sample_global = (
torch.rand(batch_size, int(sample_size * global_ratio), dim, device=pc_input.device)
* (self.global_sigma * 2)
) - self.global_sigma
sample = torch.cat([sample_local, sample_global], dim=1)
return sample |