{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "source": [ "!pip install librosa numpy tensorflow scikit-learn sounddevice\n", "!pip install gradio" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "QdDz0RCbIBwe", "outputId": "8cb2c152-52ae-4a62-a0bf-c3067481af1f" }, "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Requirement already satisfied: librosa in /usr/local/lib/python3.11/dist-packages (0.10.2.post1)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (1.26.4)\n", "Requirement already satisfied: tensorflow in /usr/local/lib/python3.11/dist-packages (2.18.0)\n", "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.11/dist-packages (1.6.1)\n", "Requirement already satisfied: sounddevice in /usr/local/lib/python3.11/dist-packages (0.5.1)\n", "Requirement already satisfied: audioread>=2.1.9 in /usr/local/lib/python3.11/dist-packages (from librosa) (3.0.1)\n", "Requirement already satisfied: scipy>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from librosa) (1.14.1)\n", "Requirement already satisfied: joblib>=0.14 in /usr/local/lib/python3.11/dist-packages (from librosa) (1.4.2)\n", "Requirement already satisfied: decorator>=4.3.0 in /usr/local/lib/python3.11/dist-packages (from librosa) (4.4.2)\n", "Requirement already satisfied: numba>=0.51.0 in /usr/local/lib/python3.11/dist-packages (from librosa) (0.60.0)\n", "Requirement already satisfied: soundfile>=0.12.1 in /usr/local/lib/python3.11/dist-packages (from librosa) (0.13.1)\n", "Requirement already satisfied: pooch>=1.1 in /usr/local/lib/python3.11/dist-packages (from librosa) (1.8.2)\n", "Requirement already satisfied: soxr>=0.3.2 in /usr/local/lib/python3.11/dist-packages (from librosa) (0.5.0.post1)\n", "Requirement already satisfied: typing-extensions>=4.1.1 in /usr/local/lib/python3.11/dist-packages (from librosa) (4.12.2)\n", "Requirement already satisfied: lazy-loader>=0.1 in /usr/local/lib/python3.11/dist-packages (from librosa) (0.4)\n", "Requirement already satisfied: msgpack>=1.0 in /usr/local/lib/python3.11/dist-packages (from librosa) (1.1.0)\n", "Requirement already satisfied: absl-py>=1.0.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow) (1.4.0)\n", "Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow) (1.6.3)\n", "Requirement already satisfied: flatbuffers>=24.3.25 in /usr/local/lib/python3.11/dist-packages (from tensorflow) (25.2.10)\n", "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /usr/local/lib/python3.11/dist-packages (from tensorflow) (0.6.0)\n", "Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from tensorflow) (0.2.0)\n", "Requirement already satisfied: libclang>=13.0.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow) (18.1.1)\n", "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.11/dist-packages (from tensorflow) (3.4.0)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from tensorflow) (24.2)\n", "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.3 in /usr/local/lib/python3.11/dist-packages (from tensorflow) (4.25.6)\n", "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow) (2.32.3)\n", "Requirement already satisfied: setuptools in /usr/local/lib/python3.11/dist-packages (from tensorflow) (75.1.0)\n", "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow) (1.17.0)\n", "Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow) (2.5.0)\n", "Requirement already satisfied: wrapt>=1.11.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow) (1.17.2)\n", "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.11/dist-packages (from tensorflow) (1.71.0)\n", "Requirement already satisfied: tensorboard<2.19,>=2.18 in /usr/local/lib/python3.11/dist-packages (from tensorflow) (2.18.0)\n", "Requirement already satisfied: keras>=3.5.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow) (3.8.0)\n", "Requirement already satisfied: h5py>=3.11.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow) (3.12.1)\n", "Requirement already satisfied: ml-dtypes<0.5.0,>=0.4.0 in /usr/local/lib/python3.11/dist-packages (from tensorflow) (0.4.1)\n", "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.11/dist-packages (from tensorflow) (0.37.1)\n", "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn) (3.5.0)\n", "Requirement already satisfied: CFFI>=1.0 in /usr/local/lib/python3.11/dist-packages (from sounddevice) (1.17.1)\n", "Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.11/dist-packages (from astunparse>=1.6.0->tensorflow) (0.45.1)\n", "Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from CFFI>=1.0->sounddevice) (2.22)\n", "Requirement already satisfied: rich in /usr/local/lib/python3.11/dist-packages (from keras>=3.5.0->tensorflow) (13.9.4)\n", "Requirement already satisfied: namex in /usr/local/lib/python3.11/dist-packages (from keras>=3.5.0->tensorflow) (0.0.8)\n", "Requirement already satisfied: optree in /usr/local/lib/python3.11/dist-packages (from keras>=3.5.0->tensorflow) (0.14.1)\n", "Requirement already satisfied: llvmlite<0.44,>=0.43.0dev0 in /usr/local/lib/python3.11/dist-packages (from numba>=0.51.0->librosa) (0.43.0)\n", "Requirement already satisfied: platformdirs>=2.5.0 in /usr/local/lib/python3.11/dist-packages (from pooch>=1.1->librosa) (4.3.6)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorflow) (3.4.1)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorflow) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorflow) (2.3.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests<3,>=2.21.0->tensorflow) (2025.1.31)\n", "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.19,>=2.18->tensorflow) (3.7)\n", "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.19,>=2.18->tensorflow) (0.7.2)\n", "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from tensorboard<2.19,>=2.18->tensorflow) (3.1.3)\n", "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.11/dist-packages (from werkzeug>=1.0.1->tensorboard<2.19,>=2.18->tensorflow) (2.1.5)\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.11/dist-packages (from rich->keras>=3.5.0->tensorflow) (3.0.0)\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.11/dist-packages (from rich->keras>=3.5.0->tensorflow) (2.18.0)\n", "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.11/dist-packages (from markdown-it-py>=2.2.0->rich->keras>=3.5.0->tensorflow) (0.1.2)\n", "Requirement already satisfied: gradio in /usr/local/lib/python3.11/dist-packages (5.21.0)\n", "Requirement already satisfied: aiofiles<24.0,>=22.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (23.2.1)\n", "Requirement already satisfied: anyio<5.0,>=3.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (3.7.1)\n", "Requirement already satisfied: fastapi<1.0,>=0.115.2 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.115.11)\n", "Requirement already satisfied: ffmpy in /usr/local/lib/python3.11/dist-packages (from gradio) (0.5.0)\n", "Requirement already satisfied: gradio-client==1.7.2 in /usr/local/lib/python3.11/dist-packages (from gradio) (1.7.2)\n", "Requirement already satisfied: groovy~=0.1 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.1.2)\n", "Requirement already satisfied: httpx>=0.24.1 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.28.1)\n", "Requirement already satisfied: huggingface-hub>=0.28.1 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.28.1)\n", "Requirement already satisfied: jinja2<4.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (3.1.6)\n", "Requirement already satisfied: markupsafe~=2.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (2.1.5)\n", "Requirement already satisfied: numpy<3.0,>=1.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (1.26.4)\n", "Requirement already satisfied: orjson~=3.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (3.10.15)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.11/dist-packages (from gradio) (24.2)\n", "Requirement already satisfied: pandas<3.0,>=1.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (2.2.2)\n", "Requirement already satisfied: pillow<12.0,>=8.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (11.1.0)\n", "Requirement already satisfied: pydantic>=2.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (2.10.6)\n", "Requirement already satisfied: pydub in /usr/local/lib/python3.11/dist-packages (from gradio) (0.25.1)\n", "Requirement already satisfied: python-multipart>=0.0.18 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.0.20)\n", "Requirement already satisfied: pyyaml<7.0,>=5.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (6.0.2)\n", "Requirement already satisfied: ruff>=0.9.3 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.11.0)\n", "Requirement already satisfied: safehttpx<0.2.0,>=0.1.6 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.1.6)\n", "Requirement already satisfied: semantic-version~=2.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (2.10.0)\n", "Requirement already satisfied: starlette<1.0,>=0.40.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.46.1)\n", "Requirement already satisfied: tomlkit<0.14.0,>=0.12.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.13.2)\n", "Requirement already satisfied: typer<1.0,>=0.12 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.15.2)\n", "Requirement already satisfied: typing-extensions~=4.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (4.12.2)\n", "Requirement already satisfied: uvicorn>=0.14.0 in /usr/local/lib/python3.11/dist-packages (from gradio) (0.34.0)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from gradio-client==1.7.2->gradio) (2024.10.0)\n", "Requirement already satisfied: websockets<16.0,>=10.0 in /usr/local/lib/python3.11/dist-packages (from gradio-client==1.7.2->gradio) (14.2)\n", "Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.11/dist-packages (from anyio<5.0,>=3.0->gradio) (3.10)\n", "Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio<5.0,>=3.0->gradio) (1.3.1)\n", "Requirement already satisfied: certifi in /usr/local/lib/python3.11/dist-packages (from httpx>=0.24.1->gradio) (2025.1.31)\n", "Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.11/dist-packages (from httpx>=0.24.1->gradio) (1.0.7)\n", "Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.11/dist-packages (from httpcore==1.*->httpx>=0.24.1->gradio) (0.14.0)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.28.1->gradio) (3.17.0)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.28.1->gradio) (2.32.3)\n", "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub>=0.28.1->gradio) (4.67.1)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0,>=1.0->gradio) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0,>=1.0->gradio) (2025.1)\n", "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas<3.0,>=1.0->gradio) (2025.1)\n", "Requirement already satisfied: annotated-types>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from pydantic>=2.0->gradio) (0.7.0)\n", "Requirement already satisfied: pydantic-core==2.27.2 in /usr/local/lib/python3.11/dist-packages (from pydantic>=2.0->gradio) (2.27.2)\n", "Requirement already satisfied: click>=8.0.0 in /usr/local/lib/python3.11/dist-packages (from typer<1.0,>=0.12->gradio) (8.1.8)\n", "Requirement already satisfied: shellingham>=1.3.0 in /usr/local/lib/python3.11/dist-packages (from typer<1.0,>=0.12->gradio) (1.5.4)\n", "Requirement already satisfied: rich>=10.11.0 in /usr/local/lib/python3.11/dist-packages (from typer<1.0,>=0.12->gradio) (13.9.4)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas<3.0,>=1.0->gradio) (1.17.0)\n", "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.11/dist-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (3.0.0)\n", "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.11/dist-packages (from rich>=10.11.0->typer<1.0,>=0.12->gradio) (2.18.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub>=0.28.1->gradio) (3.4.1)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub>=0.28.1->gradio) (2.3.0)\n", "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.11/dist-packages (from markdown-it-py>=2.2.0->rich>=10.11.0->typer<1.0,>=0.12->gradio) (0.1.2)\n" ] } ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "D5BPAg-OGnnD" }, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "import librosa\n", "import tensorflow as tf\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import LabelEncoder\n", "from tensorflow.keras.utils import to_categorical\n", "from transformers import TFWav2Vec2Model, Wav2Vec2FeatureExtractor\n", "from imblearn.over_sampling import SMOTE\n", "from pydub import AudioSegment\n", "import gradio as gr\n", "import joblib\n", "from sklearn.utils.class_weight import compute_class_weight" ] }, { "cell_type": "code", "source": [ "# === Configuration ===\n", "SAMPLE_RATE = 16000 # Matches Wav2Vec2 requirements\n", "DURATION = 3 # Standardize audio clips to 3 seconds\n", "MAX_AUDIO_LENGTH = SAMPLE_RATE * DURATION\n", "PRETRAINED_MODEL_NAME = \"facebook/wav2vec2-base-960h\"\n", "DATASET_PATH = \"/content/drive/MyDrive/dataset/YAF DATASET\" # Adjust this path as needed\n", "\n", "# Initialize Wav2Vec2 model and feature extractor\n", "wav2vec2 = TFWav2Vec2Model.from_pretrained(PRETRAINED_MODEL_NAME)\n", "feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(PRETRAINED_MODEL_NAME)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "bWV4e2kKGs39", "outputId": "0d721d1c-307a-4117-8e7a-9ded57986367" }, "execution_count": 3, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n", "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", "You will be able to reuse this secret in all of your notebooks.\n", "Please note that authentication is recommended but still optional to access public models or datasets.\n", " warnings.warn(\n", "\n", "TFWav2Vec2Model has backpropagation operations that are NOT supported on CPU. If you wish to train/fine-tune this model, you need a GPU or a TPU\n", "Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFWav2Vec2Model: ['lm_head.weight', 'lm_head.bias']\n", "- This IS expected if you are initializing TFWav2Vec2Model from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing TFWav2Vec2Model from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights or buffers of the TF 2.0 model TFWav2Vec2Model were not initialized from the PyTorch model and are newly initialized: ['wav2vec2.masked_spec_embed']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ] }, { "cell_type": "code", "source": [ "# === Audio Preprocessing ===\n", "def load_and_preprocess_audio(file_path):\n", " \"\"\"Load and preprocess audio to a fixed length.\"\"\"\n", " try:\n", " audio, _ = librosa.load(file_path, sr=SAMPLE_RATE, duration=DURATION)\n", " except Exception as e:\n", " print(f\"Error loading {file_path}: {e}\")\n", " return None\n", " if len(audio) < MAX_AUDIO_LENGTH:\n", " audio = np.pad(audio, (0, MAX_AUDIO_LENGTH - len(audio)))\n", " else:\n", " audio = audio[:MAX_AUDIO_LENGTH]\n", " return audio" ], "metadata": { "id": "meaDR05VG8ke" }, "execution_count": 4, "outputs": [] }, { "cell_type": "code", "source": [ "# === Standard Data Augmentation ===\n", "def augment(audio):\n", " \"\"\"Apply random pitch shifting and noise addition (no time stretching).\"\"\"\n", " if np.random.rand() < 0.5:\n", " audio = librosa.effects.pitch_shift(audio, sr=SAMPLE_RATE, n_steps=np.random.uniform(-2, 2))\n", " if np.random.rand() < 0.5:\n", " noise = np.random.normal(0, 0.005, audio.shape)\n", " audio = audio + noise\n", " return audio\n" ], "metadata": { "id": "PiEncyiTG9VE" }, "execution_count": 5, "outputs": [] }, { "cell_type": "code", "source": [ "# === Feature Extraction ===\n", "def extract_features(audio):\n", " \"\"\"Extract full Wav2Vec2 feature sequence.\"\"\"\n", " inputs = feature_extractor(audio, sampling_rate=SAMPLE_RATE, return_tensors=\"tf\")\n", " wav_features = wav2vec2(inputs.input_values).last_hidden_state[0] # Shape: (time_steps, 768)\n", " return wav_features.numpy()" ], "metadata": { "id": "Pnk5-vbEHA-V" }, "execution_count": 6, "outputs": [] }, { "cell_type": "code", "source": [ "\n", "# === Enhanced Data Loading ===\n", "def load_enhanced_dataset(dataset_path):\n", " \"\"\"Load dataset with speaker IDs and apply augmentation to minority classes.\"\"\"\n", " features, labels, speakers = [], [], []\n", " for emotion in os.listdir(dataset_path):\n", " emotion_dir = os.path.join(dataset_path, emotion)\n", " if os.path.isdir(emotion_dir):\n", " for file in os.listdir(emotion_dir):\n", " if file.endswith(\".wav\"):\n", " file_path = os.path.join(emotion_dir, file)\n", " audio = load_and_preprocess_audio(file_path)\n", " if audio is None:\n", " continue\n", " feat = extract_features(audio)\n", " features.append(feat)\n", " # Extract speaker ID from emotion label (e.g., 'YAF' from 'YAF_happy')\n", " speaker_id = emotion.split('_')[0]\n", " speakers.append(speaker_id)\n", " labels.append(emotion)\n", " # Augment for minority classes\n", " if emotion in ['YAF_happy', 'YAF_sad']:\n", " aug_audio = augment(audio)\n", " aug_feat = extract_features(aug_audio)\n", " features.append(aug_feat)\n", " speakers.append(speaker_id)\n", " labels.append(emotion)\n", " # Encode speaker IDs numerically\n", " speaker_encoder = LabelEncoder()\n", " speakers_encoded = speaker_encoder.fit_transform(speakers)\n", " return np.array(features), np.array(labels), speakers_encoded, speaker_encoder\n" ], "metadata": { "id": "ZMGSVxAzHFW_" }, "execution_count": 7, "outputs": [] }, { "cell_type": "code", "source": [ "# === Advanced Model Architecture ===\n", "def build_enhanced_model(time_steps, feature_dim, num_speakers, num_classes):\n", " \"\"\"Build a model with sequence input, speaker embedding, LSTM, and attention.\"\"\"\n", " audio_input = tf.keras.Input(shape=(time_steps, feature_dim), name='audio_input')\n", " speaker_input = tf.keras.Input(shape=(1,), name='speaker_input')\n", " # Embed speaker ID and repeat across time steps\n", " speaker_embed = tf.keras.layers.Embedding(num_speakers, 8)(speaker_input)\n", " speaker_embed = tf.keras.layers.Flatten()(speaker_embed)\n", " speaker_embed = tf.keras.layers.RepeatVector(time_steps)(speaker_embed)\n", " # Combine audio and speaker features\n", " combined = tf.keras.layers.concatenate([audio_input, speaker_embed], axis=-1)\n", " # Bidirectional LSTM with attention\n", " x = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(128, return_sequences=True))(combined)\n", " x = tf.keras.layers.Attention()([x, x])\n", " x = tf.keras.layers.GlobalAveragePooling1D()(x)\n", " x = tf.keras.layers.Dense(128, activation='relu')(x)\n", " x = tf.keras.layers.Dropout(0.5)(x)\n", " # Final Dense layer outputs raw logits (no activation)\n", " outputs = tf.keras.layers.Dense(num_classes, activation=None)(x)\n", "\n", " model = tf.keras.Model(inputs=[audio_input, speaker_input], outputs=outputs)\n", " # Fine-tuning: unfreeze wav2vec2 for adaptation\n", " wav2vec2.trainable = True\n", " optimizer = tf.keras.optimizers.Adam(1e-5) # Low learning rate\n", " model.compile(optimizer=optimizer,\n", " loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),\n", " metrics=['accuracy'])\n", " return model" ], "metadata": { "id": "nh4EeWwiHHpJ" }, "execution_count": 8, "outputs": [] }, { "cell_type": "code", "source": [ "\n", "# === Main Execution ===\n", "if __name__ == \"__main__\":\n", " # Load dataset with speaker information\n", " X, y, speakers, speaker_encoder = load_enhanced_dataset(DATASET_PATH)\n", " label_encoder = LabelEncoder()\n", " y_encoded = label_encoder.fit_transform(y)\n", " y_onehot = to_categorical(y_encoded)\n", "\n", " # Handle class imbalance with SMOTE\n", " X_flat = X.reshape(-1, X.shape[1] * X.shape[2])\n", " smote = SMOTE()\n", " X_res_flat, y_res = smote.fit_resample(X_flat, y_encoded)\n", " X_res = X_res_flat.reshape(-1, X.shape[1], X.shape[2])\n", " y_res = to_categorical(y_res)\n", "" ], "metadata": { "id": "rphJ4dtmHJjm" }, "execution_count": 9, "outputs": [] }, { "cell_type": "code", "source": [ "# Split dataset\n", "X_train, X_test, y_train, y_test = train_test_split(\n", " X_res, y_res, test_size=0.2, random_state=42, stratify=y_res # Stratify by y_res\n", ")\n", "\n", "# Resample the speaker data to match the new data shape\n", "speaker_res = np.repeat(speakers, np.ceil(len(X_res) / len(speakers)))[:len(X_res)]\n", "\n", "# Now, split the resampled speaker data along with the data\n", "speaker_train, speaker_test = train_test_split(\n", " speaker_res, test_size=0.2, random_state=42, stratify=y_res # Stratify by y_res\n", ")\n", "\n", "time_steps, feature_dim = X_train.shape[1], X_train.shape[2]\n", "num_speakers = len(speaker_encoder.classes_) # Removed extra indentation\n", "num_classes = y_onehot.shape[1] # Removed extra indentation" ], "metadata": { "id": "pqR9blxIHNqJ" }, "execution_count": 13, "outputs": [] }, { "cell_type": "code", "source": [ "# Build and train model\n", "model = build_enhanced_model(time_steps, feature_dim, num_speakers, num_classes)\n", "model.summary()\n", "class_weights = compute_class_weight('balanced', classes=np.unique(y_encoded), y=y_encoded)\n", "\n", "history = model.fit(\n", " [X_train, speaker_train], y_train,\n", " validation_data=([X_test, speaker_test], y_test),\n", " epochs=50,\n", " batch_size=32,\n", " class_weight=dict(enumerate(class_weights)),\n", " callbacks=[\n", " tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True),\n", " tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=3)\n", " ]\n", " )" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "hyJSdOskHXh2", "outputId": "00abc60f-bd38-435c-efa7-193b2ced64f5" }, "execution_count": 15, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "\u001b[1mModel: \"functional\"\u001b[0m\n" ], "text/html": [ "
Model: \"functional\"\n",
"
\n"
]
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mLayer (type) \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m Param #\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mConnected to \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
"│ speaker_input │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"│ (\u001b[38;5;33mInputLayer\u001b[0m) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ embedding (\u001b[38;5;33mEmbedding\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m1\u001b[0m, \u001b[38;5;34m8\u001b[0m) │ \u001b[38;5;34m8\u001b[0m │ speaker_input[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ flatten (\u001b[38;5;33mFlatten\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m8\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ embedding[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ audio_input (\u001b[38;5;33mInputLayer\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m149\u001b[0m, \u001b[38;5;34m768\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ - │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ repeat_vector │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m149\u001b[0m, \u001b[38;5;34m8\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ flatten[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mRepeatVector\u001b[0m) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ concatenate (\u001b[38;5;33mConcatenate\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m149\u001b[0m, \u001b[38;5;34m776\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ audio_input[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
"│ │ │ │ repeat_vector[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ bidirectional │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m149\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m926,720\u001b[0m │ concatenate[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mBidirectional\u001b[0m) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ attention (\u001b[38;5;33mAttention\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m149\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ bidirectional[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m], │\n",
"│ │ │ │ bidirectional[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ global_average_pooling1d │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ attention[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"│ (\u001b[38;5;33mGlobalAveragePooling1D\u001b[0m) │ │ │ │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ dense (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m32,896\u001b[0m │ global_average_poolin… │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ dropout (\u001b[38;5;33mDropout\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m128\u001b[0m) │ \u001b[38;5;34m0\u001b[0m │ dense[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n",
"│ dense_1 (\u001b[38;5;33mDense\u001b[0m) │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m5\u001b[0m) │ \u001b[38;5;34m645\u001b[0m │ dropout[\u001b[38;5;34m0\u001b[0m][\u001b[38;5;34m0\u001b[0m] │\n",
"└───────────────────────────┴────────────────────────┴────────────────┴────────────────────────┘\n"
],
"text/html": [
"┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│ speaker_input │ (None, 1) │ 0 │ - │\n", "│ (InputLayer) │ │ │ │\n", "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", "│ embedding (Embedding) │ (None, 1, 8) │ 8 │ speaker_input[0][0] │\n", "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", "│ flatten (Flatten) │ (None, 8) │ 0 │ embedding[0][0] │\n", "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", "│ audio_input (InputLayer) │ (None, 149, 768) │ 0 │ - │\n", "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", "│ repeat_vector │ (None, 149, 8) │ 0 │ flatten[0][0] │\n", "│ (RepeatVector) │ │ │ │\n", "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", "│ concatenate (Concatenate) │ (None, 149, 776) │ 0 │ audio_input[0][0], │\n", "│ │ │ │ repeat_vector[0][0] │\n", "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", "│ bidirectional │ (None, 149, 256) │ 926,720 │ concatenate[0][0] │\n", "│ (Bidirectional) │ │ │ │\n", "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", "│ attention (Attention) │ (None, 149, 256) │ 0 │ bidirectional[0][0], │\n", "│ │ │ │ bidirectional[0][0] │\n", "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", "│ global_average_pooling1d │ (None, 256) │ 0 │ attention[0][0] │\n", "│ (GlobalAveragePooling1D) │ │ │ │\n", "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", "│ dense (Dense) │ (None, 128) │ 32,896 │ global_average_poolin… │\n", "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", "│ dropout (Dropout) │ (None, 128) │ 0 │ dense[0][0] │\n", "├───────────────────────────┼────────────────────────┼────────────────┼────────────────────────┤\n", "│ dense_1 (Dense) │ (None, 5) │ 645 │ dropout[0][0] │\n", "└───────────────────────────┴────────────────────────┴────────────────┴────────────────────────┘\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m960,269\u001b[0m (3.66 MB)\n" ], "text/html": [ "
Total params: 960,269 (3.66 MB)\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m960,269\u001b[0m (3.66 MB)\n" ], "text/html": [ "
Trainable params: 960,269 (3.66 MB)\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n" ], "text/html": [ "
Non-trainable params: 0 (0.00 B)\n", "\n" ] }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 52ms/step - accuracy: 0.2126 - loss: 1.8167 - val_accuracy: 0.3550 - val_loss: 1.5980 - learning_rate: 1.0000e-05\n", "Epoch 2/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 32ms/step - accuracy: 0.2259 - loss: 1.7496 - val_accuracy: 0.3725 - val_loss: 1.5877 - learning_rate: 1.0000e-05\n", "Epoch 3/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 38ms/step - accuracy: 0.2732 - loss: 1.7249 - val_accuracy: 0.3900 - val_loss: 1.5707 - learning_rate: 1.0000e-05\n", "Epoch 4/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 34ms/step - accuracy: 0.2897 - loss: 1.6866 - val_accuracy: 0.4175 - val_loss: 1.5469 - learning_rate: 1.0000e-05\n", "Epoch 5/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 32ms/step - accuracy: 0.3233 - loss: 1.6432 - val_accuracy: 0.4200 - val_loss: 1.5183 - learning_rate: 1.0000e-05\n", "Epoch 6/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 34ms/step - accuracy: 0.3692 - loss: 1.5907 - val_accuracy: 0.4400 - val_loss: 1.4816 - learning_rate: 1.0000e-05\n", "Epoch 7/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 32ms/step - accuracy: 0.3881 - loss: 1.5541 - val_accuracy: 0.4650 - val_loss: 1.4324 - learning_rate: 1.0000e-05\n", "Epoch 8/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 33ms/step - accuracy: 0.4236 - loss: 1.4783 - val_accuracy: 0.4575 - val_loss: 1.3681 - learning_rate: 1.0000e-05\n", "Epoch 9/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 40ms/step - accuracy: 0.4236 - loss: 1.4111 - val_accuracy: 0.4675 - val_loss: 1.3021 - learning_rate: 1.0000e-05\n", "Epoch 10/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 32ms/step - accuracy: 0.4405 - loss: 1.3170 - val_accuracy: 0.4775 - val_loss: 1.2408 - learning_rate: 1.0000e-05\n", "Epoch 11/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 32ms/step - accuracy: 0.4661 - loss: 1.2461 - val_accuracy: 0.5200 - val_loss: 1.1904 - learning_rate: 1.0000e-05\n", "Epoch 12/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 34ms/step - accuracy: 0.4960 - loss: 1.1909 - val_accuracy: 0.5400 - val_loss: 1.1487 - learning_rate: 1.0000e-05\n", "Epoch 13/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 31ms/step - accuracy: 0.4760 - loss: 1.1721 - val_accuracy: 0.5600 - val_loss: 1.1179 - learning_rate: 1.0000e-05\n", "Epoch 14/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 34ms/step - accuracy: 0.5220 - loss: 1.1201 - val_accuracy: 0.5775 - val_loss: 1.0832 - learning_rate: 1.0000e-05\n", "Epoch 15/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 39ms/step - accuracy: 0.5509 - loss: 1.0576 - val_accuracy: 0.5675 - val_loss: 1.0315 - learning_rate: 1.0000e-05\n", "Epoch 16/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 32ms/step - accuracy: 0.5761 - loss: 1.0416 - val_accuracy: 0.6150 - val_loss: 0.9876 - learning_rate: 1.0000e-05\n", "Epoch 17/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 32ms/step - accuracy: 0.5882 - loss: 1.0076 - val_accuracy: 0.6075 - val_loss: 0.9579 - learning_rate: 1.0000e-05\n", "Epoch 18/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 31ms/step - accuracy: 0.5964 - loss: 0.9858 - val_accuracy: 0.6375 - val_loss: 0.9172 - learning_rate: 1.0000e-05\n", "Epoch 19/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 34ms/step - accuracy: 0.6365 - loss: 0.9032 - val_accuracy: 0.6425 - val_loss: 0.9103 - learning_rate: 1.0000e-05\n", "Epoch 20/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 34ms/step - accuracy: 0.6461 - loss: 0.8624 - val_accuracy: 0.6525 - val_loss: 0.8516 - learning_rate: 1.0000e-05\n", "Epoch 21/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 40ms/step - accuracy: 0.6431 - loss: 0.8475 - val_accuracy: 0.6750 - val_loss: 0.8211 - learning_rate: 1.0000e-05\n", "Epoch 22/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 32ms/step - accuracy: 0.6869 - loss: 0.8131 - val_accuracy: 0.6875 - val_loss: 0.7836 - learning_rate: 1.0000e-05\n", "Epoch 23/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 33ms/step - accuracy: 0.6748 - loss: 0.8001 - val_accuracy: 0.7050 - val_loss: 0.7506 - learning_rate: 1.0000e-05\n", "Epoch 24/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 34ms/step - accuracy: 0.6901 - loss: 0.7635 - val_accuracy: 0.7075 - val_loss: 0.7369 - learning_rate: 1.0000e-05\n", "Epoch 25/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 35ms/step - accuracy: 0.7058 - loss: 0.7296 - val_accuracy: 0.7275 - val_loss: 0.6809 - learning_rate: 1.0000e-05\n", "Epoch 26/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 44ms/step - accuracy: 0.7418 - loss: 0.6508 - val_accuracy: 0.7550 - val_loss: 0.6529 - learning_rate: 1.0000e-05\n", "Epoch 27/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 50ms/step - accuracy: 0.7381 - loss: 0.6755 - val_accuracy: 0.7275 - val_loss: 0.6935 - learning_rate: 1.0000e-05\n", "Epoch 28/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 39ms/step - accuracy: 0.7410 - loss: 0.6279 - val_accuracy: 0.7625 - val_loss: 0.6355 - learning_rate: 1.0000e-05\n", "Epoch 29/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 52ms/step - accuracy: 0.7481 - loss: 0.6329 - val_accuracy: 0.7600 - val_loss: 0.6302 - learning_rate: 1.0000e-05\n", "Epoch 30/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m5s\u001b[0m 45ms/step - accuracy: 0.7426 - loss: 0.6093 - val_accuracy: 0.7850 - val_loss: 0.5961 - learning_rate: 1.0000e-05\n", "Epoch 31/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m4s\u001b[0m 75ms/step - accuracy: 0.7727 - loss: 0.5551 - val_accuracy: 0.8000 - val_loss: 0.5790 - learning_rate: 1.0000e-05\n", "Epoch 32/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 61ms/step - accuracy: 0.7872 - loss: 0.5304 - val_accuracy: 0.7975 - val_loss: 0.5452 - learning_rate: 1.0000e-05\n", "Epoch 33/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 35ms/step - accuracy: 0.8003 - loss: 0.5277 - val_accuracy: 0.8150 - val_loss: 0.5350 - learning_rate: 1.0000e-05\n", "Epoch 34/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 36ms/step - accuracy: 0.7944 - loss: 0.5542 - val_accuracy: 0.7700 - val_loss: 0.6045 - learning_rate: 1.0000e-05\n", "Epoch 35/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 39ms/step - accuracy: 0.8164 - loss: 0.4469 - val_accuracy: 0.8050 - val_loss: 0.5263 - learning_rate: 1.0000e-05\n", "Epoch 36/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 38ms/step - accuracy: 0.7986 - loss: 0.4953 - val_accuracy: 0.8300 - val_loss: 0.4949 - learning_rate: 1.0000e-05\n", "Epoch 37/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 39ms/step - accuracy: 0.8164 - loss: 0.4580 - val_accuracy: 0.8350 - val_loss: 0.4910 - learning_rate: 1.0000e-05\n", "Epoch 38/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 33ms/step - accuracy: 0.8153 - loss: 0.4393 - val_accuracy: 0.8500 - val_loss: 0.4688 - learning_rate: 1.0000e-05\n", "Epoch 39/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 34ms/step - accuracy: 0.8406 - loss: 0.4233 - val_accuracy: 0.8475 - val_loss: 0.4530 - learning_rate: 1.0000e-05\n", "Epoch 40/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 34ms/step - accuracy: 0.8276 - loss: 0.4257 - val_accuracy: 0.8450 - val_loss: 0.4463 - learning_rate: 1.0000e-05\n", "Epoch 41/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 34ms/step - accuracy: 0.8446 - loss: 0.4167 - val_accuracy: 0.8500 - val_loss: 0.4663 - learning_rate: 1.0000e-05\n", "Epoch 42/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 34ms/step - accuracy: 0.8445 - loss: 0.4211 - val_accuracy: 0.8350 - val_loss: 0.4317 - learning_rate: 1.0000e-05\n", "Epoch 43/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 39ms/step - accuracy: 0.8697 - loss: 0.3649 - val_accuracy: 0.8350 - val_loss: 0.4538 - learning_rate: 1.0000e-05\n", "Epoch 44/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 41ms/step - accuracy: 0.8651 - loss: 0.3714 - val_accuracy: 0.8525 - val_loss: 0.4089 - learning_rate: 1.0000e-05\n", "Epoch 45/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 34ms/step - accuracy: 0.8670 - loss: 0.3753 - val_accuracy: 0.8675 - val_loss: 0.4030 - learning_rate: 1.0000e-05\n", "Epoch 46/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 34ms/step - accuracy: 0.8578 - loss: 0.3713 - val_accuracy: 0.8625 - val_loss: 0.4232 - learning_rate: 1.0000e-05\n", "Epoch 47/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 34ms/step - accuracy: 0.8739 - loss: 0.3369 - val_accuracy: 0.8775 - val_loss: 0.3967 - learning_rate: 1.0000e-05\n", "Epoch 48/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 34ms/step - accuracy: 0.8895 - loss: 0.3170 - val_accuracy: 0.8700 - val_loss: 0.3946 - learning_rate: 1.0000e-05\n", "Epoch 49/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m2s\u001b[0m 36ms/step - accuracy: 0.8870 - loss: 0.3139 - val_accuracy: 0.8625 - val_loss: 0.3739 - learning_rate: 1.0000e-05\n", "Epoch 50/50\n", "\u001b[1m50/50\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m3s\u001b[0m 40ms/step - accuracy: 0.9062 - loss: 0.2966 - val_accuracy: 0.8800 - val_loss: 0.3581 - learning_rate: 1.0000e-05\n" ] } ] }, { "cell_type": "code", "source": [ "# Post-hoc calibration with temperature scaling\n", "scaled_logits = model.output / 2.0 # Temperature scaling factor\n", "calibrated_outputs = tf.keras.layers.Activation('softmax')(scaled_logits)\n", "calibrated_model = tf.keras.Model(inputs=model.input, outputs=calibrated_outputs)\n", "\n", "# Save calibrated model and encoders\n", "calibrated_model.save(\"improved_emotion_model.keras\")\n", "joblib.dump(label_encoder, 'label_encoder.pkl')\n", "joblib.dump(speaker_encoder, 'speaker_encoder.pkl')" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "EbxtL-GIHYCV", "outputId": "8622160b-6575-4e7b-b161-6204c4efa38b" }, "execution_count": 16, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "['speaker_encoder.pkl']" ] }, "metadata": {}, "execution_count": 16 } ] }, { "cell_type": "code", "source": [ "# === Gradio Interface ===\n", "def convert_mp3_to_wav(mp3_filepath):\n", " \"\"\"Convert MP3 to WAV for processing.\"\"\"\n", " wav_filepath = mp3_filepath.replace('.mp3', '.wav')\n", " sound = AudioSegment.from_mp3(mp3_filepath)\n", " sound.export(wav_filepath, format=\"wav\")\n", " return wav_filepath" ], "metadata": { "id": "p0swcOtTHdHE" }, "execution_count": 17, "outputs": [] }, { "cell_type": "code", "source": [ "def predict_emotion(audio_file):\n", " \"\"\"Predict emotion from audio input using the trained model.\"\"\"\n", " if audio_file.lower().endswith('.mp3'):\n", " audio_file = convert_mp3_to_wav(audio_file)\n", " audio = load_and_preprocess_audio(audio_file)\n", " features = extract_features(audio)\n", " speaker_id = np.array([0]) # Default speaker ID for inference\n", " pred = calibrated_model.predict([np.expand_dims(features, 0), speaker_id.reshape(-1, 1)])\n", " return label_encoder.inverse_transform([np.argmax(pred)])[0]" ], "metadata": { "id": "xvC9KHe7HjYZ" }, "execution_count": 18, "outputs": [] }, { "cell_type": "code", "source": [ "iface = gr.Interface(\n", " fn=predict_emotion,\n", " inputs=gr.Audio(type=\"filepath\"),\n", " outputs=\"text\",\n", " title=\"Enhanced Emotion Recognition\",\n", " description=\"Record or upload audio (MP3/WAV) for emotion prediction with improved accuracy\"\n", ")\n" ], "metadata": { "id": "H0tGQXqepvYe" }, "execution_count": 21, "outputs": [] }, { "cell_type": "code", "source": [ "iface.launch()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 646 }, "id": "1WxOHPpypv_c", "outputId": "78d7614d-00dc-4e42-c3f2-6376d4fa71b3" }, "execution_count": 22, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n", "\n", "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n", "* Running on public URL: https://962c8bd414079ad9b5.gradio.live\n", "\n", "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "