Spaces:
Sleeping
Sleeping
import os | |
import logging | |
import requests | |
from geopy.geocoders import Nominatim | |
from stable_baselines3 import PPO | |
from stable_baselines3.common.vec_env import DummyVecEnv | |
import gymnasium as gym | |
from gymnasium import spaces | |
import numpy as np | |
from PIL import Image | |
import gradio as gr | |
import io | |
import base64 | |
# Environment Variables | |
HOSTED_API_URL = os.getenv("HOSTED_API_URL") # FastAPI backend URL | |
WEATHER_API_KEY = os.getenv("WEATHER_API_KEY") # OpenWeatherMap API key | |
# Logging setup | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
# Validation Environment Variables | |
if not HOSTED_API_URL: | |
logging.error("HOSTED_API_URL environment variable is not set.") | |
raise ValueError("HOSTED_API_URL must be set.") | |
if not WEATHER_API_KEY: | |
logging.error("WEATHER_API_KEY environment variable is not set.") | |
raise ValueError("WEATHER_API_KEY must be set.") | |
# OpenStreetMap Setup | |
geolocator = Nominatim(user_agent="traffic_management_system") | |
# Post Image to FastAPI API URL | |
def analyze_traffic(image_path): | |
""" | |
Sends the traffic image to the FastAPI backend for analysis. | |
Returns traffic details and the processed image from the backend. | |
""" | |
try: | |
with open(image_path, "rb") as image_file: | |
files = {"file": image_file} | |
response = requests.post(f"{HOSTED_API_URL}/analyze_traffic/", files=files) | |
if response.status_code == 200: | |
result = response.json() | |
vehicle_count = result.get("vehicle_count", 0) | |
congestion_level = result.get("congestion_level", "Unknown") | |
flow_rate = result.get("flow_rate", "Unknown") | |
# Fetch processed image | |
processed_image_url = result.get("processed_image_url", None) | |
processed_image = None | |
if processed_image_url: | |
img_response = requests.get(processed_image_url) | |
if img_response.status_code == 200: | |
processed_image = Image.open(io.BytesIO(img_response.content)) | |
return vehicle_count, congestion_level, flow_rate, processed_image | |
else: | |
logging.error(f"Error analyzing traffic: {response.text}") | |
return 0, "Error", "Error", None | |
except Exception as e: | |
logging.error(f"Error analyzing traffic: {e}") | |
return 0, "Error", "Error", None | |
# RL Optimization Class and Methods | |
class TrafficSimEnv(gym.Env): | |
def __init__(self, congestion_level): | |
super(TrafficSimEnv, self).__init__() | |
self.congestion_level = congestion_level | |
self.observation_space = spaces.Box( | |
low=np.array([0, 0], dtype=np.float32), | |
high=np.array([10, 60], dtype=np.float32), | |
dtype=np.float32 | |
) | |
self.action_space = spaces.Discrete(3) | |
self.state = np.array([congestion_level, 30], dtype=np.float32) | |
self.done = False | |
def reset(self, *, seed=None, options=None): | |
super().reset(seed=seed) | |
self.state = np.array([self.congestion_level, 30], dtype=np.float32) | |
self.done = False | |
return self.state, {} | |
def step(self, action): | |
if self.done: | |
raise RuntimeError("Cannot call step() on a terminated environment. Please reset the environment.") | |
# Extract state components | |
congestion, signal_duration = self.state | |
# Apply action | |
if action == 0: | |
signal_duration = max(10, signal_duration - 5) | |
elif action == 1: | |
signal_duration = signal_duration | |
elif action == 2: | |
signal_duration = min(60, signal_duration + 5) | |
if signal_duration > 30: | |
congestion += 1 | |
else: | |
congestion -= 1 | |
if 20 <= signal_duration <= 40: | |
reward = 0 | |
else: | |
reward = -abs(signal_duration - 30) | |
self.done = congestion <= 0 or congestion >= 10 | |
self.state = np.array([congestion, signal_duration], dtype=np.float32) | |
return self.state, reward, self.done, False, {} | |
def render(self): | |
print(f"State: {self.state}") | |
def close(self): | |
pass | |
def optimize_signal_rl(congestion_level): | |
try: | |
# Map congestion levels (string to numeric) converter | |
congestion_map = {"Low": 2, "Medium": 5, "High": 8} | |
congestion_level = congestion_map.get(congestion_level, 5) if isinstance(congestion_level, str) else congestion_level | |
# Create environment | |
env = DummyVecEnv([lambda: TrafficSimEnv(congestion_level)]) | |
model = PPO("MlpPolicy", env, verbose=0) | |
# Train the model | |
model.learn(total_timesteps=1000) | |
# Reset environment and get the initial observation | |
obs = env.reset() | |
logging.info(f"Initial Observation: {obs}") | |
# Ensure `obs` is valid (It was a very bugged area I faced) | |
if not isinstance(obs, np.ndarray) or obs.ndim != 2: | |
raise ValueError(f"Invalid observation after reset: {obs} (type: {type(obs)}, ndim: {obs.ndim if isinstance(obs, np.ndarray) else 'N/A'})") | |
# RL Optimization loop | |
for step_count in range(10): | |
# Predict action | |
action, _ = model.predict(obs, deterministic=True) | |
# Perform environment step | |
obs, rewards, dones, infos = env.step(action) | |
# Check and log the observation | |
logging.debug(f"Step {step_count}: Obs: {obs}, Rewards: {rewards}, Done: {dones}, Infos: {infos}") | |
# Validate `obs` after each step (To make sure it doesnt go 0 and create an error) | |
if not isinstance(obs, np.ndarray): | |
raise ValueError(f"Observation is not an array after step {step_count}: {obs}") | |
if obs.ndim == 0: | |
obs = np.array([obs]) # Convert at least 1D | |
logging.warning(f"Converted 0D obs to array: {obs}") | |
elif obs.ndim == 1: # Makesure its 2D for consistency | |
obs = obs.reshape(1, -1) | |
# Stop if the env signals is done | |
if dones[0]: | |
break | |
# Get the optimal signal duration from the last valid observation | |
obs = obs.flatten() # Confirm its a 1D array for indexing | |
if len(obs) < 2: | |
raise ValueError(f"Observation does not contain enough elements: {obs}") | |
optimal_duration = int(obs[1]) | |
return f"Green for {optimal_duration}s, Red for {60 - optimal_duration}s" | |
except ValueError as ve: | |
logging.error(f"Value error during RL optimization: {ve}") | |
return "Error: Unexpected values encountered during optimization." | |
except Exception as e: | |
logging.error(f"Error optimizing signal with RL: {e}") | |
return "Error in RL Optimization" | |
def process_traffic_image(image): | |
""" | |
Orchestrates the traffic analysis workflow. | |
""" | |
# Save the uploaded image temp | |
image_path = "temp_traffic_image.jpg" | |
image.save(image_path) | |
try: | |
# Send the image to the backend for analysis | |
with open(image_path, "rb") as img_file: | |
response = requests.post( | |
f"{HOSTED_API_URL}/analyze_traffic/", | |
files={"file": img_file} | |
) | |
except Exception as e: | |
logging.error(f"Error in backend request: {e}") | |
return "Error in backend request.", None | |
finally: | |
os.remove(image_path) # Clean up the temp | |
# Process backend response | |
if response.status_code == 200: | |
data = response.json() | |
vehicle_count = data.get("vehicle_count", 0) | |
congestion_level = data.get("congestion_level", "Unknown") | |
flow_rate = data.get("flow_rate", "Unknown") | |
processed_image_base64 = data.get("processed_image", None) | |
# Decode the processed image (if provided) | |
processed_image = None | |
if processed_image_base64: | |
try: | |
processed_image = Image.open(io.BytesIO(base64.b64decode(processed_image_base64))) | |
except Exception as e: | |
logging.error(f"Error decoding processed image: {e}") | |
processed_image = None | |
# Signal timing optimization | |
signal_timing = optimize_signal_rl(congestion_level) | |
# Return the results | |
return ( | |
f"Detected Vehicles: {vehicle_count}\n" | |
f"Congestion Level: {congestion_level}\n" | |
f"Traffic Flow: {flow_rate}\n" | |
f"Signal Timing Suggestion: {signal_timing}", | |
processed_image | |
) | |
else: | |
logging.error(f"Error from backend: {response.text}") | |
return f"Error from backend: {response.status_code}", None | |
# Gradio Interface | |
def gradio_interface(image): | |
""" | |
Wrapper for Gradio to handle input/output for traffic analysis. | |
""" | |
try: | |
results, analyzed_image = process_traffic_image(image) | |
return results, analyzed_image | |
except Exception as e: | |
logging.error(f"Error in Gradio interface: {e}") | |
return "An error occurred. Please try again with a valid traffic image.", None | |
if __name__ == "__main__": | |
# UI | |
interface = gr.Interface( | |
fn=gradio_interface, | |
inputs=gr.Image(type="pil", label="Upload Traffic Image"), | |
outputs=[ | |
gr.Textbox(label="Traffic Analysis Results"), | |
gr.Image(label="Analyzed Traffic Image") | |
], | |
title="Traffic Management System", | |
description="Upload a traffic image to analyze congestion and get signal timing suggestions." | |
) | |
# Launch Gradio app | |
interface.launch() | |