Fzina commited on
Commit
27b976b
·
verified ·
1 Parent(s): 12eccb4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -4,7 +4,7 @@ import requests
4
  from geopy.geocoders import Nominatim # Valid import
5
  from stable_baselines3 import PPO
6
  from stable_baselines3.common.vec_env import DummyVecEnv
7
- import gymnasium as gym
8
  from gymnasium import spaces
9
  import numpy as np
10
  from PIL import Image
@@ -148,14 +148,22 @@ def optimize_signal_rl(congestion_level):
148
  model.learn(total_timesteps=1000)
149
 
150
  # Reset the environment
151
- obs = env.reset() # Returns a single value (batch of observations)
 
 
 
 
152
 
153
  # RL Optimization loop
154
  for _ in range(10):
155
  action, _ = model.predict(obs, deterministic=True)
156
  obs, rewards, dones, infos = env.step(action)
157
 
158
- # Unpack single observation and reward (batch size of 1)
 
 
 
 
159
  obs = obs[0]
160
  reward = rewards[0]
161
  done = dones[0]
@@ -175,6 +183,7 @@ def optimize_signal_rl(congestion_level):
175
  return "Error in RL Optimization"
176
 
177
 
 
178
  def process_traffic_image(image):
179
  """
180
  Orchestrates the traffic analysis workflow.
 
4
  from geopy.geocoders import Nominatim # Valid import
5
  from stable_baselines3 import PPO
6
  from stable_baselines3.common.vec_env import DummyVecEnv
7
+ import gym as gym
8
  from gymnasium import spaces
9
  import numpy as np
10
  from PIL import Image
 
148
  model.learn(total_timesteps=1000)
149
 
150
  # Reset the environment
151
+ obs = env.reset()
152
+
153
+ # Ensure obs is properly handled (flatten if necessary)
154
+ if obs.ndim == 0:
155
+ obs = np.expand_dims(obs, axis=0) # Convert 0D array to 1D
156
 
157
  # RL Optimization loop
158
  for _ in range(10):
159
  action, _ = model.predict(obs, deterministic=True)
160
  obs, rewards, dones, infos = env.step(action)
161
 
162
+ # Ensure obs is properly handled
163
+ if obs.ndim == 0:
164
+ obs = np.expand_dims(obs, axis=0)
165
+
166
+ # Extract single observation and reward
167
  obs = obs[0]
168
  reward = rewards[0]
169
  done = dones[0]
 
183
  return "Error in RL Optimization"
184
 
185
 
186
+
187
  def process_traffic_image(image):
188
  """
189
  Orchestrates the traffic analysis workflow.