Fzina commited on
Commit
8ea7693
·
verified ·
1 Parent(s): 02611b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -52
app.py CHANGED
@@ -66,47 +66,61 @@ class TrafficSimEnv(gym.Env):
66
  def __init__(self, congestion_level):
67
  super(TrafficSimEnv, self).__init__()
68
  self.congestion_level = congestion_level
 
 
 
 
 
69
  self.action_space = spaces.Discrete(3)
70
- self.observation_space = spaces.Box(low=np.array([0, 20]), high=np.array([2, 60]), dtype=np.float32)
71
- self.current_signal = 30
72
- self.steps = 0
73
 
74
- def reset(self, seed=None, options=None):
75
- self.steps = 0
76
- self.current_signal = 30
77
- congestion_map = {"Low": 0, "Medium": 1, "High": 2}
78
- self.congestion_numeric = congestion_map.get(self.congestion_level, 0)
 
79
 
80
- obs = np.array([self.congestion_numeric, self.current_signal], dtype=np.float32)
81
- info = {} # Empty dictionary as info
82
- print(f"RESET: obs={obs}, info={info}") # Debugging print
83
- return obs, info
84
 
85
  def step(self, action):
86
- signal_change = {0: -5, 1: 0, 2: 5}[action]
87
- self.current_signal = max(20, min(60, self.current_signal + signal_change))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- reward = self._calculate_reward()
 
 
 
 
90
 
91
- self.steps += 1
92
- done = self.steps >= 10
93
- truncated = False
94
 
95
- obs = np.array([self.congestion_numeric, self.current_signal], dtype=np.float32)
96
- info = {} # Info dictionary (can be populated with useful debugging data)
97
- print(f"STEP: action={action}, obs={obs}, reward={reward}, done={done}, truncated={truncated}, info={info}")
98
- return obs, reward, done, truncated, info
99
 
100
- def _calculate_reward(self):
101
- if self.congestion_level == "High":
102
- return -abs(40 - self.current_signal)
103
- elif self.congestion_level == "Medium":
104
- return -abs(30 - self.current_signal)
105
- else:
106
- return -abs(20 - self.current_signal)
107
 
108
- def render(self, mode="human"):
109
- print(f"Current Signal: {self.current_signal}s")
110
 
111
  def close(self):
112
  pass
@@ -114,7 +128,7 @@ class TrafficSimEnv(gym.Env):
114
  # Prior to commit had a lot of errors regarding expected output errors
115
  def optimize_signal_rl(congestion_level):
116
  try:
117
- # Create the environment wrapped in DummyVecEnv
118
  env = DummyVecEnv([lambda: TrafficSimEnv(congestion_level)])
119
 
120
  # Initialize PPO model
@@ -124,40 +138,34 @@ def optimize_signal_rl(congestion_level):
124
  model.learn(total_timesteps=1000)
125
 
126
  # Reset the environment
127
- obs, info = env.reset() # For Gymnasium 1.0.0
128
- obs = obs[0] # Extract single observation from the batched observation
129
 
130
  for _ in range(10):
131
- # Predict action
132
  action, _ = model.predict(obs, deterministic=True)
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- # Take a step
135
- step_output = env.step(action)
136
-
137
- # Handle batched output of step()
138
- if len(step_output) == 5:
139
- obs, reward, done, truncated, info = step_output
140
- obs = obs[0] # Extract single observation
141
- reward = reward[0] # Extract single reward
142
- done = done[0] # Extract single done flag
143
- truncated = truncated[0] # Extract single truncated flag
144
- info = info[0] # Extract single info
145
- else:
146
- raise ValueError(f"Unexpected step output: {step_output}")
147
-
148
- # End simulation if done or truncated
149
  if done or truncated:
150
  break
151
 
152
- # Extract optimal signal duration
153
  optimal_duration = int(obs[1]) if len(obs) > 1 else 30
154
  return f"Green for {optimal_duration}s, Red for {60 - optimal_duration}s"
 
155
  except Exception as e:
156
  logging.error(f"Error optimizing signal with RL: {e}")
