Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -76,28 +76,49 @@ class TrafficSimEnv(gym.Env):
|
|
76 |
self.steps = 0
|
77 |
|
78 |
def reset(self, seed=None, options=None):
|
|
|
|
|
|
|
79 |
self.steps = 0
|
80 |
self.current_signal = 30
|
81 |
congestion_map = {"Low": 0, "Medium": 1, "High": 2}
|
82 |
self.congestion_numeric = congestion_map.get(self.congestion_level, 0)
|
|
|
|
|
83 |
obs = np.array([self.congestion_numeric, self.current_signal], dtype=np.float32)
|
84 |
-
|
|
|
|
|
85 |
|
86 |
def step(self, action):
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
self.current_signal = max(20, min(60, self.current_signal + signal_change))
|
|
|
|
|
92 |
reward = self._calculate_reward()
|
|
|
|
|
93 |
self.steps += 1
|
94 |
-
done = self.steps >= 10
|
95 |
-
truncated = False #
|
|
|
|
|
96 |
obs = np.array([self.congestion_numeric, self.current_signal], dtype=np.float32)
|
97 |
-
|
98 |
-
return obs, reward, done, truncated, info
|
99 |
|
|
|
100 |
|
|
|
101 |
|
102 |
def _calculate_reward(self):
|
103 |
if self.congestion_level == "High":
|
@@ -116,25 +137,37 @@ class TrafficSimEnv(gym.Env):
|
|
116 |
# Prior to commit had a lot of errors regarding expected output errors
|
117 |
def optimize_signal_rl(congestion_level):
|
118 |
try:
|
|
|
119 |
env = DummyVecEnv([lambda: TrafficSimEnv(congestion_level)])
|
|
|
|
|
120 |
model = PPO("MlpPolicy", env, verbose=0)
|
|
|
|
|
121 |
model.learn(total_timesteps=1000)
|
122 |
-
obs = env.reset()
|
123 |
|
|
|
|
|
|
|
|
|
124 |
for _ in range(10):
|
125 |
action, _ = model.predict(obs)
|
|
|
|
|
126 |
obs, reward, done, truncated, info = env.step(action)
|
127 |
-
logging.info(f"Step results: obs={obs}, reward={reward}, done={done}, info={info}")
|
128 |
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
133 |
|
134 |
-
|
|
|
135 |
break
|
136 |
|
137 |
-
|
|
|
138 |
return f"Green for {optimal_duration}s, Red for {60 - optimal_duration}s"
|
139 |
except Exception as e:
|
140 |
logging.error(f"Error optimizing signal with RL: {e}")
|
@@ -142,6 +175,7 @@ def optimize_signal_rl(congestion_level):
|
|
142 |
|
143 |
|
144 |
|
|
|
145 |
def process_traffic_image(image):
|
146 |
"""
|
147 |
Orchestrates the traffic analysis workflow.
|
|
|
76 |
self.steps = 0
|
77 |
|
78 |
def reset(self, seed=None, options=None):
|
79 |
+
"""
|
80 |
+
Resets the environment and returns the initial observation and info.
|
81 |
+
"""
|
82 |
self.steps = 0
|
83 |
self.current_signal = 30
|
84 |
congestion_map = {"Low": 0, "Medium": 1, "High": 2}
|
85 |
self.congestion_numeric = congestion_map.get(self.congestion_level, 0)
|
86 |
+
|
87 |
+
# Initial observation: [congestion level, current signal]
|
88 |
obs = np.array([self.congestion_numeric, self.current_signal], dtype=np.float32)
|
89 |
+
|
90 |
+
return obs, {} # Returns observation and info (info can be empty)
|
91 |
+
|
92 |
|
93 |
def step(self, action):
|
94 |
+
"""
|
95 |
+
Takes an action and updates the environment state.
|
96 |
+
Returns 5 values: observation, reward, done, truncated, and info.
|
97 |
+
"""
|
98 |
+
|
99 |
+
print(f"Action taken: {action}")
|
100 |
+
print(f"Signal before change: {self.current_signal}")
|
101 |
+
print(f"Congestion Level: {self.congestion_numeric}")
|
102 |
+
|
103 |
+
# Signal changes based on the action
|
104 |
+
signal_change = {0: -5, 1: 0, 2: 5}[action]
|
105 |
self.current_signal = max(20, min(60, self.current_signal + signal_change))
|
106 |
+
|
107 |
+
# Reward Calculation based on congestion level and signal
|
108 |
reward = self._calculate_reward()
|
109 |
+
|
110 |
+
# Increment the step count
|
111 |
self.steps += 1
|
112 |
+
done = self.steps >= 10 # End condition
|
113 |
+
truncated = False # Default to False; can change based on custom conditions
|
114 |
+
|
115 |
+
# Ensure we always return 5 values: obs, reward, done, truncated, info
|
116 |
obs = np.array([self.congestion_numeric, self.current_signal], dtype=np.float32)
|
117 |
+
print(f"Observation: {obs}")
|
|
|
118 |
|
119 |
+
info = {} # Additional info (can remain empty or contain any useful data)
|
120 |
|
121 |
+
return obs, reward, done, truncated, info # Must return 5 values here
|
122 |
|
123 |
def _calculate_reward(self):
|
124 |
if self.congestion_level == "High":
|
|
|
137 |
# Prior to commit had a lot of errors regarding expected output errors
|
138 |
def optimize_signal_rl(congestion_level):
|
139 |
try:
|
140 |
+
# Create the environment with DummyVecEnv to wrap TrafficSimEnv
|
141 |
env = DummyVecEnv([lambda: TrafficSimEnv(congestion_level)])
|
142 |
+
|
143 |
+
# Initialize PPO model (policy = "MlpPolicy", for multi-layer perceptron model)
|
144 |
model = PPO("MlpPolicy", env, verbose=0)
|
145 |
+
|
146 |
+
# Train the model on the environment for 1000 timesteps
|
147 |
model.learn(total_timesteps=1000)
|
|
|
148 |
|
149 |
+
# Reset environment to start the simulation
|
150 |
+
obs, _ = env.reset()
|
151 |
+
|
152 |
+
# Loop through to simulate for 10 timesteps
|
153 |
for _ in range(10):
|
154 |
action, _ = model.predict(obs)
|
155 |
+
|
156 |
+
# Step through the environment with the predicted action
|
157 |
obs, reward, done, truncated, info = env.step(action)
|
|
|
158 |
|
159 |
+
# Extract the first value from each returned array (since env is wrapped in DummyVecEnv)
|
160 |
+
obs = obs[0] # First observation (from batch)
|
161 |
+
reward = reward[0] # Reward for the action (from batch)
|
162 |
+
done = done[0] # Done flag (from batch)
|
163 |
+
truncated = truncated[0] # Truncated flag (from batch)
|
164 |
|
165 |
+
# Stop when the environment signals that the episode is done or truncated
|
166 |
+
if done or truncated:
|
167 |
break
|
168 |
|
169 |
+
# Extract the optimal signal duration (second observation value) from `obs`
|
170 |
+
optimal_duration = int(obs[1]) if len(obs) > 1 else 30 # Ensure the signal value is within range
|
171 |
return f"Green for {optimal_duration}s, Red for {60 - optimal_duration}s"
|
172 |
except Exception as e:
|
173 |
logging.error(f"Error optimizing signal with RL: {e}")
|
|
|
175 |
|
176 |
|
177 |
|
178 |
+
|
179 |
def process_traffic_image(image):
|
180 |
"""
|
181 |
Orchestrates the traffic analysis workflow.
|