import torch class PointInSpace: def __init__(self, global_sigma=0.5, local_sigma=0.01): self.global_sigma = global_sigma self.local_sigma = local_sigma def get_points(self, pc_input=None, local_sigma=None, global_ratio=0.125): """Sample one point near each of the given point + 1/8 uniformly. Args: pc_input (tensor): sampling centers. shape: [B, N, D] Returns: samples (tensor): sampled points. shape: [B, N + N / 8, D] """ batch_size, sample_size, dim = pc_input.shape if local_sigma is None: sample_local = pc_input + (torch.randn_like(pc_input) * self.local_sigma) else: sample_local = pc_input + (torch.randn_like(pc_input) * local_sigma) sample_global = ( torch.rand(batch_size, int(sample_size * global_ratio), dim, device=pc_input.device) * (self.global_sigma * 2) ) - self.global_sigma sample = torch.cat([sample_local, sample_global], dim=1) return sample