|
from flask import Flask, request, jsonify |
|
from faster_whisper import WhisperModel |
|
import torch |
|
import io |
|
import time |
|
import datetime |
|
from threading import Semaphore |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
compute_type = "float16" if device == "cuda" else "int8" |
|
print(f"Using device: {device} with compute_type: {compute_type}") |
|
|
|
|
|
beamsize = 2 |
|
wmodel = WhisperModel("guillaumekln/faster-whisper-small", device=device, compute_type=compute_type) |
|
|
|
|
|
MAX_CONCURRENT_REQUESTS = 2 |
|
request_semaphore = Semaphore(MAX_CONCURRENT_REQUESTS) |
|
active_requests = 0 |
|
|
|
|
|
print("Warming up the model...") |
|
try: |
|
dummy_audio = io.BytesIO(b'') |
|
segments, info = wmodel.transcribe(dummy_audio, beam_size=beamsize) |
|
_ = [segment.text for segment in segments] |
|
print("Model warmup complete") |
|
except Exception as e: |
|
print(f"Model warmup failed: {str(e)}") |
|
|
|
@app.route("/health", methods=["GET"]) |
|
def health_check(): |
|
"""Endpoint to check if API is running""" |
|
return jsonify({ |
|
'status': 'API is running', |
|
'timestamp': datetime.datetime.now().isoformat(), |
|
'device': device, |
|
'compute_type': compute_type, |
|
'active_requests': active_requests |
|
}) |
|
|
|
@app.route("/status/busy", methods=["GET"]) |
|
def server_busy(): |
|
"""Endpoint to check if server is busy""" |
|
is_busy = active_requests >= MAX_CONCURRENT_REQUESTS |
|
return jsonify({ |
|
'is_busy': is_busy, |
|
'active_requests': active_requests, |
|
'max_capacity': MAX_CONCURRENT_REQUESTS |
|
}) |
|
|
|
@app.route("/whisper_transcribe", methods=["POST"]) |
|
def whisper_transcribe(): |
|
global active_requests |
|
|
|
|
|
if not request_semaphore.acquire(blocking=False): |
|
return jsonify({ |
|
'status': 'Server busy', |
|
'message': f'Currently processing {active_requests} requests', |
|
'suggestion': 'Please try again shortly' |
|
}), 503 |
|
|
|
active_requests += 1 |
|
print(f"Starting transcription (Active requests: {active_requests})") |
|
|
|
try: |
|
if 'audio' not in request.files: |
|
return jsonify({'error': 'No file provided'}), 400 |
|
|
|
audio_file = request.files['audio'] |
|
allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a'} |
|
if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions): |
|
return jsonify({'error': 'Invalid file format'}), 400 |
|
|
|
audio_bytes = audio_file.read() |
|
audio_file = io.BytesIO(audio_bytes) |
|
|
|
try: |
|
|
|
start_time = time.time() |
|
segments, info = wmodel.transcribe(audio_file, beam_size=beamsize) |
|
|
|
text = '' |
|
for segment in segments: |
|
if time.time() - start_time > 60: |
|
raise TimeoutError("Transcription took too long") |
|
text += segment.text |
|
|
|
processing_time = time.time() - start_time |
|
print(f"Transcription completed in {processing_time:.2f} seconds") |
|
|
|
return jsonify({ |
|
'transcription': text, |
|
'processing_time': processing_time, |
|
'language': info.language, |
|
'language_probability': info.language_probability |
|
}) |
|
|
|
except TimeoutError: |
|
print("Transcription timeout") |
|
return jsonify({'error': 'Transcription timeout'}), 504 |
|
except Exception as e: |
|
print(f"Transcription error: {str(e)}") |
|
return jsonify({'error': 'Transcription failed'}), 500 |
|
|
|
finally: |
|
active_requests -= 1 |
|
request_semaphore.release() |
|
print(f"Request completed (Active requests: {active_requests})") |
|
|
|
if __name__ == "__main__": |
|
app.run(host="0.0.0.0", debug=True, port=7860, threaded=True) |