File size: 4,135 Bytes
56a884d
0117db0
56a884d
0117db0
56a884d
27eb3e4
922901f
0117db0
 
 
56a884d
 
 
 
 
0117db0
 
56a884d
0117db0
922901f
 
 
27eb3e4
922901f
 
 
 
 
 
 
 
 
 
27eb3e4
 
 
 
 
 
 
 
922901f
 
27eb3e4
 
 
 
 
922901f
27eb3e4
922901f
 
 
27eb3e4
 
0117db0
 
27eb3e4
 
 
922901f
 
 
 
 
 
 
 
 
27eb3e4
 
 
 
0117db0
27eb3e4
 
 
 
0117db0
27eb3e4
 
0117db0
27eb3e4
922901f
 
27eb3e4
922901f
27eb3e4
 
922901f
 
27eb3e4
922901f
 
 
 
 
 
 
 
 
 
 
 
 
 
27eb3e4
 
 
922901f
27eb3e4
922901f
 
 
0117db0
 
e480215
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from flask import Flask, request, jsonify
from faster_whisper import WhisperModel
import torch
import io
import time
import datetime
from threading import Semaphore

app = Flask(__name__)

# Device check for faster-whisper
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
print(f"Using device: {device} with compute_type: {compute_type}")

# Faster Whisper setup
beamsize = 2
wmodel = WhisperModel("guillaumekln/faster-whisper-small", device=device, compute_type=compute_type)

# Concurrency control
MAX_CONCURRENT_REQUESTS = 2  # Adjust based on your server capacity
request_semaphore = Semaphore(MAX_CONCURRENT_REQUESTS)
active_requests = 0

# Warm up the model (important for CUDA)
print("Warming up the model...")
try:
    dummy_audio = io.BytesIO(b'')  # Empty audio for warmup
    segments, info = wmodel.transcribe(dummy_audio, beam_size=beamsize)
    _ = [segment.text for segment in segments]  # Force execution
    print("Model warmup complete")
except Exception as e:
    print(f"Model warmup failed: {str(e)}")

@app.route("/health", methods=["GET"])
def health_check():
    """Endpoint to check if API is running"""
    return jsonify({
        'status': 'API is running',
        'timestamp': datetime.datetime.now().isoformat(),
        'device': device,
        'compute_type': compute_type,
        'active_requests': active_requests
    })

@app.route("/status/busy", methods=["GET"])
def server_busy():
    """Endpoint to check if server is busy"""
    is_busy = active_requests >= MAX_CONCURRENT_REQUESTS
    return jsonify({
        'is_busy': is_busy,
        'active_requests': active_requests,
        'max_capacity': MAX_CONCURRENT_REQUESTS
    })

@app.route("/whisper_transcribe", methods=["POST"])
def whisper_transcribe():
    global active_requests
    
    # Check if server is at capacity
    if not request_semaphore.acquire(blocking=False):
        return jsonify({
            'status': 'Server busy',
            'message': f'Currently processing {active_requests} requests',
            'suggestion': 'Please try again shortly'
        }), 503
    
    active_requests += 1
    print(f"Starting transcription (Active requests: {active_requests})")
    
    try:
        if 'audio' not in request.files:
            return jsonify({'error': 'No file provided'}), 400

        audio_file = request.files['audio']
        allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a'}
        if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions):
            return jsonify({'error': 'Invalid file format'}), 400

        audio_bytes = audio_file.read()
        audio_file = io.BytesIO(audio_bytes)

        try:
            # Timeout handling (60 seconds max processing time)
            start_time = time.time()
            segments, info = wmodel.transcribe(audio_file, beam_size=beamsize)
            
            text = ''
            for segment in segments:
                if time.time() - start_time > 60:  # Timeout after 60 seconds
                    raise TimeoutError("Transcription took too long")
                text += segment.text
            
            processing_time = time.time() - start_time
            print(f"Transcription completed in {processing_time:.2f} seconds")
            
            return jsonify({
                'transcription': text,
                'processing_time': processing_time,
                'language': info.language,
                'language_probability': info.language_probability
            })
            
        except TimeoutError:
            print("Transcription timeout")
            return jsonify({'error': 'Transcription timeout'}), 504
        except Exception as e:
            print(f"Transcription error: {str(e)}")
            return jsonify({'error': 'Transcription failed'}), 500
            
    finally:
        active_requests -= 1
        request_semaphore.release()
        print(f"Request completed (Active requests: {active_requests})")

if __name__ == "__main__":
    app.run(host="0.0.0.0", debug=True, port=7860, threaded=True)