Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -66,47 +66,61 @@ class TrafficSimEnv(gym.Env):
|
|
66 |
def __init__(self, congestion_level):
|
67 |
super(TrafficSimEnv, self).__init__()
|
68 |
self.congestion_level = congestion_level
|
|
|
|
|
|
|
|
|
|
|
69 |
self.action_space = spaces.Discrete(3)
|
70 |
-
self.observation_space = spaces.Box(low=np.array([0, 20]), high=np.array([2, 60]), dtype=np.float32)
|
71 |
-
self.current_signal = 30
|
72 |
-
self.steps = 0
|
73 |
|
74 |
-
|
75 |
-
self.
|
76 |
-
self.
|
77 |
-
|
78 |
-
|
|
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
return obs, info
|
84 |
|
85 |
def step(self, action):
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
-
|
|
|
|
|
|
|
|
|
90 |
|
91 |
-
|
92 |
-
done =
|
93 |
-
truncated = False
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
print(f"STEP: action={action}, obs={obs}, reward={reward}, done={done}, truncated={truncated}, info={info}")
|
98 |
-
return obs, reward, done, truncated, info
|
99 |
|
100 |
-
|
101 |
-
if self.congestion_level == "High":
|
102 |
-
return -abs(40 - self.current_signal)
|
103 |
-
elif self.congestion_level == "Medium":
|
104 |
-
return -abs(30 - self.current_signal)
|
105 |
-
else:
|
106 |
-
return -abs(20 - self.current_signal)
|
107 |
|
108 |
-
def render(self
|
109 |
-
print(f"
|
110 |
|
111 |
def close(self):
|
112 |
pass
|
@@ -114,7 +128,7 @@ class TrafficSimEnv(gym.Env):
|
|
114 |
# Prior to commit had a lot of errors regarding expected output errors
|
115 |
def optimize_signal_rl(congestion_level):
|
116 |
try:
|
117 |
-
# Create
|
118 |
env = DummyVecEnv([lambda: TrafficSimEnv(congestion_level)])
|
119 |
|
120 |
# Initialize PPO model
|
@@ -124,40 +138,34 @@ def optimize_signal_rl(congestion_level):
|
|
124 |
model.learn(total_timesteps=1000)
|
125 |
|
126 |
# Reset the environment
|
127 |
-
obs, info = env.reset()
|
128 |
-
obs = obs[0] # Extract
|
129 |
|
130 |
for _ in range(10):
|
131 |
-
# Predict action
|
132 |
action, _ = model.predict(obs, deterministic=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
-
# Take a step
|
135 |
-
step_output = env.step(action)
|
136 |
-
|
137 |
-
# Handle batched output of step()
|
138 |
-
if len(step_output) == 5:
|
139 |
-
obs, reward, done, truncated, info = step_output
|
140 |
-
obs = obs[0] # Extract single observation
|
141 |
-
reward = reward[0] # Extract single reward
|
142 |
-
done = done[0] # Extract single done flag
|
143 |
-
truncated = truncated[0] # Extract single truncated flag
|
144 |
-
info = info[0] # Extract single info
|
145 |
-
else:
|
146 |
-
raise ValueError(f"Unexpected step output: {step_output}")
|
147 |
-
|
148 |
-
# End simulation if done or truncated
|
149 |
if done or truncated:
|
150 |
break
|
151 |
|
152 |
-
# Extract optimal signal duration
|
153 |
optimal_duration = int(obs[1]) if len(obs) > 1 else 30
|
154 |
return f"Green for {optimal_duration}s, Red for {60 - optimal_duration}s"
|
|
|
155 |
except Exception as e:
|
156 |
logging.error(f"Error optimizing signal with RL: {e}")
|
157 |
return "Error in RL Optimization"
|
158 |
|
159 |
|
160 |
-
|
161 |
def process_traffic_image(image):
|
162 |
"""
|
163 |
Orchestrates the traffic analysis workflow.
|
|
|
66 |
def __init__(self, congestion_level):
|
67 |
super(TrafficSimEnv, self).__init__()
|
68 |
self.congestion_level = congestion_level
|
69 |
+
|
70 |
+
# Define observation space: [congestion_level, signal_duration]
|
71 |
+
self.observation_space = spaces.Box(low=np.array([0, 0]), high=np.array([10, 60]), dtype=np.float32)
|
72 |
+
|
73 |
+
# Define action space: [increase, decrease, maintain]
|
74 |
self.action_space = spaces.Discrete(3)
|
|
|
|
|
|
|
75 |
|
76 |
+
# Initial state
|
77 |
+
self.state = np.array([congestion_level, 30], dtype=np.float32)
|
78 |
+
self.done = False
|
79 |
+
|
80 |
+
def reset(self, *, seed=None, options=None):
|
81 |
+
super().reset(seed=seed)
|
82 |
|
83 |
+
self.state = np.array([self.congestion_level, 30], dtype=np.float32)
|
84 |
+
self.done = False
|
85 |
+
return self.state, {}
|
|
|
86 |
|
87 |
def step(self, action):
|
88 |
+
if self.done:
|
89 |
+
raise RuntimeError("Cannot call step() on a terminated environment. Please reset the environment.")
|
90 |
+
|
91 |
+
# Extract state components
|
92 |
+
congestion, signal_duration = self.state
|
93 |
+
|
94 |
+
# Apply action
|
95 |
+
if action == 0: # Decrease signal duration
|
96 |
+
signal_duration = max(10, signal_duration - 5)
|
97 |
+
elif action == 1: # Maintain signal duration
|
98 |
+
signal_duration = signal_duration
|
99 |
+
elif action == 2: # Increase signal duration
|
100 |
+
signal_duration = min(60, signal_duration + 5)
|
101 |
+
|
102 |
+
# Update congestion (simple simulation logic for this example)
|
103 |
+
if signal_duration > 30:
|
104 |
+
congestion += 1
|
105 |
+
else:
|
106 |
+
congestion -= 1
|
107 |
|
108 |
+
# Set rewards (example logic)
|
109 |
+
if 20 <= signal_duration <= 40:
|
110 |
+
reward = 0
|
111 |
+
else:
|
112 |
+
reward = -abs(signal_duration - 30)
|
113 |
|
114 |
+
# Check if done
|
115 |
+
self.done = congestion <= 0 or congestion >= 10
|
|
|
116 |
|
117 |
+
# Update state
|
118 |
+
self.state = np.array([congestion, signal_duration], dtype=np.float32)
|
|
|
|
|
119 |
|
120 |
+
return self.state, reward, self.done, False, {}
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
+
def render(self):
|
123 |
+
print(f"State: {self.state}")
|
124 |
|
125 |
def close(self):
|
126 |
pass
|
|
|
128 |
# Prior to commit had a lot of errors regarding expected output errors
|
129 |
def optimize_signal_rl(congestion_level):
|
130 |
try:
|
131 |
+
# Create environment
|
132 |
env = DummyVecEnv([lambda: TrafficSimEnv(congestion_level)])
|
133 |
|
134 |
# Initialize PPO model
|
|
|
138 |
model.learn(total_timesteps=1000)
|
139 |
|
140 |
# Reset the environment
|
141 |
+
obs, info = env.reset()
|
142 |
+
obs = obs[0] # Extract first observation from batch
|
143 |
|
144 |
for _ in range(10):
|
|
|
145 |
action, _ = model.predict(obs, deterministic=True)
|
146 |
+
obs, reward, done, truncated, info = env.step(action)
|
147 |
+
|
148 |
+
# Ensure single observation extraction
|
149 |
+
obs = obs[0]
|
150 |
+
reward = reward[0]
|
151 |
+
done = done[0]
|
152 |
+
truncated = truncated[0]
|
153 |
+
info = info[0]
|
154 |
+
|
155 |
+
# Log each step for debugging
|
156 |
+
print(f"STEP: action={action}, obs={obs}, reward={reward}, done={done}, truncated={truncated}, info={info}")
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
if done or truncated:
|
159 |
break
|
160 |
|
|
|
161 |
optimal_duration = int(obs[1]) if len(obs) > 1 else 30
|
162 |
return f"Green for {optimal_duration}s, Red for {60 - optimal_duration}s"
|
163 |
+
|
164 |
except Exception as e:
|
165 |
logging.error(f"Error optimizing signal with RL: {e}")
|
166 |
return "Error in RL Optimization"
|
167 |
|
168 |
|
|
|
169 |
def process_traffic_image(image):
|
170 |
"""
|
171 |
Orchestrates the traffic analysis workflow.
|