Skip to content

Commit a67dd7d

Browse files
committed
[ENH] ARIMA: Add support for exogenous variables
1 parent c3694c9 commit a67dd7d

File tree

1 file changed

+119
-28
lines changed

1 file changed

+119
-28
lines changed

aeon/forecasting/stats/_arima.py

Lines changed: 119 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class ARIMA(BaseForecaster, IterativeForecastingMixin):
5959

6060
_tags = {
6161
"capability:horizon": False, # cannot fit to a horizon other than 1
62+
"capability:exogenous": True, # can handle exogenous variables
6263
}
6364

6465
def __init__(
@@ -84,6 +85,9 @@ def __init__(
8485
self.aic_ = 0
8586
self._model = []
8687
self._parameters = []
88+
self.exog_ = None
89+
self.beta_ = None
90+
self.exog_n_features_ = None
8791
super().__init__(horizon=1, axis=1)
8892

8993
def _fit(self, y, exog=None):
@@ -94,19 +98,39 @@ def _fit(self, y, exog=None):
9498
y : np.ndarray
9599
A time series on which to learn a forecaster to predict horizon ahead
96100
exog : np.ndarray, default =None
97-
Not allowed for this forecaster
101+
Optional exogenous time series data aligned with y. If provided, an
102+
OLS regression (with intercept) is fit and the ARIMA model is fit on
103+
the residual series (y - X beta).
98104
99105
Returns
100106
-------
101107
self
102108
Fitted ARIMA.
103109
"""
104110
self._series = np.array(y.squeeze(), dtype=np.float64)
111+
series_for_arima = self._series.copy()
112+
113+
if exog is not None:
114+
exog = np.asarray(exog)
115+
if exog.ndim == 1:
116+
exog = exog.reshape(-1, 1)
117+
if len(exog) != len(self._series):
118+
raise ValueError("exog must have the same number of rows as y")
119+
self.exog_ = exog
120+
self.exog_n_features_ = exog.shape[1]
121+
X = np.column_stack([np.ones(len(self._series)), exog])
122+
self.beta_ = np.linalg.lstsq(X, self._series, rcond=None)[0]
123+
series_for_arima = self._series - X @ self.beta_
124+
else:
125+
self.beta_ = None
126+
self.exog_ = None
127+
self.exog_n_features_ = None
128+
105129
# Model is an array of the (c,p,q)
106130
self._model = np.array(
107131
(1 if self.use_constant else 0, self.p, self.q), dtype=np.int32
108132
)
109-
self._differenced_series = np.diff(self._series, n=self.d)
133+
self._differenced_series = np.diff(series_for_arima, n=self.d)
110134
s = 0.1 / (np.sum(self._model) + 1) # Randomise
111135
# Nelder Mead returns the parameters in a single array
112136
(self._parameters, self.aic_) = nelder_mead(
@@ -152,7 +176,14 @@ def _predict(self, y, exog=None):
152176
A time series to predict the value of. y can be independent of the series
153177
seen in fit.
154178
exog : np.ndarray, default =None
155-
Optional exogenous time series data assumed to be aligned with y
179+
Optional exogenous time series data. If the model was fit with exogenous
180+
variables, a regression contribution will be added to the ARIMA forecast.
181+
For one-step forecasting, `exog` may be:
182+
- a 1D array representing the single future exog row (n_features,)
183+
- a 2D array with shape (n_obs, n_features), in which case the last row
184+
will be used as the future exog row.
185+
- a 2D array aligned with `y` (same number of rows);
186+
last row will be used.
156187
157188
Returns
158189
-------
@@ -194,6 +225,36 @@ def _predict(self, y, exog=None):
194225

195226
forecast_diff = c + ar_forecast + ma_forecast
196227

228+
reg_component = 0.0
229+
if self.beta_ is not None:
230+
if exog is None:
231+
raise ValueError(
232+
"exog must be provided for prediction"
233+
"when model was fit with exogenous variables"
234+
)
235+
exog_arr = np.asarray(exog)
236+
if exog_arr.ndim == 1:
237+
if exog_arr.shape[0] == len(y):
238+
exog_row = np.atleast_1d(exog_arr[-1])
239+
elif exog_arr.shape[0] == 1:
240+
exog_row = np.atleast_1d(exog_arr.reshape(1, -1)[0])
241+
else:
242+
exog_row = np.atleast_1d(exog_arr.reshape(-1, 1)[-1])
243+
else:
244+
exog_row = exog_arr[-1]
245+
exog_row = np.asarray(exog_row).reshape(-1)
246+
if (
247+
self.exog_n_features_ is not None
248+
and exog_row.shape[0] != self.exog_n_features_
249+
):
250+
raise ValueError(
251+
f"exog must have {self.exog_n_features_} features,"
252+
f" got {exog_row.shape[0]}"
253+
)
254+
Xf = np.concatenate(([1.0], exog_row), axis=0)
255+
reg_component = float(Xf @ self.beta_)
256+
forecast_diff = forecast_diff + reg_component
257+
197258
# Undifference the forecast
198259
if self.d == 0:
199260
return forecast_diff
@@ -207,42 +268,63 @@ def _forecast(self, y, exog=None):
207268
self._fit(y, exog)
208269
return float(self.forecast_)
209270

210-
def iterative_forecast(self, y, prediction_horizon):
271+
def iterative_forecast(self, y, prediction_horizon, exog=None):
211272
self.fit(y)
212-
n = len(self._differenced_series)
213-
p, q = self.p, self.q
214-
phi, theta = self.phi_, self.theta_
215273
h = prediction_horizon
216-
c = 0.0
217-
if self.use_constant:
218-
c = self.c_
274+
p, q, d = self.p, self.q, self.d
275+
phi, theta = self.phi_, self.theta_
276+
c = self.c_ if self.use_constant else 0.0
277+
278+
if self.beta_ is not None:
279+
if exog is None:
280+
raise ValueError(
281+
"Future exogenous values must be provided"
282+
" for multi-step forecasting."
283+
)
284+
exog = np.asarray(exog)
285+
if exog.ndim == 1:
286+
exog = exog.reshape(-1, self.exog_n_features_)
287+
if exog.shape[0] != h:
288+
raise ValueError("Future exog must have prediction_horizon rows.")
289+
if exog.shape[1] != self.exog_n_features_:
290+
raise ValueError(
291+
f"Future exog must have {self.exog_n_features_} columns."
292+
)
293+
future_exog = exog
294+
else:
295+
future_exog = None
219296

220-
# Start with a copy of the original series and residuals
221-
residuals = np.zeros(len(self.residuals_) + h)
222-
residuals[: len(self.residuals_)] = self.residuals_
297+
n = len(self._differenced_series)
223298
forecast_series = np.zeros(n + h)
224299
forecast_series[:n] = self._differenced_series
300+
301+
residuals = np.zeros(len(self.residuals_) + h)
302+
residuals[: len(self.residuals_)] = self.residuals_
303+
225304
for i in range(h):
226-
# Get most recent p values (lags)
305+
227306
t = n + i
228-
ar_term = 0.0
229-
if p > 0:
230-
ar_term = np.dot(phi, forecast_series[t - np.arange(1, p + 1)])
231-
# Get most recent q residuals (lags)
232-
ma_term = 0.0
233-
if q > 0:
234-
ma_term = np.dot(theta, residuals[t - np.arange(1, q + 1)])
307+
ar_term = (
308+
np.dot(phi, forecast_series[t - np.arange(1, p + 1)]) if p > 0 else 0.0
309+
)
310+
ma_term = (
311+
np.dot(theta, residuals[t - np.arange(1, q + 1)]) if q > 0 else 0.0
312+
)
235313
next_value = c + ar_term + ma_term
236-
# Append prediction and a zero residual (placeholder)
237-
forecast_series[n + i] = next_value
238-
# Can't compute real residual during prediction, leave as zero
239314

240-
# Correct differencing using forecast values
315+
if future_exog is not None:
316+
Xf = np.concatenate(([1.0], future_exog[i]))
317+
next_value += float(Xf @ self.beta_)
318+
319+
forecast_series[t] = next_value
320+
241321
y_forecast_diff = forecast_series[n : n + h]
242-
if self.d == 0:
322+
323+
if d == 0:
243324
return y_forecast_diff
244325
else:
245-
return _undifference(y_forecast_diff, self._series[-self.d :])[self.d :]
326+
restored = _undifference(y_forecast_diff, self._series[-d:])
327+
return restored[d:]
246328

247329

248330
class AutoARIMA(BaseForecaster, IterativeForecastingMixin):
@@ -280,6 +362,11 @@ class AutoARIMA(BaseForecaster, IterativeForecastingMixin):
280362
481.87157356139943
281363
"""
282364

365+
_tags = {
366+
"capability:exogenous": True,
367+
"capability:horizon": False,
368+
}
369+
283370
def __init__(self, max_p=3, max_d=3, max_q=2):
284371
self.max_p = max_p
285372
self.max_d = max_d
@@ -336,7 +423,11 @@ def _fit(self, y, exog=None):
336423
) = model
337424
self.constant_term_ = constant_term_int == 1
338425
self.final_model_ = ARIMA(self.p_, self.d_, self.q_, self.constant_term_)
339-
self.final_model_.fit(y)
426+
427+
if exog is not None:
428+
self.final_model_.fit(y, exog=exog)
429+
else:
430+
self.final_model_.fit(y)
340431
self.forecast_ = self.final_model_.forecast_
341432
return self
342433

0 commit comments

Comments
 (0)