hamza2923's picture
Update app.py
922901f verified
raw
history blame
4.14 kB
from flask import Flask, request, jsonify
from faster_whisper import WhisperModel
import torch
import io
import time
import datetime
from threading import Semaphore
app = Flask(__name__)
# Device check for faster-whisper
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
print(f"Using device: {device} with compute_type: {compute_type}")
# Faster Whisper setup
beamsize = 2
wmodel = WhisperModel("guillaumekln/faster-whisper-small", device=device, compute_type=compute_type)
# Concurrency control
MAX_CONCURRENT_REQUESTS = 2 # Adjust based on your server capacity
request_semaphore = Semaphore(MAX_CONCURRENT_REQUESTS)
active_requests = 0
# Warm up the model (important for CUDA)
print("Warming up the model...")
try:
dummy_audio = io.BytesIO(b'') # Empty audio for warmup
segments, info = wmodel.transcribe(dummy_audio, beam_size=beamsize)
_ = [segment.text for segment in segments] # Force execution
print("Model warmup complete")
except Exception as e:
print(f"Model warmup failed: {str(e)}")
@app.route("/health", methods=["GET"])
def health_check():
"""Endpoint to check if API is running"""
return jsonify({
'status': 'API is running',
'timestamp': datetime.datetime.now().isoformat(),
'device': device,
'compute_type': compute_type,
'active_requests': active_requests
})
@app.route("/status/busy", methods=["GET"])
def server_busy():
"""Endpoint to check if server is busy"""
is_busy = active_requests >= MAX_CONCURRENT_REQUESTS
return jsonify({
'is_busy': is_busy,
'active_requests': active_requests,
'max_capacity': MAX_CONCURRENT_REQUESTS
})
@app.route("/whisper_transcribe", methods=["POST"])
def whisper_transcribe():
global active_requests
# Check if server is at capacity
if not request_semaphore.acquire(blocking=False):
return jsonify({
'status': 'Server busy',
'message': f'Currently processing {active_requests} requests',
'suggestion': 'Please try again shortly'
}), 503
active_requests += 1
print(f"Starting transcription (Active requests: {active_requests})")
try:
if 'audio' not in request.files:
return jsonify({'error': 'No file provided'}), 400
audio_file = request.files['audio']
allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a'}
if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions):
return jsonify({'error': 'Invalid file format'}), 400
audio_bytes = audio_file.read()
audio_file = io.BytesIO(audio_bytes)
try:
# Timeout handling (60 seconds max processing time)
start_time = time.time()
segments, info = wmodel.transcribe(audio_file, beam_size=beamsize)
text = ''
for segment in segments:
if time.time() - start_time > 60: # Timeout after 60 seconds
raise TimeoutError("Transcription took too long")
text += segment.text
processing_time = time.time() - start_time
print(f"Transcription completed in {processing_time:.2f} seconds")
return jsonify({
'transcription': text,
'processing_time': processing_time,
'language': info.language,
'language_probability': info.language_probability
})
except TimeoutError:
print("Transcription timeout")
return jsonify({'error': 'Transcription timeout'}), 504
except Exception as e:
print(f"Transcription error: {str(e)}")
return jsonify({'error': 'Transcription failed'}), 500
finally:
active_requests -= 1
request_semaphore.release()
print(f"Request completed (Active requests: {active_requests})")
if __name__ == "__main__":
app.run(host="0.0.0.0", debug=True, port=7860, threaded=True)