Fzina commited on
Commit
7131d6e
·
verified ·
1 Parent(s): 94264f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -30
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import logging
3
  import requests
4
- from geopy.geocoders import Nominatim # Valid import
5
  from stable_baselines3 import PPO
6
  from stable_baselines3.common.vec_env import DummyVecEnv
7
  import gymnasium as gym
@@ -17,9 +17,9 @@ HOSTED_API_URL = os.getenv("HOSTED_API_URL") # FastAPI backend URL
17
  WEATHER_API_KEY = os.getenv("WEATHER_API_KEY") # OpenWeatherMap API key
18
 
19
  # Logging setup
20
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") # Valid logging format
21
 
22
- # Validate Environment Variables
23
  if not HOSTED_API_URL:
24
  logging.error("HOSTED_API_URL environment variable is not set.")
25
  raise ValueError("HOSTED_API_URL must be set.")
@@ -28,9 +28,9 @@ if not WEATHER_API_KEY:
28
  raise ValueError("WEATHER_API_KEY must be set.")
29
 
30
  # OpenStreetMap Setup
31
- geolocator = Nominatim(user_agent="traffic_management_system") # Valid usage of Nominatim
32
 
33
- # Analyze Traffic using FastAPI API URL
34
  def analyze_traffic(image_path):
35
  """
36
  Sends the traffic image to the FastAPI backend for analysis.
@@ -69,17 +69,14 @@ class TrafficSimEnv(gym.Env):
69
  super(TrafficSimEnv, self).__init__()
70
  self.congestion_level = congestion_level
71
 
72
- # Define observation space: [congestion_level, signal_duration]
73
  self.observation_space = spaces.Box(
74
  low=np.array([0, 0], dtype=np.float32),
75
  high=np.array([10, 60], dtype=np.float32),
76
  dtype=np.float32
77
  )
78
 
79
- # Define action space: [increase, decrease, maintain]
80
  self.action_space = spaces.Discrete(3)
81
-
82
- # Initial state
83
  self.state = np.array([congestion_level, 30], dtype=np.float32)
84
  self.done = False
85
 
@@ -97,26 +94,23 @@ class TrafficSimEnv(gym.Env):
97
  congestion, signal_duration = self.state
98
 
99
  # Apply action
100
- if action == 0: # Decrease signal duration
101
  signal_duration = max(10, signal_duration - 5)
102
- elif action == 1: # Maintain signal duration
103
  signal_duration = signal_duration
104
- elif action == 2: # Increase signal duration
105
  signal_duration = min(60, signal_duration + 5)
106
 
107
- # Update congestion (simple simulation logic for this example)
108
  if signal_duration > 30:
109
  congestion += 1
110
  else:
111
  congestion -= 1
112
 
113
- # Set rewards (example logic)
114
  if 20 <= signal_duration <= 40:
115
  reward = 0
116
  else:
117
  reward = -abs(signal_duration - 30)
118
 
119
- # Check if done
120
  self.done = congestion <= 0 or congestion >= 10
121
 
122
  self.state = np.array([congestion, signal_duration], dtype=np.float32)
@@ -132,24 +126,22 @@ class TrafficSimEnv(gym.Env):
132
 
133
  def optimize_signal_rl(congestion_level):
134
  try:
135
- # Map congestion levels (string to numeric)
136
  congestion_map = {"Low": 2, "Medium": 5, "High": 8}
137
  congestion_level = congestion_map.get(congestion_level, 5) if isinstance(congestion_level, str) else congestion_level
138
 
139
  # Create environment
140
  env = DummyVecEnv([lambda: TrafficSimEnv(congestion_level)])
141
-
142
- # Initialize PPO model
143
  model = PPO("MlpPolicy", env, verbose=0)
144
 
145
  # Train the model
146
  model.learn(total_timesteps=1000)
147
 
148
- # Reset the environment and get the initial observation
149
  obs = env.reset()
150
- logging.info(f"Initial observation: {obs}")
151
 
152
- # Ensure `obs` is valid
153
  if not isinstance(obs, np.ndarray) or obs.ndim != 2:
154
  raise ValueError(f"Invalid observation after reset: {obs} (type: {type(obs)}, ndim: {obs.ndim if isinstance(obs, np.ndarray) else 'N/A'})")
155
 
@@ -164,21 +156,21 @@ def optimize_signal_rl(congestion_level):
164
  # Check and log the observation
165
  logging.debug(f"Step {step_count}: Obs: {obs}, Rewards: {rewards}, Done: {dones}, Infos: {infos}")
166
 
167
- # Validate `obs` after each step
168
  if not isinstance(obs, np.ndarray):
169
  raise ValueError(f"Observation is not an array after step {step_count}: {obs}")
170
- if obs.ndim == 0: # Handle 0-dimensional observation
171
- obs = np.array([obs]) # Convert to at least 1D
172
  logging.warning(f"Converted 0D obs to array: {obs}")
173
- elif obs.ndim == 1: # Ensure it's 2D for consistency
174
  obs = obs.reshape(1, -1)
175
 
176
- # Stop if the environment signals it's done
177
  if dones[0]:
178
  break
179
 
180
- # Extract the optimal signal duration from the last valid observation
181
- obs = obs.flatten() # Ensure it's a 1D array for indexing
182
  if len(obs) < 2:
183
  raise ValueError(f"Observation does not contain enough elements: {obs}")
184
  optimal_duration = int(obs[1])
@@ -199,7 +191,7 @@ def process_traffic_image(image):
199
  Orchestrates the traffic analysis workflow.
200
  """
