Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import os
|
2 |
import logging
|
3 |
import requests
|
4 |
-
from geopy.geocoders import Nominatim
|
5 |
from stable_baselines3 import PPO
|
6 |
from stable_baselines3.common.vec_env import DummyVecEnv
|
7 |
import gymnasium as gym
|
@@ -17,9 +17,9 @@ HOSTED_API_URL = os.getenv("HOSTED_API_URL") # FastAPI backend URL
|
|
17 |
WEATHER_API_KEY = os.getenv("WEATHER_API_KEY") # OpenWeatherMap API key
|
18 |
|
19 |
# Logging setup
|
20 |
-
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
21 |
|
22 |
-
#
|
23 |
if not HOSTED_API_URL:
|
24 |
logging.error("HOSTED_API_URL environment variable is not set.")
|
25 |
raise ValueError("HOSTED_API_URL must be set.")
|
@@ -28,9 +28,9 @@ if not WEATHER_API_KEY:
|
|
28 |
raise ValueError("WEATHER_API_KEY must be set.")
|
29 |
|
30 |
# OpenStreetMap Setup
|
31 |
-
geolocator = Nominatim(user_agent="traffic_management_system")
|
32 |
|
33 |
-
#
|
34 |
def analyze_traffic(image_path):
|
35 |
"""
|
36 |
Sends the traffic image to the FastAPI backend for analysis.
|
@@ -69,17 +69,14 @@ class TrafficSimEnv(gym.Env):
|
|
69 |
super(TrafficSimEnv, self).__init__()
|
70 |
self.congestion_level = congestion_level
|
71 |
|
72 |
-
# Define observation space: [congestion_level, signal_duration]
|
73 |
self.observation_space = spaces.Box(
|
74 |
low=np.array([0, 0], dtype=np.float32),
|
75 |
high=np.array([10, 60], dtype=np.float32),
|
76 |
dtype=np.float32
|
77 |
)
|
78 |
|
79 |
-
# Define action space: [increase, decrease, maintain]
|
80 |
self.action_space = spaces.Discrete(3)
|
81 |
-
|
82 |
-
# Initial state
|
83 |
self.state = np.array([congestion_level, 30], dtype=np.float32)
|
84 |
self.done = False
|
85 |
|
@@ -97,26 +94,23 @@ class TrafficSimEnv(gym.Env):
|
|
97 |
congestion, signal_duration = self.state
|
98 |
|
99 |
# Apply action
|
100 |
-
if action == 0:
|
101 |
signal_duration = max(10, signal_duration - 5)
|
102 |
-
elif action == 1:
|
103 |
signal_duration = signal_duration
|
104 |
-
elif action == 2:
|
105 |
signal_duration = min(60, signal_duration + 5)
|
106 |
|
107 |
-
# Update congestion (simple simulation logic for this example)
|
108 |
if signal_duration > 30:
|
109 |
congestion += 1
|
110 |
else:
|
111 |
congestion -= 1
|
112 |
|
113 |
-
# Set rewards (example logic)
|
114 |
if 20 <= signal_duration <= 40:
|
115 |
reward = 0
|
116 |
else:
|
117 |
reward = -abs(signal_duration - 30)
|
118 |
|
119 |
-
# Check if done
|
120 |
self.done = congestion <= 0 or congestion >= 10
|
121 |
|
122 |
self.state = np.array([congestion, signal_duration], dtype=np.float32)
|
@@ -132,24 +126,22 @@ class TrafficSimEnv(gym.Env):
|
|
132 |
|
133 |
def optimize_signal_rl(congestion_level):
|
134 |
try:
|
135 |
-
# Map congestion levels (string to numeric)
|
136 |
congestion_map = {"Low": 2, "Medium": 5, "High": 8}
|
137 |
congestion_level = congestion_map.get(congestion_level, 5) if isinstance(congestion_level, str) else congestion_level
|
138 |
|
139 |
# Create environment
|
140 |
env = DummyVecEnv([lambda: TrafficSimEnv(congestion_level)])
|
141 |
-
|
142 |
-
# Initialize PPO model
|
143 |
model = PPO("MlpPolicy", env, verbose=0)
|
144 |
|
145 |
# Train the model
|
146 |
model.learn(total_timesteps=1000)
|
147 |
|
148 |
-
# Reset
|
149 |
obs = env.reset()
|
150 |
-
logging.info(f"Initial
|
151 |
|
152 |
-
# Ensure `obs` is valid
|
153 |
if not isinstance(obs, np.ndarray) or obs.ndim != 2:
|
154 |
raise ValueError(f"Invalid observation after reset: {obs} (type: {type(obs)}, ndim: {obs.ndim if isinstance(obs, np.ndarray) else 'N/A'})")
|
155 |
|
@@ -164,21 +156,21 @@ def optimize_signal_rl(congestion_level):
|
|
164 |
# Check and log the observation
|
165 |
logging.debug(f"Step {step_count}: Obs: {obs}, Rewards: {rewards}, Done: {dones}, Infos: {infos}")
|
166 |
|
167 |
-
# Validate `obs` after each step
|
168 |
if not isinstance(obs, np.ndarray):
|
169 |
raise ValueError(f"Observation is not an array after step {step_count}: {obs}")
|
170 |
-
if obs.ndim == 0:
|
171 |
-
obs = np.array([obs]) # Convert
|
172 |
logging.warning(f"Converted 0D obs to array: {obs}")
|
173 |
-
elif obs.ndim == 1: #
|
174 |
obs = obs.reshape(1, -1)
|
175 |
|
176 |
-
# Stop if the
|
177 |
if dones[0]:
|
178 |
break
|
179 |
|
180 |
-
#
|
181 |
-
obs = obs.flatten() #
|
182 |
if len(obs) < 2:
|
183 |
raise ValueError(f"Observation does not contain enough elements: {obs}")
|
184 |
optimal_duration = int(obs[1])
|
@@ -199,7 +191,7 @@ def process_traffic_image(image):
|
|
199 |
Orchestrates the traffic analysis workflow.
|
200 |
"""
|
201 |
|
202 |
-
# Save the uploaded image
|
203 |
image_path = "temp_traffic_image.jpg"
|
204 |
image.save(image_path)
|
205 |
|
@@ -214,7 +206,7 @@ def process_traffic_image(image):
|
|
214 |
logging.error(f"Error in backend request: {e}")
|
215 |
return "Error in backend request.", None
|
216 |
finally:
|
217 |
-
os.remove(image_path) # Clean up the
|
218 |
|
219 |
# Process backend response
|
220 |
if response.status_code == 200:
|
|
|
1 |
import os
|
2 |
import logging
|
3 |
import requests
|
4 |
+
from geopy.geocoders import Nominatim
|
5 |
from stable_baselines3 import PPO
|
6 |
from stable_baselines3.common.vec_env import DummyVecEnv
|
7 |
import gymnasium as gym
|
|
|
17 |
WEATHER_API_KEY = os.getenv("WEATHER_API_KEY") # OpenWeatherMap API key
|
18 |
|
19 |
# Logging setup
|
20 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
21 |
|
22 |
+
# Validation Environment Variables
|
23 |
if not HOSTED_API_URL:
|
24 |
logging.error("HOSTED_API_URL environment variable is not set.")
|
25 |
raise ValueError("HOSTED_API_URL must be set.")
|
|
|
28 |
raise ValueError("WEATHER_API_KEY must be set.")
|
29 |
|
30 |
# OpenStreetMap Setup
|
31 |
+
geolocator = Nominatim(user_agent="traffic_management_system")
|
32 |
|
33 |
+
# Post Image to FastAPI API URL
|
34 |
def analyze_traffic(image_path):
|
35 |
"""
|
36 |
Sends the traffic image to the FastAPI backend for analysis.
|
|
|
69 |
super(TrafficSimEnv, self).__init__()
|
70 |
self.congestion_level = congestion_level
|
71 |
|
|
|
72 |
self.observation_space = spaces.Box(
|
73 |
low=np.array([0, 0], dtype=np.float32),
|
74 |
high=np.array([10, 60], dtype=np.float32),
|
75 |
dtype=np.float32
|
76 |
)
|
77 |
|
|
|
78 |
self.action_space = spaces.Discrete(3)
|
79 |
+
|
|
|
80 |
self.state = np.array([congestion_level, 30], dtype=np.float32)
|
81 |
self.done = False
|
82 |
|
|
|
94 |
congestion, signal_duration = self.state
|
95 |
|
96 |
# Apply action
|
97 |
+
if action == 0:
|
98 |
signal_duration = max(10, signal_duration - 5)
|
99 |
+
elif action == 1:
|
100 |
signal_duration = signal_duration
|
101 |
+
elif action == 2:
|
102 |
signal_duration = min(60, signal_duration + 5)
|
103 |
|
|
|
104 |
if signal_duration > 30:
|
105 |
congestion += 1
|
106 |
else:
|
107 |
congestion -= 1
|
108 |
|
|
|
109 |
if 20 <= signal_duration <= 40:
|
110 |
reward = 0
|
111 |
else:
|
112 |
reward = -abs(signal_duration - 30)
|
113 |
|
|
|
114 |
self.done = congestion <= 0 or congestion >= 10
|
115 |
|
116 |
self.state = np.array([congestion, signal_duration], dtype=np.float32)
|
|
|
126 |
|
127 |
def optimize_signal_rl(congestion_level):
|
128 |
try:
|
129 |
+
# Map congestion levels (string to numeric) converter
|
130 |
congestion_map = {"Low": 2, "Medium": 5, "High": 8}
|
131 |
congestion_level = congestion_map.get(congestion_level, 5) if isinstance(congestion_level, str) else congestion_level
|
132 |
|
133 |
# Create environment
|
134 |
env = DummyVecEnv([lambda: TrafficSimEnv(congestion_level)])
|
|
|
|
|
135 |
model = PPO("MlpPolicy", env, verbose=0)
|
136 |
|
137 |
# Train the model
|
138 |
model.learn(total_timesteps=1000)
|
139 |
|
140 |
+
# Reset environment and get the initial observation
|
141 |
obs = env.reset()
|
142 |
+
logging.info(f"Initial Observation: {obs}")
|
143 |
|
144 |
+
# Ensure `obs` is valid (It was a very bugged area I faced)
|
145 |
if not isinstance(obs, np.ndarray) or obs.ndim != 2:
|
146 |
raise ValueError(f"Invalid observation after reset: {obs} (type: {type(obs)}, ndim: {obs.ndim if isinstance(obs, np.ndarray) else 'N/A'})")
|
147 |
|
|
|
156 |
# Check and log the observation
|
157 |
logging.debug(f"Step {step_count}: Obs: {obs}, Rewards: {rewards}, Done: {dones}, Infos: {infos}")
|
158 |
|
159 |
+
# Validate `obs` after each step (To make sure it doesnt go 0 and create an error)
|
160 |
if not isinstance(obs, np.ndarray):
|
161 |
raise ValueError(f"Observation is not an array after step {step_count}: {obs}")
|
162 |
+
if obs.ndim == 0:
|
163 |
+
obs = np.array([obs]) # Convert at least 1D
|
164 |
logging.warning(f"Converted 0D obs to array: {obs}")
|
165 |
+
elif obs.ndim == 1: # Makesure its 2D for consistency
|
166 |
obs = obs.reshape(1, -1)
|
167 |
|
168 |
+
# Stop if the env signals is done
|
169 |
if dones[0]:
|
170 |
break
|
171 |
|
172 |
+
# Get the optimal signal duration from the last valid observation
|
173 |
+
obs = obs.flatten() # Confirm its a 1D array for indexing
|
174 |
if len(obs) < 2:
|
175 |
raise ValueError(f"Observation does not contain enough elements: {obs}")
|
176 |
optimal_duration = int(obs[1])
|
|
|
191 |
Orchestrates the traffic analysis workflow.
|
192 |
"""
|
193 |
|
194 |
+
# Save the uploaded image temp
|
195 |
image_path = "temp_traffic_image.jpg"
|
196 |
image.save(image_path)
|
197 |
|
|
|
206 |
logging.error(f"Error in backend request: {e}")
|
207 |
return "Error in backend request.", None
|
208 |
finally:
|
209 |
+
os.remove(image_path) # Clean up the temp
|
210 |
|
211 |
# Process backend response
|
212 |
if response.status_code == 200:
|