Spaces:
Sleeping
Sleeping
Update webapp.py
Browse files
webapp.py
CHANGED
@@ -1,30 +1,31 @@
|
|
1 |
-
#
|
2 |
|
3 |
import asyncio
|
4 |
import base64
|
5 |
import json
|
6 |
import os
|
7 |
-
|
8 |
|
9 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
10 |
from fastapi.responses import HTMLResponse
|
11 |
from fastapi.staticfiles import StaticFiles
|
12 |
import uvicorn
|
13 |
|
14 |
-
|
|
|
15 |
|
16 |
app = FastAPI()
|
17 |
|
18 |
# Store active client connections
|
19 |
-
|
20 |
|
21 |
-
# Mount
|
22 |
current_dir = os.path.dirname(os.path.realpath(__file__))
|
23 |
app.mount("/static", StaticFiles(directory=current_dir), name="static")
|
24 |
|
25 |
@app.get("/")
|
26 |
async def get_index():
|
27 |
-
|
28 |
index_path = os.path.join(current_dir, "index.html")
|
29 |
with open(index_path, "r", encoding="utf-8") as f:
|
30 |
html_content = f.read()
|
@@ -32,6 +33,7 @@ async def get_index():
|
|
32 |
|
33 |
@app.websocket("/ws")
|
34 |
async def websocket_endpoint(websocket: WebSocket):
|
|
|
35 |
await websocket.accept()
|
36 |
print("[websocket_endpoint] Client connected.")
|
37 |
|
@@ -40,47 +42,45 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
40 |
|
41 |
# Create a new AudioLoop instance for this client
|
42 |
audio_loop = AudioLoop()
|
43 |
-
|
44 |
"websocket": websocket,
|
45 |
"audio_loop": audio_loop,
|
46 |
-
"
|
47 |
-
"expected_audio_seq": 0,
|
48 |
-
"repo_url": None,
|
49 |
-
"preferences": None
|
50 |
}
|
51 |
|
52 |
# Start the AudioLoop for this client
|
53 |
loop_task = asyncio.create_task(audio_loop.run())
|
54 |
-
print(f"[websocket_endpoint] Started
|
55 |
|
56 |
-
async def
|
57 |
-
"""
|
58 |
try:
|
59 |
while True:
|
60 |
data = await websocket.receive_text()
|
61 |
msg = json.loads(data)
|
62 |
-
msg_type = msg.get("type")
|
63 |
|
64 |
-
# Handle repository URL and preferences
|
65 |
if msg_type == "init":
|
66 |
-
|
67 |
-
|
|
|
68 |
"github_token": msg.get("github_token", ""),
|
69 |
"user_type": msg.get("user_type", "coder"),
|
70 |
"response_detail": msg.get("response_detail", "normal")
|
71 |
}
|
72 |
-
print(f"[
|
73 |
-
|
|
|
74 |
await websocket.send_text(json.dumps({
|
75 |
"type": "status",
|
76 |
"status": "initialized",
|
77 |
-
"message": "
|
78 |
}))
|
79 |
-
|
80 |
-
# Handle audio data from client
|
81 |
elif msg_type == "audio":
|
82 |
-
#
|
83 |
raw_pcm = base64.b64decode(msg["payload"])
|
|
|
84 |
forward_msg = {
|
85 |
"realtime_input": {
|
86 |
"media_chunks": [
|
@@ -91,48 +91,27 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
91 |
]
|
92 |
}
|
93 |
}
|
94 |
-
|
95 |
-
|
96 |
-
seq = msg.get("seq")
|
97 |
-
audio_ordering_buffer = client_connections[client_id]["audio_ordering_buffer"]
|
98 |
-
expected_audio_seq = client_connections[client_id]["expected_audio_seq"]
|
99 |
-
|
100 |
-
if seq is not None:
|
101 |
-
# Store the message in the buffer
|
102 |
-
audio_ordering_buffer[seq] = forward_msg
|
103 |
-
# Forward any messages in order
|
104 |
-
while expected_audio_seq in audio_ordering_buffer:
|
105 |
-
msg_to_forward = audio_ordering_buffer.pop(expected_audio_seq)
|
106 |
-
await audio_loop.out_queue.put(msg_to_forward)
|
107 |
-
expected_audio_seq += 1
|
108 |
-
client_connections[client_id]["expected_audio_seq"] = expected_audio_seq
|
109 |
-
else:
|
110 |
-
# If no sequence number is provided, forward immediately
|
111 |
-
await audio_loop.out_queue.put(forward_msg)
|
112 |
-
|
113 |
-
# Handle text data from client
|
114 |
elif msg_type == "text":
|
|
|
115 |
user_text = msg.get("content", "")
|
116 |
-
# Augment the query with repository context if available
|
117 |
-
repo_context = ""
|
118 |
-
if client_connections[client_id]["repo_url"]:
|
119 |
-
repo_context = f"For GitHub repository: {client_connections[client_id]['repo_url']}"
|
120 |
-
|
121 |
-
# Add preferences context
|
122 |
-
prefs = client_connections[client_id]["preferences"]
|
123 |
-
if prefs:
|
124 |
-
user_type = prefs.get("user_type", "coder")
|
125 |
-
detail_level = prefs.get("response_detail", "normal")
|
126 |
-
repo_context += f"\nUser role: {user_type}, Preferred detail level: {detail_level}"
|
127 |
|
128 |
-
#
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
-
print(f"[
|
135 |
|
|
|
136 |
forward_msg = {
|
137 |
"client_content": {
|
138 |
"turn_complete": True,
|
@@ -140,81 +119,98 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
140 |
{
|
141 |
"role": "user",
|
142 |
"parts": [
|
143 |
-
{"text":
|
144 |
]
|
145 |
}
|
146 |
]
|
147 |
}
|
148 |
}
|
149 |
await audio_loop.out_queue.put(forward_msg)
|
150 |
-
|
151 |
-
# Handle interrupt request
|
152 |
elif msg_type == "interrupt":
|
153 |
-
|
154 |
-
#
|
155 |
-
#
|
|
|
156 |
await websocket.send_text(json.dumps({
|
157 |
"type": "status",
|
158 |
"status": "interrupted",
|
159 |
-
"message": "
|
160 |
}))
|
161 |
-
|
162 |
else:
|
163 |
-
print(f"[
|
164 |
|
165 |
except WebSocketDisconnect:
|
166 |
-
print(f"[
|
167 |
cleanup_client(client_id, loop_task)
|
168 |
except Exception as e:
|
169 |
-
print(f"[
|
170 |
cleanup_client(client_id, loop_task)
|
171 |
|
172 |
-
async def
|
173 |
-
"""
|
174 |
try:
|
175 |
while True:
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
|
186 |
except WebSocketDisconnect:
|
187 |
-
print(f"[
|
188 |
cleanup_client(client_id, loop_task)
|
189 |
except Exception as e:
|
190 |
-
print(f"[
|
191 |
cleanup_client(client_id, loop_task)
|
192 |
|
193 |
-
def cleanup_client(client_id,
|
194 |
-
"""Clean up resources
|
195 |
-
if client_id in
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
try:
|
205 |
await asyncio.gather(
|
206 |
-
|
207 |
-
|
208 |
)
|
209 |
finally:
|
210 |
-
print(f"[websocket_endpoint] WebSocket handler finished for
|
211 |
cleanup_client(client_id, loop_task)
|
212 |
|
213 |
if __name__ == "__main__":
|
214 |
-
#
|
215 |
if "GOOGLE_API_KEY" not in os.environ:
|
216 |
print("Error: GOOGLE_API_KEY environment variable not set")
|
217 |
-
print("Please set it with: export GOOGLE_API_KEY='
|
218 |
-
exit(1)
|
219 |
-
|
|
|
220 |
uvicorn.run("webapp:app", host="0.0.0.0", port=7860, reload=True)
|
|
|
1 |
+
# basic_webapp.py
|
2 |
|
3 |
import asyncio
|
4 |
import base64
|
5 |
import json
|
6 |
import os
|
7 |
+
import sys
|
8 |
|
9 |
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
10 |
from fastapi.responses import HTMLResponse
|
11 |
from fastapi.staticfiles import StaticFiles
|
12 |
import uvicorn
|
13 |
|
14 |
+
# Import the simplified AudioLoop
|
15 |
+
from basic_handler import AudioLoop
|
16 |
|
17 |
app = FastAPI()
|
18 |
|
19 |
# Store active client connections
|
20 |
+
active_clients = {}
|
21 |
|
22 |
+
# Mount static files directory
|
23 |
current_dir = os.path.dirname(os.path.realpath(__file__))
|
24 |
app.mount("/static", StaticFiles(directory=current_dir), name="static")
|
25 |
|
26 |
@app.get("/")
|
27 |
async def get_index():
|
28 |
+
"""Serve the main HTML interface."""
|
29 |
index_path = os.path.join(current_dir, "index.html")
|
30 |
with open(index_path, "r", encoding="utf-8") as f:
|
31 |
html_content = f.read()
|
|
|
33 |
|
34 |
@app.websocket("/ws")
|
35 |
async def websocket_endpoint(websocket: WebSocket):
|
36 |
+
"""Handle WebSocket connections from clients."""
|
37 |
await websocket.accept()
|
38 |
print("[websocket_endpoint] Client connected.")
|
39 |
|
|
|
42 |
|
43 |
# Create a new AudioLoop instance for this client
|
44 |
audio_loop = AudioLoop()
|
45 |
+
active_clients[client_id] = {
|
46 |
"websocket": websocket,
|
47 |
"audio_loop": audio_loop,
|
48 |
+
"repo_context": None
|
|
|
|
|
|
|
49 |
}
|
50 |
|
51 |
# Start the AudioLoop for this client
|
52 |
loop_task = asyncio.create_task(audio_loop.run())
|
53 |
+
print(f"[websocket_endpoint] Started AudioLoop for client {client_id}")
|
54 |
|
55 |
+
async def process_client_messages():
|
56 |
+
"""Handle messages from the client and forward to Gemini."""
|
57 |
try:
|
58 |
while True:
|
59 |
data = await websocket.receive_text()
|
60 |
msg = json.loads(data)
|
61 |
+
msg_type = msg.get("type", "")
|
62 |
|
|
|
63 |
if msg_type == "init":
|
64 |
+
# Store repository context info
|
65 |
+
active_clients[client_id]["repo_context"] = {
|
66 |
+
"repo_url": msg.get("repo_url", ""),
|
67 |
"github_token": msg.get("github_token", ""),
|
68 |
"user_type": msg.get("user_type", "coder"),
|
69 |
"response_detail": msg.get("response_detail", "normal")
|
70 |
}
|
71 |
+
print(f"[process_client_messages] Stored context for {client_id}: {msg.get('repo_url', '')}")
|
72 |
+
|
73 |
+
# Send confirmation
|
74 |
await websocket.send_text(json.dumps({
|
75 |
"type": "status",
|
76 |
"status": "initialized",
|
77 |
+
"message": "Ready to assist with this repository."
|
78 |
}))
|
79 |
+
|
|
|
80 |
elif msg_type == "audio":
|
81 |
+
# Forward audio data to Gemini
|
82 |
raw_pcm = base64.b64decode(msg["payload"])
|
83 |
+
|
84 |
forward_msg = {
|
85 |
"realtime_input": {
|
86 |
"media_chunks": [
|
|
|
91 |
]
|
92 |
}
|
93 |
}
|
94 |
+
await audio_loop.out_queue.put(forward_msg)
|
95 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
elif msg_type == "text":
|
97 |
+
# Process text query from client
|
98 |
user_text = msg.get("content", "")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
+
# Add repository context if available
|
101 |
+
context = active_clients[client_id]["repo_context"]
|
102 |
+
if context and context["repo_url"]:
|
103 |
+
# Format context info for Gemini
|
104 |
+
context_text = (
|
105 |
+
f"The GitHub repository being discussed is: {context['repo_url']}\n"
|
106 |
+
f"User role: {context['user_type']}\n"
|
107 |
+
f"Preferred detail level: {context['response_detail']}\n\n"
|
108 |
+
f"Please consider this context when answering the following question:\n"
|
109 |
+
)
|
110 |
+
user_text = context_text + user_text
|
111 |
|
112 |
+
print(f"[process_client_messages] Sending text to Gemini: {user_text[:100]}...")
|
113 |
|
114 |
+
# Format message for Gemini
|
115 |
forward_msg = {
|
116 |
"client_content": {
|
117 |
"turn_complete": True,
|
|
|
119 |
{
|
120 |
"role": "user",
|
121 |
"parts": [
|
122 |
+
{"text": user_text}
|
123 |
]
|
124 |
}
|
125 |
]
|
126 |
}
|
127 |
}
|
128 |
await audio_loop.out_queue.put(forward_msg)
|
129 |
+
|
|
|
130 |
elif msg_type == "interrupt":
|
131 |
+
# For now, just acknowledge the interrupt
|
132 |
+
# This is a simple implementation because true interruption
|
133 |
+
# may require additional API support
|
134 |
+
print(f"[process_client_messages] Interrupt requested by {client_id}")
|
135 |
await websocket.send_text(json.dumps({
|
136 |
"type": "status",
|
137 |
"status": "interrupted",
|
138 |
+
"message": "Processing interrupted by user."
|
139 |
}))
|
140 |
+
|
141 |
else:
|
142 |
+
print(f"[process_client_messages] Unknown message type: {msg_type}")
|
143 |
|
144 |
except WebSocketDisconnect:
|
145 |
+
print(f"[process_client_messages] Client {client_id} disconnected")
|
146 |
cleanup_client(client_id, loop_task)
|
147 |
except Exception as e:
|
148 |
+
print(f"[process_client_messages] Error: {e}")
|
149 |
cleanup_client(client_id, loop_task)
|
150 |
|
151 |
+
async def forward_gemini_responses():
|
152 |
+
"""Read responses from Gemini and send them to the client."""
|
153 |
try:
|
154 |
while True:
|
155 |
+
# Check for audio data
|
156 |
+
try:
|
157 |
+
pcm_data = await asyncio.wait_for(audio_loop.audio_in_queue.get(), 0.5)
|
158 |
+
b64_pcm = base64.b64encode(pcm_data).decode()
|
159 |
+
|
160 |
+
# Send audio to client
|
161 |
+
out_msg = {
|
162 |
+
"type": "audio",
|
163 |
+
"payload": b64_pcm
|
164 |
+
}
|
165 |
+
print(f"[forward_gemini_responses] Sending audio chunk to client {client_id}")
|
166 |
+
await websocket.send_text(json.dumps(out_msg))
|
167 |
+
except asyncio.TimeoutError:
|
168 |
+
# No audio available, continue checking
|
169 |
+
pass
|
170 |
+
|
171 |
+
# We could add additional processing for text responses here
|
172 |
+
# if we had a separate queue for text content
|
173 |
|
174 |
except WebSocketDisconnect:
|
175 |
+
print(f"[forward_gemini_responses] Client {client_id} disconnected")
|
176 |
cleanup_client(client_id, loop_task)
|
177 |
except Exception as e:
|
178 |
+
print(f"[forward_gemini_responses] Error: {e}")
|
179 |
cleanup_client(client_id, loop_task)
|
180 |
|
181 |
+
def cleanup_client(client_id, task):
|
182 |
+
"""Clean up resources when a client disconnects."""
|
183 |
+
if client_id in active_clients:
|
184 |
+
client_data = active_clients[client_id]
|
185 |
+
|
186 |
+
# Stop the AudioLoop
|
187 |
+
if "audio_loop" in client_data:
|
188 |
+
client_data["audio_loop"].stop()
|
189 |
+
|
190 |
+
# Cancel the task if it's still running
|
191 |
+
if task and not task.done():
|
192 |
+
task.cancel()
|
193 |
+
|
194 |
+
# Remove from active clients
|
195 |
+
del active_clients[client_id]
|
196 |
+
print(f"[cleanup_client] Cleaned up resources for {client_id}")
|
197 |
+
|
198 |
+
# Run both tasks concurrently
|
199 |
try:
|
200 |
await asyncio.gather(
|
201 |
+
process_client_messages(),
|
202 |
+
forward_gemini_responses()
|
203 |
)
|
204 |
finally:
|
205 |
+
print(f"[websocket_endpoint] WebSocket handler finished for {client_id}")
|
206 |
cleanup_client(client_id, loop_task)
|
207 |
|
208 |
if __name__ == "__main__":
|
209 |
+
# Verify API key is present
|
210 |
if "GOOGLE_API_KEY" not in os.environ:
|
211 |
print("Error: GOOGLE_API_KEY environment variable not set")
|
212 |
+
print("Please set it with: export GOOGLE_API_KEY='your_api_key_here'")
|
213 |
+
sys.exit(1)
|
214 |
+
|
215 |
+
# Start the server
|
216 |
uvicorn.run("webapp:app", host="0.0.0.0", port=7860, reload=True)
|