201
 
202
- # Save the uploaded image temporarily
203
  image_path = "temp_traffic_image.jpg"
204
  image.save(image_path)
205
 
@@ -214,7 +206,7 @@ def process_traffic_image(image):
214
  logging.error(f"Error in backend request: {e}")
215
  return "Error in backend request.", None
216
  finally:
217
- os.remove(image_path) # Clean up the temporary image
218
 
219
  # Process backend response
220
  if response.status_code == 200:
 
1
  import os
2
  import logging
3
  import requests
4
+ from geopy.geocoders import Nominatim
5
  from stable_baselines3 import PPO
6
  from stable_baselines3.common.vec_env import DummyVecEnv
7
  import gymnasium as gym
 
17
  WEATHER_API_KEY = os.getenv("WEATHER_API_KEY") # OpenWeatherMap API key
18
 
19
  # Logging setup
20
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
21
 
22
+ # Validation Environment Variables
23
  if not HOSTED_API_URL:
24
  logging.error("HOSTED_API_URL environment variable is not set.")
25
  raise ValueError("HOSTED_API_URL must be set.")
 
28
  raise ValueError("WEATHER_API_KEY must be set.")
29
 
30
  # OpenStreetMap Setup
31
+ geolocator = Nominatim(user_agent="traffic_management_system")
32
 
33
+ # Post Image to FastAPI API URL
34
  def analyze_traffic(image_path):
35
  """
36
  Sends the traffic image to the FastAPI backend for analysis.
 
69
  super(TrafficSimEnv, self).__init__()
70
  self.congestion_level = congestion_level
71
 
 
72
  self.observation_space = spaces.Box(
73
  low=np.array([0, 0], dtype=np.float32),
74
  high=np.array([10, 60], dtype=np.float32),
75
  dtype=np.float32
76
  )
77
 
 
78
  self.action_space = spaces.Discrete(3)
79
+
 
80
  self.state = np.array([congestion_level, 30], dtype=np.float32)
81
  self.done = False
82
 
 
94
  congestion, signal_duration = self.state
95
 
96
  # Apply action
97
+ if action == 0:
98
  signal_duration = max(10, signal_duration - 5)
99
+ elif action == 1:
100
  signal_duration = signal_duration
101
+ elif action == 2:
102
  signal_duration = min(60, signal_duration + 5)
103
 
 
104
  if signal_duration > 30:
105
  congestion += 1
106
  else:
107
  congestion -= 1
108
 
 
109
  if 20 <= signal_duration <= 40:
110
  reward = 0
111
  else:
112
  reward = -abs(signal_duration - 30)
113
 
 
114
  self.done = congestion <= 0 or congestion >= 10
115
 
116
  self.state = np.array([congestion, signal_duration], dtype=np.float32)
 
126
 
127
  def optimize_signal_rl(congestion_level):
128
  try:
129
+ # Map congestion levels (string to numeric) converter
130
  congestion_map = {"Low": 2, "Medium": 5, "High": 8}
131
  congestion_level = congestion_map.get(congestion_level, 5) if isinstance(congestion_level, str) else congestion_level
132
 
133
  # Create environment
134
  env = DummyVecEnv([lambda: TrafficSimEnv(congestion_level)])
 
 
135
  model = PPO("MlpPolicy", env, verbose=0)
136
 
137
  # Train the model
138
  model.learn(total_timesteps=1000)
139
 
140
+ # Reset environment and get the initial observation
141
  obs = env.reset()
142
+ logging.info(f"Initial Observation: {obs}")
143
 
144
+ # Ensure `obs` is valid (It was a very bugged area I faced)
145
  if not isinstance(obs, np.ndarray) or obs.ndim != 2:
146
  raise ValueError(f"Invalid observation after reset: {obs} (type: {type(obs)}, ndim: {obs.ndim if isinstance(obs, np.ndarray) else 'N/A'})")
147
 
 
156
  # Check and log the observation
157
  logging.debug(f"Step {step_count}: Obs: {obs}, Rewards: {rewards}, Done: {dones}, Infos: {infos}")
158
 
159
+ # Validate `obs` after each step (To make sure it doesnt go 0 and create an error)
160
  if not isinstance(obs, np.ndarray):
161
  raise ValueError(f"Observation is not an array after step {step_count}: {obs}")
162
+ if obs.ndim == 0:
163
+ obs = np.array([obs]) # Convert at least 1D
164
  logging.warning(f"Converted 0D obs to array: {obs}")
165
+ elif obs.ndim == 1: # Makesure its 2D for consistency
166
  obs = obs.reshape(1, -1)
167
 
168
+ # Stop if the env signals is done
169
  if dones[0]:
170
  break
171
 
172
+ # Get the optimal signal duration from the last valid observation
173
+ obs = obs.flatten() # Confirm its a 1D array for indexing
174
  if len(obs) < 2:
175
  raise ValueError(f"Observation does not contain enough elements: {obs}")
176
  optimal_duration = int(obs[1])
 
191
  Orchestrates the traffic analysis workflow.
192
  """
193
 
194
+ # Save the uploaded image temp
195
  image_path = "temp_traffic_image.jpg"
196
  image.save(image_path)
197
 
 
206
  logging.error(f"Error in backend request: {e}")
207
  return "Error in backend request.", None
208
  finally:
209
+ os.remove(image_path) # Clean up the temp
210
 
211
  # Process backend response
212
  if response.status_code == 200: