nihalaninihal commited on
Commit
0f005f1
·
verified ·
1 Parent(s): 4ee8f0b

Update webapp.py

Browse files
Files changed (1) hide show
  1. webapp.py +99 -103
webapp.py CHANGED
@@ -1,30 +1,31 @@
1
- # updated_webapp.py
2
 
3
  import asyncio
4
  import base64
5
  import json
6
  import os
7
- from typing import Optional, Dict, Any, List
8
 
9
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
10
  from fastapi.responses import HTMLResponse
11
  from fastapi.staticfiles import StaticFiles
12
  import uvicorn
13
 
14
- from handler import AudioLoop # Import AudioLoop class
 
15
 
16
  app = FastAPI()
17
 
18
  # Store active client connections
19
- client_connections: Dict[str, Dict[str, Any]] = {}
20
 
21
- # Mount the web_ui directory to serve static files
22
  current_dir = os.path.dirname(os.path.realpath(__file__))
23
  app.mount("/static", StaticFiles(directory=current_dir), name="static")
24
 
25
  @app.get("/")
26
  async def get_index():
27
- # Read and return the index.html file
28
  index_path = os.path.join(current_dir, "index.html")
29
  with open(index_path, "r", encoding="utf-8") as f:
30
  html_content = f.read()
@@ -32,6 +33,7 @@ async def get_index():
32
 
33
  @app.websocket("/ws")
34
  async def websocket_endpoint(websocket: WebSocket):
 
35
  await websocket.accept()
36
  print("[websocket_endpoint] Client connected.")
37
 
@@ -40,47 +42,45 @@ async def websocket_endpoint(websocket: WebSocket):
40
 
41
  # Create a new AudioLoop instance for this client
42
  audio_loop = AudioLoop()
43
- client_connections[client_id] = {
44
  "websocket": websocket,
45
  "audio_loop": audio_loop,
46
- "audio_ordering_buffer": {},
47
- "expected_audio_seq": 0,
48
- "repo_url": None,
49
- "preferences": None
50
  }
51
 
52
  # Start the AudioLoop for this client
53
  loop_task = asyncio.create_task(audio_loop.run())
54
- print(f"[websocket_endpoint] Started new AudioLoop for client {client_id}")
55
 
56
- async def from_client_to_gemini():
57
- """Handles incoming messages from the client and forwards them to Gemini."""
58
  try:
59
  while True:
60
  data = await websocket.receive_text()
61
  msg = json.loads(data)
62
- msg_type = msg.get("type")
63
 
64
- # Handle repository URL and preferences
65
  if msg_type == "init":
66
- client_connections[client_id]["repo_url"] = msg.get("repo_url", "")
67
- client_connections[client_id]["preferences"] = {
 
68
  "github_token": msg.get("github_token", ""),
69
  "user_type": msg.get("user_type", "coder"),
70
  "response_detail": msg.get("response_detail", "normal")
71
  }
72
- print(f"[from_client_to_gemini] Client {client_id} initialized with repo: {client_connections[client_id]['repo_url']}")
73
- # Send a confirmation back to client
 
74
  await websocket.send_text(json.dumps({
75
  "type": "status",
76
  "status": "initialized",
77
- "message": "G.E.N.I.E. is ready to assist with this repository."
78
  }))
79
-
80
- # Handle audio data from client
81
  elif msg_type == "audio":
82
- # Decode base64 audio data
83
  raw_pcm = base64.b64decode(msg["payload"])
 
84
  forward_msg = {
85
  "realtime_input": {
86
  "media_chunks": [
@@ -91,48 +91,27 @@ async def websocket_endpoint(websocket: WebSocket):
91
  ]
92
  }
93
  }
94
-
95
- # Retrieve the sequence number from the message
96
- seq = msg.get("seq")
97
- audio_ordering_buffer = client_connections[client_id]["audio_ordering_buffer"]
98
- expected_audio_seq = client_connections[client_id]["expected_audio_seq"]
99
-
100
- if seq is not None:
101
- # Store the message in the buffer
102
- audio_ordering_buffer[seq] = forward_msg
103
- # Forward any messages in order
104
- while expected_audio_seq in audio_ordering_buffer:
105
- msg_to_forward = audio_ordering_buffer.pop(expected_audio_seq)
106
- await audio_loop.out_queue.put(msg_to_forward)
107
- expected_audio_seq += 1
108
- client_connections[client_id]["expected_audio_seq"] = expected_audio_seq
109
- else:
110
- # If no sequence number is provided, forward immediately
111
- await audio_loop.out_queue.put(forward_msg)
112
-
113
- # Handle text data from client
114
  elif msg_type == "text":
 
