Fzina commited on
Commit
70891c8
·
verified ·
1 Parent(s): e7b8aae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -79
app.py CHANGED
@@ -8,130 +8,160 @@ from oauth2client.service_account import ServiceAccountCredentials
8
  from ultralytics import YOLO
9
  from stable_baselines3 import PPO
10
  from stable_baselines3.common.vec_env import DummyVecEnv
11
- from mapbox import Geocoder
12
- import random
 
 
13
 
14
  # Environment Variables
15
  WEATHER_API_KEY = os.getenv("WEATHER_API_KEY")
16
- MAPBOX_ACCESS_TOKEN = os.getenv("MAPBOX_ACCESS_TOKEN")
 
 
17
 
18
  # Load YOLOv8 Model
19
- model = YOLO('yolov8n.pt') # YOLOv8 small model
 
 
 
 
 
20
 
21
  # Google Sheets Setup
22
  def setup_google_sheets(sheet_name):
23
- scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
24
- creds = ServiceAccountCredentials.from_json_keyfile_name('credentials.json', scope)
25
- client = gspread.authorize(creds)
26
- sheet = client.open(sheet_name).sheet1
27
- return sheet
 
 
 
 
 
28
 
29
  sheet = setup_google_sheets("TrafficManagementData")
30
 
31
- # Mapbox API Setup
32
- geocoder = Geocoder(access_token=MAPBOX_ACCESS_TOKEN)
33
 
34
  # Analyze Traffic with YOLOv8
35
  def analyze_traffic(image):
36
- results = model.predict(image)
37
- detected_objects = results[0].boxes
38
- vehicle_count = sum(1 for box in detected_objects if box.cls in [2, 3, 5, 7]) # Vehicle class IDs: car, truck, bus, motorbike
39
-
40
- # Predict congestion level
41
- congestion_level = "High" if vehicle_count > 150 else "Medium" if vehicle_count > 75 else "Low"
42
- flow_rate = "Slow" if congestion_level == "High" else "Moderate" if congestion_level == "Medium" else "Smooth"
43
- return vehicle_count, congestion_level, flow_rate
44
-
45
- # Simulated Reinforcement Learning Environment for Signal Optimization
46
- class TrafficSimEnv:
 
 
 
47
  def __init__(self, congestion_level):
 
48
  self.congestion_level = congestion_level
49
- self.current_signal = 30 # Start with a 30-second green signal
 
 
 
 
 
 
50
  self.steps = 0
51
 
52
- def reset(self):
53
  self.steps = 0
54
- return [self.congestion_level, self.current_signal]
 
 
 
55
 
56
  def step(self, action):
57
- self.current_signal = max(20, min(60, self.current_signal + action)) # Adjust signal duration
 
58
  reward = self._calculate_reward()
59
  self.steps += 1
60
- done = self.steps >= 10 # Limit to 10 steps for optimization
61
- return [self.congestion_level, self.current_signal], reward, done, {}
 
62
 
63
  def _calculate_reward(self):
64
  if self.congestion_level == "High":
65
- return -abs(40 - self.current_signal) # Reward for closer to 40 seconds green
66
  elif self.congestion_level == "Medium":
67
  return -abs(30 - self.current_signal)
68
  else:
69
  return -abs(20 - self.current_signal)
70
 
71
- def sample_action(self):
72
- return random.choice([-5, 0, 5]) # Increase, decrease, or keep signal duration
 
 
 
73
 
74
  def optimize_signal_rl(congestion_level):
75
- # Simulate RL environment
76
- env = DummyVecEnv([lambda: TrafficSimEnv(congestion_level)])
77
- model = PPO("MlpPolicy", env, verbose=0)
78
- model.learn(total_timesteps=1000)
79
-
80
- # Predict optimal signal duration
81
- obs = env.reset()
82
- for _ in range(10):
83
- action, _states = model.predict(obs)
84
- obs, reward, done, info = env.step(action)
85
- if done:
86
- break
87
-
88
- optimal_duration = int(obs[0][1]) # Extract the signal duration
89
- return f"Green for {optimal_duration}s, Red for {60 - optimal_duration}s"
90
-
91
- # Location-Based Signal Detection
92
  def check_signal_nearby(latitude, longitude):
93
- response = geocoder.reverse(lon=longitude, lat=latitude)
94
- if response.status_code == 200:
95
- features = response.geojson().get('features', [])
96
- for feature in features:
97
- if 'traffic_signal' in feature.get('place_type', []):
98
- return True
99
- return False
100
-
101
- # Get Weather Impact on Traffic
102
  def get_weather_impact(api_key):
