diff --git a/TODO.md b/TODO.md index b710c5e..65a481b 100644 --- a/TODO.md +++ b/TODO.md @@ -207,6 +207,12 @@ After completing a milestone, create a pull request with your changes for review - [x] Add global theme toggle stored in session state - [x] Update tests to verify theme CSS output +## PR21: Forecasting Models + +- [x] Create utils/time_series.py with ARIMA and naive forecast functions +- [x] Add UI controls on Time Series page for model selection and forecast horizon +- [x] Write tests for forecasting utilities + ## Notes for Development - Create comprehensive commit messages that clearly describe changes diff --git a/pages/time_series.py b/pages/time_series.py index 30838ab..66ecbec 100644 --- a/pages/time_series.py +++ b/pages/time_series.py @@ -7,6 +7,7 @@ import streamlit as st from utils import ui, viz +from utils import time_series as ts from utils.logging import configure_logging configure_logging() @@ -38,6 +39,12 @@ def main() -> None: period = st.number_input( "Seasonal Period", min_value=2, value=2, step=1, key="ts_period" ) + model_choice = st.selectbox( + "Forecast Model", ["Naive", "ARIMA"], key="ts_model" + ) + horizon = st.number_input( + "Forecast Horizon", min_value=1, value=5, step=1, key="ts_horizon" + ) if st.button("Generate Plots"): ts_fig = viz.time_series_plot(df, time_col, value_col, title="Time Series") @@ -52,6 +59,20 @@ def main() -> None: mime=f"image/{export_fmt}", ) + series = df.set_index(time_col)[value_col] + if model_choice == "ARIMA": + try: + forecast = ts.arima_forecast(series, steps=horizon) + except ImportError: + st.error("statsmodels is required for ARIMA forecasting.") + forecast = None + else: + forecast = ts.naive_forecast(series, steps=horizon) + + if forecast is not None: + st.subheader("Forecast") + st.write(forecast.to_frame(name="forecast")) + dec_fig = viz.decomposition_plot( df.set_index(time_col)[value_col], period=period, title="Decomposition" ) diff --git a/tests/test_pages.py b/tests/test_pages.py index b1235a8..9054851 100644 --- a/tests/test_pages.py +++ b/tests/test_pages.py @@ -112,6 +112,13 @@ def test_time_series_page_contents(): assert "decomposition_plot" in content +def test_time_series_page_forecast_widgets(): + with open("pages/time_series.py", "r", encoding="utf-8") as f: + content = f.read() + assert "Forecast Model" in content + assert "Forecast Horizon" in content + + def test_datetime_cols_persist_after_transforms(): import streamlit as st from utils import transform, eda diff --git a/tests/test_time_series_utils.py b/tests/test_time_series_utils.py new file mode 100644 index 0000000..0bbb78e --- /dev/null +++ b/tests/test_time_series_utils.py @@ -0,0 +1,29 @@ +import pandas as pd +import pytest + +from utils import time_series + + +def sample_series(): + return pd.Series( + range(5), index=pd.date_range("2021-01-01", periods=5, freq="D") + ) + + +def test_naive_forecast_extends_index(): + series = sample_series() + forecast = time_series.naive_forecast(series, steps=3) + assert len(forecast) == 3 + assert forecast.iloc[0] == series.iloc[-1] + assert forecast.index[0] == series.index[-1] + series.index.freq + + +def test_arima_forecast_runs_or_errors(): + series = sample_series() + try: + fc = time_series.arima_forecast(series, steps=2) + assert len(fc) == 2 + except ImportError: + pytest.skip("statsmodels not available") + + diff --git a/utils/time_series.py b/utils/time_series.py new file mode 100644 index 0000000..4703768 --- /dev/null +++ b/utils/time_series.py @@ -0,0 +1,46 @@ +"""Time series forecasting utilities.""" + +from __future__ import annotations + +from typing import Tuple + +import pandas as pd + + +def _extend_index(index: pd.Index, steps: int) -> pd.Index: + """Return an extended index for forecast values.""" + if isinstance(index, pd.DatetimeIndex) and index.freq is not None: + start = index[-1] + index.freq + return pd.date_range(start, periods=steps, freq=index.freq) + return pd.RangeIndex(index[-1] + 1, index[-1] + 1 + steps) + + +def naive_forecast(series: pd.Series, steps: int = 1) -> pd.Series: + """Forecast future values using the last observed value.""" + last = series.iloc[-1] + index = _extend_index(series.index, steps) + return pd.Series([last] * steps, index=index, name="naive_forecast") + + +def arima_forecast( + series: pd.Series, + *, + order: Tuple[int, int, int] = (1, 1, 0), + steps: int = 1, +) -> pd.Series: + """Forecast future values using an ARIMA model. + + Requires the ``statsmodels`` package. If it is not installed an + ``ImportError`` is raised. + """ + try: + from statsmodels.tsa.arima.model import ARIMA + except Exception as exc: # pragma: no cover - optional dependency + raise ImportError("statsmodels is required for ARIMA forecasting") from exc + + model = ARIMA(series, order=order) + fitted = model.fit() + forecast = fitted.forecast(steps=steps) + index = _extend_index(series.index, steps) + forecast.index = index + return forecast