from typing import List

import numpy as np
import pandas as pd
import plotly.graph_objects as go


def plot_train_test(df1: pd.DataFrame, df2: pd.DataFrame) -> go.Figure:
    """
    Plot the training and test datasets using Plotly.

    Args:
        df1 (pd.DataFrame): Train dataset
        df2 (pd.DataFrame): Test dataset

    Returns:
        None
    """

    # Create a Plotly figure
    fig = go.Figure()

    # Add the first scatter plot with steelblue color
    fig.add_trace(
        go.Scatter(
            x=df1.index,
            y=df1.iloc[:, 0],
            mode="lines",
            name="Training Data",
            line=dict(color="steelblue"),
            marker=dict(color="steelblue"),
        )
    )

    # Add the second scatter plot with yellow color
    fig.add_trace(
        go.Scatter(
            x=df2.index,
            y=df2.iloc[:, 0],
            mode="lines",
            name="Test Data",
            line=dict(color="gold"),
            marker=dict(color="gold"),
        )
    )

    # Customize the layout
    fig.update_layout(
        title="Univariate Time Series",
        xaxis=dict(title="Date"),
        yaxis=dict(title="Value"),
        showlegend=True,
        template="plotly_white",
    )
    return fig


def plot_forecast(df: pd.DataFrame, forecasts: List[pd.DataFrame]):
    """
    Plot the true values and forecasts using Plotly.

    Args:
        df (pd.DataFrame): DataFrame with the true values. Assumed to have an index and columns.
        forecasts (List[pd.DataFrame]): List of DataFrames containing the forecasts.

    Returns:
        go.Figure: Plotly figure object.
    """

    # Create a Plotly figure
    fig = go.Figure()

    # Add the true values trace
    fig.add_trace(
        go.Scatter(
            x=pd.to_datetime(df.index),
            y=df.iloc[:, 0],
            mode="lines",
            name="True values",
            line=dict(color="black"),
        )
    )

    # Add the forecast traces
    colors = ["green", "blue", "purple"]
    for i, forecast in enumerate(forecasts):
        color = colors[i % len(colors)]
        for sample in forecast.samples:
            fig.add_trace(
                go.Scatter(
                    x=forecast.index.to_timestamp(),
                    y=sample,
                    mode="lines",
                    opacity=0.15,  # Adjust opacity to control visibility of individual samples
                    name=f"Forecast {i + 1}",
                    showlegend=False,  # Hide the individual forecast series from the legend
                    hoverinfo="none",  # Disable hover information for the forecast series
                    line=dict(color=color),
                )
            )
        # Add the average
        mean_forecast = np.mean(forecast.samples, axis=0)
        fig.add_trace(
            go.Scatter(
                x=forecast.index.to_timestamp(),
                y=mean_forecast,
                mode="lines",
                name="Mean Forecast",
                line=dict(color="red", dash="dash"),
                legendgroup="mean forecast",
                showlegend=i == 0,
            )
        )

    # Customize the layout
    fig.update_layout(
        title=f"{df.columns[0]} Forecast",
        yaxis=dict(title=df.columns[0]),
        showlegend=True,
        legend=dict(x=0, y=1),
        hovermode="x",  # Enable x-axis hover for better interactivity
    )

    # Return the figure
    return fig