--- language: en license: mit tags: - audio - audio-classification - musical-instruments - wav2vec2 - transformers - pytorch datasets: - custom metrics: - accuracy - roc_auc model-index: - name: epoch_musical_instruments_identification_2 results: - task: type: audio-classification name: Musical Instrument Classification metrics: - type: accuracy value: 0.9333 name: Accuracy - type: roc_auc value: 0.9859 name: ROC AUC (Macro) - type: loss value: 1.0639 name: Validation Loss base_model: - facebook/wav2vec2-base-960h --- # Musical Instrument Classification Model This model is a fine-tuned version of [facebook/wav2vec2-base-960h](https://huggingface.co/facebook/wav2vec2-base-960h) for musical instrument classification. It can identify 9 different musical instruments from audio recordings with high accuracy. ## Model Description - **Model type:** Audio Classification - **Base model:** facebook/wav2vec2-base-960h - **Language:** Audio (no specific language) - **License:** MIT - **Fine-tuned on:** Custom musical instrument dataset (200 samples for each class) ## Performance The model achieves excellent performance on the evaluation set after 5 epochs of training: - **Final Accuracy:** 93.33% - **Final ROC AUC (Macro):** 98.59% - **Final Validation Loss:** 1.064 - **Evaluation Runtime:** 14.18 seconds - **Evaluation Speed:** 25.39 samples/second ### Training Progress | Epoch | Training Loss | Validation Loss | ROC AUC | Accuracy | |-------|---------------|-----------------|---------|----------| | 1 | 1.9872 | 1.8875 | 0.9248 | 0.6639 | | 2 | 1.8652 | 1.4793 | 0.9799 | 0.8000 | | 3 | 1.3868 | 1.2311 | 0.9861 | 0.8194 | | 4 | 1.3242 | 1.1121 | 0.9827 | 0.9250 | | 5 | 1.1869 | 1.0639 | 0.9859 | 0.9333 | ## Supported Instruments The model can classify the following 9 musical instruments: 1. **Acoustic Guitar** 2. **Bass Guitar** 3. **Drum Set** 4. **Electric Guitar** 5. **Flute** 6. **Hi-Hats** 7. **Keyboard** 8. **Trumpet** 9. **Violin** ## Usage ### Quick Start with Pipeline ```python from transformers import pipeline import torchaudio # Load the classification pipeline classifier = pipeline("audio-classification", model="Bhaveen/epoch_musical_instruments_identification_2") # Load and preprocess audio audio, rate = torchaudio.load("your_audio_file.wav") transform = torchaudio.transforms.Resample(rate, 16000) audio = transform(audio).numpy().reshape(-1)[:48000] # Classify the audio result = classifier(audio) print(result) ``` ### Using Transformers Directly ```python from transformers import AutoFeatureExtractor, AutoModelForAudioClassification import torchaudio import torch # Load model and feature extractor model_name = "Bhaveen/epoch_musical_instruments_identification_2" feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) model = AutoModelForAudioClassification.from_pretrained(model_name) # Load and preprocess audio audio, rate = torchaudio.load("your_audio_file.wav") transform = torchaudio.transforms.Resample(rate, 16000) audio = transform(audio).numpy().reshape(-1)[:48000] # Extract features and make prediction inputs = feature_extractor(audio, sampling_rate=16000, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) predicted_class = torch.argmax(predictions, dim=-1) print(f"Predicted instrument: {model.config.id2label[predicted_class.item()]}") ``` ## Training Details ### Dataset and Preprocessing - **Custom dataset** with audio recordings of 9 musical instruments - **Train/Test Split:** 80/20 using file numbering (files < 160 for training) - **Data Balancing:** Random oversampling applied to minority classes - **Audio Preprocessing:** - Resampling to 16,000 Hz - Fixed length of 48,000 samples (3 seconds) - Truncation of longer audio files ### Training Configuration ```python # Training hyperparameters batch_size = 1 gradient_accumulation_steps = 4 learning_rate = 5e-6 num_train_epochs = 5 warmup_steps = 50 weight_decay = 0.02 ``` ### Model Architecture - **Base Model:** facebook/wav2vec2-base-960h - **Classification Head:** Added for 9-class classification - **Parameters:** ~95M trainable parameters - **Features:** Wav2Vec2 audio representations with fine-tuned classification layer ## Technical Specifications - **Audio Format:** WAV files - **Sample Rate:** 16,000 Hz - **Input Length:** 3 seconds (48,000 samples) - **Model Framework:** PyTorch + Transformers - **Inference Device:** GPU recommended (CUDA) ## Evaluation Metrics The model uses the following evaluation metrics: - **Accuracy:** Standard classification accuracy - **ROC AUC:** Macro-averaged ROC AUC with one-vs-rest approach - **Multi-class Classification:** Softmax probabilities for all 9 instrument classes ## Limitations and Considerations 1. **Audio Duration:** Model expects exactly 3-second audio clips (truncates longer, may not work well with shorter) 2. **Single Instrument Focus:** Optimized for single instrument classification, mixed instruments may produce uncertain results 3. **Audio Quality:** Performance depends on audio quality and recording conditions 4. **Sample Rate:** Input must be resampled to 16kHz for optimal performance 5. **Domain Specificity:** Trained on specific instrument recordings, may not generalize to all variants or playing styles ## Training Environment - **Platform:** Google Colab - **GPU:** CUDA-enabled device - **Libraries:** - transformers==4.28.1 - torchaudio==0.12 - datasets - evaluate - imblearn ## Model Files The repository contains: - Model weights and configuration - Feature extractor configuration - Training logs and metrics - Label mappings (id2label, label2id) --- *Model trained as part of a hackathon project*