File size: 4,135 Bytes
56a884d 0117db0 56a884d 0117db0 56a884d 27eb3e4 922901f 0117db0 56a884d 0117db0 56a884d 0117db0 922901f 27eb3e4 922901f 27eb3e4 922901f 27eb3e4 922901f 27eb3e4 922901f 27eb3e4 0117db0 27eb3e4 922901f 27eb3e4 0117db0 27eb3e4 0117db0 27eb3e4 0117db0 27eb3e4 922901f 27eb3e4 922901f 27eb3e4 922901f 27eb3e4 922901f 27eb3e4 922901f 27eb3e4 922901f 0117db0 e480215 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
from flask import Flask, request, jsonify
from faster_whisper import WhisperModel
import torch
import io
import time
import datetime
from threading import Semaphore
app = Flask(__name__)
# Device check for faster-whisper
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
print(f"Using device: {device} with compute_type: {compute_type}")
# Faster Whisper setup
beamsize = 2
wmodel = WhisperModel("guillaumekln/faster-whisper-small", device=device, compute_type=compute_type)
# Concurrency control
MAX_CONCURRENT_REQUESTS = 2 # Adjust based on your server capacity
request_semaphore = Semaphore(MAX_CONCURRENT_REQUESTS)
active_requests = 0
# Warm up the model (important for CUDA)
print("Warming up the model...")
try:
dummy_audio = io.BytesIO(b'') # Empty audio for warmup
segments, info = wmodel.transcribe(dummy_audio, beam_size=beamsize)
_ = [segment.text for segment in segments] # Force execution
print("Model warmup complete")
except Exception as e:
print(f"Model warmup failed: {str(e)}")
@app.route("/health", methods=["GET"])
def health_check():
"""Endpoint to check if API is running"""
return jsonify({
'status': 'API is running',
'timestamp': datetime.datetime.now().isoformat(),
'device': device,
'compute_type': compute_type,
'active_requests': active_requests
})
@app.route("/status/busy", methods=["GET"])
def server_busy():
"""Endpoint to check if server is busy"""
is_busy = active_requests >= MAX_CONCURRENT_REQUESTS
return jsonify({
'is_busy': is_busy,
'active_requests': active_requests,
'max_capacity': MAX_CONCURRENT_REQUESTS
})
@app.route("/whisper_transcribe", methods=["POST"])
def whisper_transcribe():
global active_requests
# Check if server is at capacity
if not request_semaphore.acquire(blocking=False):
return jsonify({
'status': 'Server busy',
'message': f'Currently processing {active_requests} requests',
'suggestion': 'Please try again shortly'
}), 503
active_requests += 1
print(f"Starting transcription (Active requests: {active_requests})")
try:
if 'audio' not in request.files:
return jsonify({'error': 'No file provided'}), 400
audio_file = request.files['audio']
allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a'}
if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions):
return jsonify({'error': 'Invalid file format'}), 400
audio_bytes = audio_file.read()
audio_file = io.BytesIO(audio_bytes)
try:
# Timeout handling (60 seconds max processing time)
start_time = time.time()
segments, info = wmodel.transcribe(audio_file, beam_size=beamsize)
text = ''
for segment in segments:
if time.time() - start_time > 60: # Timeout after 60 seconds
raise TimeoutError("Transcription took too long")
text += segment.text
processing_time = time.time() - start_time
print(f"Transcription completed in {processing_time:.2f} seconds")
return jsonify({
'transcription': text,
'processing_time': processing_time,
'language': info.language,
'language_probability': info.language_probability
})
except TimeoutError:
print("Transcription timeout")
return jsonify({'error': 'Transcription timeout'}), 504
except Exception as e:
print(f"Transcription error: {str(e)}")
return jsonify({'error': 'Transcription failed'}), 500
finally:
active_requests -= 1
request_semaphore.release()
print(f"Request completed (Active requests: {active_requests})")
if __name__ == "__main__":
app.run(host="0.0.0.0", debug=True, port=7860, threaded=True) |