Musical Instrument Classification Model
This model is a fine-tuned version of facebook/wav2vec2-base-960h for musical instrument classification. It can identify 9 different musical instruments from audio recordings with high accuracy.
Model Description
- Model type: Audio Classification
- Base model: facebook/wav2vec2-base-960h
- Language: Audio (no specific language)
- License: MIT
- Fine-tuned on: Custom musical instrument dataset (200 samples for each class)
Performance
The model achieves excellent performance on the evaluation set after 5 epochs of training:
- Final Accuracy: 93.33%
- Final ROC AUC (Macro): 98.59%
- Final Validation Loss: 1.064
- Evaluation Runtime: 14.18 seconds
- Evaluation Speed: 25.39 samples/second
Training Progress
Epoch | Training Loss | Validation Loss | ROC AUC | Accuracy |
---|---|---|---|---|
1 | 1.9872 | 1.8875 | 0.9248 | 0.6639 |
2 | 1.8652 | 1.4793 | 0.9799 | 0.8000 |
3 | 1.3868 | 1.2311 | 0.9861 | 0.8194 |
4 | 1.3242 | 1.1121 | 0.9827 | 0.9250 |
5 | 1.1869 | 1.0639 | 0.9859 | 0.9333 |
Supported Instruments
The model can classify the following 9 musical instruments:
- Acoustic Guitar
- Bass Guitar
- Drum Set
- Electric Guitar
- Flute
- Hi-Hats
- Keyboard
- Trumpet
- Violin
Usage
Quick Start with Pipeline
from transformers import pipeline
import torchaudio
# Load the classification pipeline
classifier = pipeline("audio-classification", model="Bhaveen/epoch_musical_instruments_identification_2")
# Load and preprocess audio
audio, rate = torchaudio.load("your_audio_file.wav")
transform = torchaudio.transforms.Resample(rate, 16000)
audio = transform(audio).numpy().reshape(-1)[:48000]
# Classify the audio
result = classifier(audio)
print(result)
Using Transformers Directly
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import torchaudio
import torch
# Load model and feature extractor
model_name = "Bhaveen/epoch_musical_instruments_identification_2"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = AutoModelForAudioClassification.from_pretrained(model_name)
# Load and preprocess audio
audio, rate = torchaudio.load("your_audio_file.wav")
transform = torchaudio.transforms.Resample(rate, 16000)
audio = transform(audio).numpy().reshape(-1)[:48000]
# Extract features and make prediction
inputs = feature_extractor(audio, sampling_rate=16000, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(predictions, dim=-1)
print(f"Predicted instrument: {model.config.id2label[predicted_class.item()]}")
Training Details
Dataset and Preprocessing
- Custom dataset with audio recordings of 9 musical instruments
- Train/Test Split: 80/20 using file numbering (files < 160 for training)
- Data Balancing: Random oversampling applied to minority classes
- Audio Preprocessing:
- Resampling to 16,000 Hz
- Fixed length of 48,000 samples (3 seconds)
- Truncation of longer audio files
Training Configuration
# Training hyperparameters
batch_size = 1
gradient_accumulation_steps = 4
learning_rate = 5e-6
num_train_epochs = 5
warmup_steps = 50
weight_decay = 0.02
Model Architecture
- Base Model: facebook/wav2vec2-base-960h
- Classification Head: Added for 9-class classification
- Parameters: ~95M trainable parameters
- Features: Wav2Vec2 audio representations with fine-tuned classification layer
Technical Specifications
- Audio Format: WAV files
- Sample Rate: 16,000 Hz
- Input Length: 3 seconds (48,000 samples)
- Model Framework: PyTorch + Transformers
- Inference Device: GPU recommended (CUDA)
Evaluation Metrics
The model uses the following evaluation metrics:
- Accuracy: Standard classification accuracy
- ROC AUC: Macro-averaged ROC AUC with one-vs-rest approach
- Multi-class Classification: Softmax probabilities for all 9 instrument classes
Limitations and Considerations
- Audio Duration: Model expects exactly 3-second audio clips (truncates longer, may not work well with shorter)
- Single Instrument Focus: Optimized for single instrument classification, mixed instruments may produce uncertain results
- Audio Quality: Performance depends on audio quality and recording conditions
- Sample Rate: Input must be resampled to 16kHz for optimal performance
- Domain Specificity: Trained on specific instrument recordings, may not generalize to all variants or playing styles
Training Environment
- Platform: Google Colab
- GPU: CUDA-enabled device
- Libraries:
- transformers==4.28.1
- torchaudio==0.12
- datasets
- evaluate
- imblearn
Model Files
The repository contains:
- Model weights and configuration
- Feature extractor configuration
- Training logs and metrics
- Label mappings (id2label, label2id)
Model trained as part of a hackathon project
- Downloads last month
- 6
Model tree for Bhaveen/Musical-Instrument-Classification
Base model
facebook/wav2vec2-base-960hEvaluation results
- Accuracyself-reported0.933
- ROC AUC (Macro)self-reported0.986
- Validation Lossself-reported1.064