jung-ming commited on
Commit
e60de6d
·
verified ·
1 Parent(s): 05c5842

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +104 -0
  2. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["HF_HOME"] = "/tmp/.cache/huggingface"
3
+ os.environ["MPLCONFIGDIR"] = "/tmp/.config/matplotlib"
4
+
5
+ import streamlit as st
6
+ import pandas as pd
7
+ import numpy as np
8
+ import shap
9
+ import joblib
10
+ import matplotlib.pyplot as plt
11
+ import matplotlib.font_manager as fm
12
+ from huggingface_hub import hf_hub_download
13
+
14
+ # 🚩 載入模型
15
+ @st.cache_resource(show_spinner=True)
16
+ def load_model():
17
+ model_path = hf_hub_download(
18
+ repo_id="jung-ming/god-eye-traffic-predictor",
19
+ filename="multi_output_lgbm.pkl",
20
+ repo_type="model"
21
+ )
22
+ return joblib.load(model_path)
23
+
24
+ model = load_model()
25
+
26
+ # 🎯 定義 SHAP 解釋器(只針對預測60分鐘的模型)
27
+ explainer = shap.TreeExplainer(model.estimators_[0])
28
+
29
+ # ⚙️ 載入字型(顯示中文)
30
+ def find_chinese_font():
31
+ for font_path in fm.findSystemFonts(fontext='ttf'):
32
+ if "NotoSans" in font_path and ("TC" in font_path or "TraditionalChinese" in font_path):
33
+ return font_path
34
+ if "Heiti" in font_path or "LiHei" in font_path:
35
+ return font_path
36
+ return None
37
+
38
+ font_path = find_chinese_font()
39
+ if font_path:
40
+ chinese_font_prop = fm.FontProperties(fname=font_path)
41
+ plt.rcParams['font.family'] = chinese_font_prop.get_name()
42
+ else:
43
+ plt.rcParams['font.family'] = 'DejaVu Sans'
44
+ plt.rcParams['axes.unicode_minus'] = False
45
+
46
+ # 🔢 Streamlit UI
47
+ st.title("🚦 上帝視角:AI 國道壅塞預測系統")
48
+ st.markdown("輸入當前交通特徵,預測 60–90 分鐘後之速率與壅塞風險。")
49
+
50
+ # ✏️ 使用者輸入
51
+ hour = st.selectbox("小時 (hour)", list(range(0, 24)))
52
+ minute = st.selectbox("分鐘 (minute)", list(range(0, 60)))
53
+ count_total = st.number_input("總車流量 (count_total)", 0, 5000, 1000)
54
+ median_speed_avg = st.number_input("平均速率 (median_speed_avg)", 0.0, 200.0, 90.0)
55
+ median_time_avg = st.number_input("平均時間 (median_time_avg)", 0.0, 1000.0, 600.0)
56
+ start_detector_encode = st.number_input("起點檢測器編碼 (start_detector_encode)", 0.0, 400.0, 90.0)
57
+ direction = st.selectbox("方向 (direction)", [0, 1], format_func=lambda x: "南向" if x == 0 else "北向")
58
+ diff_speed_0_10 = st.number_input("0-10分鐘差速 (diff_speed_0_10)", -100.0, 100.0, 0.0)
59
+
60
+ if st.button("🔮 預測壅塞狀況"):
61
+ input_data = pd.DataFrame({
62
+ 'hour': [hour],
63
+ 'minute': [minute],
64
+ 'count_total': [count_total],
65
+ 'median_speed_avg': [median_speed_avg],
66
+ 'median_time_avg': [median_time_avg],
67
+ 'start_detector_encode': [start_detector_encode],
68
+ 'direction': [direction],
69
+ 'diff_speed_0_10': [diff_speed_0_10]
70
+ })
71
+
72
+ y_pred = model.predict(input_data)[0]
73
+
74
+ def classify_risk(speed):
75
+ if speed < 40:
76
+ return "高風險"
77
+ elif speed < 60:
78
+ return "中風險"
79
+ else:
80
+ return "低風險"
81
+
82
+ risk_levels = [classify_risk(s) for s in y_pred]
83
+ st.subheader("📈 預測結果")
84
+ for i, t in enumerate(['60', '70', '80', '90']):
85
+ st.write(f"{t} 分鐘後速率:{y_pred[i]:.2f} km/h → 風險:{risk_levels[i]}")
86
+
87
+ st.subheader("🧠 模型決策解釋圖(60分鐘 SHAP Waterfall)")
88
+ shap_values = explainer.shap_values(input_data)
89
+ ax = shap.plots.waterfall(shap.Explanation(
90
+ values=shap_values[0],
91
+ base_values=explainer.expected_value,
92
+ data=input_data.iloc[0],
93
+ feature_names=input_data.columns
94
+ ), show=False)
95
+
96
+ # 設定中文字型(如果有)
97
+ if font_path:
98
+ for text in ax.texts:
99
+ text.set_fontproperties(chinese_font_prop)
100
+ if text.get_text().startswith('\u2212'):
101
+ text.set_text(text.get_text().replace('\u2212', '-'))
102
+
103
+ st.pyplot(ax.figure)
104
+ plt.close(ax.figure)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.35.0
2
+ scikit-learn==1.3.2
3
+ pandas==2.2.2
4
+ numpy==1.24.4
5
+ shap==0.45.0
6
+ matplotlib==3.7.4
7
+ joblib==1.3.2
8
+ huggingface_hub==0.22.2
9
+ folium==0.16.0
10
+ geopandas==0.14.3
11
+ fiona==1.9.6