|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import json |
|
import os |
|
import re |
|
|
|
import librosa |
|
import soundfile as sf |
|
import torch |
|
|
|
from nemo.collections.tts.models import AlignerModel |
|
from nemo.collections.tts.parts.utils.tts_dataset_utils import general_padding |
|
|
|
|
|
""" |
|
G2P disambiguation using an Aligner model's input embedding distances. |
|
|
|
Does not handle OOV and leaves them as graphemes. |
|
|
|
The output will have each token's phonemes (or graphemes) bracketed, e.g. |
|
<\"><M AH1 L ER0><, ><M AH1 L ER0><, ><HH IY1 Z>< ><DH AH0>< ><M AE1 N><.\"> |
|
|
|
Example: |
|
python aligner_heteronym_disambiguation.py \ |
|
--model=<model_path> \ |
|
--manifest=<manifest_path> \ |
|
--out=<output_json_path> \ |
|
--confidence=0.02 \ |
|
--verbose |
|
""" |
|
|
|
|
|
def get_args(): |
|
"""Retrieve arguments for disambiguation. |
|
""" |
|
parser = argparse.ArgumentParser("G2P disambiguation using Aligner input embedding distances.") |
|
|
|
parser.add_argument('--model', required=True, type=str, help="Path to Aligner model checkpoint (.nemo file).") |
|
parser.add_argument( |
|
'--manifest', |
|
required=True, |
|
type=str, |
|
help="Path to data manifest. Each entry should contain the path to the audio file as well as the text in graphemes.", |
|
) |
|
parser.add_argument( |
|
'--out', required=True, type=str, help="Path to output file where disambiguations will be written." |
|
) |
|
parser.add_argument( |
|
'--sr', |
|
required=False, |
|
default=22050, |
|
type=int, |
|
help="Target sample rate to load the dataset. Should match what the model was trained on.", |
|
) |
|
parser.add_argument( |
|
'--heteronyms', |
|
required=False, |
|
type=str, |
|
default='../../scripts/tts_dataset_files/heteronyms-052722', |
|
help="Heteronyms file to specify which words should be disambiguated. All others will use default pron.", |
|
) |
|
parser.add_argument( |
|
'--confidence', required=False, type=float, default=0.0, help="Confidence threshold to keep a disambiguation." |
|
) |
|
parser.add_argument( |
|
'--verbose', |
|
action='store_true', |
|
help="If set to True, logs scores for each disambiguated word in disambiguation_logs.txt.", |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def load_and_prepare_audio(aligner, audio_path, target_sr, device): |
|
"""Loads and resamples audio to target sample rate (if necessary), and preprocesses for Aligner input. |
|
""" |
|
|
|
audio_data, orig_sr = sf.read(audio_path) |
|
if orig_sr != target_sr: |
|
audio_data = librosa.core.resample(audio_data, orig_sr=orig_sr, target_sr=target_sr) |
|
|
|
audio = torch.tensor(audio_data, dtype=torch.float, device=device).unsqueeze(0) |
|
audio_len = torch.tensor(audio_data.shape[0], device=device).long().unsqueeze(0) |
|
|
|
|
|
spec, spec_len = aligner.preprocessor(input_signal=audio, length=audio_len) |
|
|
|
return spec, spec_len |
|
|
|
|
|
def disambiguate_candidates(aligner, text, spec, spec_len, confidence, device, heteronyms, log_file=None): |
|
"""Retrieves and disambiguate all candidate sentences for disambiguation of a given some text. |
|
|
|
Assumes that the max number of candidates per word is a reasonable batch size. |
|
|
|
Note: This could be sped up if multiple words' candidates were batched, but this is conceptually easier. |
|
""" |
|
|
|
aligner_g2p = aligner.tokenizer.g2p |
|
base_g2p = aligner_g2p(text) |
|
|
|
|
|
words = [word for word, _ in aligner_g2p.word_tokenize_func(text)] |
|
|
|
|
|
result_g2p = [] |
|
word_start_idx = 0 |
|
|
|
has_heteronym = False |
|
|
|
for word in words: |
|
|
|
g2p_default_len = len(aligner_g2p(word)) |
|
|
|
|
|
if word in heteronyms: |
|
has_heteronym = True |
|
|
|
|
|
word_candidates = [] |
|
candidate_prons_and_lengths = [] |
|
|
|
for pron in aligner_g2p.phoneme_dict[word]: |
|
|
|
candidate = base_g2p[:word_start_idx] + pron + base_g2p[word_start_idx + g2p_default_len :] |
|
candidate_tokens = aligner.tokenizer.encode_from_g2p(candidate) |
|
|
|
word_candidates.append(candidate_tokens) |
|
candidate_prons_and_lengths.append((pron, len(pron))) |
|
|
|
|
|
num_candidates = len(word_candidates) |
|
|
|
|
|
if num_candidates == 1: |
|
has_heteronym = False |
|
result_g2p.append(f"<{' '.join(candidate_prons_and_lengths[0][0])}>") |
|
word_start_idx += g2p_default_len |
|
continue |
|
|
|
text_len = [len(toks) for toks in word_candidates] |
|
text_len_in = torch.tensor(text_len, device=device).long() |
|
|
|
|
|
max_text_len = max(text_len) |
|
text_stack = [] |
|
for i in range(num_candidates): |
|
padded_tokens = general_padding( |
|
torch.tensor(word_candidates[i], device=device).long(), text_len[i], max_text_len |
|
) |
|
text_stack.append(padded_tokens) |
|
text_in = torch.stack(text_stack) |
|
|
|
|
|
spec_in = spec.repeat([num_candidates, 1, 1]) |
|
spec_len_in = spec_len.repeat([num_candidates]) |
|
|
|
with torch.no_grad(): |
|
soft_attn, _ = aligner(spec=spec_in, spec_len=spec_len_in, text=text_in, text_len=text_len_in) |
|
|
|
|
|
text_embeddings = aligner.embed(text_in).transpose(1, 2) |
|
l2_dists = aligner.alignment_encoder.get_dist(keys=text_embeddings, queries=spec_in).sqrt() |
|
|
|
durations = aligner.alignment_encoder.get_durations(soft_attn, text_len_in, spec_len_in).int() |
|
|
|
|
|
min_dist = float('inf') |
|
max_dist = 0.0 |
|
best_candidate = None |
|
for i in range(num_candidates): |
|
candidate_mean_dist = aligner.alignment_encoder.get_mean_distance_for_word( |
|
l2_dists=l2_dists[i], |
|
durs=durations[i], |
|
start_token=word_start_idx + (1 if aligner.tokenizer.pad_with_space else 0), |
|
num_tokens=candidate_prons_and_lengths[i][1], |
|
) |
|
if log_file: |
|
log_file.write(f"{candidate_prons_and_lengths[i][0]} -- {candidate_mean_dist}\n") |
|
|
|
if candidate_mean_dist < min_dist: |
|
min_dist = candidate_mean_dist |
|
best_candidate = candidate_prons_and_lengths[i][0] |
|
if candidate_mean_dist > max_dist: |
|
max_dist = candidate_mean_dist |
|
|
|
|
|
disamb_conf = (max_dist - min_dist) / ((max_dist + min_dist) / 2.0) |
|
if disamb_conf < confidence: |
|
if log_file: |
|
log_file.write(f"Below confidence threshold: {best_candidate} ({disamb_conf})\n") |
|
|
|
has_heteronym = False |
|
result_g2p.append(f"<{' '.join(aligner_g2p(word))}>") |
|
word_start_idx += g2p_default_len |
|
continue |
|
|
|
|
|
if log_file: |
|
log_file.write(f"best candidate: {best_candidate} (confidence: {disamb_conf})\n") |
|
|
|
result_g2p.append(f"<{' '.join(best_candidate)}>") |
|
else: |
|
if re.search("[a-zA-Z]", word) is None: |
|
|
|
result_g2p.append(f"<{word}>") |
|
elif word in aligner_g2p.phoneme_dict: |
|
|
|
result_g2p.append(f"<{' '.join(aligner_g2p.phoneme_dict[word][0])}>") |
|
else: |
|
|
|
result_g2p.append(f"<{' '.join(aligner_g2p(word))}>") |
|
|
|
|
|
word_start_idx += g2p_default_len |
|
|
|
if log_file and has_heteronym: |
|
log_file.write(f"{text}\n") |
|
log_file.write(f"===\n{''.join(result_g2p)}\n===\n") |
|
log_file.write(f"===============================\n") |
|
|
|
return result_g2p, has_heteronym |
|
|
|
|
|
def disambiguate_dataset( |
|
aligner, manifest_path, out_path, sr, heteronyms, confidence, device, verbose, heteronyms_only=True |
|
): |
|
"""Disambiguates the phonemes for all words with ambiguous pronunciations in the given manifest. |
|
""" |
|
log_file = open('disambiguation_logs.txt', 'w') if verbose else None |
|
|
|
with open(out_path, 'w') as f_out: |
|
with open(manifest_path, 'r') as f_in: |
|
count = 0 |
|
|
|
for line in f_in: |
|
|
|
entry = json.loads(line) |
|
|
|
text = aligner.normalizer.normalize(entry['text'], punct_post_process=True) |
|
text = aligner.tokenizer.text_preprocessing_func(text) |
|
|
|
|
|
audio_path = entry['audio_filepath'] |
|
spec, spec_len = load_and_prepare_audio(aligner, audio_path, sr, device) |
|
|
|
|
|
disambiguated_text, has_heteronym = disambiguate_candidates( |
|
aligner, text, spec, spec_len, confidence, device, heteronyms, log_file |
|
) |
|
|
|
|
|
if heteronyms_only and not has_heteronym: |
|
continue |
|
|
|
|
|
entry['disambiguated_text'] = ''.join(disambiguated_text) |
|
f_out.write(f"{json.dumps(entry)}\n") |
|
|
|
count += 1 |
|
if count % 100 == 0: |
|
print(f"Finished {count} entries.") |
|
|
|
print(f"Finished all entries, with a total of {count}.") |
|
if log_file: |
|
log_file.close() |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
|
|
if not os.path.exists(args.model): |
|
print("Could not find model checkpoint file: ", args.model) |
|
if not os.path.exists(args.manifest): |
|
print("Could not find data manifest file: ", args.manifest) |
|
if os.path.exists(args.out): |
|
print("Output file already exists: ", args.out) |
|
overwrite = input("Is it okay to overwrite it? (Y/N): ") |
|
if overwrite.lower() != 'y': |
|
print("Not overwriting output file, quitting.") |
|
quit() |
|
if not os.path.exists(args.heteronyms): |
|
print("Could not find heteronyms list: ", args.heteronyms) |
|
|
|
|
|
heteronyms = set() |
|
with open(args.heteronyms, 'r') as f_het: |
|
for line in f_het: |
|
heteronyms.add(line.strip().lower()) |
|
|
|
|
|
print("Restoring Aligner model from checkpoint...") |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
aligner = AlignerModel.restore_from(args.model, map_location=device) |
|
|
|
|
|
print("Beginning disambiguation...") |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
disambiguate_dataset(aligner, args.manifest, args.out, args.sr, heteronyms, args.confidence, device, args.verbose) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|