File size: 5,957 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch
from einops import rearrange
from omegaconf import DictConfig

from nemo.collections.tts.models import SpectrogramEnhancerModel
from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor


@pytest.fixture
def enhancer_config():
    n_bands = 80
    latent_dim = 192
    style_depth = 4
    network_capacity = 16
    fmap_max = 192

    config = {
        "model": {
            "n_bands": n_bands,
            "latent_dim": latent_dim,
            "style_depth": style_depth,
            "network_capacity": network_capacity,
            "mixed_prob": 0.9,
            "fmap_max": fmap_max,
            "generator": {
                "_target_": "nemo.collections.tts.modules.spectrogram_enhancer.Generator",
                "n_bands": n_bands,
                "latent_dim": latent_dim,
                "network_capacity": network_capacity,
                "style_depth": style_depth,
                "fmap_max": fmap_max,
            },
            "discriminator": {
                "_target_": "nemo.collections.tts.modules.spectrogram_enhancer.Discriminator",
                "n_bands": n_bands,
                "network_capacity": network_capacity,
                "fmap_max": fmap_max,
            },
            "spectrogram_min_value": -13.18,
            "spectrogram_max_value": 4.78,
            "consistency_loss_weight": 10.0,
            "gradient_penalty_loss_weight": 10.0,
            "gradient_penalty_loss_every_n_steps": 4,
            "spectrogram_predictor_path": None,
        },
        "generator_opt": {"_target_": "torch.optim.Adam", "lr": 2e-4, "betas": [0.5, 0.9]},
        "discriminator_opt": {"_target_": "torch.optim.Adam", "lr": 2e-4, "betas": [0.5, 0.9]},
    }

    return DictConfig(config)


@pytest.fixture
def enhancer(enhancer_config):
    return SpectrogramEnhancerModel(cfg=enhancer_config.model)


@pytest.fixture
def enhancer_with_fastpitch(enhancer_config_with_fastpitch):
    return SpectrogramEnhancerModel(cfg=enhancer_config_with_fastpitch.model)


@pytest.fixture
def sample_input(batch_size=15, max_length=1000):
    generator = torch.Generator()
    generator.manual_seed(0)

    lengths = torch.randint(max_length // 4, max_length - 7, (batch_size,), generator=generator)
    input_spectrograms = torch.randn((batch_size, 80, 1000), generator=generator)

    input_spectrograms = mask_sequence_tensor(input_spectrograms, lengths)

    return input_spectrograms, lengths


@pytest.mark.unit
def test_pad_spectrograms(enhancer: SpectrogramEnhancerModel, sample_input):
    input_spectrograms, lengths = sample_input
    output = enhancer.pad_spectrograms(input_spectrograms)

    assert output.size(-1) >= input_spectrograms.size(-1)


@pytest.mark.unit
def test_spectrogram_norm_unnorm(enhancer: SpectrogramEnhancerModel, sample_input):
    input_spectrograms, lengths = sample_input
    same_input_spectrograms = enhancer.unnormalize_spectrograms(
        enhancer.normalize_spectrograms(input_spectrograms, lengths), lengths
    )
    assert torch.allclose(input_spectrograms, same_input_spectrograms, atol=1e-5)


@pytest.mark.unit
def test_spectrogram_unnorm_norm(enhancer: SpectrogramEnhancerModel, sample_input):
    input_spectrograms, lengths = sample_input
    same_input_spectrograms = enhancer.normalize_spectrograms(
        enhancer.unnormalize_spectrograms(input_spectrograms, lengths), lengths
    )
    assert torch.allclose(input_spectrograms, same_input_spectrograms, atol=1e-5)


@pytest.mark.unit
def test_spectrogram_norm_unnorm_dont_look_at_padding(enhancer: SpectrogramEnhancerModel, sample_input):
    input_spectrograms, lengths = sample_input
    same_input_spectrograms = enhancer.unnormalize_spectrograms(
        enhancer.normalize_spectrograms(input_spectrograms, lengths), lengths
    )
    for i, length in enumerate(lengths.tolist()):
        assert torch.allclose(input_spectrograms[i, :, :length], same_input_spectrograms[i, :, :length], atol=1e-5)


@pytest.mark.unit
def test_spectrogram_unnorm_norm_dont_look_at_padding(enhancer: SpectrogramEnhancerModel, sample_input):
    input_spectrograms, lengths = sample_input
    same_input_spectrograms = enhancer.normalize_spectrograms(
        enhancer.unnormalize_spectrograms(input_spectrograms, lengths), lengths
    )
    for i, length in enumerate(lengths.tolist()):
        assert torch.allclose(input_spectrograms[i, :, :length], same_input_spectrograms[i, :, :length], atol=1e-5)


@pytest.mark.unit
def test_generator_pass_keeps_size(enhancer: SpectrogramEnhancerModel, sample_input):
    input_spectrograms, lengths = sample_input
    output = enhancer.forward(input_spectrograms=input_spectrograms, lengths=lengths)

    assert output.shape == input_spectrograms.shape


@pytest.mark.unit
def test_discriminator_pass(enhancer: SpectrogramEnhancerModel, sample_input):
    input_spectrograms, lengths = sample_input
    input_spectrograms = rearrange(input_spectrograms, "b c l -> b 1 c l")
    logits = enhancer.discriminator(x=input_spectrograms, condition=input_spectrograms, lengths=lengths)

    assert logits.shape == lengths.shape


@pytest.mark.unit
def test_nemo_save_load(enhancer: SpectrogramEnhancerModel, tmp_path):
    path = tmp_path / "test-enhancer-save-load.nemo"

    enhancer.save_to(path)
    SpectrogramEnhancerModel.restore_from(path)