# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import pytest from omegaconf import DictConfig from nemo.collections.asr.models import EncDecCTCModel from nemo.collections.asr.modules import SpectrogramAugmentation from nemo.core.classes.common import Serialization def get_class_path(cls): return f"{cls.__module__}.{cls.__name__}" class MockSerializationImpl(Serialization): def __init__(self, cfg: DictConfig): self.cfg = cfg self.value = self.__class__.__name__ class MockSerializationImplV2(MockSerializationImpl): pass class TestSerialization: @pytest.mark.unit def test_from_config_dict_with_cls(self): """Here we test that instantiation works for configs with cls class path in them. Note that just Serialization.from_config_dict can be used to create an object""" config = DictConfig( { 'cls': 'nemo.collections.asr.modules.SpectrogramAugmentation', 'params': {'rect_freq': 50, 'rect_masks': 5, 'rect_time': 120,}, } ) obj = Serialization.from_config_dict(config=config) assert isinstance(obj, SpectrogramAugmentation) @pytest.mark.unit def test_from_config_dict_without_cls(self): """Here we test that instantiation works for configs without cls class path in them. IMPORTANT: in this case, correct class type should call from_config_dict. This should work for Models.""" preprocessor = {'cls': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor', 'params': dict({})} encoder = { 'cls': 'nemo.collections.asr.modules.ConvASREncoder', 'params': { 'feat_in': 64, 'activation': 'relu', 'conv_mask': True, 'jasper': [ { 'filters': 1024, 'repeat': 1, 'kernel': [1], 'stride': [1], 'dilation': [1], 'dropout': 0.0, 'residual': False, 'separable': True, 'se': True, 'se_context_size': -1, } ], }, } decoder = { 'cls': 'nemo.collections.asr.modules.ConvASRDecoder', 'params': { 'feat_in': 1024, 'num_classes': 28, 'vocabulary': [ ' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', "'", ], }, } modelConfig = DictConfig( {'preprocessor': DictConfig(preprocessor), 'encoder': DictConfig(encoder), 'decoder': DictConfig(decoder)} ) obj = EncDecCTCModel.from_config_dict(config=modelConfig) assert isinstance(obj, EncDecCTCModel) @pytest.mark.unit def test_config_updated(self): config = DictConfig( { 'cls': 'nemo.collections.asr.modules.SpectrogramAugmentation', 'params': {'rect_freq': 50, 'rect_masks': 5, 'rect_time': 120,}, } ) obj = Serialization.from_config_dict(config=config) new_config = obj.to_config_dict() assert config != new_config assert 'params' not in new_config assert 'cls' not in new_config assert '_target_' in new_config @pytest.mark.unit def test_base_class_instantiation(self): # Target class is V2 impl, calling class is Serialization (base class) config = DictConfig({'target': get_class_path(MockSerializationImplV2)}) obj = Serialization.from_config_dict(config=config) new_config = obj.to_config_dict() assert config == new_config assert isinstance(obj, MockSerializationImplV2) assert obj.value == "MockSerializationImplV2" @pytest.mark.unit def test_self_class_instantiation(self): # Target class is V1 impl, calling class is V1 (same class) config = DictConfig({'target': get_class_path(MockSerializationImpl)}) obj = MockSerializationImpl.from_config_dict(config=config) # Serialization is base class new_config = obj.to_config_dict() assert config == new_config assert isinstance(obj, MockSerializationImpl) assert obj.value == "MockSerializationImpl" @pytest.mark.unit def test_sub_class_instantiation(self): # Target class is V1 impl, calling class is V2 (sub class) config = DictConfig({'target': get_class_path(MockSerializationImpl)}) obj = MockSerializationImplV2.from_config_dict(config=config) # Serialization is base class new_config = obj.to_config_dict() assert config == new_config assert isinstance(obj, MockSerializationImplV2) assert obj.value == "MockSerializationImplV2"