115
  user_text = msg.get("content", "")
116
- # Augment the query with repository context if available
117
- repo_context = ""
118
- if client_connections[client_id]["repo_url"]:
119
- repo_context = f"For GitHub repository: {client_connections[client_id]['repo_url']}"
120
-
121
- # Add preferences context
122
- prefs = client_connections[client_id]["preferences"]
123
- if prefs:
124
- user_type = prefs.get("user_type", "coder")
125
- detail_level = prefs.get("response_detail", "normal")
126
- repo_context += f"\nUser role: {user_type}, Preferred detail level: {detail_level}"
127
 
128
- # Combine context with user query
129
- if repo_context:
130
- enhanced_text = f"{repo_context}\n\nUser query: {user_text}"
131
- else:
132
- enhanced_text = user_text
 
 
 
 
 
 
133
 
134
- print(f"[from_client_to_gemini] Forwarding user text to Gemini: {enhanced_text[:100]}...")
135
 
 
136
  forward_msg = {
137
  "client_content": {
138
  "turn_complete": True,
@@ -140,81 +119,98 @@ async def websocket_endpoint(websocket: WebSocket):
140
  {
141
  "role": "user",
142
  "parts": [
143
- {"text": enhanced_text}
144
  ]
145
  }
146
  ]
147
  }
148
  }
149
  await audio_loop.out_queue.put(forward_msg)
150
-
151
- # Handle interrupt request
152
  elif msg_type == "interrupt":
153
- print(f"[from_client_to_gemini] Client {client_id} requested interrupt")
154
- # TODO: Send interrupt signal to Gemini if possible
155
- # For now, just acknowledge the interrupt request
 
156
  await websocket.send_text(json.dumps({
157
  "type": "status",
158
  "status": "interrupted",
159
- "message": "G.E.N.I.E. processing interrupted by user."
160
  }))
161
-
162
  else:
163
- print(f"[from_client_to_gemini] Unknown message type: {msg_type}")
164
 
165
  except WebSocketDisconnect:
166
- print(f"[from_client_to_gemini] Client {client_id} disconnected.")
167
  cleanup_client(client_id, loop_task)
168
  except Exception as e:
169
- print(f"[from_client_to_gemini] Error: {e}")
170
  cleanup_client(client_id, loop_task)
171
 
172
- async def from_gemini_to_client():
173
- """Reads PCM audio from Gemini and sends it back to the client."""
174
  try:
175
  while True:
176
- pcm_data = await audio_loop.audio_in_queue.get()
177
- b64_pcm = base64.b64encode(pcm_data).decode()
178
-
179
- out_msg = {
180
- "type": "audio",
181
- "payload": b64_pcm
182
- }
183
- print(f"[from_gemini_to_client] Sending audio chunk to client {client_id}. Size: {len(pcm_data)}")
184
- await websocket.send_text(json.dumps(out_msg))
 
 
 
 
 
 
 
 
 
185
 
186
  except WebSocketDisconnect:
187
- print(f"[from_gemini_to_client] Client {client_id} disconnected.")
188
  cleanup_client(client_id, loop_task)
189
  except Exception as e:
190
- print(f"[from_gemini_to_client] Error: {e}")
191
  cleanup_client(client_id, loop_task)
192
 
193
- def cleanup_client(client_id, loop_task):
194
- """Clean up resources for a disconnected client."""
195
- if client_id in client_connections:
196
- # Cancel the AudioLoop task
197
- if loop_task and not loop_task.done():
198
- loop_task.cancel()
199
- # Remove the client from active connections
200
- del client_connections[client_id]
201
- print(f"[cleanup_client] Cleaned up resources for client {client_id}")
202
-
203
- # Launch both tasks concurrently. If either fails or disconnects, we exit.
 
 
 
 
 
 
 
