Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,130 +8,160 @@ from oauth2client.service_account import ServiceAccountCredentials
|
|
8 |
from ultralytics import YOLO
|
9 |
from stable_baselines3 import PPO
|
10 |
from stable_baselines3.common.vec_env import DummyVecEnv
|
11 |
-
|
12 |
-
import
|
|
|
|
|
13 |
|
14 |
# Environment Variables
|
15 |
WEATHER_API_KEY = os.getenv("WEATHER_API_KEY")
|
16 |
-
|
|
|
|
|
17 |
|
18 |
# Load YOLOv8 Model
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
# Google Sheets Setup
|
22 |
def setup_google_sheets(sheet_name):
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
sheet = setup_google_sheets("TrafficManagementData")
|
30 |
|
31 |
-
#
|
32 |
-
|
33 |
|
34 |
# Analyze Traffic with YOLOv8
|
35 |
def analyze_traffic(image):
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
47 |
def __init__(self, congestion_level):
|
|
|
48 |
self.congestion_level = congestion_level
|
49 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
self.steps = 0
|
51 |
|
52 |
-
def reset(self):
|
53 |
self.steps = 0
|
54 |
-
|
|
|
|
|
|
|
55 |
|
56 |
def step(self, action):
|
57 |
-
|
|
|
58 |
reward = self._calculate_reward()
|
59 |
self.steps += 1
|
60 |
-
done = self.steps >= 10
|
61 |
-
|
|
|
62 |
|
63 |
def _calculate_reward(self):
|
64 |
if self.congestion_level == "High":
|
65 |
-
return -abs(40 - self.current_signal)
|
66 |
elif self.congestion_level == "Medium":
|
67 |
return -abs(30 - self.current_signal)
|
68 |
else:
|
69 |
return -abs(20 - self.current_signal)
|
70 |
|
71 |
-
def
|
72 |
-
|
|
|
|
|
|
|
73 |
|
74 |
def optimize_signal_rl(congestion_level):
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
# Location-Based Signal Detection
|
92 |
def check_signal_nearby(latitude, longitude):
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
def get_weather_impact(api_key):
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
if
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
114 |
def process_traffic_image(image):
|
115 |
-
# Analyze traffic using YOLOv8
|
116 |
vehicle_count, congestion_level, flow_rate = analyze_traffic(image)
|
117 |
-
|
118 |
-
# Optimize signal using RL
|
119 |
signal_timing = optimize_signal_rl(congestion_level)
|
120 |
-
|
121 |
-
# Get weather impact
|
122 |
weather_impact = get_weather_impact(WEATHER_API_KEY)
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
return (f"Detected Vehicles: {vehicle_count}\n"
|
129 |
f"Congestion Level: {congestion_level}\n"
|
130 |
f"Traffic Flow: {flow_rate}\n"
|
131 |
f"Signal Timing Suggestion: {signal_timing}\n"
|
132 |
f"Weather Impact: {weather_impact}")
|
133 |
|
134 |
-
# Gradio Interface
|
135 |
interface = gr.Interface(
|
136 |
fn=process_traffic_image,
|
137 |
inputs=gr.Image(type="numpy", label="Upload Traffic Image"),
|
|
|
8 |
from ultralytics import YOLO
|
9 |
from stable_baselines3 import PPO
|
10 |
from stable_baselines3.common.vec_env import DummyVecEnv
|
11 |
+
import gymnasium as gym
|
12 |
+
from gymnasium import spaces
|
13 |
+
import logging
|
14 |
+
from geopy.geocoders import Nominatim
|
15 |
|
16 |
# Environment Variables
|
17 |
WEATHER_API_KEY = os.getenv("WEATHER_API_KEY")
|
18 |
+
|
19 |
+
# Logging setup
|
20 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
21 |
|
22 |
# Load YOLOv8 Model
|
23 |
+
try:
|
24 |
+
model = YOLO("yolov8n.pt") # YOLOv8 small model
|
25 |
+
logging.info("YOLOv8 model loaded successfully.")
|
26 |
+
except Exception as e:
|
27 |
+
logging.error(f"Failed to load YOLOv8 model: {e}")
|
28 |
+
model = None
|
29 |
|
30 |
# Google Sheets Setup
|
31 |
def setup_google_sheets(sheet_name):
|
32 |
+
try:
|
33 |
+
scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
|
34 |
+
creds = ServiceAccountCredentials.from_json_keyfile_name("credentials.json", scope)
|
35 |
+
client = gspread.authorize(creds)
|
36 |
+
sheet = client.open(sheet_name).sheet1
|
37 |
+
logging.info("Google Sheets setup completed.")
|
38 |
+
return sheet
|
39 |
+
except Exception as e:
|
40 |
+
logging.error(f"Error setting up Google Sheets: {e}")
|
41 |
+
return None
|
42 |
|
43 |
sheet = setup_google_sheets("TrafficManagementData")
|
44 |
|
45 |
+
# OpenStreetMap Geocoder Setup
|
46 |
+
geolocator = Nominatim(user_agent="traffic_management_system")
|
47 |
|
48 |
# Analyze Traffic with YOLOv8
|
49 |
def analyze_traffic(image):
|
50 |
+
if model is None:
|
51 |
+
return 0, "Error", "Error"
|
52 |
+
try:
|
53 |
+
results = model.predict(image)
|
54 |
+
detected_objects = results[0].boxes
|
55 |
+
vehicle_count = sum(1 for box in detected_objects if box.cls in [2, 3, 5, 7]) # Vehicle class IDs
|
56 |
+
congestion_level = "High" if vehicle_count > 150 else "Medium" if vehicle_count > 75 else "Low"
|
57 |
+
flow_rate = "Slow" if congestion_level == "High" else "Moderate" if congestion_level == "Medium" else "Smooth"
|
58 |
+
return vehicle_count, congestion_level, flow_rate
|
59 |
+
except Exception as e:
|
60 |
+
logging.error(f"Error analyzing traffic: {e}")
|
61 |
+
return 0, "Error", "Error"
|
62 |
+
|
63 |
+
class TrafficSimEnv(gym.Env):
|
64 |
def __init__(self, congestion_level):
|
65 |
+
super(TrafficSimEnv, self).__init__()
|
66 |
self.congestion_level = congestion_level
|
67 |
+
self.action_space = spaces.Discrete(3)
|
68 |
+
self.observation_space = spaces.Box(
|
69 |
+
low=np.array([0, 20]),
|
70 |
+
high=np.array([2, 60]),
|
71 |
+
dtype=np.float32
|
72 |
+
)
|
73 |
+
self.current_signal = 30
|
74 |
self.steps = 0
|
75 |
|
76 |
+
def reset(self, seed=None, options=None):
|
77 |
self.steps = 0
|
78 |
+
self.current_signal = 30
|
79 |
+
congestion_map = {"Low": 0, "Medium": 1, "High": 2}
|
80 |
+
self.congestion_numeric = congestion_map.get(self.congestion_level, 0)
|
81 |
+
return np.array([self.congestion_numeric, self.current_signal], dtype=np.float32), {}
|
82 |
|
83 |
def step(self, action):
|
84 |
+
signal_change = {0: -5, 1: 0, 2: +5}[action]
|
85 |
+
self.current_signal = max(20, min(60, self.current_signal + signal_change))
|
86 |
reward = self._calculate_reward()
|
87 |
self.steps += 1
|
88 |
+
done = self.steps >= 10
|
89 |
+
obs = np.array([self.congestion_numeric, self.current_signal], dtype=np.float32)
|
90 |
+
return obs, reward, done, {}
|
91 |
|
92 |
def _calculate_reward(self):
|
93 |
if self.congestion_level == "High":
|
94 |
+
return -abs(40 - self.current_signal)
|
95 |
elif self.congestion_level == "Medium":
|
96 |
return -abs(30 - self.current_signal)
|
97 |
else:
|
98 |
return -abs(20 - self.current_signal)
|
99 |
|
100 |
+
def render(self, mode="human"):
|
101 |
+
print(f"Current Signal: {self.current_signal}s")
|
102 |
+
|
103 |
+
def close(self):
|
104 |
+
pass
|
105 |
|
106 |
def optimize_signal_rl(congestion_level):
|
107 |
+
try:
|
108 |
+
env = DummyVecEnv([lambda: TrafficSimEnv(congestion_level)])
|
109 |
+
model = PPO("MlpPolicy", env, verbose=0)
|
110 |
+
model.learn(total_timesteps=1000)
|
111 |
+
obs = env.reset()
|
112 |
+
for _ in range(10):
|
113 |
+
action, _ = model.predict(obs)
|
114 |
+
obs, reward, done, _ = env.step(action)
|
115 |
+
if done:
|
116 |
+
break
|
117 |
+
optimal_duration = int(obs[0][1])
|
118 |
+
return f"Green for {optimal_duration}s, Red for {60 - optimal_duration}s"
|
119 |
+
except Exception as e:
|
120 |
+
logging.error(f"Error optimizing signal with RL: {e}")
|
121 |
+
return "Error in RL Optimization"
|
122 |
+
|
|
|
123 |
def check_signal_nearby(latitude, longitude):
|
124 |
+
try:
|
125 |
+
location = geolocator.reverse((latitude, longitude))
|
126 |
+
if location and "traffic_signal" in location.raw.get("address", {}):
|
127 |
+
return True
|
128 |
+
return False
|
129 |
+
except Exception as e:
|
130 |
+
logging.error(f"Error checking signal nearby: {e}")
|
131 |
+
return False
|
132 |
+
|
133 |
def get_weather_impact(api_key):
|
134 |
+
if not api_key:
|
135 |
+
return "API Key not found."
|
136 |
+
try:
|
137 |
+
url = f"http://api.openweathermap.org/data/2.5/weather?q=New+York&appid={api_key}"
|
138 |
+
response = requests.get(url)
|
139 |
+
if response.status_code == 200:
|
140 |
+
data = response.json()
|
141 |
+
weather = data["weather"][0]["description"]
|
142 |
+
if "rain" in weather or "snow" in weather:
|
143 |
+
return "Traffic is likely to be slower due to weather conditions."
|
144 |
+
return "Weather is clear, no major impact on traffic."
|
145 |
+
return "Error fetching weather data."
|
146 |
+
except Exception as e:
|
147 |
+
logging.error(f"Error fetching weather impact: {e}")
|
148 |
+
return "Error fetching weather data."
|
149 |
+
|
150 |
def process_traffic_image(image):
|
|
|
151 |
vehicle_count, congestion_level, flow_rate = analyze_traffic(image)
|
|
|
|
|
152 |
signal_timing = optimize_signal_rl(congestion_level)
|
|
|
|
|
153 |
weather_impact = get_weather_impact(WEATHER_API_KEY)
|
154 |
+
try:
|
155 |
+
if sheet:
|
156 |
+
sheet.append_row([vehicle_count, congestion_level, flow_rate, signal_timing, weather_impact])
|
157 |
+
except Exception as e:
|
158 |
+
logging.error(f"Error logging data to Google Sheets: {e}")
|
159 |
return (f"Detected Vehicles: {vehicle_count}\n"
|
160 |
f"Congestion Level: {congestion_level}\n"
|
161 |
f"Traffic Flow: {flow_rate}\n"
|
162 |
f"Signal Timing Suggestion: {signal_timing}\n"
|
163 |
f"Weather Impact: {weather_impact}")
|
164 |
|
|
|
165 |
interface = gr.Interface(
|
166 |
fn=process_traffic_image,
|
167 |
inputs=gr.Image(type="numpy", label="Upload Traffic Image"),
|