103
- url = f"http://api.openweathermap.org/data/2.5/weather?q=New+York&appid={api_key}"
104
- response = requests.get(url)
105
- if response.status_code == 200:
106
- data = response.json()
107
- weather = data['weather'][0]['description']
108
- if "rain" in weather or "snow" in weather:
109
- return "Traffic is likely to be slower due to weather conditions."
110
- return "Weather is clear, no major impact on traffic."
111
- return "Error fetching weather data."
112
-
113
- # Process Traffic Image
 
 
 
 
 
114
  def process_traffic_image(image):
115
- # Analyze traffic using YOLOv8
116
  vehicle_count, congestion_level, flow_rate = analyze_traffic(image)
117
-
118
- # Optimize signal using RL
119
  signal_timing = optimize_signal_rl(congestion_level)
120
-
121
- # Get weather impact
122
  weather_impact = get_weather_impact(WEATHER_API_KEY)
123
-
124
- # Log data into Google Sheets
125
- sheet.append_row([vehicle_count, congestion_level, flow_rate, signal_timing, weather_impact])
126
-
127
- # Provide detailed results
128
  return (f"Detected Vehicles: {vehicle_count}\n"
129
  f"Congestion Level: {congestion_level}\n"
130
  f"Traffic Flow: {flow_rate}\n"
131
  f"Signal Timing Suggestion: {signal_timing}\n"
132
  f"Weather Impact: {weather_impact}")
133
 
