Fzina commited on
Commit
28bf265
·
verified ·
1 Parent(s): 8b73aa6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -9,6 +9,8 @@ from gymnasium import spaces
9
  import numpy as np
10
  from PIL import Image
11
  import gradio as gr
 
 
12
 
13
  # Environment Variables
14
  HOSTED_API_URL = os.getenv("HOSTED_API_URL") # FastAPI backend URL
@@ -51,7 +53,7 @@ def analyze_traffic(image_path):
51
  if processed_image_url:
52
  img_response = requests.get(processed_image_url)
53
  if img_response.status_code == 200:
54
- processed_image = Image.open(BytesIO(img_response.content))
55
 
56
  return vehicle_count, congestion_level, flow_rate, processed_image
57
  else:
@@ -85,7 +87,7 @@ class TrafficSimEnv(gym.Env):
85
  super().reset(seed=seed)
86
  self.state = np.array([self.congestion_level, 30], dtype=np.float32)
87
  self.done = False
88
- return self.state
89
 
90
  def step(self, action):
91
  if self.done:
@@ -144,7 +146,7 @@ def optimize_signal_rl(congestion_level):
144
  model.learn(total_timesteps=1000)
145
 
146
  # Reset the environment and get the initial observation
147
- obs = env.reset()
148
  logging.debug(f"Reset observation: {obs}, type: {type(obs)}, shape: {np.shape(obs)}")
149
 
150
  # Ensure obs is a valid array
@@ -153,7 +155,7 @@ def optimize_signal_rl(congestion_level):
153
  # RL Optimization loop
154
  for _ in range(10):
155
  action, _ = model.predict(obs, deterministic=True)
156
- obs, rewards, dones, infos = env.step(action)
157
 
158
  # Debug observation structure
159
  logging.debug(f"Step observation: {obs}, type: {type(obs)}, shape: {np.shape(obs)}")
@@ -253,4 +255,4 @@ if __name__ == "__main__":
253
  )
254
 
255
  # Launch Gradio app
256
- interface.launch()
 
9
  import numpy as np
10
  from PIL import Image
11
  import gradio as gr
12
+ import io
13
+ import base64
14
 
15
  # Environment Variables
16
  HOSTED_API_URL = os.getenv("HOSTED_API_URL") # FastAPI backend URL
 
53
  if processed_image_url:
54
  img_response = requests.get(processed_image_url)
55
  if img_response.status_code == 200:
56
+ processed_image = Image.open(io.BytesIO(img_response.content))
57
 
58
  return vehicle_count, congestion_level, flow_rate, processed_image
59
  else:
 
87
  super().reset(seed=seed)
88
  self.state = np.array([self.congestion_level, 30], dtype=np.float32)
89
  self.done = False
90
+ return self.state, {}
91
 
92
  def step(self, action):
93
  if self.done:
 
146
  model.learn(total_timesteps=1000)
147
 
148
  # Reset the environment and get the initial observation
149
+ obs, _ = env.reset() # Ensure this is a tuple (obs, info)
150
  logging.debug(f"Reset observation: {obs}, type: {type(obs)}, shape: {np.shape(obs)}")
151
 
152
  # Ensure obs is a valid array
 
155
  # RL Optimization loop
156
  for _ in range(10):
157
  action, _ = model.predict(obs, deterministic=True)
158
+ obs, rewards, dones, _, infos = env.step(action)
159
 
160
  # Debug observation structure
161
  logging.debug(f"Step observation: {obs}, type: {type(obs)}, shape: {np.shape(obs)}")
 
255
  )
256
 
257
  # Launch Gradio app
258
+ interface.launch()