Skip to content

Commit c1f1767

Browse files
committed
fixed errors to handle different data in Theta forecaster
1 parent c068dd8 commit c1f1767

File tree

1 file changed

+9
-3
lines changed
  • ads/opctl/operator/lowcode/forecast/model

1 file changed

+9
-3
lines changed

ads/opctl/operator/lowcode/forecast/model/theta.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def freq_to_sp(freq: str) -> int | None:
6363
logger.warning("Unable to infer data frequency and sp")
6464
return None
6565

66+
6667
class ThetaOperatorModel(ForecastOperatorBaseModel):
6768
"""Theta operator model"""
6869

@@ -78,6 +79,7 @@ def set_kwargs(self):
7879
model_kwargs = self.spec.model_kwargs
7980
model_kwargs["alpha"] = self.spec.model_kwargs.get("alpha", None)
8081
model_kwargs["initial_level"] = self.spec.model_kwargs.get("initial_level", None)
82+
model_kwargs["deseasonalize"] = self.spec.model_kwargs.get("deseasonalize", True)
8183

8284
if self.spec.confidence_interval_width is None:
8385
self.spec.confidence_interval_width = 1 - 0.90 if model_kwargs["alpha"] is None else model_kwargs["alpha"]
@@ -117,9 +119,12 @@ def _train_model(self, i, series_id, df: pd.DataFrame, model_kwargs: Dict[str, A
117119
else:
118120
if self.perform_tuning:
119121
model_kwargs = self.run_tuning(y, model_kwargs)
122+
if len(y) < 2 * model_kwargs["sp"]:
123+
model_kwargs["deseasonalize"] = False
120124

121125
# Fit ThetaModel using params
122126
model = ThetaForecaster(initial_level=model_kwargs["initial_level"],
127+
deseasonalize=model_kwargs["deseasonalize"],
123128
deseasonalize_model=model_kwargs["deseasonalize_model"], sp=model_kwargs["sp"])
124129
model.fit(y)
125130

@@ -129,8 +134,8 @@ def _train_model(self, i, series_id, df: pd.DataFrame, model_kwargs: Dict[str, A
129134
forecast_values = model.predict(fh)
130135
forecast_range = model.predict_interval(fh=fh)
131136

132-
lower = forecast_range[("y", 0.9, "lower")].rename("yhat_lower")
133-
upper = forecast_range[("y", 0.9, "upper")].rename("yhat_upper")
137+
lower = forecast_range[(self.original_target_column, 0.9, "lower")].rename("yhat_lower")
138+
upper = forecast_range[(self.original_target_column, 0.9, "upper")].rename("yhat_upper")
134139
point = forecast_values.rename("yhat")
135140
forecast = pd.DataFrame(
136141
pd.concat([point, lower, upper], axis=1)
@@ -205,7 +210,8 @@ def objective(trial):
205210
model = ThetaForecaster(
206211
initial_level=initial_level,
207212
sp=sp,
208-
deseasonalize_model=deseason
213+
deseasonalize_model=deseason,
214+
deseasonalize=model_kwargs_i["deseasonalize"],
209215
)
210216

211217
cv = ExpandingWindowSplitter(

0 commit comments

Comments
 (0)