| import fastapi |
| import numpy as np |
| import torch |
| import torchaudio |
| from silero_vad import get_speech_timestamps, load_silero_vad |
| import whisperx |
| import edge_tts |
| import gc |
| import logging |
| import time |
| import os |
| from openai import OpenAI |
| import asyncio |
| from pydub import AudioSegment |
| from io import BytesIO |
| import threading |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
| |
| app = fastapi.FastAPI() |
|
|
| |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| logging.info(f'Using device: {device}') |
| vad_model = load_silero_vad().to(device) |
| logging.info('Loaded Silero VAD model') |
|
|
| |
| whisper_model = whisperx.load_model("tiny", device, compute_type="float16") |
| logging.info('Loaded WhisperX model') |
|
|
| OPENAI_API_KEY = "" |
| if not OPENAI_API_KEY: |
| logging.error("OpenAI API key not found. Please set the OPENAI_API_KEY environment variable.") |
| raise ValueError("OpenAI API key not found.") |
| logging.info('Initialized OpenAI client') |
| llm_client = OpenAI(api_key=OPENAI_API_KEY) |
|
|
| |
| TTS_VOICE = "en-GB-SoniaNeural" |
|
|
| |
| def check_vad(audio_data, sample_rate): |
| logging.info('Checking voice activity') |
| target_sample_rate = 16000 |
| if sample_rate != target_sample_rate: |
| resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate) |
| audio_tensor = resampler(torch.from_numpy(audio_data)) |
| else: |
| audio_tensor = torch.from_numpy(audio_data) |
| audio_tensor = audio_tensor.to(device) |
|
|
| speech_timestamps = get_speech_timestamps(audio_tensor, vad_model, sampling_rate=target_sample_rate) |
| logging.info(f'Found {len(speech_timestamps)} speech timestamps') |
| return len(speech_timestamps) > 0 |
|
|
| |
| def transcribe(audio_data, sample_rate): |
| logging.info('Transcribing audio') |
| target_sample_rate = 16000 |
| if sample_rate != target_sample_rate: |
| resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate) |
| audio_data = resampler(torch.from_numpy(audio_data)).numpy() |
| else: |
| audio_data = audio_data |
|
|
| batch_size = 16 |
| result = whisper_model.transcribe(audio_data, batch_size=batch_size) |
| text = result["segments"][0]["text"] if len(result["segments"]) > 0 else "" |
| logging.info(f'Transcription result: {text}') |
| del result |
| gc.collect() |
| if device == 'cuda': |
| torch.cuda.empty_cache() |
| return text |
|
|
| |
| def tts_streaming(text_stream): |
| logging.info('Performing TTS') |
| buffer = "" |
| punctuation = {'.', '!', '?'} |
| for text_chunk in text_stream: |
| if text_chunk is not None: |
| buffer += text_chunk |
| |
| sentences = [] |
| start = 0 |
| for i, char in enumerate(buffer): |
| if char in punctuation: |
| sentences.append(buffer[start:i+1].strip()) |
| start = i+1 |
| buffer = buffer[start:] |
|
|
| for sentence in sentences: |
| if sentence: |
| communicate = edge_tts.Communicate(sentence, TTS_VOICE) |
| for chunk in communicate.stream_sync(): |
| if chunk["type"] == "audio": |
| yield chunk["data"] |
| |
| if buffer.strip(): |
| communicate = edge_tts.Communicate(buffer.strip(), TTS_VOICE) |
| for chunk in communicate.stream_sync(): |
| if chunk["type"] == "audio": |
| yield chunk["data"] |
|
|
| |
| def llm(text): |
| logging.info('Getting response from OpenAI API') |
| response = llm_client.chat.completions.create( |
| model="gpt-4o", |
| messages=[ |
| {"role": "system", "content": "You respond to the following transcript from the conversation that you are having with the user."}, |
| {"role": "user", "content": text} |
| ], |
| stream=True, |
| temperature=0.7, |
| top_p=0.9 |
| ) |
| for chunk in response: |
| yield chunk.choices[0].delta.content |
|
|
| class Conversation: |
| def __init__(self): |
| self.mode = 'idle' |
| self.audio_stream = [] |
| self.valid_chunk_queue = [] |
| self.first_valid_chunk = None |
| self.last_valid_chunks = [] |
| self.valid_chunk_transcriptions = '' |
| self.in_transcription = False |
| self.llm_n_tts_task = None |
| self.stop_signal = False |
| self.sample_rate = 0 |
| self.out_audio_stream = [] |
| self.chunk_buffer = 0.5 |
| |
| def llm_n_tts(self): |
| for text_chunk in llm(self.transcription): |
| if self.stop_signal: |
| break |
| for audio_chunk in tts_streaming([text_chunk]): |
| if self.stop_signal: |
| break |
| self.out_audio_stream.append(np.frombuffer(audio_chunk, dtype=np.int16)) |
| |
| def process_audio_chunk(self, audio_chunk): |
| |
| audio_data = AudioSegment.from_file(BytesIO(audio_chunk), format="wav") |
| audio_data = np.array(audio_data.get_array_of_samples()) |
| self.sample_rate = audio_data.frame_rate |
| |
| |
| vad = check_vad(audio_data, self.sample_rate) |
| |
| if vad: |
| if self.first_valid_chunk is not None: |
| self.valid_chunk_queue.append(self.first_valid_chunk) |
| self.first_valid_chunk = None |
| self.valid_chunk_queue.append(audio_chunk) |
| |
| if len(self.valid_chunk_queue) > 2: |
| |
| |
| if self.mode == 'idle': |
| self.mode = 'listening' |
| elif self.mode == 'speaking': |
| |
| if self.llm_n_tts_task is not None: |
| self.stop_signal = True |
| self.llm_n_tts_task |
| self.stop_signal = False |
| self.mode = 'listening' |
|
|
| else: |
| if self.mode == 'listening': |
| self.last_valid_chunks.append(audio_chunk) |
| |
| if len(self.last_valid_chunks) > 2: |
| |
| |
| |
| self.valid_chunk_queue.extend(self.last_valid_chunks[:2]) |
| self.last_valid_chunks = [] |
| |
| while len(self.valid_chunk_queue) > 0: |
| time.sleep(0.1) |
| |
| self.mode = 'speaking' |
| self.llm_n_tts_task = threading.Thread(target=self.llm_n_tts) |
| self.llm_n_tts_task.start() |
| |
| def transcribe_loop(self): |
| while True: |
| if self.mode == 'listening': |
| if len(self.valid_chunk_queue) > 0: |
| accumulated_chunks = np.concatenate(self.valid_chunk_queue) |
| total_duration = len(accumulated_chunks) / self.sample_rate |
| |
| if total_duration >= 3.0 and self.in_transcription == True: |
| |
| first_2s_audio = accumulated_chunks[:int(2 * self.sample_rate)] |
| transcribed_text = transcribe(first_2s_audio, self.sample_rate) |
| self.valid_chunk_transcriptions += transcribed_text |
| self.valid_chunk_queue = [accumulated_chunks[int(2 * self.sample_rate):]] |
| |
| if self.mode == any(['idle', 'speaking']): |
| |
| |
| transcribed_text = transcribe(accumulated_chunks, self.sample_rate) |
| self.valid_chunk_transcriptions += transcribed_text |
| self.valid_chunk_queue = [] |
| else: |
| time.sleep(0.1) |
|
|
| def stream_out_audio(self): |
| while True: |
| if len(self.out_audio_stream) > 0: |
| yield AudioSegment(data=self.out_audio_stream.pop(0), sample_width=2, frame_rate=self.sample_rate, channels=1).raw_data |
|
|
| @app.websocket("/ws") |
| async def websocket_endpoint(websocket: fastapi.WebSocket): |
| |
| await websocket.accept() |
| |
| |
| conversation = Conversation() |
| |
| |
| transcribe_thread = threading.Thread(target=conversation.transcribe_loop) |
| transcribe_thread.start() |
| |
| |
| chunk_buffer_size = conversation.chunk_buffer |
| while True: |
| try: |
| audio_chunk = await websocket.receive_bytes() |
| conversation.process_audio_chunk(audio_chunk) |
| |
| if conversation.mode == 'speaking': |
| for audio_chunk in conversation.stream_out_audio(): |
| await websocket.send_bytes(audio_chunk) |
| else: |
| await websocket.send_bytes(b'') |
| except Exception as e: |
| logging.error(e) |
| break |
|
|
| @app.get("/") |
| async def index(): |
| return fastapi.responses.FileResponse("index.html") |
|
|
| if __name__ == '__main__': |
| import uvicorn |
| uvicorn.run(app, host='0.0.0.0', port=8000) |
|
|