Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -7,50 +7,60 @@ import soundfile as sf
|
|
7 |
import os
|
8 |
from pydub import AudioSegment
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
# Initialize the FastAPI app
|
11 |
app = FastAPI()
|
12 |
|
13 |
-
#
|
14 |
model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft"
|
15 |
-
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
16 |
-
model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
17 |
|
18 |
# Ensure the model is in evaluation mode
|
19 |
model.eval()
|
20 |
|
21 |
# Function to convert audio to the required format
|
22 |
def convert_audio(audio_bytes):
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
35 |
|
36 |
|
37 |
@app.post("/assess-pronunciation/")
|
38 |
async def assess_pronunciation(audio_file: UploadFile = File(...)):
|
39 |
"""
|
40 |
-
This endpoint takes an audio file and returns the recognized phonemes.
|
41 |
"""
|
42 |
# Read the audio file content
|
43 |
audio_bytes = await audio_file.read()
|
44 |
|
45 |
-
# Convert audio to the model's required format (16kHz, mono)
|
46 |
-
|
|
|
|
|
|
|
47 |
|
48 |
-
# Load the waveform and sample rate from the processed audio bytes
|
49 |
-
waveform, sample_rate = sf.read(io.BytesIO(processed_audio_bytes), dtype='float32')
|
50 |
|
51 |
-
#
|
52 |
-
|
53 |
-
waveform = waveform.mean(axis=1)
|
54 |
|
55 |
# Process the audio waveform
|
56 |
input_values = processor(waveform, sampling_rate=sample_rate, return_tensors="pt", padding="longest").input_values
|
@@ -63,6 +73,7 @@ async def assess_pronunciation(audio_file: UploadFile = File(...)):
|
|
63 |
predicted_ids = torch.argmax(logits, dim=-1)
|
64 |
transcription = processor.batch_decode(predicted_ids)
|
65 |
|
|
|
66 |
return {"phoneme_transcription": transcription[0]}
|
67 |
|
68 |
@app.get("/")
|
|
|
7 |
import os
|
8 |
from pydub import AudioSegment
|
9 |
|
10 |
+
# --- FIX: Define a local cache directory ---
|
11 |
+
# This tells transformers to download models here, inside our app's folder,
|
12 |
+
# instead of the restricted '/.cache' directory.
|
13 |
+
CACHE_DIR = "/code/cache"
|
14 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
15 |
+
|
16 |
# Initialize the FastAPI app
|
17 |
app = FastAPI()
|
18 |
|
19 |
+
# --- FIX: Load model and processor using the local cache_dir ---
|
20 |
model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft"
|
21 |
+
processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
22 |
+
model = Wav2Vec2ForCTC.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
23 |
|
24 |
# Ensure the model is in evaluation mode
|
25 |
model.eval()
|
26 |
|
27 |
# Function to convert audio to the required format
|
28 |
def convert_audio(audio_bytes):
|
29 |
+
try:
|
30 |
+
# Load audio from bytes using pydub
|
31 |
+
audio = AudioSegment.from_file(io.BytesIO(audio_bytes))
|
32 |
+
# Set to mono
|
33 |
+
audio = audio.set_channels(1)
|
34 |
+
# Set sample rate to 16kHz
|
35 |
+
audio = audio.set_frame_rate(16000)
|
36 |
|
37 |
+
# Export to a buffer in WAV format
|
38 |
+
buffer = io.BytesIO()
|
39 |
+
audio.export(buffer, format="wav")
|
40 |
+
buffer.seek(0)
|
41 |
+
return buffer.read()
|
42 |
+
except Exception as e:
|
43 |
+
# This will catch errors if ffmpeg has issues with a specific file
|
44 |
+
raise ValueError(f"Error processing audio file: {e}")
|
45 |
|
46 |
|
47 |
@app.post("/assess-pronunciation/")
|
48 |
async def assess_pronunciation(audio_file: UploadFile = File(...)):
|
49 |
"""
|
50 |
+
This endpoint takes an audio file, converts it, and returns the recognized phonemes.
|
51 |
"""
|
52 |
# Read the audio file content
|
53 |
audio_bytes = await audio_file.read()
|
54 |
|
55 |
+
# Convert audio to the model's required format (16kHz, mono WAV)
|
56 |
+
try:
|
57 |
+
processed_audio_bytes = convert_audio(audio_bytes)
|
58 |
+
except ValueError as e:
|
59 |
+
return {"error": str(e)}
|
60 |
|
|
|
|
|
61 |
|
62 |
+
# Load the waveform from the processed audio bytes
|
63 |
+
waveform, sample_rate = sf.read(io.BytesIO(processed_audio_bytes), dtype='float32')
|
|
|
64 |
|
65 |
# Process the audio waveform
|
66 |
input_values = processor(waveform, sampling_rate=sample_rate, return_tensors="pt", padding="longest").input_values
|
|
|
73 |
predicted_ids = torch.argmax(logits, dim=-1)
|
74 |
transcription = processor.batch_decode(predicted_ids)
|
75 |
|
76 |
+
# The output is a list with one item, so we return the item itself
|
77 |
return {"phoneme_transcription": transcription[0]}
|
78 |
|
79 |
@app.get("/")
|