import os
import math
import tempfile
import warnings
import streamlit as st
import pandas as pd
import torch
import plotly.express as px
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from transformers import (
EarlyStoppingCallback,
Trainer,
TrainingArguments,
set_seed,
)
from transformers.integrations import INTEGRATION_TO_CALLBACK
from tsfm_public import (
TimeSeriesPreprocessor,
TrackingCallback,
count_parameters,
get_datasets,
)
from tsfm_public.toolkit.get_model import get_model
from tsfm_public.toolkit.lr_finder import optimal_lr_finder
from tsfm_public.toolkit.visualization import plot_predictions
# For M4 Hourly Example
from tsfm_public.models.tinytimemixer import TinyTimeMixerForPrediction
# Suppress warnings and set a reproducible seed
warnings.filterwarnings("ignore")
SEED = 42
set_seed(SEED)
# Default model parameters and output directory
TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r2"
DEFAULT_CONTEXT_LENGTH = 512
DEFAULT_PREDICTION_LENGTH = 96
OUT_DIR = "dashboard_outputs"
os.makedirs(OUT_DIR, exist_ok=True)
# --------------------------
# Helper: Interactive Plot
def interactive_plot(actual, forecast, title="Forecast vs Actual"):
df = pd.DataFrame(
{"Time": range(len(actual)), "Actual": actual, "Forecast": forecast}
)
fig = px.line(df, x="Time", y=["Actual", "Forecast"], title=title)
return fig
# --------------------------
# Mode 1: Zero-shot Evaluation
def run_zero_shot_forecasting(
data,
context_length,
prediction_length,
batch_size,
selected_target_columns,
selected_conditional_columns,
rolling_forecast_extension,
selected_forecast_index,
):
st.write("### Preparing Data for Forecasting")
timestamp_column = "date"
id_columns = [] # Modify if needed.
# Use selected target columns; default to all columns (except "date") if not provided.
if not selected_target_columns:
target_columns = [col for col in data.columns if col != timestamp_column]
else:
target_columns = selected_target_columns
# Incorporate exogenous/control columns.
conditional_columns = selected_conditional_columns
# Define column specifiers (if your preprocessor supports static columns, add here)
column_specifiers = {
"timestamp_column": timestamp_column,
"id_columns": id_columns,
"target_columns": target_columns,
"control_columns": conditional_columns,
}
n = len(data)
split_config = {
"train": [0, int(n * 0.7)],
"valid": [int(n * 0.7), int(n * 0.8)],
"test": [int(n * 0.8), n],
}
tsp = TimeSeriesPreprocessor(
**column_specifiers,
context_length=context_length,
prediction_length=prediction_length,
scaling=True,
encode_categorical=False,
scaler_type="standard",
)
dset_train, dset_valid, dset_test = get_datasets(tsp, data, split_config)
st.write("Data split into train, validation, and test sets.")
st.write("### Loading the Pre-trained TTM Model")
model = get_model(
TTM_MODEL_PATH,
context_length=context_length,
prediction_length=prediction_length,
)
temp_dir = tempfile.mkdtemp()
training_args = TrainingArguments(
output_dir=temp_dir,
per_device_eval_batch_size=batch_size,
seed=SEED,
report_to="none",
)
trainer = Trainer(model=model, args=training_args)
st.write("### Running Zero-shot Evaluation")
st.info("Evaluating on the test set...")
eval_output = trainer.evaluate(dset_test)
st.write("**Zero-shot Evaluation Metrics:**")
st.json(eval_output)
st.write("### Generating Forecast Predictions")
predictions_dict = trainer.predict(dset_test)
try:
predictions_np = predictions_dict.predictions[0]
except Exception as e:
st.error("Error extracting predictions: " + str(e))
return
st.write("Predictions shape:", predictions_np.shape)
if rolling_forecast_extension > 0:
st.write(
f"### Rolling Forecast Extension: {rolling_forecast_extension} extra steps"
)
st.info("Rolling forecast logic can be implemented here.")
# Interactive plot for a selected forecast index.
idx = selected_forecast_index
try:
# This example assumes dset_test[idx] is a dict with a "target" key; adjust as needed.
actual = (
dset_test[idx]["target"]
if isinstance(dset_test[idx], dict)
else dset_test[idx][0]
)
except Exception:
actual = predictions_np[idx] # Fallback if actual is not available.
fig = interactive_plot(
actual, predictions_np[idx], title=f"Forecast vs Actual for index {idx}"
)
st.plotly_chart(fig)
# Static plots (generated via plot_predictions)
plot_dir = os.path.join(OUT_DIR, "zero_shot_plots")
os.makedirs(plot_dir, exist_ok=True)
try:
plot_predictions(
model=trainer.model,
dset=dset_test,
plot_dir=plot_dir,
plot_prefix="test_zeroshot",
indices=[idx],
channel=0,
)
except Exception as e:
st.error("Error during static plotting: " + str(e))
return
for file in os.listdir(plot_dir):
if file.endswith(".png"):
st.image(os.path.join(plot_dir, file), caption=file)
# --------------------------
# Mode 2: Channel-Mix Finetuning Example
def run_channel_mix_finetuning():
st.write("## Channel-Mix Finetuning Example (Bike Sharing Data)")
# Load bike sharing dataset
target_dataset = "bike_sharing"
DATA_ROOT_PATH = (
"https://raw.githubusercontent.com/blobibob/bike-sharing-dataset/main/hour.csv"
)
timestamp_column = "dteday"
id_columns = []
try:
data = pd.read_csv(DATA_ROOT_PATH, parse_dates=[timestamp_column])
except Exception as e:
st.error("Error loading bike sharing dataset: " + str(e))
return
data[timestamp_column] = pd.to_datetime(data[timestamp_column])
# Adjust timestamps (to add hourly information)
data[timestamp_column] = data[timestamp_column] + pd.to_timedelta(
data.groupby(data[timestamp_column].dt.date).cumcount(), unit="h"
)
st.write("### Bike Sharing Data Preview")
st.dataframe(data.head())
# Define columns: targets and conditional (exogenous) channels
column_specifiers = {
"timestamp_column": timestamp_column,
"id_columns": id_columns,
"target_columns": ["casual", "registered", "cnt"],
"conditional_columns": [
"season",
"yr",
"mnth",
"holiday",
"weekday",
"workingday",
"weathersit",
"temp",
"atemp",
"hum",
"windspeed",
],
}
n = len(data)
split_config = {
"train": [0, int(n * 0.5)],
"valid": [int(n * 0.5), int(n * 0.75)],
"test": [int(n * 0.75), n],
}
context_length = 512
forecast_length = 96
tsp = TimeSeriesPreprocessor(
**column_specifiers,
context_length=context_length,
prediction_length=forecast_length,
scaling=True,
encode_categorical=False,
scaler_type="standard",
)
train_dataset, valid_dataset, test_dataset = get_datasets(tsp, data, split_config)
st.write("Data split completed.")
# For channel-mix finetuning, we use TTM-R1 (as per provided script)
TTM_MODEL_PATH_CM = "ibm-granite/granite-timeseries-ttm-r1"
finetune_forecast_model = get_model(
TTM_MODEL_PATH_CM,
context_length=context_length,
prediction_length=forecast_length,
num_input_channels=tsp.num_input_channels,
decoder_mode="mix_channel",
prediction_channel_indices=tsp.prediction_channel_indices,
)
st.write(
"Number of params before freezing backbone:",
count_parameters(finetune_forecast_model),
)
for param in finetune_forecast_model.backbone.parameters():
param.requires_grad = False
st.write(
"Number of params after freezing backbone:",
count_parameters(finetune_forecast_model),
)
num_epochs = 50
batch_size = 64
learning_rate = 0.001
optimizer = AdamW(finetune_forecast_model.parameters(), lr=learning_rate)
scheduler = OneCycleLR(
optimizer,
learning_rate,
epochs=num_epochs,
steps_per_epoch=math.ceil(len(train_dataset) / batch_size),
)
out_dir = os.path.join(OUT_DIR, target_dataset)
os.makedirs(out_dir, exist_ok=True)
finetune_args = TrainingArguments(
output_dir=os.path.join(out_dir, "output"),
overwrite_output_dir=True,
learning_rate=learning_rate,
num_train_epochs=num_epochs,
do_eval=True,
evaluation_strategy="epoch",
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
dataloader_num_workers=8,
report_to="none",
save_strategy="epoch",
logging_strategy="epoch",
save_total_limit=1,
logging_dir=os.path.join(out_dir, "logs"),
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
seed=SEED,
)
early_stopping_callback = EarlyStoppingCallback(
early_stopping_patience=10,
early_stopping_threshold=1e-5,
)
tracking_callback = TrackingCallback()
finetune_trainer = Trainer(
model=finetune_forecast_model,
args=finetune_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
callbacks=[early_stopping_callback, tracking_callback],
optimizers=(optimizer, scheduler),
)
finetune_trainer.remove_callback(INTEGRATION_TO_CALLBACK["codecarbon"])
st.write("Starting channel-mix finetuning...")
finetune_trainer.train()
st.write("Evaluating finetuned model on test set...")
eval_output = finetune_trainer.evaluate(test_dataset)
st.write("Few-shot (channel-mix) evaluation metrics:")
st.json(eval_output)
# Plot predictions
plot_dir = os.path.join(out_dir, "channel_mix_plots")
os.makedirs(plot_dir, exist_ok=True)
try:
plot_predictions(
model=finetune_trainer.model,
dset=test_dataset,
plot_dir=plot_dir,
plot_prefix="test_channel_mix",
indices=[0],
channel=0,
)
except Exception as e:
st.error("Error plotting channel mix predictions: " + str(e))
return
for file in os.listdir(plot_dir):
if file.endswith(".png"):
st.image(os.path.join(plot_dir, file), caption=file)
# --------------------------
# Mode 3: M4 Hourly Example
def run_m4_hourly_example():
st.write("## M4 Hourly Example")
st.info("This example reproduces a simplified version of the M4 hourly evaluation.")
# For demonstration, we attempt to load an M4 hourly dataset from a URL.
# (In practice, you would need to download and prepare the dataset.)
M4_DATASET_URL = "https://raw.githubusercontent.com/IBM/TSFM-public/main/tsfm_public/notebooks/ETTh1.csv" # Placeholder URL
try:
m4_data = pd.read_csv(M4_DATASET_URL, parse_dates=["date"])
except Exception as e:
st.error("Could not load M4 hourly dataset: " + str(e))
return
st.write("### M4 Hourly Data Preview")
st.dataframe(m4_data.head())
context_length = 512
forecast_length = 48 # M4 hourly forecast horizon
timestamp_column = "date"
id_columns = []
target_columns = [col for col in m4_data.columns if col != timestamp_column]
n = len(m4_data)
split_config = {
"train": [0, int(n * 0.7)],
"valid": [int(n * 0.7), int(n * 0.85)],
"test": [int(n * 0.85), n],
}
column_specifiers = {
"timestamp_column": timestamp_column,
"id_columns": id_columns,
"target_columns": target_columns,
"control_columns": [],
}
tsp = TimeSeriesPreprocessor(
**column_specifiers,
context_length=context_length,
prediction_length=forecast_length,
scaling=True,
encode_categorical=False,
scaler_type="standard",
)
dset_train, dset_valid, dset_test = get_datasets(tsp, m4_data, split_config)
st.write("Data split completed.")
# Load model from Hugging Face TTM Model Repository (TTM-V1 for M4)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = TinyTimeMixerForPrediction.from_pretrained(
"ibm-granite/granite-timeseries-ttm-v1",
revision="main",
prediction_filter_length=forecast_length,
).to(device)
st.write("Running zero-shot evaluation on M4 hourly data...")
temp_dir = tempfile.mkdtemp()
trainer = Trainer(
model=model,
args=TrainingArguments(
output_dir=temp_dir,
per_device_eval_batch_size=64,
report_to="none",
),
)
eval_output = trainer.evaluate(dset_test)
st.write("Zero-shot evaluation metrics on M4 hourly:")
st.json(eval_output)
plot_dir = os.path.join(OUT_DIR, "m4_hourly", "zero_shot")
os.makedirs(plot_dir, exist_ok=True)
try:
plot_predictions(
model=trainer.model,
dset=dset_test,
plot_dir=plot_dir,
plot_prefix="m4_zero_shot",
indices=[0],
channel=0,
)
except Exception as e:
st.error("Error plotting M4 zero-shot predictions: " + str(e))
return
for file in os.listdir(plot_dir):
if file.endswith(".png"):
st.image(os.path.join(plot_dir, file), caption=file)
st.info("Fine-tuning on M4 hourly data can be added similarly.")
# --------------------------
# Main UI
def main():
st.title("Interactive Time-Series Forecasting Dashboard")
st.markdown(
"""
This dashboard lets you run advanced forecasting experiments using the Granite-TimeSeries-TTM model.
Select one of the modes below:
- **Zero-shot Evaluation**
- **Channel-Mix Finetuning Example**
- **M4 Hourly Example**
"""
)
mode = st.selectbox(
"Select Evaluation Mode",
options=[
"Zero-shot Evaluation",
"Channel-Mix Finetuning Example",
"M4 Hourly Example",
],
)
if mode == "Zero-shot Evaluation":
# Allow user to choose dataset source
dataset_source = st.radio(
"Dataset Source", options=["Default (ETTh1)", "Upload CSV"]
)
if dataset_source == "Default (ETTh1)":
DATASET_PATH = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv"
try:
data = pd.read_csv(DATASET_PATH, parse_dates=["date"])
except Exception as e:
st.error("Error loading default dataset.")
return
st.write("### Default Dataset Preview")
st.dataframe(data.head())
selected_target_columns = [
"HUFL",
"HULL",
"MUFL",
"MULL",
"LUFL",
"LULL",
"OT",
]
else:
uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"])
if not uploaded_file:
st.info("Awaiting CSV file upload.")
return
data = pd.read_csv(uploaded_file, parse_dates=["date"])
st.write("### Uploaded Data Preview")
st.dataframe(data.head())
available_columns = [col for col in data.columns if col != "date"]
selected_target_columns = st.multiselect(
"Select Target Column(s)",
options=available_columns,
default=available_columns,
)
# Advanced options
available_exog = [
col
for col in data.columns
if col not in (["date"] + selected_target_columns)
]
selected_conditional_columns = st.multiselect(
"Select Exogenous/Control Columns", options=available_exog, default=[]
)
rolling_extension = st.number_input(
"Rolling Forecast Extension (Extra Steps)", value=0, min_value=0, step=1
)
forecast_index = st.slider(
"Select Forecast Index for Plotting",
min_value=0,
max_value=len(data) - 1,
value=0,
)
context_length = st.number_input(
"Context Length", value=DEFAULT_CONTEXT_LENGTH, step=64
)
prediction_length = st.number_input(
"Prediction Length", value=DEFAULT_PREDICTION_LENGTH, step=1
)
batch_size = st.number_input("Batch Size", value=64, step=1)
if st.button("Run Zero-shot Evaluation"):
with st.spinner("Running zero-shot evaluation..."):
run_zero_shot_forecasting(
data,
context_length,
prediction_length,
batch_size,
selected_target_columns,
selected_conditional_columns,
rolling_extension,
forecast_index,
)
elif mode == "Channel-Mix Finetuning Example":
if st.button("Run Channel-Mix Finetuning Example"):
with st.spinner("Running channel-mix finetuning..."):
run_channel_mix_finetuning()
elif mode == "M4 Hourly Example":
if st.button("Run M4 Hourly Example"):
with st.spinner("Running M4 hourly example..."):
run_m4_hourly_example()
if __name__ == "__main__":
main()