{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
"\n",
"Instructions for setting up Colab are as follows:\n",
"1. Open a new Python 3 notebook.\n",
"2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
"3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
"4. Run this cell to set up dependencies.\n",
"5. Restart the runtime (Runtime -> Restart Runtime) for any upgraded packages to take effect\n",
"\"\"\"\n",
"\n",
"NEMO_DIR_PATH = \"NeMo\"\n",
"BRANCH = 'r1.17.0'\n",
"\n",
"! git clone https://github.com/NVIDIA/NeMo\n",
"%cd NeMo\n",
"! python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]\n",
"%cd .."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Speaker Diarization Training"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Neural Diarizer in Speaker Diarization Pipeline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Speaker diarization system needs to capture the characteristics of unseen speakers from the given audio recording and generate speaker-homogeneous segments which belong to corresponding speaker labels. During the speaker diarization process, the number of speakers should be estimated, then the audio segments should be assigned to a few of speaker labels. \n",
"\n",
"While clustering algorithms can also assign segments to speaker groups, overlap-aware diarization cannot be done with clustering based diarizer since one segment is only assigned to one speaker label. However, we can use the clustering result to create initial speaker profiles and train a neural model that generates overlap-aware speaker labels by comparing the input audio signal with the initial speaker profiles. In the NeMo speaker diarization toolkit, we refer to such neural modules as **neural diarizer**. \n",
"\n",
"The Multi-scale Diarization Decoder (MSDD) model is a type of neural diarizer we can use in the NeMo speaker diarization pipeline. This tutorial shows how to train MSDD on a small toy dataset. By using MSDD on top of clustering diarizer, we can obtain the following benefits:\n",
"\n",
"- **Improved diarization accuracy**: Compared to clustering diarizer, MSDD could achieve a lower diarization error rate (DER)\n",
"- **Overlap aware diarization**: Speaker diarization results in clustering diarizer do not include speech overlaps\n",
"- **Model training on actual multispeaker dataset**: Unlike training a speaker embedding model, we can train or finetune a neural model on an actual speaker diarization dataset where multiple speakers are recorded in a single audio file. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training and inference of Multi-scale Diarization Decoder\n",
"\n",
"When it comes to the speaker diarization problem, MSDD model employs a divide-and-conquer strategy where a pairwise model is employed for both training and inference. The following figure explains how a pairwise model is employed for training and inference."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here are itemized descriptions of noteworthy features of the MSDD model.\n",
"\n",
"#### Training \n",
"\n",
"- **Oracle VAD, multi-scale segmentation for training** \n",
"In a training setup, we use oracle VAD from ground-truth annotation files (RTTM files) and perform multiscale segmentation. After we obtain timestamps for each and every segment, we feed multi-scale timestamps and raw audio signals into a computational graph where the speaker embedding extractor and neural diarizer is trained. \n",
"\n",
" \n",
"- **MSDD inputs for training process** \n",
"During training, we employ oracle clustering result (ground-truth speaker labels in the annotation file) to calculate the cluster-average embeddings. Subsequently, we calculate binary cross-entropy loss which calculates a loss value for each timestep and each speaker.\n",
"\n",
"\n",
"- **End-to-end training: from raw audio to speaker label** \n",
"The training approach we employ can be considered as end-to-end training since the input to the computational graph is raw audio signal and the outputs are speaker labels. The end-to-end training is depicted in a dotted box in the above figure. We can either freeze the speaker embedding model or train it jointly depending on the tasks. \n",
"\n",
"\n",
"- **Pairwise (two-speaker) unit model** \n",
"While training the MSDD model, we use a two-speaker dataset for a two speaker model. For this pairwise training, we clean the source dataset to have only two speakers by splitting the annotation. \n",
"\n",
"\n",
"- **Split training samples** \n",
"Since we have finite GPU memory for training, we break down the training audio samples into short audio samples. We set step-count, and step-count indicates a unit of decision for speaker label estimation. We set step-count (e.g., `step_count=50`) when we create training datasets and use the step-count for training. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Inference\n",
"\n",
"- **Multi-scale clustering** \n",
"In inference mode, we apply multi-scale clustering for obtaining speaker profiles that are represented by cluster-average embedding. \n",
"\n",
"\n",
"- **Divide-and-conquer approach with pairwise (two-speaker) unit model** \n",
"We retrieve all possible pairs from the estimated number of speakers and average the results. For example, if there are four speakers `(A, B, C, D)`, we extract 6 pairs: `(A,B)`, `(A,C)`, `(A,D)`, `(B,C)`, `(B,D)`, `(C,D)`. Finally, the sigmoid outputs are averaged. In this way, MSDD can deal with a flexible number of speakers using a pairwise model.\n",
"\n",
"\n",
"- **Split inference samples** \n",
"As in the training process, we can also break down the target samples for inference. While we can do inference on whole input audio at once, split inference generally gives an improved performance. It is recommended to use the same step-count you used for training the MSDD model (e.g., `diar_window_length=50`) for your inference configurations."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Input and Output of Multi-scale Diarization Decoder"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"While using an MSDD model as neural diarizer has a few benefits, MSDD models require a clustering result to obtain initial speaker profiles as references for performing overlap-aware speaker diarization inference. Here are descriptions for input and output of the MSDD model."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Input: Clustering as Initialization\n",
"\n",
"MSDD model is a diarizer model that accepts two different data inputs:\n",
"\n",
" 1. Cluster-average embeddings\n",
" 2. Multi-scale embedding sequence \n",
" \n",
"The two input signals are depicted in the following figure."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By initializing the diarization task with a clustering algorithm, we can estimate the number of speakers and cluster-average embeddings. Thus, the cluster-average embeddings provide the speaker profile of each speaker. The cluster-average embeddings we provide can be regarded as reference signals for providing seed speaker profiles.\n",
"\n",
"Once we obtain the fixed (or estimated) number of speakers, the speaker diarization problem becomes a binary classification task where we need to estimate whether a certain speaker's speech exists or not at a given timestep. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Output: Sigmoid Output and Binary Cross-entropy Loss"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above figure depicts the inputs and outputs of the MSDD model. As an output from MSDD, sigmoid values for each speaker are generated. These sigmoid values are independent from the other speakers and indicate the simulated probability of the corresponding speaker's speech signal at the given step. During the training process, binary cross-entropy (BCE) is calculated for each individual sigmoid value and summed up to calculate the total loss for optimization."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example Data Creation\n",
"\n",
"- Please skip this section and go directly to [Prepare Training data for MSDD](#Prepare-Training-data-for-MSDD) section if you have your own speaker diarization dataset. \n",
"\n",
"In this tutorial, we use [NeMo Multispeaker Simulator](https://github.com/NVIDIA/NeMo/blob/main/tutorials/tools/Multispeaker_Simulator.ipynb) and the Librispeech corpus to generate a toy training dataset for demonstration purpose. You can replace the simulated dataset with your own datasets if you have proper speaker annotations (RTTM files) for the dataset. If you do not have access to any speaker diarization datasets, you can use [NeMo Multispeaker Simulator](https://github.com/NVIDIA/NeMo/blob/main/tutorials/tools/Multispeaker_Simulator.ipynb) by generating a good amount of data samples to meet your needs. \n",
"\n",
"For more details regarding data simulator, please follow the descriptions in [NeMo Multispeaker Simulator](https://github.com/NVIDIA/NeMo/blob/main/tutorials/tools/Multispeaker_Simulator.ipynb) and we will not cover configurations and detailed process of data simulation in this tutorial. \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install dependencies for data simulator\n",
"!apt-get install sox libsndfile1 ffmpeg\n",
"!pip install wget\n",
"!pip install unidecode\n",
"!pip install \"matplotlib>=3.3.2\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Data Simulation Step 1: Download Required Resources\n",
"\n",
"We need to download the LibriSpeech corpus and corresponding word alignments for generating synthetic multi-speaker audio sessions. In addition, we need to download necessary data cleaning scripts from NeMo git."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"NEMO_DIR_PATH = \"NeMo\"\n",
"BRANCH = 'r1.17.0'\n",
"\n",
"# download scripts if not already there \n",
"if not os.path.exists('NeMo/scripts'):\n",
" print(\"Downloading necessary scripts\")\n",
" !mkdir -p NeMo/scripts/dataset_processing\n",
" !mkdir -p NeMo/scripts/speaker_tasks\n",
" !wget -P NeMo/scripts/dataset_processing/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/scripts/dataset_processing/get_librispeech_data.py\n",
" !wget -P NeMo/scripts/speaker_tasks/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/scripts/speaker_tasks/create_alignment_manifest.py\n",
" !wget -P NeMo/scripts/speaker_tasks/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/scripts/speaker_tasks/create_msdd_train_dataset.py \n",
" !wget -P NeMo/scripts/speaker_tasks/ https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/scripts/speaker_tasks/pathfiles_to_diarize_manifest.py"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have downloaded all the necessary scripts for data creation and preparation, we can start the data simulation step by downloading the LibriSpeech corpus."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!mkdir -p LibriSpeech\n",
"!python {NEMO_DIR_PATH}/scripts/dataset_processing/get_librispeech_data.py \\\n",
" --data_root LibriSpeech \\\n",
" --data_sets dev_clean"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can get the forced word alignments data for the LibriSpeech corpus from [this repository.](https://github.com/CorentinJ/librispeech-alignments). Full forced alignments data can be downloaded at [google drive link for alignments data](https://drive.google.com/file/d/1WYfgr31T-PPwMcxuAq09XZfHQO5Mw8fE/view?usp=sharing). We will download only a subset of forced alignment data containing dev-clean part."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!wget -nc https://dldata-public.s3.us-east-2.amazonaws.com/LibriSpeech_Alignments.tar.gz\n",
"!tar -xzf LibriSpeech_Alignments.tar.gz\n",
"!rm -f LibriSpeech_Alignments.tar.gz"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Data Simulation Step 2: Produce Manifest File with Forced Alignments\n",
"\n",
"We will merge the LibriSpeech manifest files and LibriSpeech forced alignments into one manifest file for ease of use when generating synthetic data. Create alignment files by running the following script.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!python NeMo/scripts/speaker_tasks/create_alignment_manifest.py \\\n",
" --input_manifest_filepath LibriSpeech/dev_clean.json \\\n",
" --base_alignment_path LibriSpeech_Alignments \\\n",
" --output_manifest_filepath ./dev-clean-align.json \\\n",
" --ctm_output_directory ./ctm_out \\\n",
" --libri_dataset_split dev-clean"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Data Simulation Step 3: Set data simulation parameters"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have downloaded all the sources we need for data creation, we need to download data simulator configurations in `.yaml` format. Download the YAML file and download `data_simulator.py` script from NeMo repository."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from omegaconf import OmegaConf\n",
"import os\n",
"ROOT = os.getcwd()\n",
"conf_dir = os.path.join(ROOT,'conf')\n",
"!mkdir -p $conf_dir\n",
"CONFIG_PATH = os.path.join(conf_dir, 'data_simulator.yaml')\n",
"if not os.path.exists(CONFIG_PATH):\n",
" !wget -P $conf_dir https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/tools/speech_data_simulator/conf/data_simulator.yaml\n",
"\n",
"config = OmegaConf.load(CONFIG_PATH)\n",
"print(OmegaConf.to_yaml(config))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Data Simulation Step 4: Generate Simulated Audio Session\n",
"\n",
"We will generate a set of example sessions with the following specifications:\n",
"\n",
"- 3 example sessions for train \n",
"- 3 example sessions for validation\n",
"- 2-speakers in each session\n",
"- 60 seconds of recordings\n",
"\n",
"We need to setup different seed for train and validation sets."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from nemo.collections.asr.data.data_simulation import MultiSpeakerSimulator\n",
"\n",
"# Generate train set \n",
"ROOT = os.getcwd()\n",
"data_dir = os.path.join(ROOT,'simulated_train')\n",
"config.data_simulator.random_seed=10\n",
"config.data_simulator.manifest_filepath=\"./dev-clean-align.json\"\n",
"config.data_simulator.outputs.output_dir=data_dir\n",
"config.data_simulator.session_config.num_sessions=3\n",
"config.data_simulator.session_config.num_speakers=2\n",
"config.data_simulator.session_config.session_length=60\n",
"config.data_simulator.background_noise.add_bg=False \n",
"\n",
"lg = MultiSpeakerSimulator(cfg=config)\n",
"lg.generate_sessions()\n",
"\n",
"# Generate validation set \n",
"data_dir = os.path.join(ROOT,'simulated_valid')\n",
"config.data_simulator.random_seed=20\n",
"config.data_simulator.outputs.output_dir=data_dir\n",
"\n",
"lg = MultiSpeakerSimulator(cfg=config)\n",
"lg.generate_sessions()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that parameter setting is done, generate the samples by launching `generate_sessions()`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"lg = MultiSpeakerSimulator(cfg=config)\n",
"lg.generate_sessions()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data preparation step 5: Listen to and Visualize Session\n",
"\n",
"Listen to the audio and visualize the corresponding speaker timestamps (recorded in a RTTM file for each session)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"import os\n",
"import wget\n",
"import IPython\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import librosa\n",
"from nemo.collections.asr.parts.utils.speaker_utils import rttm_to_labels, labels_to_pyannote_object\n",
"\n",
"ROOT = os.getcwd()\n",
"data_dir = os.path.join(ROOT,'simulated_train')\n",
"audio = os.path.join(data_dir,'multispeaker_session_0.wav')\n",
"rttm = os.path.join(data_dir,'multispeaker_session_0.rttm')\n",
"\n",
"sr = 16000\n",
"signal, sr = librosa.load(audio,sr=sr) \n",
"\n",
"fig,ax = plt.subplots(1,1)\n",
"fig.set_figwidth(20)\n",
"fig.set_figheight(2)\n",
"plt.plot(np.arange(len(signal)),signal,'gray')\n",
"fig.suptitle('Synthetic Audio Session', fontsize=16)\n",
"plt.xlabel('Time (s)', fontsize=18)\n",
"plt.yticks([], [])\n",
"ax.margins(x=0)\n",
"a,_ = plt.xticks()\n",
"plt.xticks(a[:-1],a[:-1]/sr);\n",
"IPython.display.Audio(audio)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can visually check the ground-truth file of the first sample by running the following commands. We can see that it has plenty of overlap between two speakers. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# display speaker labels for reference\n",
"labels = rttm_to_labels(rttm)\n",
"reference = labels_to_pyannote_object(labels)\n",
"reference"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can check that corresponding RTTM files are generated as ground-truth labels for training and evaluation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!cat simulated_train/multispeaker_session_0.rttm | head -10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Data preparation step 6: Check out the created files\n",
"\n",
"The following files are generated from data simulator:\n",
"\n",
"* _wav files_ (one per audio session) - the output audio sessions\n",
"* _rttm files_ (one per audio session) - the speaker timestamps for the corresponding audio session (used for diarization training)\n",
"* _list files_ (one per file type per batch of sessions) - a list of generated files of the given type (e.g., wav, rttm), used primarily for manifest creation\n",
"\n",
"Check if the files we need are generated by running the following commands."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"\\n Training audio files:\")\n",
"!ls simulated_train/*.wav\n",
"print(\"\\n Training audio files:\")\n",
"!ls simulated_train/*.rttm\n",
"print(\"\\n Training RTTM list content:\")\n",
"!cat simulated_train/synthetic_wav.list\n",
"print(\"\\n Training RTTM list content:\")\n",
"!cat simulated_train/synthetic_rttm.list\n",
"\n",
"print(\"\\n Validation audio files:\")\n",
"!ls simulated_valid/*.wav\n",
"print(\"\\n Validation audio files:\")\n",
"!ls simulated_valid/*.rttm\n",
"print(\"\\n Validation RTTM list content:\")\n",
"!cat simulated_valid/synthetic_wav.list\n",
"print(\"\\n Validation RTTM list content:\")\n",
"!cat simulated_valid/synthetic_rttm.list"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Prepare Training Data for MSDD"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have datasets for both train and validation (dev), we can start preparing and cleaning the data samples for training. Make sure you have the following list of files:\n",
"\n",
"**Training set** \n",
"\n",
"- Train audio files `.wav`\n",
"- A train audio list file `.list`\n",
"- Train RTTM files `.rttm`\n",
"- A train RTTM list content `.list`\n",
"\n",
"**Validation set** \n",
"\n",
"- Validation audio files `.wav`\n",
"- A validation audio list file `.list`\n",
"- Validation RTTM files `.rttm`\n",
"- A validation RTTM list file `.list`\n",
"\n",
"\n",
"Based on these files, we need to create manifest files containing data samples we have. If you don't have a `.list` file, you need to create a `.list` file for the `.wav` files and `.rttm` files."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# create a NeMo manifest (.json) file for training dataset\n",
"!python NeMo/scripts/speaker_tasks/pathfiles_to_diarize_manifest.py \\\n",
" --paths2audio_files='simulated_train/synthetic_wav.list' \\\n",
" --paths2rttm_files='simulated_train/synthetic_rttm.list' \\\n",
" --manifest_filepath='simulated_train/msdd_data.json'\n",
"\n",
"# create a NeMo manifest (.json) file for validation (dev) dataset\n",
"!python NeMo/scripts/speaker_tasks/pathfiles_to_diarize_manifest.py \\\n",
" --paths2audio_files='simulated_valid/synthetic_wav.list' \\\n",
" --paths2rttm_files='simulated_valid/synthetic_rttm.list' \\\n",
" --manifest_filepath='simulated_valid/msdd_data.json'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you print the content of the created manifest file, you can see that `.rttm` files in the list and `.wav` files are grouped together in the generated manifest files."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"\\n An example line in training dataset manifest file:\")\n",
"!cat simulated_train/msdd_data.json | head -1\n",
"print(\"\\n An example line in validation Dataset manifest file:\")\n",
"!cat simulated_valid/msdd_data.json | head -1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have input a standard manifest file, we need to break down each audio clip into short audio clips so that we can put several samples in a batch. \n",
"\n",
"Before we generate a manifest file and RTTM files for training MSDD, you have to determine:\n",
"\n",
"- `window`: the window length of the base scale (the shortest scale)\n",
"- `shift`: the hop-length of the base scale (the shortest scale)\n",
"- `step_count`: how many decision steps in one data sample\n",
"\n",
"Note that these numbers should match the parameters in the configurations for your desired MSDD model. If you want to train with new parameters (`window`, `shift` and `step_count`), you need to make new manifest files with the new parameters."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# create a manifest (.json) file for training dataset\n",
"!python NeMo/scripts/speaker_tasks/create_msdd_train_dataset.py \\\n",
" --input_manifest_path='simulated_train/msdd_data.json' \\\n",
" --output_manifest_path='simulated_train/msdd_data.50step.json' \\\n",
" --pairwise_rttm_output_folder='simulated_train/' \\\n",
" --window 0.5 \\\n",
" --shift 0.25 \\\n",
" --step_count 50 \n",
" \n",
"# create a manifest (.json) file for validation (dev) dataset\n",
"!python NeMo/scripts/speaker_tasks/create_msdd_train_dataset.py \\\n",
" --input_manifest_path='simulated_valid/msdd_data.json' \\\n",
" --output_manifest_path='simulated_valid/msdd_data.50step.json' \\\n",
" --pairwise_rttm_output_folder='simulated_valid/' \\\n",
" --window 0.5 \\\n",
" --shift 0.25 \\\n",
" --step_count 50 "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we broke down the training and validation dataset into 50-step samples, let's checkout how the output manifest files look like. We used 0.25 second of shift length so in theory, if there is no silence or pause in the data, the length of data sample should be `step_count*shift` which is `50*0.25=12.5` second in the example we used. However, since there are pauses between the segments in practice, the final lengths of data samples are longer than 12.5 second."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"\\nTraining Dataset:\")\n",
"!cat simulated_train/msdd_data.50step.json | tail -5\n",
"print(\"\\nValidation Dataset:\")\n",
"!cat simulated_valid/msdd_data.50step.json | tail -5 "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train an MSDD Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have prepared all the necessary dataset, we can train an MSDD model on the prepared dataset. Download YAML file for training form NeMo repository and load the configuration into `config` variable."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import nemo\n",
"import nemo.collections.asr as nemo_asr\n",
"from omegaconf import OmegaConf\n",
"\n",
"NEMO_ROOT = os.getcwd()\n",
"!mkdir conf \n",
"!wget -P conf https://raw.githubusercontent.com/NVIDIA/NeMo/$BRANCH/examples/speaker_tasks/diarization/conf/neural_diarizer/msdd_5scl_15_05_50Povl_256x3x32x2.yaml\n",
"MODEL_CONFIG = os.path.join(NEMO_ROOT,'conf/msdd_5scl_15_05_50Povl_256x3x32x2.yaml')\n",
"config = OmegaConf.load(MODEL_CONFIG)\n",
"print(OmegaConf.to_yaml(config))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Setup the `manifest_filepath` for `train_ds` and `validation_ds` by feeding the `json` file paths from `create_msdd_train_dataset.py`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config.model.train_ds.manifest_filepath = 'simulated_train/msdd_data.50step.json'\n",
"config.model.validation_ds.manifest_filepath = 'simulated_valid/msdd_data.50step.json'"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train MSDD with frozen speaker embedding model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Provide a batch size number for training in `config.model.batch_size`. In a batch, we will assign the given number of split samples (in this example, 50 step-size). Note that you might need to change this batch size if the following batch size maxes out your GPU memory size. \n",
"\n",
"`config.model.emb_batch_size` determines the number of embedding vectors attached to a computational graph. This means that \n",
"If you want to freeze the speaker embedding extractor, you should set `emb_batch_size=0`.\n",
"If you want to jointly optimize speaker embedding extractor, you need to assign an adequate number that does not max out the GPU memory. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config.batch_size=5\n",
"config.model.emb_batch_size=0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Provide paths to the temporary folders for saving timestamp data during training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config.model.train_ds.emb_dir=\"simulated_train\" \n",
"config.model.validation_ds.emb_dir=\"simulated_valid\" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Setup a speaker embedding model that will be used for speaker embedding extraction. We will use `titanet_large` model checkpoint from NGC. Note that this speaker embedding model will be saved together in a `.ckpt` file whenever pytorch lightning trainer saves checkpoint. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config.model.diarizer.speaker_embeddings.model_path=\"titanet_large\"\n",
"config.trainer.max_epochs = 5\n",
"config.trainer.strategy = None"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use `pytorch_lightning` and train a model instance from class `EncDecDiarLabelModel`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pytorch_lightning as pl\n",
"from nemo.collections.asr.models import EncDecDiarLabelModel\n",
"from nemo.utils.exp_manager import exp_manager\n",
"\n",
"trainer = pl.Trainer(**config.trainer)\n",
"exp_manager(trainer, config.get(\"exp_manager\", None))\n",
"msdd_model = EncDecDiarLabelModel(cfg=config.model, trainer=trainer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Before we start training, let's check a few of the weights in the speaker embedding model in `msdd_model.msdd._speaker_model`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"msdd_model.msdd._speaker_model.state_dict()[\"encoder.encoder.0.mconv.0.conv.weight\"][0,:,:]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you successfully ran the previous step, now it is ready to initiate a training session of MSDD. Launch `trainer.fit()` function for `msdd_model`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer.fit(msdd_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this way, you can train an MSDD model and use `.ckpt` files saved during training process."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's check the weights in speaker embedding model again and see if the numbers are changed."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"msdd_model.msdd._speaker_model.state_dict()[\"encoder.encoder.0.mconv.0.conv.weight\"][0,:,:]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train MSDD and speaker embedding model together"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this time, let's attach a few speaker embeddings to a graph and jointly train the speaker embedding mode."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config.model.emb_batch_size = 100 # choose the largest number that does not max out GPU memory.\n",
"config.trainer.max_epochs = 5"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Setup another trainer and initiate a training session with the new parameters."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"trainer = pl.Trainer(**config.trainer)\n",
"exp_manager(trainer, config.get(\"exp_manager\", None))\n",
"msdd_model = EncDecDiarLabelModel(cfg=config.model, trainer=trainer)\n",
"trainer.fit(msdd_model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's check whether there is a change in the weights."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"msdd_model.msdd._speaker_model.state_dict()[\"encoder.encoder.0.mconv.0.conv.weight\"][0,:,:]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"pycharm": {
"stem_cell": {
"cell_type": "raw",
"metadata": {
"collapsed": false
},
"source": []
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}