157
  return "Error in RL Optimization"
158
 
159
 
160
-
161
  def process_traffic_image(image):
162
  """
163
  Orchestrates the traffic analysis workflow.
 
66
  def __init__(self, congestion_level):
67
  super(TrafficSimEnv, self).__init__()
68
  self.congestion_level = congestion_level
69
+
70
+ # Define observation space: [congestion_level, signal_duration]
71
+ self.observation_space = spaces.Box(low=np.array([0, 0]), high=np.array([10, 60]), dtype=np.float32)
72
+
73
+ # Define action space: [increase, decrease, maintain]
74
  self.action_space = spaces.Discrete(3)
 
 
 
75
 
76
+ # Initial state
77
+ self.state = np.array([congestion_level, 30], dtype=np.float32)
78
+ self.done = False
79
+
80
+ def reset(self, *, seed=None, options=None):
81
+ super().reset(seed=seed)
82
 
83
+ self.state = np.array([self.congestion_level, 30], dtype=np.float32)
84
+ self.done = False
85
+ return self.state, {}
 
86
 
87
  def step(self, action):
88
+ if self.done:
89
+ raise RuntimeError("Cannot call step() on a terminated environment. Please reset the environment.")
90
+
91
+ # Extract state components
92
+ congestion, signal_duration = self.state
93
+
94
+ # Apply action
95
+ if action == 0: # Decrease signal duration
96
+ signal_duration = max(10, signal_duration - 5)
97
+ elif action == 1: # Maintain signal duration
98
+ signal_duration = signal_duration
99
+ elif action == 2: # Increase signal duration
100
+ signal_duration = min(60, signal_duration + 5)
101
+
102
+ # Update congestion (simple simulation logic for this example)
103
+ if signal_duration > 30:
104
+ congestion += 1
105
+ else:
106
+ congestion -= 1
107
 
108
+ # Set rewards (example logic)
109
+ if 20 <= signal_duration <= 40:
110
+ reward = 0
111
+ else:
112
+ reward = -abs(signal_duration - 30)
113
 
114
+ # Check if done
115
+ self.done = congestion <= 0 or congestion >= 10
 
116
 
117
+ # Update state
118
+ self.state = np.array([congestion, signal_duration], dtype=np.float32)
 
 
119
 
120
+ return self.state, reward, self.done, False, {}
 
 
 
 
 
 
121
 
122
+ def render(self):
123
+ print(f"State: {self.state}")
124
 
125
  def close(self):
126
  pass
 
128
  # Prior to commit had a lot of errors regarding expected output errors
129
  def optimize_signal_rl(congestion_level):
130
  try:
131
+ # Create environment
132
  env = DummyVecEnv([lambda: TrafficSimEnv(congestion_level)])
133
 
134
  # Initialize PPO model
 
138
  model.learn(total_timesteps=1000)
139
 
140
  # Reset the environment
141
+ obs, info = env.reset()
142
+ obs = obs[0] # Extract first observation from batch
143
 
144
  for _ in range(10):
 
145
  action, _ = model.predict(obs, deterministic=True)
146
+ obs, reward, done, truncated, info = env.step(action)
147
+
148
+ # Ensure single observation extraction
149
+ obs = obs[0]
150
+ reward = reward[0]
151
+ done = done[0]
152
+ truncated = truncated[0]
153
+ info = info[0]
154
+
155
+ # Log each step for debugging
156
+ print(f"STEP: action={action}, obs={obs}, reward={reward}, done={done}, truncated={truncated}, info={info}")
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  if done or truncated:
159
  break
160
 
 
161
  optimal_duration = int(obs[1]) if len(obs) > 1 else 30
162
  return f"Green for {optimal_duration}s, Red for {60 - optimal_duration}s"
163
+
164
  except Exception as e:
165
  logging.error(f"Error optimizing signal with RL: {e}")
166
  return "Error in RL Optimization"
167
 
168
 
 
169
  def process_traffic_image(image):
170
  """
171
  Orchestrates the traffic analysis workflow.