Nusri7 commited on
Commit
20acaf7
·
1 Parent(s): 456232f

Initial commit with FastAPI + Gradio app

Browse files
Files changed (1) hide show
  1. app.py +14 -18
app.py CHANGED
@@ -5,6 +5,7 @@ import torch
5
  from fastapi import FastAPI, HTTPException, File, UploadFile
6
  from speechbrain.inference import SpeakerRecognition
7
  from fastapi.responses import JSONResponse
 
8
 
9
  # Initialize the speaker verification model
10
  speaker_verification = SpeakerRecognition.from_hparams(
@@ -15,7 +16,7 @@ speaker_verification = SpeakerRecognition.from_hparams(
15
  # Function to calculate similarity score
16
  def get_similarity(audio1, audio2, sample_rate=16000):
17
  try:
18
- # Convert numpy arrays to tensors
19
  signal1 = torch.tensor(audio1)
20
  signal2 = torch.tensor(audio2)
21
 
@@ -34,23 +35,15 @@ def get_similarity(audio1, audio2, sample_rate=16000):
34
  # API function to compare voices
35
  def compare_voices(file1, file2):
36
  try:
37
- # Debug: Print the file types
38
  print(f"Received file1: {type(file1)}")
39
  print(f"Received file2: {type(file2)}")
40
-
41
- if not file1 or not file2:
42
- return {"error": "One or both audio inputs are missing."}
43
-
44
- # Ensure file1 and file2 are tuples (numpy_array, sample_rate)
45
- if isinstance(file1, tuple) and len(file1) == 2:
46
- audio1, _ = file1 # Audio1 is a tuple (numpy_array, sample_rate)
47
- else:
48
- return {"error": "Invalid format for the first audio input."}
49
 
50
- if isinstance(file2, tuple) and len(file2) == 2:
51
- audio2, _ = file2 # Audio2 is a tuple (numpy_array, sample_rate)
 
52
  else:
53
- return {"error": "Invalid format for the second audio input."}
54
 
55
  # Get similarity score
56
  score, is_same_user = get_similarity(audio1, audio2)
@@ -79,13 +72,16 @@ async def compare_voices_api(file1: UploadFile = File(...), file2: UploadFile =
79
  file1_data = await file1.read()
80
  file2_data = await file2.read()
81
 
82
- # You need to process these byte strings into numpy arrays
83
  # Assuming the audio is decoded into numpy arrays here (e.g., using torchaudio)
84
  # For example:
85
- # audio1 = torchaudio.load(io.BytesIO(file1_data))[0].numpy()
86
- # audio2 = torchaudio.load(io.BytesIO(file2_data))[0].numpy()
 
 
 
87
 
88
- return {"message": "Processing files directly without saving them."}
 
89
 
90
  except Exception as e:
91
  raise HTTPException(status_code=400, detail=str(e))
 
5
  from fastapi import FastAPI, HTTPException, File, UploadFile
6
  from speechbrain.inference import SpeakerRecognition
7
  from fastapi.responses import JSONResponse
8
+ import numpy as np
9
 
10
  # Initialize the speaker verification model
11
  speaker_verification = SpeakerRecognition.from_hparams(
 
16
  # Function to calculate similarity score
17
  def get_similarity(audio1, audio2, sample_rate=16000):
18
  try:
19
+ # Ensure audio1 and audio2 are numpy arrays
20
  signal1 = torch.tensor(audio1)
21
  signal2 = torch.tensor(audio2)
22
 
 
35
  # API function to compare voices
36
  def compare_voices(file1, file2):
37
  try:
38
+ # Debugging: Check the types of inputs
39
  print(f"Received file1: {type(file1)}")
40
  print(f"Received file2: {type(file2)}")
 
 
 
 
 
 
 
 
 
41
 
42
+ # Ensure file1 and file2 are numpy arrays
43
+ if isinstance(file1, np.ndarray) and isinstance(file2, np.ndarray):
44
+ audio1, audio2 = file1, file2
45
  else:
46
+ return {"error": "Invalid input format. Both inputs must be numpy arrays."}
47
 
48
  # Get similarity score
49
  score, is_same_user = get_similarity(audio1, audio2)
 
72
  file1_data = await file1.read()
73
  file2_data = await file2.read()
74
 
 
75
  # Assuming the audio is decoded into numpy arrays here (e.g., using torchaudio)
76
  # For example:
77
+ audio1, _ = torchaudio.load(io.BytesIO(file1_data)) # (Tensor, sample_rate)
78
+ audio2, _ = torchaudio.load(io.BytesIO(file2_data)) # (Tensor, sample_rate)
79
+
80
+ audio1 = audio1.numpy()
81
+ audio2 = audio2.numpy()
82
 
83
+ # Compare the two audio files and return the result
84
+ return compare_voices(audio1, audio2)
85
 
86
  except Exception as e:
87
  raise HTTPException(status_code=400, detail=str(e))