File size: 1,057 Bytes
6325697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch


class PointInSpace:
    def __init__(self, global_sigma=0.5, local_sigma=0.01):
        self.global_sigma = global_sigma
        self.local_sigma = local_sigma

    def get_points(self, pc_input=None, local_sigma=None, global_ratio=0.125):
        """Sample one point near each of the given point + 1/8 uniformly. 
        Args:
            pc_input (tensor): sampling centers. shape: [B, N, D]
        Returns:
            samples (tensor): sampled points. shape: [B, N + N / 8, D]
        """

        batch_size, sample_size, dim = pc_input.shape
        if local_sigma is None:
            sample_local = pc_input + (torch.randn_like(pc_input) * self.local_sigma)
        else:
            sample_local = pc_input + (torch.randn_like(pc_input) * local_sigma)
        sample_global = (
            torch.rand(batch_size, int(sample_size * global_ratio), dim, device=pc_input.device)
            * (self.global_sigma * 2)
        ) - self.global_sigma

        sample = torch.cat([sample_local, sample_global], dim=1)

        return sample