nyc_taxi_app / streamlit_taxi_app.py
Daniela-C's picture
Update streamlit_taxi_app.py
ac6d75c verified
raw
history blame
2.43 kB
# Streamlit App: NYC Taxi Anomaly Detector with Event Markers
pip install --upgrade pip
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
from datasets import load_dataset
dataset = load_dataset("Daniela-C/Yellow_tripdata_2025")
df = dataset["train"].to_pandas()
st.set_page_config(page_title="NYC Taxi Anomaly Detector", layout="wide")
st.title("πŸš• NYC Taxi Passenger Count - Anomaly Detection")
# Load data
@st.cache_data
def load_data():
df = pd.read_csv('Yellow_tripdata_2025', parse_dates=['tpep_pickup_datetime'], index_col='tpep_pickup_datetime')
return df
df = load_data()
# Define NYC events/holidays
events = {
"New Year\'s Eve": "2015-12-31",
"New Year\'s Day": "2016-01-01",
"Martin Luther King Jr. Day": "2016-01-18"
}
# Sidebar controls
st.sidebar.header("Filters")
start_date = st.sidebar.date_input("Start Date", df.index.min().date())
end_date = st.sidebar.date_input("End Date", df.index.max().date())
threshold_slider = st.sidebar.slider("Anomaly Threshold (%)", 90, 99, 95)
# Filtered data
filtered_df = df.loc[str(start_date):str(end_date)]
# Apply new threshold
new_threshold = filtered_df['reconstruction_error'].quantile(threshold_slider / 100.0)
filtered_df['anomaly_custom'] = filtered_df['reconstruction_error'] > new_threshold
# Plot
fig, ax = plt.subplots(figsize=(15, 5))
ax.plot(filtered_df.index, filtered_df['passenger_count'], label='Passenger Count')
ax.scatter(filtered_df[filtered_df['anomaly_custom']].index,
filtered_df[filtered_df['anomaly_custom']]['passenger_count'],
color='red', label='Anomaly')
# Add event markers
for name, date_str in events.items():
event_date = pd.to_datetime(date_str)
if event_date in filtered_df.index:
ax.axvline(event_date, color='orange', linestyle='--', alpha=0.7)
ax.text(event_date, ax.get_ylim()[1]*0.9, name, rotation=90, color='orange', fontsize=8)
ax.set_title('Anomaly Detection with NYC Event Markers')
ax.legend()
ax.set_xlabel("Date")
ax.set_ylabel("Passenger Count")
plt.xticks(rotation=45)
st.pyplot(fig)
# Show data table
with st.expander("πŸ“„ View Data Table"):
st.dataframe(filtered_df[['passenger_count', 'reconstruction_error', 'anomaly_custom']])
# Download
st.download_button("Download Anomalies CSV", data=filtered_df.to_csv().encode('utf-8'), file_name="filtered_anomalies.csv", mime="text/csv")