Fzina commited on
Commit
91028ce
·
verified ·
1 Parent(s): ee598c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -17
app.py CHANGED
@@ -76,28 +76,49 @@ class TrafficSimEnv(gym.Env):
76
  self.steps = 0
77
 
78
  def reset(self, seed=None, options=None):
 
 
 
79
  self.steps = 0
80
  self.current_signal = 30
81
  congestion_map = {"Low": 0, "Medium": 1, "High": 2}
82
  self.congestion_numeric = congestion_map.get(self.congestion_level, 0)
 
 
83
  obs = np.array([self.congestion_numeric, self.current_signal], dtype=np.float32)
84
- return obs, {} # Fixed format for DummyVecEnv
 
 
85
 
86
  def step(self, action):
87
- print("Step: ", action)
88
- print("Obs before update: ", self.current_signal)
89
- print("Action Taken: ", action)
90
- signal_change = {0: -5, 1: 0, 2: +5}[action]
 
 
 
 
 
 
 
91
  self.current_signal = max(20, min(60, self.current_signal + signal_change))
 
 
92
  reward = self._calculate_reward()
 
 
93
  self.steps += 1
94
- done = self.steps >= 10
95
- truncated = False # Compatibility
 
 
96
  obs = np.array([self.congestion_numeric, self.current_signal], dtype=np.float32)
97
- info = {} # Additional info (if needbe)
98
- return obs, reward, done, truncated, info
99
 
 
100
 
 
101
 
102
  def _calculate_reward(self):
103
  if self.congestion_level == "High":
@@ -116,25 +137,37 @@ class TrafficSimEnv(gym.Env):
116
  # Prior to commit had a lot of errors regarding expected output errors
117
  def optimize_signal_rl(congestion_level):
118
  try:
 
119
  env = DummyVecEnv([lambda: TrafficSimEnv(congestion_level)])
 
 
120
  model = PPO("MlpPolicy", env, verbose=0)
 
 
121
  model.learn(total_timesteps=1000)
122
- obs = env.reset()
123
 
 
 
 
 
124
  for _ in range(10):
125
  action, _ = model.predict(obs)
 
 
126
  obs, reward, done, truncated, info = env.step(action)
127
- logging.info(f"Step results: obs={obs}, reward={reward}, done={done}, info={info}")
128
 
129
- obs = obs[0] # Extract the first (and only) observation from the batch
130
- reward = reward[0] # Reward for the current action
131
- done = done[0] # Done flag for the current step
132
- truncated = truncated[0] # Truncated flag
 
133
 
134
- if done or truncated: # Check if episode ends (either 'done' or 'truncated')
 
135
  break
136
 
137
- optimal_duration = int(obs[1]) if len(obs) > 1 else 30 # Get the optimal signal duration
 
138
  return f"Green for {optimal_duration}s, Red for {60 - optimal_duration}s"
139
  except Exception as e:
140
  logging.error(f"Error optimizing signal with RL: {e}")
@@ -142,6 +175,7 @@ def optimize_signal_rl(congestion_level):
142
 
143
 
144
 
 
145
  def process_traffic_image(image):
146
  """
147
  Orchestrates the traffic analysis workflow.
 
76
  self.steps = 0
77
 
78
  def reset(self, seed=None, options=None):
79
+ """
80
+ Resets the environment and returns the initial observation and info.
81
+ """
82
  self.steps = 0
83
  self.current_signal = 30
84
  congestion_map = {"Low": 0, "Medium": 1, "High": 2}
85
  self.congestion_numeric = congestion_map.get(self.congestion_level, 0)
86
+
87
+ # Initial observation: [congestion level, current signal]
88
  obs = np.array([self.congestion_numeric, self.current_signal], dtype=np.float32)
89
+
90
+ return obs, {} # Returns observation and info (info can be empty)
91
+
92
 
93
  def step(self, action):
94
+ """
95
+ Takes an action and updates the environment state.
96
+ Returns 5 values: observation, reward, done, truncated, and info.
97
+ """
98
+
99
+ print(f"Action taken: {action}")
100
+ print(f"Signal before change: {self.current_signal}")
101
+ print(f"Congestion Level: {self.congestion_numeric}")
102
+
103
+ # Signal changes based on the action
104
+ signal_change = {0: -5, 1: 0, 2: 5}[action]
105
  self.current_signal = max(20, min(60, self.current_signal + signal_change))
106
+
107
+ # Reward Calculation based on congestion level and signal
108
  reward = self._calculate_reward()
109
+
110
+ # Increment the step count
111
  self.steps += 1
112
+ done = self.steps >= 10 # End condition
113
+ truncated = False # Default to False; can change based on custom conditions
114
+
115
+ # Ensure we always return 5 values: obs, reward, done, truncated, info
116
  obs = np.array([self.congestion_numeric, self.current_signal], dtype=np.float32)
117
+ print(f"Observation: {obs}")
 
118
 
119
+ info = {} # Additional info (can remain empty or contain any useful data)
120
 
121
+ return obs, reward, done, truncated, info # Must return 5 values here
122
 
123
  def _calculate_reward(self):
124
  if self.congestion_level == "High":
 
137
  # Prior to commit had a lot of errors regarding expected output errors
138
  def optimize_signal_rl(congestion_level):
139
  try:
140
+ # Create the environment with DummyVecEnv to wrap TrafficSimEnv
141
  env = DummyVecEnv([lambda: TrafficSimEnv(congestion_level)])
142
+
143
+ # Initialize PPO model (policy = "MlpPolicy", for multi-layer perceptron model)
144
  model = PPO("MlpPolicy", env, verbose=0)
145
+
146
+ # Train the model on the environment for 1000 timesteps
147
  model.learn(total_timesteps=1000)
 
148
 
149
+ # Reset environment to start the simulation
150
+ obs, _ = env.reset()
151
+
152
+ # Loop through to simulate for 10 timesteps
153
  for _ in range(10):
154
  action, _ = model.predict(obs)
155
+
156
+ # Step through the environment with the predicted action
157
  obs, reward, done, truncated, info = env.step(action)
 
158
 
159
+ # Extract the first value from each returned array (since env is wrapped in DummyVecEnv)
160
+ obs = obs[0] # First observation (from batch)
161
+ reward = reward[0] # Reward for the action (from batch)
162
+ done = done[0] # Done flag (from batch)
163
+ truncated = truncated[0] # Truncated flag (from batch)
164
 
165
+ # Stop when the environment signals that the episode is done or truncated
166
+ if done or truncated:
167
  break
168
 
169
+ # Extract the optimal signal duration (second observation value) from `obs`
170
+ optimal_duration = int(obs[1]) if len(obs) > 1 else 30 # Ensure the signal value is within range
171
  return f"Green for {optimal_duration}s, Red for {60 - optimal_duration}s"
172
  except Exception as e:
173
  logging.error(f"Error optimizing signal with RL: {e}")
 
175
 
176
 
177
 
178
+
179
  def process_traffic_image(image):
180
  """
181
  Orchestrates the traffic analysis workflow.