Fzina commited on
Commit
8618a39
·
verified ·
1 Parent(s): b3e5ddf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -15
app.py CHANGED
@@ -146,28 +146,23 @@ def optimize_signal_rl(congestion_level):
146
  model.learn(total_timesteps=1000)
147
 
148
  # Reset the environment and get the initial observation
149
- obs, _ = env.reset() # We don't need to unpack infos for now
150
  logging.debug(f"Reset observation: {obs}, type: {type(obs)}, shape: {np.shape(obs)}")
151
 
152
- # Flatten obs for the model
153
- obs = np.array(obs).flatten()
154
-
155
  # RL Optimization loop
156
  for _ in range(10):
157
- action, _ = model.predict(obs, deterministic=True)
158
- obs, reward, done, truncated, _ = env.step(action) # We don't need to unpack infos here
159
-
160
- # Debug observation structure
161
- logging.debug(f"Step observation: {obs}, type: {type(obs)}, shape: {np.shape(obs)}")
162
-
163
- # Flatten the observation again for next step
164
- obs = np.array(obs).flatten()
165
 
166
- if done[0]: # Check if done for the first environment in the batch
 
167
  break
168
 
169
- # Get optimal duration
170
- optimal_duration = int(obs[1]) if len(obs) > 1 else 30
171
  return f"Green for {optimal_duration}s, Red for {60 - optimal_duration}s"
172
 
173
  except Exception as e:
 
146
  model.learn(total_timesteps=1000)
147
 
148
  # Reset the environment and get the initial observation
149
+ obs = env.reset() # No need to unpack anything; this gives the observation array
150
  logging.debug(f"Reset observation: {obs}, type: {type(obs)}, shape: {np.shape(obs)}")
151
 
 
 
 
152
  # RL Optimization loop
153
  for _ in range(10):
154
+ action, _ = model.predict(obs, deterministic=True) # Predict action
155
+ obs, rewards, dones, infos = env.step(action) # Step environment
156
+
157
+ # Flatten the observation for the next step
158
+ obs = obs.flatten()
 
 
 
159
 
160
+ # If any environment is done, stop
161
+ if dones[0]:
162
  break
163
 
164
+ # Get the optimal duration
165
+ optimal_duration = int(obs[1]) if len(obs) > 1 else 30 # Use default if obs is incorrect
166
  return f"Green for {optimal_duration}s, Red for {60 - optimal_duration}s"
167
 
168
  except Exception as e: