|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import unittest |
|
|
|
import torch |
|
from omegaconf import DictConfig, OmegaConf |
|
from pytorch3d.implicitron.models.generic_model import GenericModel |
|
from pytorch3d.implicitron.models.renderer.base import EvaluationMode |
|
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args |
|
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras |
|
from tests.common_testing import get_pytorch3d_dir |
|
|
|
from .common_resources import provide_resnet34 |
|
|
|
IMPLICITRON_CONFIGS_DIR = ( |
|
get_pytorch3d_dir() / "projects" / "implicitron_trainer" / "configs" |
|
) |
|
|
|
|
|
class TestGenericModel(unittest.TestCase): |
|
@classmethod |
|
def setUpClass(cls) -> None: |
|
provide_resnet34() |
|
|
|
def setUp(self): |
|
torch.manual_seed(42) |
|
|
|
def test_gm(self): |
|
|
|
device = torch.device("cuda:0") |
|
expand_args_fields(GenericModel) |
|
model = GenericModel(render_image_height=80, render_image_width=80) |
|
model.to(device) |
|
self._one_model_test(model, device) |
|
|
|
def test_all_gm_configs(self): |
|
|
|
device = torch.device("cuda:0") |
|
config_files = [] |
|
|
|
for pattern in ("repro_singleseq*.yaml", "repro_multiseq*.yaml"): |
|
config_files.extend( |
|
[ |
|
f |
|
for f in IMPLICITRON_CONFIGS_DIR.glob(pattern) |
|
if not f.name.endswith("_base.yaml") |
|
] |
|
) |
|
|
|
for config_file in config_files: |
|
with self.subTest(name=config_file.stem): |
|
cfg = _load_model_config_from_yaml(str(config_file)) |
|
cfg.render_image_height = 80 |
|
cfg.render_image_width = 80 |
|
model = GenericModel(**cfg) |
|
model.to(device) |
|
self._one_model_test( |
|
model, |
|
device, |
|
eval_test=True, |
|
bw_test=True, |
|
) |
|
|
|
def _one_model_test( |
|
self, |
|
model, |
|
device, |
|
n_train_cameras: int = 5, |
|
eval_test: bool = True, |
|
bw_test: bool = True, |
|
): |
|
|
|
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360) |
|
cameras = PerspectiveCameras(R=R, T=T, device=device) |
|
|
|
N, H, W = n_train_cameras, model.render_image_height, model.render_image_width |
|
|
|
random_args = { |
|
"camera": cameras, |
|
"fg_probability": _random_input_tensor(N, 1, H, W, True, device), |
|
"depth_map": _random_input_tensor(N, 1, H, W, False, device) + 0.1, |
|
"mask_crop": _random_input_tensor(N, 1, H, W, True, device), |
|
"sequence_name": ["sequence"] * N, |
|
"image_rgb": _random_input_tensor(N, 3, H, W, False, device), |
|
} |
|
|
|
|
|
model.train() |
|
train_preds = model( |
|
**random_args, |
|
evaluation_mode=EvaluationMode.TRAINING, |
|
) |
|
self.assertTrue( |
|
train_preds["objective"].isfinite().item() |
|
) |
|
|
|
if bw_test: |
|
train_preds["objective"].backward() |
|
|
|
if eval_test: |
|
model.eval() |
|
with torch.no_grad(): |
|
eval_preds = model( |
|
**random_args, |
|
evaluation_mode=EvaluationMode.EVALUATION, |
|
) |
|
self.assertEqual( |
|
eval_preds["images_render"].shape, |
|
(1, 3, model.render_image_height, model.render_image_width), |
|
) |
|
|
|
def test_idr(self): |
|
|
|
device = torch.device("cuda:0") |
|
args = get_default_args(GenericModel) |
|
args.renderer_class_type = "SignedDistanceFunctionRenderer" |
|
args.implicit_function_class_type = "IdrFeatureField" |
|
args.implicit_function_IdrFeatureField_args.n_harmonic_functions_xyz = 6 |
|
|
|
model = GenericModel(**args) |
|
model.to(device) |
|
|
|
n_train_cameras = 2 |
|
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360) |
|
cameras = PerspectiveCameras(R=R, T=T, device=device) |
|
|
|
defaulted_args = { |
|
"depth_map": None, |
|
"mask_crop": None, |
|
"sequence_name": None, |
|
} |
|
|
|
target_image_rgb = torch.rand( |
|
(n_train_cameras, 3, model.render_image_height, model.render_image_width), |
|
device=device, |
|
) |
|
fg_probability = torch.rand( |
|
(n_train_cameras, 1, model.render_image_height, model.render_image_width), |
|
device=device, |
|
) |
|
train_preds = model( |
|
camera=cameras, |
|
evaluation_mode=EvaluationMode.TRAINING, |
|
image_rgb=target_image_rgb, |
|
fg_probability=fg_probability, |
|
**defaulted_args, |
|
) |
|
self.assertGreater(train_preds["objective"].item(), 0) |
|
|
|
def test_viewpool(self): |
|
device = torch.device("cuda:0") |
|
args = get_default_args(GenericModel) |
|
args.view_pooler_enabled = True |
|
args.image_feature_extractor_class_type = "ResNetFeatureExtractor" |
|
args.image_feature_extractor_ResNetFeatureExtractor_args.add_masks = False |
|
model = GenericModel(**args) |
|
model.to(device) |
|
|
|
n_train_cameras = 2 |
|
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360) |
|
cameras = PerspectiveCameras(R=R, T=T, device=device) |
|
|
|
defaulted_args = { |
|
"fg_probability": None, |
|
"depth_map": None, |
|
"mask_crop": None, |
|
} |
|
|
|
target_image_rgb = torch.rand( |
|
(n_train_cameras, 3, model.render_image_height, model.render_image_width), |
|
device=device, |
|
) |
|
train_preds = model( |
|
camera=cameras, |
|
evaluation_mode=EvaluationMode.TRAINING, |
|
image_rgb=target_image_rgb, |
|
sequence_name=["a"] * n_train_cameras, |
|
**defaulted_args, |
|
) |
|
self.assertGreater(train_preds["objective"].item(), 0) |
|
|
|
|
|
def _random_input_tensor( |
|
N: int, |
|
C: int, |
|
H: int, |
|
W: int, |
|
is_binary: bool, |
|
device: torch.device, |
|
) -> torch.Tensor: |
|
T = torch.rand(N, C, H, W, device=device) |
|
if is_binary: |
|
T = (T > 0.5).float() |
|
return T |
|
|
|
|
|
def _load_model_config_from_yaml(config_path, strict=True) -> DictConfig: |
|
default_cfg = get_default_args(GenericModel) |
|
cfg = _load_model_config_from_yaml_rec(default_cfg, config_path) |
|
return cfg |
|
|
|
|
|
def _load_model_config_from_yaml_rec(cfg: DictConfig, config_path: str) -> DictConfig: |
|
cfg_loaded = OmegaConf.load(config_path) |
|
cfg_model_loaded = None |
|
if "model_factory_ImplicitronModelFactory_args" in cfg_loaded: |
|
factory_args = cfg_loaded.model_factory_ImplicitronModelFactory_args |
|
if "model_GenericModel_args" in factory_args: |
|
cfg_model_loaded = factory_args.model_GenericModel_args |
|
defaults = cfg_loaded.pop("defaults", None) |
|
if defaults is not None: |
|
for default_name in defaults: |
|
if default_name in ("_self_", "default_config"): |
|
continue |
|
default_name = os.path.splitext(default_name)[0] |
|
defpath = os.path.join(os.path.dirname(config_path), default_name + ".yaml") |
|
cfg = _load_model_config_from_yaml_rec(cfg, defpath) |
|
if cfg_model_loaded is not None: |
|
cfg = OmegaConf.merge(cfg, cfg_model_loaded) |
|
elif cfg_model_loaded is not None: |
|
cfg = OmegaConf.merge(cfg, cfg_model_loaded) |
|
return cfg |
|
|