204
  try:
205
  await asyncio.gather(
206
- from_client_to_gemini(),
207
- from_gemini_to_client(),
208
  )
209
  finally:
210
- print(f"[websocket_endpoint] WebSocket handler finished for client {client_id}.")
211
  cleanup_client(client_id, loop_task)
212
 
213
  if __name__ == "__main__":
214
- # Make sure the GOOGLE_API_KEY environment variable is set before running
215
  if "GOOGLE_API_KEY" not in os.environ:
216
  print("Error: GOOGLE_API_KEY environment variable not set")
217
- print("Please set it with: export GOOGLE_API_KEY='your_api_key'")
218
- exit(1)
219
-
 
220
  uvicorn.run("webapp:app", host="0.0.0.0", port=7860, reload=True)
 
1
+ # basic_webapp.py
2
 
3
  import asyncio
4
  import base64
5
  import json
6
  import os
7
+ import sys
8
 
9
  from fastapi import FastAPI, WebSocket, WebSocketDisconnect
10
  from fastapi.responses import HTMLResponse
11
  from fastapi.staticfiles import StaticFiles
12
  import uvicorn
13
 
14
+ # Import the simplified AudioLoop
15
+ from basic_handler import AudioLoop
16
 
17
  app = FastAPI()
18
 
19
  # Store active client connections
20
+ active_clients = {}
21
 
22
+ # Mount static files directory
23
  current_dir = os.path.dirname(os.path.realpath(__file__))
24
  app.mount("/static", StaticFiles(directory=current_dir), name="static")
25
 
26
  @app.get("/")
27
  async def get_index():
28
+ """Serve the main HTML interface."""
29
  index_path = os.path.join(current_dir, "index.html")
30
  with open(index_path, "r", encoding="utf-8") as f:
31
  html_content = f.read()
 
33
 
34
  @app.websocket("/ws")
35
  async def websocket_endpoint(websocket: WebSocket):
36
+ """Handle WebSocket connections from clients."""
37
  await websocket.accept()
38
  print("[websocket_endpoint] Client connected.")
39
 
 
42
 
43
  # Create a new AudioLoop instance for this client
44
  audio_loop = AudioLoop()
45
+ active_clients[client_id] = {
46
  "websocket": websocket,
47
  "audio_loop": audio_loop,
48
+ "repo_context": None
 
 
 
49
  }
50
 
51
  # Start the AudioLoop for this client
52
  loop_task = asyncio.create_task(audio_loop.run())
53
+ print(f"[websocket_endpoint] Started AudioLoop for client {client_id}")
54
 
55
+ async def process_client_messages():
56
+ """Handle messages from the client and forward to Gemini."""
57
  try:
58
  while True:
59
  data = await websocket.receive_text()
60
  msg = json.loads(data)
61
+ msg_type = msg.get("type", "")
62
 
 
63
  if msg_type == "init":
64
+ # Store repository context info
65
+ active_clients[client_id]["repo_context"] = {
66
+ "repo_url": msg.get("repo_url", ""),
67
  "github_token": msg.get("github_token", ""),
68
  "user_type": msg.get("user_type", "coder"),
69
  "response_detail": msg.get("response_detail", "normal")
70
  }
71
+ print(f"[process_client_messages] Stored context for {client_id}: {msg.get('repo_url', '')}")
72
+
73
+ # Send confirmation
74
  await websocket.send_text(json.dumps({
75
  "type": "status",
76
  "status": "initialized",
77
+ "message": "Ready to assist with this repository."
78
  }))
79
+
 
80
  elif msg_type == "audio":
81
+ # Forward audio data to Gemini
82
  raw_pcm = base64.b64decode(msg["payload"])
83
+
84
  forward_msg = {
85
  "realtime_input": {
86
  "media_chunks": [
 
91
  ]
92
  }
93
  }
94
+ await audio_loop.out_queue.put(forward_msg)
95
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  elif msg_type == "text":
97
+ # Process text query from client
98
  user_text = msg.get("content", "")
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ # Add repository context if available
101
+ context = active_clients[client_id]["repo_context"]
102
+ if context and context["repo_url"]:
103
+ # Format context info for Gemini
104
+ context_text = (
105
+ f"The GitHub repository being discussed is: {context['repo_url']}\n"
106
+ f"User role: {context['user_type']}\n"
107
+ f"Preferred detail level: {context['response_detail']}\n\n"
108
+ f"Please consider this context when answering the following question:\n"
109
+ )
110
+ user_text = context_text + user_text
111
 
