|
|
|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
import torch |
|
from pytorch3d.renderer.implicit import HarmonicEmbedding |
|
from torch.distributions import MultivariateNormal |
|
|
|
from .common_testing import TestCaseMixin |
|
|
|
|
|
class TestHarmonicEmbedding(TestCaseMixin, unittest.TestCase): |
|
def setUp(self) -> None: |
|
super().setUp() |
|
torch.manual_seed(1) |
|
|
|
def test_correct_output_dim(self): |
|
embed_fun = HarmonicEmbedding(n_harmonic_functions=2, append_input=False) |
|
|
|
output_dim = 3 * (2 * 2 + int(False)) |
|
self.assertEqual( |
|
output_dim, |
|
embed_fun.get_output_dim_static( |
|
input_dims=3, n_harmonic_functions=2, append_input=False |
|
), |
|
) |
|
self.assertEqual(output_dim, embed_fun.get_output_dim()) |
|
|
|
def test_correct_frequency_range(self): |
|
embed_fun_log = HarmonicEmbedding(n_harmonic_functions=3) |
|
embed_fun_lin = HarmonicEmbedding(n_harmonic_functions=3, logspace=False) |
|
self.assertClose(embed_fun_log._frequencies, torch.FloatTensor((1.0, 2.0, 4.0))) |
|
self.assertClose(embed_fun_lin._frequencies, torch.FloatTensor((1.0, 2.5, 4.0))) |
|
|
|
def test_correct_embed_out(self): |
|
n_harmonic_functions = 2 |
|
x = torch.randn((1, 5)) |
|
D = 5 * n_harmonic_functions * 2 |
|
|
|
embed_fun = HarmonicEmbedding( |
|
n_harmonic_functions=n_harmonic_functions, append_input=False |
|
) |
|
embed_out = embed_fun(x) |
|
|
|
self.assertEqual(embed_out.shape, (1, D)) |
|
|
|
|
|
sum_squares = embed_out[0, : D // 2] ** 2 + embed_out[0, D // 2 :] ** 2 |
|
self.assertClose(sum_squares, torch.ones((D // 2))) |
|
|
|
|
|
embed_fun = HarmonicEmbedding( |
|
n_harmonic_functions=n_harmonic_functions, append_input=True |
|
) |
|
embed_out_appended_input = embed_fun(x) |
|
self.assertClose( |
|
embed_out_appended_input.shape, torch.tensor((1, D + x.shape[-1])) |
|
) |
|
|
|
self.assertClose(embed_out_appended_input[..., -x.shape[-1] :], x) |
|
self.assertClose(embed_out_appended_input[..., : -x.shape[-1]], embed_out) |
|
|
|
def test_correct_embed_out_with_diag_cov(self): |
|
n_harmonic_functions = 2 |
|
x = torch.randn((1, 3)) |
|
diag_cov = torch.randn((1, 3)) |
|
D = 3 * n_harmonic_functions * 2 |
|
|
|
embed_fun = HarmonicEmbedding( |
|
n_harmonic_functions=n_harmonic_functions, append_input=False |
|
) |
|
embed_out = embed_fun(x, diag_cov=diag_cov) |
|
|
|
self.assertEqual(embed_out.shape, (1, D)) |
|
|
|
|
|
scale_factor = ( |
|
-0.5 * diag_cov[..., None] * torch.pow(embed_fun._frequencies[None, :], 2) |
|
) |
|
scale_factor = torch.exp(scale_factor).reshape(1, -1).tile((1, 2)) |
|
|
|
|
|
|
|
|
|
embed_out_without_cov = embed_out / scale_factor |
|
sum_squares = ( |
|
embed_out_without_cov[0, : D // 2] ** 2 |
|
+ embed_out_without_cov[0, D // 2 :] ** 2 |
|
) |
|
self.assertClose(sum_squares, torch.ones((D // 2))) |
|
|
|
|
|
embed_fun = HarmonicEmbedding( |
|
n_harmonic_functions=n_harmonic_functions, append_input=True |
|
) |
|
embed_out_appended_input = embed_fun(x, diag_cov=diag_cov) |
|
self.assertClose( |
|
embed_out_appended_input.shape, torch.tensor((1, D + x.shape[-1])) |
|
) |
|
|
|
self.assertClose(embed_out_appended_input[..., -x.shape[-1] :], x) |
|
self.assertClose(embed_out_appended_input[..., : -x.shape[-1]], embed_out) |
|
|
|
def test_correct_behavior_between_ipe_and_its_estimation_from_harmonic_embedding( |
|
self, |
|
): |
|
""" |
|
Check that the HarmonicEmbedding with integrated_position_encoding (IPE) set to |
|
True is coherent with the HarmonicEmbedding. |
|
|
|
What is the idea behind this test? |
|
|
|
We wish to produce an IPE that is the expectation |
|
of our lifted multivariate gaussian, modulated by the sine and cosine of |
|
the coordinates. These expectation has a closed-form |
|
(see equations 11, 12, 13, 14 of [1]). |
|
|
|
We sample N elements from the multivariate gaussian defined by its mean and covariance |
|
and compute the HarmonicEmbedding. The expected value of those embeddings should be |
|
equal to our IPE. |
|
|
|
Inspired from: |
|
https://github.com/google/mipnerf/blob/84c969e0a623edd183b75693aed72a7e7c22902d/internal/mip_test.py#L359 |
|
|
|
References: |
|
[1] `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_. |
|
""" |
|
num_dims = 3 |
|
n_harmonic_functions = 6 |
|
mean = torch.randn(num_dims) |
|
diag_cov = torch.rand(num_dims) |
|
|
|
he_fun = HarmonicEmbedding( |
|
n_harmonic_functions=n_harmonic_functions, logspace=True, append_input=False |
|
) |
|
ipe_fun = HarmonicEmbedding( |
|
n_harmonic_functions=n_harmonic_functions, |
|
append_input=False, |
|
) |
|
|
|
embedding_ipe = ipe_fun(mean, diag_cov=diag_cov) |
|
|
|
rand_mvn = MultivariateNormal(mean, torch.eye(num_dims) * diag_cov) |
|
|
|
|
|
|
|
num_samples = 100000 |
|
embedding_he = he_fun(rand_mvn.sample_n(num_samples)) |
|
self.assertClose(embedding_he.mean(0), embedding_ipe, rtol=1e-2, atol=1e-2) |
|
|