Fzina commited on
Commit
758fd82
·
verified ·
1 Parent(s): 8618a39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -9
app.py CHANGED
@@ -146,25 +146,40 @@ def optimize_signal_rl(congestion_level):
146
  model.learn(total_timesteps=1000)
147
 
148
  # Reset the environment and get the initial observation
149
- obs = env.reset() # No need to unpack anything; this gives the observation array
150
- logging.debug(f"Reset observation: {obs}, type: {type(obs)}, shape: {np.shape(obs)}")
 
 
 
151
 
152
  # RL Optimization loop
153
  for _ in range(10):
154
- action, _ = model.predict(obs, deterministic=True) # Predict action
155
- obs, rewards, dones, infos = env.step(action) # Step environment
 
 
 
 
156
 
157
- # Flatten the observation for the next step
158
  obs = obs.flatten()
159
 
160
- # If any environment is done, stop
161
- if dones[0]:
 
162
  break
163
 
164
- # Get the optimal duration
165
- optimal_duration = int(obs[1]) if len(obs) > 1 else 30 # Use default if obs is incorrect
 
 
 
 
166
  return f"Green for {optimal_duration}s, Red for {60 - optimal_duration}s"
167
 
 
 
 
168
  except Exception as e:
169
  logging.error(f"Error optimizing signal with RL: {e}")
170
  return "Error in RL Optimization"
 
146
  model.learn(total_timesteps=1000)
147
 
148
  # Reset the environment and get the initial observation
149
+ obs = env.reset() # `obs` should be a 2D array (vectorized environment)
150
+ if not isinstance(obs, np.ndarray) or obs.ndim != 2: # Validate obs shape
151
+ raise ValueError(f"Unexpected observation shape: {obs}, type: {type(obs)}")
152
+
153
+ logging.info(f"Initial observation: {obs}")
154
 
155
  # RL Optimization loop
156
  for _ in range(10):
157
+ # Predict action and perform step
158
+ action, _ = model.predict(obs, deterministic=True)
159
+ obs, rewards, dones, infos = env.step(action)
160
+
161
+ if not isinstance(obs, np.ndarray) or obs.ndim != 2: # Validate obs shape again
162
+ raise ValueError(f"Unexpected observation after step: {obs}, type: {type(obs)}")
163
 
164
+ # Flatten the observation for easier handling
165
  obs = obs.flatten()
166
 
167
+ logging.debug(f"Step results - Obs: {obs}, Reward: {rewards}, Done: {dones}, Info: {infos}")
168
+
169
+ if dones[0]: # Stop if the environment signals it's done
170
  break
171
 
172
+ # Get optimal signal duration
173
+ if len(obs) > 1: # Ensure the array has enough elements
174
+ optimal_duration = int(obs[1])
175
+ else:
176
+ raise ValueError(f"Observation array too short: {obs}")
177
+
178
  return f"Green for {optimal_duration}s, Red for {60 - optimal_duration}s"
179
 
180
+ except ValueError as ve:
181
+ logging.error(f"Value error: {ve}")
182
+ return "Error: Unexpected values encountered during optimization."
183
  except Exception as e:
184
  logging.error(f"Error optimizing signal with RL: {e}")
185
  return "Error in RL Optimization"