112
+ print(f"[process_client_messages] Sending text to Gemini: {user_text[:100]}...")
113
 
114
+ # Format message for Gemini
115
  forward_msg = {
116
  "client_content": {
117
  "turn_complete": True,
 
119
  {
120
  "role": "user",
121
  "parts": [
122
+ {"text": user_text}
123
  ]
124
  }
125
  ]
126
  }
127
  }
128
  await audio_loop.out_queue.put(forward_msg)
129
+
 
130
  elif msg_type == "interrupt":
131
+ # For now, just acknowledge the interrupt
132
+ # This is a simple implementation because true interruption
133
+ # may require additional API support
134
+ print(f"[process_client_messages] Interrupt requested by {client_id}")
135
  await websocket.send_text(json.dumps({
136
  "type": "status",
137
  "status": "interrupted",
138
+ "message": "Processing interrupted by user."
139
  }))
140
+
141
  else:
142
+ print(f"[process_client_messages] Unknown message type: {msg_type}")
143
 
144
  except WebSocketDisconnect:
145
+ print(f"[process_client_messages] Client {client_id} disconnected")
146
  cleanup_client(client_id, loop_task)
147
  except Exception as e:
148
+ print(f"[process_client_messages] Error: {e}")
149
  cleanup_client(client_id, loop_task)
150
 
151
+ async def forward_gemini_responses():
152
+ """Read responses from Gemini and send them to the client."""
153
  try:
154
  while True:
155
+ # Check for audio data
156
+ try:
157
+ pcm_data = await asyncio.wait_for(audio_loop.audio_in_queue.get(), 0.5)
158
+ b64_pcm = base64.b64encode(pcm_data).decode()
159
+
160
+ # Send audio to client
161
+ out_msg = {
162
+ "type": "audio",
163
+ "payload": b64_pcm
164
+ }
165
+ print(f"[forward_gemini_responses] Sending audio chunk to client {client_id}")
166
+ await websocket.send_text(json.dumps(out_msg))
167
+ except asyncio.TimeoutError:
168
+ # No audio available, continue checking
169
+ pass
170
+
171
+ # We could add additional processing for text responses here
172
+ # if we had a separate queue for text content
173
 
174
  except WebSocketDisconnect:
175
+ print(f"[forward_gemini_responses] Client {client_id} disconnected")
176
  cleanup_client(client_id, loop_task)
177
  except Exception as e:
178
+ print(f"[forward_gemini_responses] Error: {e}")
179
  cleanup_client(client_id, loop_task)
180
 
181
+ def cleanup_client(client_id, task):
182
+ """Clean up resources when a client disconnects."""
183
+ if client_id in active_clients:
184
+ client_data = active_clients[client_id]
185
+
186
+ # Stop the AudioLoop
187
+ if "audio_loop" in client_data:
188
+ client_data["audio_loop"].stop()
189
+
190
+ # Cancel the task if it's still running
191
+ if task and not task.done():
192
+ task.cancel()
193
+
194
+ # Remove from active clients
195
+ del active_clients[client_id]
196
+ print(f"[cleanup_client] Cleaned up resources for {client_id}")
197
+
198
+ # Run both tasks concurrently
199
  try:
200
  await asyncio.gather(
201
+ process_client_messages(),
202
+ forward_gemini_responses()
203
  )
204
  finally:
205
+ print(f"[websocket_endpoint] WebSocket handler finished for {client_id}")
206
  cleanup_client(client_id, loop_task)
207
 
208
  if __name__ == "__main__":
209
+ # Verify API key is present
210
  if "GOOGLE_API_KEY" not in os.environ:
211
  print("Error: GOOGLE_API_KEY environment variable not set")
212
+ print("Please set it with: export GOOGLE_API_KEY='your_api_key_here'")
213
+ sys.exit(1)
214
+
215
+ # Start the server
216
  uvicorn.run("webapp:app", host="0.0.0.0", port=7860, reload=True)