from flask import Flask, request, jsonify from faster_whisper import WhisperModel import torch import io import time import datetime from threading import Semaphore app = Flask(__name__) # Device check for faster-whisper device = "cuda" if torch.cuda.is_available() else "cpu" compute_type = "float16" if device == "cuda" else "int8" print(f"Using device: {device} with compute_type: {compute_type}") # Faster Whisper setup beamsize = 2 wmodel = WhisperModel("guillaumekln/faster-whisper-small", device=device, compute_type=compute_type) # Concurrency control MAX_CONCURRENT_REQUESTS = 2 # Adjust based on your server capacity request_semaphore = Semaphore(MAX_CONCURRENT_REQUESTS) active_requests = 0 # Warm up the model (important for CUDA) print("Warming up the model...") try: dummy_audio = io.BytesIO(b'') # Empty audio for warmup segments, info = wmodel.transcribe(dummy_audio, beam_size=beamsize) _ = [segment.text for segment in segments] # Force execution print("Model warmup complete") except Exception as e: print(f"Model warmup failed: {str(e)}") @app.route("/health", methods=["GET"]) def health_check(): """Endpoint to check if API is running""" return jsonify({ 'status': 'API is running', 'timestamp': datetime.datetime.now().isoformat(), 'device': device, 'compute_type': compute_type, 'active_requests': active_requests }) @app.route("/status/busy", methods=["GET"]) def server_busy(): """Endpoint to check if server is busy""" is_busy = active_requests >= MAX_CONCURRENT_REQUESTS return jsonify({ 'is_busy': is_busy, 'active_requests': active_requests, 'max_capacity': MAX_CONCURRENT_REQUESTS }) @app.route("/whisper_transcribe", methods=["POST"]) def whisper_transcribe(): global active_requests # Check if server is at capacity if not request_semaphore.acquire(blocking=False): return jsonify({ 'status': 'Server busy', 'message': f'Currently processing {active_requests} requests', 'suggestion': 'Please try again shortly' }), 503 active_requests += 1 print(f"Starting transcription (Active requests: {active_requests})") try: if 'audio' not in request.files: return jsonify({'error': 'No file provided'}), 400 audio_file = request.files['audio'] allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a'} if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions): return jsonify({'error': 'Invalid file format'}), 400 audio_bytes = audio_file.read() audio_file = io.BytesIO(audio_bytes) try: # Timeout handling (60 seconds max processing time) start_time = time.time() segments, info = wmodel.transcribe(audio_file, beam_size=beamsize) text = '' for segment in segments: if time.time() - start_time > 60: # Timeout after 60 seconds raise TimeoutError("Transcription took too long") text += segment.text processing_time = time.time() - start_time print(f"Transcription completed in {processing_time:.2f} seconds") return jsonify({ 'transcription': text, 'processing_time': processing_time, 'language': info.language, 'language_probability': info.language_probability }) except TimeoutError: print("Transcription timeout") return jsonify({'error': 'Transcription timeout'}), 504 except Exception as e: print(f"Transcription error: {str(e)}") return jsonify({'error': 'Transcription failed'}), 500 finally: active_requests -= 1 request_semaphore.release() print(f"Request completed (Active requests: {active_requests})") if __name__ == "__main__": app.run(host="0.0.0.0", debug=True, port=7860, threaded=True)