|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This script is used as a CI test and shows how to chain TTS and ASR models |
|
""" |
|
|
|
from argparse import ArgumentParser |
|
from math import ceil |
|
from pathlib import Path |
|
|
|
import librosa |
|
import soundfile |
|
import torch |
|
|
|
from nemo.collections.asr.metrics.wer import word_error_rate |
|
from nemo.collections.asr.models import EncDecCTCModel |
|
from nemo.collections.common.parts.preprocessing import parsers |
|
from nemo.collections.tts.models.base import SpectrogramGenerator, Vocoder |
|
from nemo.utils import logging |
|
|
|
LIST_OF_TEST_STRINGS = [ |
|
"Hey, this is a test of the speech synthesis system.", |
|
"roupell received the announcement with a cheerful countenance.", |
|
"with thirteen dollars, eighty-seven cents when considerably greater resources were available to him.", |
|
"Two other witnesses were able to offer partial descriptions of a man they saw in the southeast corner window.", |
|
"'just to steady their legs a little' in other words, to add his weight to that of the hanging bodies.", |
|
"The discussion above has already set forth examples of his expression of hatred for the United States.", |
|
"At two:thirty-eight p.m., Eastern Standard Time, Lyndon Baines Johnson took the oath of office as the thirty-sixth President of the United States.", |
|
"or, quote, other high government officials in the nature of a complaint coupled with an expressed or implied determination to use a means.", |
|
"As for my return entrance visa please consider it separately. End quote.", |
|
"it appears that Marina Oswald also complained that her husband was not able to provide more material things for her.", |
|
"appeared in The Dallas Times Herald on November fifteen, nineteen sixty-three.", |
|
"The only exit from the office in the direction Oswald was moving was through the door to the front stairway.", |
|
] |
|
|
|
|
|
def main(): |
|
parser = ArgumentParser() |
|
parser.add_argument( |
|
"--asr_model", |
|
type=str, |
|
default="QuartzNet15x5Base-En", |
|
choices=[x.pretrained_model_name for x in EncDecCTCModel.list_available_models()], |
|
) |
|
parser.add_argument( |
|
"--tts_model_spec", |
|
type=str, |
|
default="tts_en_tacotron2", |
|
choices=[x.pretrained_model_name for x in SpectrogramGenerator.list_available_models()], |
|
) |
|
parser.add_argument( |
|
"--tts_model_vocoder", |
|
type=str, |
|
default="tts_en_waveglow_88m", |
|
choices=[x.pretrained_model_name for x in Vocoder.list_available_models()], |
|
) |
|
parser.add_argument("--wer_tolerance", type=float, default=1.0, help="used by test") |
|
parser.add_argument("--trim", action="store_true") |
|
parser.add_argument("--debug", action="store_true") |
|
args = parser.parse_args() |
|
torch.set_grad_enabled(False) |
|
|
|
if args.debug: |
|
logging.set_verbosity(logging.DEBUG) |
|
|
|
logging.info(f"Using NGC cloud ASR model {args.asr_model}") |
|
asr_model = EncDecCTCModel.from_pretrained(model_name=args.asr_model) |
|
logging.info(f"Using NGC cloud TTS Spectrogram Generator model {args.tts_model_spec}") |
|
tts_model_spec = SpectrogramGenerator.from_pretrained(model_name=args.tts_model_spec) |
|
logging.info(f"Using NGC cloud TTS Vocoder model {args.tts_model_vocoder}") |
|
tts_model_vocoder = Vocoder.from_pretrained(model_name=args.tts_model_vocoder) |
|
models = [asr_model, tts_model_spec, tts_model_vocoder] |
|
|
|
if torch.cuda.is_available(): |
|
for i, m in enumerate(models): |
|
models[i] = m.cuda() |
|
for m in models: |
|
m.eval() |
|
|
|
asr_model, tts_model_spec, tts_model_vocoder = models |
|
|
|
parser = parsers.make_parser( |
|
labels=asr_model.decoder.vocabulary, name="en", unk_id=-1, blank_id=-1, do_normalize=True, |
|
) |
|
labels_map = dict([(i, asr_model.decoder.vocabulary[i]) for i in range(len(asr_model.decoder.vocabulary))]) |
|
|
|
tts_input = [] |
|
asr_references = [] |
|
longest_tts_input = 0 |
|
for test_str in LIST_OF_TEST_STRINGS: |
|
tts_parsed_input = tts_model_spec.parse(test_str) |
|
if len(tts_parsed_input[0]) > longest_tts_input: |
|
longest_tts_input = len(tts_parsed_input[0]) |
|
tts_input.append(tts_parsed_input.squeeze()) |
|
|
|
asr_parsed = parser(test_str) |
|
asr_parsed = ''.join([labels_map[c] for c in asr_parsed]) |
|
asr_references.append(asr_parsed) |
|
|
|
|
|
for i, text in enumerate(tts_input): |
|
pad = (0, longest_tts_input - len(text)) |
|
tts_input[i] = torch.nn.functional.pad(text, pad, value=68) |
|
|
|
logging.debug(tts_input) |
|
|
|
|
|
tts_input = torch.stack(tts_input) |
|
if torch.cuda.is_available(): |
|
tts_input = tts_input.cuda() |
|
specs = tts_model_spec.generate_spectrogram(tokens=tts_input) |
|
audio = [] |
|
step = ceil(len(specs) / 4) |
|
for i in range(4): |
|
audio.append(tts_model_vocoder.convert_spectrogram_to_audio(spec=specs[i * step : i * step + step])) |
|
|
|
audio = [item for sublist in audio for item in sublist] |
|
audio_file_paths = [] |
|
|
|
logging.debug(f"args.trim: {args.trim}") |
|
for i, aud in enumerate(audio): |
|
aud = aud.cpu().numpy() |
|
if args.trim: |
|
aud = librosa.effects.trim(aud, top_db=40)[0] |
|
soundfile.write(f"{i}.wav", aud, samplerate=22050) |
|
audio_file_paths.append(str(Path(f"{i}.wav"))) |
|
|
|
|
|
hypotheses = asr_model.transcribe(audio_file_paths) |
|
for i, _ in enumerate(hypotheses): |
|
logging.debug(f"{i}") |
|
logging.debug(f"ref:'{asr_references[i]}'") |
|
logging.debug(f"hyp:'{hypotheses[i]}'") |
|
wer_value = word_error_rate(hypotheses=hypotheses, references=asr_references) |
|
if wer_value > args.wer_tolerance: |
|
raise ValueError(f"Got WER of {wer_value}. It was higher than {args.wer_tolerance}") |
|
logging.info(f'Got WER of {wer_value}. Tolerance was {args.wer_tolerance}') |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|