prakashp1893 commited on
Commit
556e71e
·
verified ·
1 Parent(s): 39ea74e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
3
+ import torch
4
+ import torchaudio
5
+ import io
6
+ import soundfile as sf
7
+ import os
8
+ from pydub import AudioSegment
9
+
10
+ # Initialize the FastAPI app
11
+ app = FastAPI()
12
+
13
+ # Load the pre-trained model and processor
14
+ model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft"
15
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
16
+ model = Wav2Vec2ForCTC.from_pretrained(model_name)
17
+
18
+ # Ensure the model is in evaluation mode
19
+ model.eval()
20
+
21
+ # Function to convert audio to the required format
22
+ def convert_audio(audio_bytes):
23
+ # Load audio from bytes
24
+ audio = AudioSegment.from_file(io.BytesIO(audio_bytes))
25
+ # Set to mono
26
+ audio = audio.set_channels(1)
27
+ # Set sample rate to 16kHz
28
+ audio = audio.set_frame_rate(16000)
29
+
30
+ # Export to a buffer
31
+ buffer = io.BytesIO()
32
+ audio.export(buffer, format="wav")
33
+ buffer.seek(0)
34
+ return buffer.read()
35
+
36
+
37
+ @app.post("/assess-pronunciation/")
38
+ async def assess_pronunciation(audio_file: UploadFile = File(...)):
39
+ """
40
+ This endpoint takes an audio file and returns the recognized phonemes.
41
+ """
42
+ # Read the audio file content
43
+ audio_bytes = await audio_file.read()
44
+
45
+ # Convert audio to the model's required format (16kHz, mono)
46
+ processed_audio_bytes = convert_audio(audio_bytes)
47
+
48
+ # Load the waveform and sample rate from the processed audio bytes
49
+ waveform, sample_rate = sf.read(io.BytesIO(processed_audio_bytes), dtype='float32')
50
+
51
+ # Ensure the audio is a 1D tensor
52
+ if waveform.ndim > 1:
53
+ waveform = waveform.mean(axis=1)
54
+
55
+ # Process the audio waveform
56
+ input_values = processor(waveform, sampling_rate=sample_rate, return_tensors="pt", padding="longest").input_values
57
+
58
+ # Perform inference
59
+ with torch.no_grad():
60
+ logits = model(input_values).logits
61
+
62
+ # Get the predicted IDs and decode them into phonemes
63
+ predicted_ids = torch.argmax(logits, dim=-1)
64
+ transcription = processor.batch_decode(predicted_ids)
65
+
66
+ return {"phoneme_transcription": transcription[0]}
67
+
68
+ @app.get("/")
69
+ def read_root():
70
+ return {"message": "Wav2Vec2 Pronunciation Assessment API is running."}