Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,12 @@ After completing a milestone, create a pull request with your changes for review
- [x] Add global theme toggle stored in session state
- [x] Update tests to verify theme CSS output

## PR21: Forecasting Models

- [x] Create utils/time_series.py with ARIMA and naive forecast functions
- [x] Add UI controls on Time Series page for model selection and forecast horizon
- [x] Write tests for forecasting utilities

## Notes for Development

- Create comprehensive commit messages that clearly describe changes
Expand Down
21 changes: 21 additions & 0 deletions pages/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import streamlit as st

from utils import ui, viz
from utils import time_series as ts
from utils.logging import configure_logging

configure_logging()
Expand Down Expand Up @@ -38,6 +39,12 @@ def main() -> None:
period = st.number_input(
"Seasonal Period", min_value=2, value=2, step=1, key="ts_period"
)
model_choice = st.selectbox(
"Forecast Model", ["Naive", "ARIMA"], key="ts_model"
)
horizon = st.number_input(
"Forecast Horizon", min_value=1, value=5, step=1, key="ts_horizon"
)

if st.button("Generate Plots"):
ts_fig = viz.time_series_plot(df, time_col, value_col, title="Time Series")
Expand All @@ -52,6 +59,20 @@ def main() -> None:
mime=f"image/{export_fmt}",
)

series = df.set_index(time_col)[value_col]
if model_choice == "ARIMA":
try:
forecast = ts.arima_forecast(series, steps=horizon)
except ImportError:
st.error("statsmodels is required for ARIMA forecasting.")
forecast = None
else:
forecast = ts.naive_forecast(series, steps=horizon)

if forecast is not None:
st.subheader("Forecast")
st.write(forecast.to_frame(name="forecast"))

dec_fig = viz.decomposition_plot(
df.set_index(time_col)[value_col], period=period, title="Decomposition"
)
Expand Down
7 changes: 7 additions & 0 deletions tests/test_pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ def test_time_series_page_contents():
assert "decomposition_plot" in content


def test_time_series_page_forecast_widgets():
with open("pages/time_series.py", "r", encoding="utf-8") as f:
content = f.read()
assert "Forecast Model" in content
assert "Forecast Horizon" in content


def test_datetime_cols_persist_after_transforms():
import streamlit as st
from utils import transform, eda
Expand Down
29 changes: 29 additions & 0 deletions tests/test_time_series_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import pandas as pd
import pytest

from utils import time_series


def sample_series():
return pd.Series(
range(5), index=pd.date_range("2021-01-01", periods=5, freq="D")
)


def test_naive_forecast_extends_index():
series = sample_series()
forecast = time_series.naive_forecast(series, steps=3)
assert len(forecast) == 3
assert forecast.iloc[0] == series.iloc[-1]
assert forecast.index[0] == series.index[-1] + series.index.freq


def test_arima_forecast_runs_or_errors():
series = sample_series()
try:
fc = time_series.arima_forecast(series, steps=2)
assert len(fc) == 2
except ImportError:
pytest.skip("statsmodels not available")


46 changes: 46 additions & 0 deletions utils/time_series.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""Time series forecasting utilities."""

from __future__ import annotations

from typing import Tuple

import pandas as pd


def _extend_index(index: pd.Index, steps: int) -> pd.Index:
"""Return an extended index for forecast values."""
if isinstance(index, pd.DatetimeIndex) and index.freq is not None:
start = index[-1] + index.freq
return pd.date_range(start, periods=steps, freq=index.freq)
return pd.RangeIndex(index[-1] + 1, index[-1] + 1 + steps)


def naive_forecast(series: pd.Series, steps: int = 1) -> pd.Series:
"""Forecast future values using the last observed value."""
last = series.iloc[-1]
index = _extend_index(series.index, steps)
return pd.Series([last] * steps, index=index, name="naive_forecast")


def arima_forecast(
series: pd.Series,
*,
order: Tuple[int, int, int] = (1, 1, 0),
steps: int = 1,
) -> pd.Series:
"""Forecast future values using an ARIMA model.

Requires the ``statsmodels`` package. If it is not installed an
``ImportError`` is raised.
"""
try:
from statsmodels.tsa.arima.model import ARIMA
except Exception as exc: # pragma: no cover - optional dependency
raise ImportError("statsmodels is required for ARIMA forecasting") from exc

model = ARIMA(series, order=order)
fitted = model.fit()
forecast = fitted.forecast(steps=steps)
index = _extend_index(series.index, steps)
forecast.index = index
return forecast