Upload 2 files
Browse files- app.py +104 -0
- requirements.txt +11 -0
app.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ["HF_HOME"] = "/tmp/.cache/huggingface"
|
3 |
+
os.environ["MPLCONFIGDIR"] = "/tmp/.config/matplotlib"
|
4 |
+
|
5 |
+
import streamlit as st
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
import shap
|
9 |
+
import joblib
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import matplotlib.font_manager as fm
|
12 |
+
from huggingface_hub import hf_hub_download
|
13 |
+
|
14 |
+
# 🚩 載入模型
|
15 |
+
@st.cache_resource(show_spinner=True)
|
16 |
+
def load_model():
|
17 |
+
model_path = hf_hub_download(
|
18 |
+
repo_id="jung-ming/god-eye-traffic-predictor",
|
19 |
+
filename="multi_output_lgbm.pkl",
|
20 |
+
repo_type="model"
|
21 |
+
)
|
22 |
+
return joblib.load(model_path)
|
23 |
+
|
24 |
+
model = load_model()
|
25 |
+
|
26 |
+
# 🎯 定義 SHAP 解釋器(只針對預測60分鐘的模型)
|
27 |
+
explainer = shap.TreeExplainer(model.estimators_[0])
|
28 |
+
|
29 |
+
# ⚙️ 載入字型(顯示中文)
|
30 |
+
def find_chinese_font():
|
31 |
+
for font_path in fm.findSystemFonts(fontext='ttf'):
|
32 |
+
if "NotoSans" in font_path and ("TC" in font_path or "TraditionalChinese" in font_path):
|
33 |
+
return font_path
|
34 |
+
if "Heiti" in font_path or "LiHei" in font_path:
|
35 |
+
return font_path
|
36 |
+
return None
|
37 |
+
|
38 |
+
font_path = find_chinese_font()
|
39 |
+
if font_path:
|
40 |
+
chinese_font_prop = fm.FontProperties(fname=font_path)
|
41 |
+
plt.rcParams['font.family'] = chinese_font_prop.get_name()
|
42 |
+
else:
|
43 |
+
plt.rcParams['font.family'] = 'DejaVu Sans'
|
44 |
+
plt.rcParams['axes.unicode_minus'] = False
|
45 |
+
|
46 |
+
# 🔢 Streamlit UI
|
47 |
+
st.title("🚦 上帝視角:AI 國道壅塞預測系統")
|
48 |
+
st.markdown("輸入當前交通特徵,預測 60–90 分鐘後之速率與壅塞風險。")
|
49 |
+
|
50 |
+
# ✏️ 使用者輸入
|
51 |
+
hour = st.selectbox("小時 (hour)", list(range(0, 24)))
|
52 |
+
minute = st.selectbox("分鐘 (minute)", list(range(0, 60)))
|
53 |
+
count_total = st.number_input("總車流量 (count_total)", 0, 5000, 1000)
|
54 |
+
median_speed_avg = st.number_input("平均速率 (median_speed_avg)", 0.0, 200.0, 90.0)
|
55 |
+
median_time_avg = st.number_input("平均時間 (median_time_avg)", 0.0, 1000.0, 600.0)
|
56 |
+
start_detector_encode = st.number_input("起點檢測器編碼 (start_detector_encode)", 0.0, 400.0, 90.0)
|
57 |
+
direction = st.selectbox("方向 (direction)", [0, 1], format_func=lambda x: "南向" if x == 0 else "北向")
|
58 |
+
diff_speed_0_10 = st.number_input("0-10分鐘差速 (diff_speed_0_10)", -100.0, 100.0, 0.0)
|
59 |
+
|
60 |
+
if st.button("🔮 預測壅塞狀況"):
|
61 |
+
input_data = pd.DataFrame({
|
62 |
+
'hour': [hour],
|
63 |
+
'minute': [minute],
|
64 |
+
'count_total': [count_total],
|
65 |
+
'median_speed_avg': [median_speed_avg],
|
66 |
+
'median_time_avg': [median_time_avg],
|
67 |
+
'start_detector_encode': [start_detector_encode],
|
68 |
+
'direction': [direction],
|
69 |
+
'diff_speed_0_10': [diff_speed_0_10]
|
70 |
+
})
|
71 |
+
|
72 |
+
y_pred = model.predict(input_data)[0]
|
73 |
+
|
74 |
+
def classify_risk(speed):
|
75 |
+
if speed < 40:
|
76 |
+
return "高風險"
|
77 |
+
elif speed < 60:
|
78 |
+
return "中風險"
|
79 |
+
else:
|
80 |
+
return "低風險"
|
81 |
+
|
82 |
+
risk_levels = [classify_risk(s) for s in y_pred]
|
83 |
+
st.subheader("📈 預測結果")
|
84 |
+
for i, t in enumerate(['60', '70', '80', '90']):
|
85 |
+
st.write(f"{t} 分鐘後速率:{y_pred[i]:.2f} km/h → 風險:{risk_levels[i]}")
|
86 |
+
|
87 |
+
st.subheader("🧠 模型決策解釋圖(60分鐘 SHAP Waterfall)")
|
88 |
+
shap_values = explainer.shap_values(input_data)
|
89 |
+
ax = shap.plots.waterfall(shap.Explanation(
|
90 |
+
values=shap_values[0],
|
91 |
+
base_values=explainer.expected_value,
|
92 |
+
data=input_data.iloc[0],
|
93 |
+
feature_names=input_data.columns
|
94 |
+
), show=False)
|
95 |
+
|
96 |
+
# 設定中文字型(如果有)
|
97 |
+
if font_path:
|
98 |
+
for text in ax.texts:
|
99 |
+
text.set_fontproperties(chinese_font_prop)
|
100 |
+
if text.get_text().startswith('\u2212'):
|
101 |
+
text.set_text(text.get_text().replace('\u2212', '-'))
|
102 |
+
|
103 |
+
st.pyplot(ax.figure)
|
104 |
+
plt.close(ax.figure)
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==1.35.0
|
2 |
+
scikit-learn==1.3.2
|
3 |
+
pandas==2.2.2
|
4 |
+
numpy==1.24.4
|
5 |
+
shap==0.45.0
|
6 |
+
matplotlib==3.7.4
|
7 |
+
joblib==1.3.2
|
8 |
+
huggingface_hub==0.22.2
|
9 |
+
folium==0.16.0
|
10 |
+
geopandas==0.14.3
|
11 |
+
fiona==1.9.6
|