Spaces:
Running
Running
import gradio as gr | |
import torch | |
import numpy as np | |
from transformers import Wav2Vec2Processor, AutoModelForCTC | |
from transformers import AutoFeatureExtractor, AutoTokenizer | |
# Load model and processor | |
tokenizer = AutoTokenizer.from_pretrained("Chan-Y/wav2vec2-large-xlsr-53-demo-colab") | |
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-large-xlsr-53") | |
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) | |
model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-large-xlsr-53") | |
def transcribe_audio(audio_file): | |
""" | |
Transcribe audio file using wav2vec2 model | |
""" | |
# Check if audio is provided | |
if audio_file is None: | |
return "Please upload an audio file" | |
try: | |
# Load audio (Gradio provides a tuple with (sample_rate, data)) | |
sample_rate, waveform = audio_file | |
# Resample if needed (wav2vec2 expects 16kHz) | |
if sample_rate != 16000: | |
print(f"Resampling from {sample_rate} to 16000 Hz") | |
# You might need to add a resampling library like librosa or torchaudio here | |
# For this example, we'll assume the input is already 16kHz | |
# Convert to float if needed and normalize | |
if waveform.dtype != np.float32: | |
waveform = waveform.astype(np.float32) / np.iinfo(waveform.dtype).max | |
# Process the audio | |
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt", padding=True) | |
input_values = inputs.input_values | |
attention_mask = inputs.attention_mask if hasattr(inputs, "attention_mask") else None | |
# Get model output | |
with torch.no_grad(): | |
logits = model(input_values, attention_mask=attention_mask).logits | |
# Decode the model output | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = processor.batch_decode(predicted_ids) | |
return transcription[0] | |
except Exception as e: | |
return f"Error processing audio: {str(e)}" | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=transcribe_audio, | |
inputs=gr.Audio(type="numpy", label="Upload Audio"), | |
outputs=gr.Textbox(label="Transcription"), | |
title="Audio Transcription with Wav2Vec2", | |
description="Upload an audio file to get a transcription using the Wav2Vec2 model.", | |
allow_flagging="never" | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
iface.launch(share=True) |