Traffic-AI / app.py
Fzina's picture
Update app.py
7131d6e verified
import os
import logging
import requests
from geopy.geocoders import Nominatim
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
import gymnasium as gym
from gymnasium import spaces
import numpy as np
from PIL import Image
import gradio as gr
import io
import base64
# Environment Variables
HOSTED_API_URL = os.getenv("HOSTED_API_URL") # FastAPI backend URL
WEATHER_API_KEY = os.getenv("WEATHER_API_KEY") # OpenWeatherMap API key
# Logging setup
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
# Validation Environment Variables
if not HOSTED_API_URL:
logging.error("HOSTED_API_URL environment variable is not set.")
raise ValueError("HOSTED_API_URL must be set.")
if not WEATHER_API_KEY:
logging.error("WEATHER_API_KEY environment variable is not set.")
raise ValueError("WEATHER_API_KEY must be set.")
# OpenStreetMap Setup
geolocator = Nominatim(user_agent="traffic_management_system")
# Post Image to FastAPI API URL
def analyze_traffic(image_path):
"""
Sends the traffic image to the FastAPI backend for analysis.
Returns traffic details and the processed image from the backend.
"""
try:
with open(image_path, "rb") as image_file:
files = {"file": image_file}
response = requests.post(f"{HOSTED_API_URL}/analyze_traffic/", files=files)
if response.status_code == 200:
result = response.json()
vehicle_count = result.get("vehicle_count", 0)
congestion_level = result.get("congestion_level", "Unknown")
flow_rate = result.get("flow_rate", "Unknown")
# Fetch processed image
processed_image_url = result.get("processed_image_url", None)
processed_image = None
if processed_image_url:
img_response = requests.get(processed_image_url)
if img_response.status_code == 200:
processed_image = Image.open(io.BytesIO(img_response.content))
return vehicle_count, congestion_level, flow_rate, processed_image
else:
logging.error(f"Error analyzing traffic: {response.text}")
return 0, "Error", "Error", None
except Exception as e:
logging.error(f"Error analyzing traffic: {e}")
return 0, "Error", "Error", None
# RL Optimization Class and Methods
class TrafficSimEnv(gym.Env):
def __init__(self, congestion_level):
super(TrafficSimEnv, self).__init__()
self.congestion_level = congestion_level
self.observation_space = spaces.Box(
low=np.array([0, 0], dtype=np.float32),
high=np.array([10, 60], dtype=np.float32),
dtype=np.float32
)
self.action_space = spaces.Discrete(3)
self.state = np.array([congestion_level, 30], dtype=np.float32)
self.done = False
def reset(self, *, seed=None, options=None):
super().reset(seed=seed)
self.state = np.array([self.congestion_level, 30], dtype=np.float32)
self.done = False
return self.state, {}
def step(self, action):
if self.done:
raise RuntimeError("Cannot call step() on a terminated environment. Please reset the environment.")
# Extract state components
congestion, signal_duration = self.state
# Apply action
if action == 0:
signal_duration = max(10, signal_duration - 5)
elif action == 1:
signal_duration = signal_duration
elif action == 2:
signal_duration = min(60, signal_duration + 5)
if signal_duration > 30:
congestion += 1
else:
congestion -= 1
if 20 <= signal_duration <= 40:
reward = 0
else:
reward = -abs(signal_duration - 30)
self.done = congestion <= 0 or congestion >= 10
self.state = np.array([congestion, signal_duration], dtype=np.float32)
return self.state, reward, self.done, False, {}
def render(self):
print(f"State: {self.state}")
def close(self):
pass
def optimize_signal_rl(congestion_level):
try:
# Map congestion levels (string to numeric) converter
congestion_map = {"Low": 2, "Medium": 5, "High": 8}
congestion_level = congestion_map.get(congestion_level, 5) if isinstance(congestion_level, str) else congestion_level
# Create environment
env = DummyVecEnv([lambda: TrafficSimEnv(congestion_level)])
model = PPO("MlpPolicy", env, verbose=0)
# Train the model
model.learn(total_timesteps=1000)
# Reset environment and get the initial observation
obs = env.reset()
logging.info(f"Initial Observation: {obs}")
# Ensure `obs` is valid (It was a very bugged area I faced)
if not isinstance(obs, np.ndarray) or obs.ndim != 2:
raise ValueError(f"Invalid observation after reset: {obs} (type: {type(obs)}, ndim: {obs.ndim if isinstance(obs, np.ndarray) else 'N/A'})")
# RL Optimization loop
for step_count in range(10):
# Predict action
action, _ = model.predict(obs, deterministic=True)
# Perform environment step
obs, rewards, dones, infos = env.step(action)
# Check and log the observation
logging.debug(f"Step {step_count}: Obs: {obs}, Rewards: {rewards}, Done: {dones}, Infos: {infos}")
# Validate `obs` after each step (To make sure it doesnt go 0 and create an error)
if not isinstance(obs, np.ndarray):
raise ValueError(f"Observation is not an array after step {step_count}: {obs}")
if obs.ndim == 0:
obs = np.array([obs]) # Convert at least 1D
logging.warning(f"Converted 0D obs to array: {obs}")
elif obs.ndim == 1: # Makesure its 2D for consistency
obs = obs.reshape(1, -1)
# Stop if the env signals is done
if dones[0]:
break
# Get the optimal signal duration from the last valid observation
obs = obs.flatten() # Confirm its a 1D array for indexing
if len(obs) < 2:
raise ValueError(f"Observation does not contain enough elements: {obs}")
optimal_duration = int(obs[1])
return f"Green for {optimal_duration}s, Red for {60 - optimal_duration}s"
except ValueError as ve:
logging.error(f"Value error during RL optimization: {ve}")
return "Error: Unexpected values encountered during optimization."
except Exception as e:
logging.error(f"Error optimizing signal with RL: {e}")
return "Error in RL Optimization"
def process_traffic_image(image):
"""
Orchestrates the traffic analysis workflow.
"""
# Save the uploaded image temp
image_path = "temp_traffic_image.jpg"
image.save(image_path)
try:
# Send the image to the backend for analysis
with open(image_path, "rb") as img_file:
response = requests.post(
f"{HOSTED_API_URL}/analyze_traffic/",
files={"file": img_file}
)
except Exception as e:
logging.error(f"Error in backend request: {e}")
return "Error in backend request.", None
finally:
os.remove(image_path) # Clean up the temp
# Process backend response
if response.status_code == 200:
data = response.json()
vehicle_count = data.get("vehicle_count", 0)
congestion_level = data.get("congestion_level", "Unknown")
flow_rate = data.get("flow_rate", "Unknown")
processed_image_base64 = data.get("processed_image", None)
# Decode the processed image (if provided)
processed_image = None
if processed_image_base64:
try:
processed_image = Image.open(io.BytesIO(base64.b64decode(processed_image_base64)))
except Exception as e:
logging.error(f"Error decoding processed image: {e}")
processed_image = None
# Signal timing optimization
signal_timing = optimize_signal_rl(congestion_level)
# Return the results
return (
f"Detected Vehicles: {vehicle_count}\n"
f"Congestion Level: {congestion_level}\n"
f"Traffic Flow: {flow_rate}\n"
f"Signal Timing Suggestion: {signal_timing}",
processed_image
)
else:
logging.error(f"Error from backend: {response.text}")
return f"Error from backend: {response.status_code}", None
# Gradio Interface
def gradio_interface(image):
"""
Wrapper for Gradio to handle input/output for traffic analysis.
"""
try:
results, analyzed_image = process_traffic_image(image)
return results, analyzed_image
except Exception as e:
logging.error(f"Error in Gradio interface: {e}")
return "An error occurred. Please try again with a valid traffic image.", None
if __name__ == "__main__":
# UI
interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Image(type="pil", label="Upload Traffic Image"),
outputs=[
gr.Textbox(label="Traffic Analysis Results"),
gr.Image(label="Analyzed Traffic Image")
],
title="Traffic Management System",
description="Upload a traffic image to analyze congestion and get signal timing suggestions."
)
# Launch Gradio app
interface.launch()