134
- # Gradio Interface
135
  interface = gr.Interface(
136
  fn=process_traffic_image,
137
  inputs=gr.Image(type="numpy", label="Upload Traffic Image"),
 
8
  from ultralytics import YOLO
9
  from stable_baselines3 import PPO
10
  from stable_baselines3.common.vec_env import DummyVecEnv
11
+ import gymnasium as gym
12
+ from gymnasium import spaces
13
+ import logging
14
+ from geopy.geocoders import Nominatim
15
 
16
  # Environment Variables
17
  WEATHER_API_KEY = os.getenv("WEATHER_API_KEY")
18
+
19
+ # Logging setup
20
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
21
 
22
  # Load YOLOv8 Model
23
+ try:
24
+ model = YOLO("yolov8n.pt") # YOLOv8 small model
25
+ logging.info("YOLOv8 model loaded successfully.")
26
+ except Exception as e:
27
+ logging.error(f"Failed to load YOLOv8 model: {e}")
28
+ model = None
29
 
30
  # Google Sheets Setup
31
  def setup_google_sheets(sheet_name):
32
+ try:
33
+ scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
34
+ creds = ServiceAccountCredentials.from_json_keyfile_name("credentials.json", scope)
35
+ client = gspread.authorize(creds)
36
+ sheet = client.open(sheet_name).sheet1
37
+ logging.info("Google Sheets setup completed.")
38
+ return sheet
39
+ except Exception as e:
40
+ logging.error(f"Error setting up Google Sheets: {e}")
41
+ return None
42
 
43
  sheet = setup_google_sheets("TrafficManagementData")
44
 
45
+ # OpenStreetMap Geocoder Setup
46
+ geolocator = Nominatim(user_agent="traffic_management_system")
47
 
48
  # Analyze Traffic with YOLOv8
49
  def analyze_traffic(image):
50
+ if model is None:
51
+ return 0, "Error", "Error"
52
+ try:
53
+ results = model.predict(image)
54
+ detected_objects = results[0].boxes
55
+ vehicle_count = sum(1 for box in detected_objects if box.cls in [2, 3, 5, 7]) # Vehicle class IDs
56
+ congestion_level = "High" if vehicle_count > 150 else "Medium" if vehicle_count > 75 else "Low"
57
+ flow_rate = "Slow" if congestion_level == "High" else "Moderate" if congestion_level == "Medium" else "Smooth"
58
+ return vehicle_count, congestion_level, flow_rate
59
+ except Exception as e:
60
+ logging.error(f"Error analyzing traffic: {e}")
61
+ return 0, "Error", "Error"
62
+
63
+ class TrafficSimEnv(gym.Env):
64
  def __init__(self, congestion_level):
65
+ super(TrafficSimEnv, self).__init__()
66
  self.congestion_level = congestion_level
67
+ self.action_space = spaces.Discrete(3)
68
+ self.observation_space = spaces.Box(
69
+ low=np.array([0, 20]),
70
+ high=np.array([2, 60]),
71
+ dtype=np.float32
72
+ )
73
+ self.current_signal = 30
74
  self.steps = 0
75
 
76
+ def reset(self, seed=None, options=None):
77
  self.steps = 0
78
+ self.current_signal = 30
79
+ congestion_map = {"Low": 0, "Medium": 1, "High": 2}
80
+ self.congestion_numeric = congestion_map.get(self.congestion_level, 0)
81
+ return np.array([self.congestion_numeric, self.current_signal], dtype=np.float32), {}
82
 
83
  def step(self, action):
84
+ signal_change = {0: -5, 1: 0, 2: +5}[action]
85
+ self.current_signal = max(20, min(60, self.current_signal + signal_change))
86
  reward = self._calculate_reward()
87
  self.steps += 1
88
+ done = self.steps >= 10
89
+ obs = np.array([self.congestion_numeric, self.current_signal], dtype=np.float32)
90
+ return obs, reward, done, {}
91
 
92
  def _calculate_reward(self):
93
  if self.congestion_level == "High":
94
+ return -abs(40 - self.current_signal)
95
  elif self.congestion_level == "Medium":
96
  return -abs(30 - self.current_signal)
97
  else:
98
  return -abs(20 - self.current_signal)
99
 
100
+ def render(self, mode="human"):
101
+ print(f"Current Signal: {self.current_signal}s")
102
+
103
+ def close(self):
104
+ pass
105
 
106
  def optimize_signal_rl(congestion_level):
107
+ try:
108
+ env = DummyVecEnv([lambda: TrafficSimEnv(congestion_level)])
109
+ model = PPO("MlpPolicy", env, verbose=0)
110
+ model.learn(total_timesteps=1000)
111
+ obs = env.reset()
112
+ for _ in range(10):
113
+ action, _ = model.predict(obs)
114
+ obs, reward, done, _ = env.step(action)
115
+ if done:
116
+ break
117
+ optimal_duration = int(obs[0][1])
118
+ return f"Green for {optimal_duration}s, Red for {60 - optimal_duration}s"
119
+ except Exception as e:
120
+ logging.error(f"Error optimizing signal with RL: {e}")
121
+ return "Error in RL Optimization"
122
+
 
123
  def check_signal_nearby(latitude, longitude):
124
+ try:
125
+ location = geolocator.reverse((latitude, longitude))
126
+ if location and "traffic_signal" in location.raw.get("address", {}):
127
+ return True
128
+ return False
129
+ except Exception as e:
130
+ logging.error(f"Error checking signal nearby: {e}")
131
+ return False
132
+
133
  def get_weather_impact(api_key):
134
+ if not api_key:
135
+ return "API Key not found."
136
+ try:
137
+ url = f"http://api.openweathermap.org/data/2.5/weather?q=New+York&appid={api_key}"
138
+ response = requests.get(url)
139
+ if response.status_code == 200:
140
+ data = response.json()
141
+ weather = data["weather"][0]["description"]
142
+ if "rain" in weather or "snow" in weather:
143
+ return "Traffic is likely to be slower due to weather conditions."
144
+ return "Weather is clear, no major impact on traffic."
145
+ return "Error fetching weather data."
146
+ except Exception as e:
147
+ logging.error(f"Error fetching weather impact: {e}")
148
+ return "Error fetching weather data."
149
+
150
  def process_traffic_image(image):
 
151
  vehicle_count, congestion_level, flow_rate = analyze_traffic(image)
 
 
152
  signal_timing = optimize_signal_rl(congestion_level)
 
 
153
  weather_impact = get_weather_impact(WEATHER_API_KEY)
154
+ try:
155
+ if sheet:
156
+ sheet.append_row([vehicle_count, congestion_level, flow_rate, signal_timing, weather_impact])
157
+ except Exception as e:
158
+ logging.error(f"Error logging data to Google Sheets: {e}")
159
  return (f"Detected Vehicles: {vehicle_count}\n"
160
  f"Congestion Level: {congestion_level}\n"
161
  f"Traffic Flow: {flow_rate}\n"
162
  f"Signal Timing Suggestion: {signal_timing}\n"
163
  f"Weather Impact: {weather_impact}")
164
 
 
165
  interface = gr.Interface(
166
  fn=process_traffic_image,
167
  inputs=gr.Image(type="numpy", label="Upload Traffic Image"),