Chan-Y's picture
Update app.py
37e562e verified
import gradio as gr
import torch
import numpy as np
from transformers import Wav2Vec2Processor, AutoModelForCTC
from transformers import AutoFeatureExtractor, AutoTokenizer
# Load model and processor
tokenizer = AutoTokenizer.from_pretrained("Chan-Y/wav2vec2-large-xlsr-53-demo-colab")
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-large-xlsr-53")
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-large-xlsr-53")
def transcribe_audio(audio_file):
"""
Transcribe audio file using wav2vec2 model
"""
# Check if audio is provided
if audio_file is None:
return "Please upload an audio file"
try:
# Load audio (Gradio provides a tuple with (sample_rate, data))
sample_rate, waveform = audio_file
# Resample if needed (wav2vec2 expects 16kHz)
if sample_rate != 16000:
print(f"Resampling from {sample_rate} to 16000 Hz")
# You might need to add a resampling library like librosa or torchaudio here
# For this example, we'll assume the input is already 16kHz
# Convert to float if needed and normalize
if waveform.dtype != np.float32:
waveform = waveform.astype(np.float32) / np.iinfo(waveform.dtype).max
# Process the audio
inputs = processor(waveform, sampling_rate=16000, return_tensors="pt", padding=True)
input_values = inputs.input_values
attention_mask = inputs.attention_mask if hasattr(inputs, "attention_mask") else None
# Get model output
with torch.no_grad():
logits = model(input_values, attention_mask=attention_mask).logits
# Decode the model output
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
return transcription[0]
except Exception as e:
return f"Error processing audio: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=transcribe_audio,
inputs=gr.Audio(type="numpy", label="Upload Audio"),
outputs=gr.Textbox(label="Transcription"),
title="Audio Transcription with Wav2Vec2",
description="Upload an audio file to get a transcription using the Wav2Vec2 model.",
allow_flagging="never"
)
# Launch the app
if __name__ == "__main__":
iface.launch(share=True)