prakashp1893 commited on
Commit
9077cd6
·
verified ·
1 Parent(s): 6c42c01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -22
app.py CHANGED
@@ -7,50 +7,60 @@ import soundfile as sf
7
  import os
8
  from pydub import AudioSegment
9
 
 
 
 
 
 
 
10
  # Initialize the FastAPI app
11
  app = FastAPI()
12
 
13
- # Load the pre-trained model and processor
14
  model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft"
15
- processor = Wav2Vec2Processor.from_pretrained(model_name)
16
- model = Wav2Vec2ForCTC.from_pretrained(model_name)
17
 
18
  # Ensure the model is in evaluation mode
19
  model.eval()
20
 
21
  # Function to convert audio to the required format
22
  def convert_audio(audio_bytes):
23
- # Load audio from bytes
24
- audio = AudioSegment.from_file(io.BytesIO(audio_bytes))
25
- # Set to mono
26
- audio = audio.set_channels(1)
27
- # Set sample rate to 16kHz
28
- audio = audio.set_frame_rate(16000)
 
29
 
30
- # Export to a buffer
31
- buffer = io.BytesIO()
32
- audio.export(buffer, format="wav")
33
- buffer.seek(0)
34
- return buffer.read()
 
 
 
35
 
36
 
37
  @app.post("/assess-pronunciation/")
38
  async def assess_pronunciation(audio_file: UploadFile = File(...)):
39
  """
40
- This endpoint takes an audio file and returns the recognized phonemes.
41
  """
42
  # Read the audio file content
43
  audio_bytes = await audio_file.read()
44
 
45
- # Convert audio to the model's required format (16kHz, mono)
46
- processed_audio_bytes = convert_audio(audio_bytes)
 
 
 
47
 
48
- # Load the waveform and sample rate from the processed audio bytes
49
- waveform, sample_rate = sf.read(io.BytesIO(processed_audio_bytes), dtype='float32')
50
 
51
- # Ensure the audio is a 1D tensor
52
- if waveform.ndim > 1:
53
- waveform = waveform.mean(axis=1)
54
 
55
  # Process the audio waveform
56
  input_values = processor(waveform, sampling_rate=sample_rate, return_tensors="pt", padding="longest").input_values
@@ -63,6 +73,7 @@ async def assess_pronunciation(audio_file: UploadFile = File(...)):
63
  predicted_ids = torch.argmax(logits, dim=-1)
64
  transcription = processor.batch_decode(predicted_ids)
65
 
 
66
  return {"phoneme_transcription": transcription[0]}
67
 
68
  @app.get("/")
 
7
  import os
8
  from pydub import AudioSegment
9
 
10
+ # --- FIX: Define a local cache directory ---
11
+ # This tells transformers to download models here, inside our app's folder,
12
+ # instead of the restricted '/.cache' directory.
13
+ CACHE_DIR = "/code/cache"
14
+ os.makedirs(CACHE_DIR, exist_ok=True)
15
+
16
  # Initialize the FastAPI app
17
  app = FastAPI()
18
 
19
+ # --- FIX: Load model and processor using the local cache_dir ---
20
  model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft"
21
+ processor = Wav2Vec2Processor.from_pretrained(model_name, cache_dir=CACHE_DIR)
22
+ model = Wav2Vec2ForCTC.from_pretrained(model_name, cache_dir=CACHE_DIR)
23
 
24
  # Ensure the model is in evaluation mode
25
  model.eval()
26
 
27
  # Function to convert audio to the required format
28
  def convert_audio(audio_bytes):
29
+ try:
30
+ # Load audio from bytes using pydub
31
+ audio = AudioSegment.from_file(io.BytesIO(audio_bytes))
32
+ # Set to mono
33
+ audio = audio.set_channels(1)
34
+ # Set sample rate to 16kHz
35
+ audio = audio.set_frame_rate(16000)
36
 
37
+ # Export to a buffer in WAV format
38
+ buffer = io.BytesIO()
39
+ audio.export(buffer, format="wav")
40
+ buffer.seek(0)
41
+ return buffer.read()
42
+ except Exception as e:
43
+ # This will catch errors if ffmpeg has issues with a specific file
44
+ raise ValueError(f"Error processing audio file: {e}")
45
 
46
 
47
  @app.post("/assess-pronunciation/")
48
  async def assess_pronunciation(audio_file: UploadFile = File(...)):
49
  """
50
+ This endpoint takes an audio file, converts it, and returns the recognized phonemes.
51
  """
52
  # Read the audio file content
53
  audio_bytes = await audio_file.read()
54
 
55
+ # Convert audio to the model's required format (16kHz, mono WAV)
56
+ try:
57
+ processed_audio_bytes = convert_audio(audio_bytes)
58
+ except ValueError as e:
59
+ return {"error": str(e)}
60
 
 
 
61
 
62
+ # Load the waveform from the processed audio bytes
63
+ waveform, sample_rate = sf.read(io.BytesIO(processed_audio_bytes), dtype='float32')
 
64
 
65
  # Process the audio waveform
66
  input_values = processor(waveform, sampling_rate=sample_rate, return_tensors="pt", padding="longest").input_values
 
73
  predicted_ids = torch.argmax(logits, dim=-1)
74
  transcription = processor.batch_decode(predicted_ids)
75
 
76
+ # The output is a list with one item, so we return the item itself
77
  return {"phoneme_transcription": transcription[0]}
78
 
79
  @app.get("/")