Spaces:
Sleeping
Sleeping
# basic_webapp.py | |
import asyncio | |
import base64 | |
import json | |
import os | |
import sys | |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
from fastapi.responses import HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
import uvicorn | |
# Import the simplified AudioLoop | |
from handler import AudioLoop | |
app = FastAPI() | |
# Store active client connections | |
active_clients = {} | |
# Mount static files directory | |
current_dir = os.path.dirname(os.path.realpath(__file__)) | |
app.mount("/static", StaticFiles(directory=current_dir), name="static") | |
async def get_index(): | |
"""Serve the main HTML interface.""" | |
index_path = os.path.join(current_dir, "index.html") | |
with open(index_path, "r", encoding="utf-8") as f: | |
html_content = f.read() | |
return HTMLResponse(content=html_content) | |
async def websocket_endpoint(websocket: WebSocket): | |
"""Handle WebSocket connections from clients.""" | |
await websocket.accept() | |
print("[websocket_endpoint] Client connected.") | |
# Generate a unique client ID | |
client_id = f"client_{id(websocket)}" | |
# Create a new AudioLoop instance for this client | |
audio_loop = AudioLoop() | |
active_clients[client_id] = { | |
"websocket": websocket, | |
"audio_loop": audio_loop, | |
"repo_context": None | |
} | |
# Start the AudioLoop for this client | |
loop_task = asyncio.create_task(audio_loop.run()) | |
print(f"[websocket_endpoint] Started AudioLoop for client {client_id}") | |
async def process_client_messages(): | |
"""Handle messages from the client and forward to Gemini.""" | |
try: | |
while True: | |
data = await websocket.receive_text() | |
msg = json.loads(data) | |
msg_type = msg.get("type", "") | |
if msg_type == "init": | |
# Store repository context info | |
active_clients[client_id]["repo_context"] = { | |
"repo_url": msg.get("repo_url", ""), | |
"github_token": msg.get("github_token", ""), | |
"user_type": msg.get("user_type", "coder"), | |
"response_detail": msg.get("response_detail", "normal") | |
} | |
print(f"[process_client_messages] Stored context for {client_id}: {msg.get('repo_url', '')}") | |
# Send confirmation | |
await websocket.send_text(json.dumps({ | |
"type": "status", | |
"status": "initialized", | |
"message": "Ready to assist with this repository." | |
})) | |
elif msg_type == "audio": | |
# Forward audio data to Gemini | |
raw_pcm = base64.b64decode(msg["payload"]) | |
forward_msg = { | |
"realtime_input": { | |
"media_chunks": [ | |
{ | |
"data": base64.b64encode(raw_pcm).decode(), | |
"mime_type": "audio/pcm" | |
} | |
] | |
} | |
} | |
await audio_loop.out_queue.put(forward_msg) | |
elif msg_type == "text": | |
# Process text query from client | |
user_text = msg.get("content", "") | |
# Add repository context if available | |
context = active_clients[client_id]["repo_context"] | |
if context and context["repo_url"]: | |
# Format context info for Gemini | |
context_text = ( | |
f"The GitHub repository being discussed is: {context['repo_url']}\n" | |
f"User role: {context['user_type']}\n" | |
f"Preferred detail level: {context['response_detail']}\n\n" | |
f"Please consider this context when answering the following question:\n" | |
) | |
user_text = context_text + user_text | |
print(f"[process_client_messages] Sending text to Gemini: {user_text[:100]}...") | |
# Format message for Gemini | |
forward_msg = { | |
"client_content": { | |
"turn_complete": True, | |
"turns": [ | |
{ | |
"role": "user", | |
"parts": [ | |
{"text": user_text} | |
] | |
} | |
] | |
} | |
} | |
await audio_loop.out_queue.put(forward_msg) | |
elif msg_type == "interrupt": | |
# For now, just acknowledge the interrupt | |
# This is a simple implementation because true interruption | |
# may require additional API support | |
print(f"[process_client_messages] Interrupt requested by {client_id}") | |
await websocket.send_text(json.dumps({ | |
"type": "status", | |
"status": "interrupted", | |
"message": "Processing interrupted by user." | |
})) | |
else: | |
print(f"[process_client_messages] Unknown message type: {msg_type}") | |
except WebSocketDisconnect: | |
print(f"[process_client_messages] Client {client_id} disconnected") | |
cleanup_client(client_id, loop_task) | |
except Exception as e: | |
print(f"[process_client_messages] Error: {e}") | |
cleanup_client(client_id, loop_task) | |
async def forward_gemini_responses(): | |
"""Read responses from Gemini and send them to the client.""" | |
try: | |
while True: | |
# Check for audio data | |
try: | |
pcm_data = await asyncio.wait_for(audio_loop.audio_in_queue.get(), 0.5) | |
b64_pcm = base64.b64encode(pcm_data).decode() | |
# Send audio to client | |
out_msg = { | |
"type": "audio", | |
"payload": b64_pcm | |
} | |
print(f"[forward_gemini_responses] Sending audio chunk to client {client_id}") | |
await websocket.send_text(json.dumps(out_msg)) | |
except asyncio.TimeoutError: | |
# No audio available, continue checking | |
pass | |
# We could add additional processing for text responses here | |
# if we had a separate queue for text content | |
except WebSocketDisconnect: | |
print(f"[forward_gemini_responses] Client {client_id} disconnected") | |
cleanup_client(client_id, loop_task) | |
except Exception as e: | |
print(f"[forward_gemini_responses] Error: {e}") | |
cleanup_client(client_id, loop_task) | |
def cleanup_client(client_id, task): | |
"""Clean up resources when a client disconnects.""" | |
if client_id in active_clients: | |
client_data = active_clients[client_id] | |
# Stop the AudioLoop | |
if "audio_loop" in client_data: | |
client_data["audio_loop"].stop() | |
# Cancel the task if it's still running | |
if task and not task.done(): | |
task.cancel() | |
# Remove from active clients | |
del active_clients[client_id] | |
print(f"[cleanup_client] Cleaned up resources for {client_id}") | |
# Run both tasks concurrently | |
try: | |
await asyncio.gather( | |
process_client_messages(), | |
forward_gemini_responses() | |
) | |
finally: | |
print(f"[websocket_endpoint] WebSocket handler finished for {client_id}") | |
cleanup_client(client_id, loop_task) | |
if __name__ == "__main__": | |
# Verify API key is present | |
if "GOOGLE_API_KEY" not in os.environ: | |
print("Error: GOOGLE_API_KEY environment variable not set") | |
print("Please set it with: export GOOGLE_API_KEY='your_api_key_here'") | |
sys.exit(1) | |
# Start the server | |
uvicorn.run("webapp:app", host="0.0.0.0", port=7860, reload=True) |