#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Minimal inference for SpeechBrain ECAPA-TDNN (ShEMO fine-tuned). """ import os import torch import speechbrain as sb from hyperpyyaml import load_hyperpyyaml from speechbrain.dataio.dataio import read_audio # ------------------------------------------------------------------ # 1) Paths # ------------------------------------------------------------------ EXP_DIR = ( "/mnt/c/Users/NoteBook/Documents/fineTuningSpeechbrain/recipes/ShEMO/" "emotion_recognition/results(2)/content/results/ECAPA-TDNN/1968" ) HP_FILE = os.path.join(EXP_DIR, "hyperparams.yaml") CKPT_DIR = os.path.join(EXP_DIR, "save") # ------------------------------------------------------------------ # 2) Load hyperparams and modules # ------------------------------------------------------------------ with open(HP_FILE) as f: hparams = load_hyperpyyaml(f) modules = { "compute_features": hparams["compute_features"], "mean_var_norm" : hparams["mean_var_norm"], "embedding_model" : hparams["embedding_model"], "classifier" : hparams["classifier"], } # ------------------------------------------------------------------ # 3) Device setup # ------------------------------------------------------------------ device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") checkpointer = sb.utils.checkpoints.Checkpointer( checkpoints_dir=CKPT_DIR, recoverables=modules, allow_partial_load=True, ) checkpointer.recover_if_possible() # ------------------------------------------------------------------ # 4) Simple batch container # ------------------------------------------------------------------ class SimpleBatch: def __init__(self, wav, lens): self.sig = (wav, lens) def to(self, device): wav, lens = self.sig self.sig = (wav.to(device), lens.to(device)) return self # ------------------------------------------------------------------ # 5) Brain class # ------------------------------------------------------------------ class EmoIdBrain(sb.Brain): def compute_forward(self, batch, stage): wavs, lens = batch.sig feats = self.modules.compute_features(wavs) feats = self.modules.mean_var_norm(feats, lens) emb = self.modules.embedding_model(feats, lens) out = self.modules.classifier(emb) return out brain = EmoIdBrain( modules=modules, hparams=hparams, run_opts={"device": device}, checkpointer=checkpointer ) # ------------------------------------------------------------------ # 6) Emotion labels # ------------------------------------------------------------------ IDX2LAB = [ "anger", "sadness", "neutral", "surprise", "happiness", "fear" ] # ------------------------------------------------------------------ # 7) Prediction function # ------------------------------------------------------------------ def predict(wav_path: str) -> str: wav_raw = read_audio(wav_path) wav = wav_raw.clone().detach().float().unsqueeze(0) if isinstance(wav_raw, torch.Tensor) else torch.tensor(wav_raw, dtype=torch.float32).unsqueeze(0) lens = torch.tensor([1.0]) batch = SimpleBatch(wav, lens).to(device) brain.modules.eval() with torch.no_grad(): logits = brain.compute_forward(batch, stage=sb.Stage.TEST) idx = int(logits.argmax(dim=-1)) return IDX2LAB[idx] # ------------------------------------------------------------------ # 8) Run # ------------------------------------------------------------------ if __name__ == "__main__": WAV_FILE = "shortvoice.wav" print("Predicted emotion:", predict(WAV_FILE))