Musical Instrument Classification Model

This model is a fine-tuned version of facebook/wav2vec2-base-960h for musical instrument classification. It can identify 9 different musical instruments from audio recordings with high accuracy.

Model Description

  • Model type: Audio Classification
  • Base model: facebook/wav2vec2-base-960h
  • Language: Audio (no specific language)
  • License: MIT
  • Fine-tuned on: Custom musical instrument dataset (200 samples for each class)

Performance

The model achieves excellent performance on the evaluation set after 5 epochs of training:

  • Final Accuracy: 93.33%
  • Final ROC AUC (Macro): 98.59%
  • Final Validation Loss: 1.064
  • Evaluation Runtime: 14.18 seconds
  • Evaluation Speed: 25.39 samples/second

Training Progress

Epoch Training Loss Validation Loss ROC AUC Accuracy
1 1.9872 1.8875 0.9248 0.6639
2 1.8652 1.4793 0.9799 0.8000
3 1.3868 1.2311 0.9861 0.8194
4 1.3242 1.1121 0.9827 0.9250
5 1.1869 1.0639 0.9859 0.9333

Supported Instruments

The model can classify the following 9 musical instruments:

  1. Acoustic Guitar
  2. Bass Guitar
  3. Drum Set
  4. Electric Guitar
  5. Flute
  6. Hi-Hats
  7. Keyboard
  8. Trumpet
  9. Violin

Usage

Quick Start with Pipeline

from transformers import pipeline
import torchaudio

# Load the classification pipeline
classifier = pipeline("audio-classification", model="Bhaveen/epoch_musical_instruments_identification_2")

# Load and preprocess audio
audio, rate = torchaudio.load("your_audio_file.wav")
transform = torchaudio.transforms.Resample(rate, 16000)
audio = transform(audio).numpy().reshape(-1)[:48000]

# Classify the audio
result = classifier(audio)
print(result)

Using Transformers Directly

from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import torchaudio
import torch

# Load model and feature extractor
model_name = "Bhaveen/epoch_musical_instruments_identification_2"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForAudioClassification.from_pretrained(model_name)

# Load and preprocess audio
audio, rate = torchaudio.load("your_audio_file.wav")
transform = torchaudio.transforms.Resample(rate, 16000)
audio = transform(audio).numpy().reshape(-1)[:48000]

# Extract features and make prediction
inputs = feature_extractor(audio, sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
    outputs = model(**inputs)
    predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
    predicted_class = torch.argmax(predictions, dim=-1)

print(f"Predicted instrument: {model.config.id2label[predicted_class.item()]}")

Training Details

Dataset and Preprocessing

  • Custom dataset with audio recordings of 9 musical instruments
  • Train/Test Split: 80/20 using file numbering (files < 160 for training)
  • Data Balancing: Random oversampling applied to minority classes
  • Audio Preprocessing:
    • Resampling to 16,000 Hz
    • Fixed length of 48,000 samples (3 seconds)
    • Truncation of longer audio files

Training Configuration

# Training hyperparameters
batch_size = 1
gradient_accumulation_steps = 4
learning_rate = 5e-6
num_train_epochs = 5
warmup_steps = 50
weight_decay = 0.02

Model Architecture

  • Base Model: facebook/wav2vec2-base-960h
  • Classification Head: Added for 9-class classification
  • Parameters: ~95M trainable parameters
  • Features: Wav2Vec2 audio representations with fine-tuned classification layer

Technical Specifications

  • Audio Format: WAV files
  • Sample Rate: 16,000 Hz
  • Input Length: 3 seconds (48,000 samples)
  • Model Framework: PyTorch + Transformers
  • Inference Device: GPU recommended (CUDA)

Evaluation Metrics

The model uses the following evaluation metrics:

  • Accuracy: Standard classification accuracy
  • ROC AUC: Macro-averaged ROC AUC with one-vs-rest approach
  • Multi-class Classification: Softmax probabilities for all 9 instrument classes

Limitations and Considerations

  1. Audio Duration: Model expects exactly 3-second audio clips (truncates longer, may not work well with shorter)
  2. Single Instrument Focus: Optimized for single instrument classification, mixed instruments may produce uncertain results
  3. Audio Quality: Performance depends on audio quality and recording conditions
  4. Sample Rate: Input must be resampled to 16kHz for optimal performance
  5. Domain Specificity: Trained on specific instrument recordings, may not generalize to all variants or playing styles

Training Environment

  • Platform: Google Colab
  • GPU: CUDA-enabled device
  • Libraries:
    • transformers==4.28.1
    • torchaudio==0.12
    • datasets
    • evaluate
    • imblearn

Model Files

The repository contains:

  • Model weights and configuration
  • Feature extractor configuration
  • Training logs and metrics
  • Label mappings (id2label, label2id)

Model trained as part of a hackathon project

Downloads last month
6
Safetensors
Model size
94.6M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for Bhaveen/Musical-Instrument-Classification

Finetuned
(166)
this model

Evaluation results