Skip to content

Commit cbb5205

Browse files
committed
lightgbm updates
1 parent 5494a81 commit cbb5205

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

ads/opctl/operator/lowcode/forecast/model/ml_forecast.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,18 @@ def set_model_config(freq):
6868
"Q": 4,
6969
}
7070
sp = seasonal_map.get(freq.upper(), 7)
71-
default_lags = [1, sp, 2 * sp]
71+
series_lengths = data_train.groupby(ForecastOutputColumns.SERIES).size()
72+
min_len = series_lengths.min()
73+
max_allowed = min_len - sp
74+
75+
default_lags = [lag for lag in [1, sp, 2 * sp] if lag <= max_allowed]
7276
lags = model_kwargs.get("lags", default_lags)
7377

7478
default_roll = 2 * sp
7579
roll = model_kwargs.get("RollingMean", default_roll)
7680

77-
diff = model_kwargs.get("Differences", sp)
81+
default_diff = sp if sp <= max_allowed else None
82+
diff = model_kwargs.get("Differences", default_diff)
7883

7984
return {
8085
"target_transforms": [Differences([diff])],
@@ -112,6 +117,7 @@ def set_model_config(freq):
112117
),
113118
},
114119
freq=data_freq,
120+
date_features=['year', 'month', 'day', 'dayofweek', 'dayofyear'],
115121
**additional_data_params,
116122
)
117123

tests/operators/forecast/test_datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def test_load_datasets(model, data_details):
177177
subprocess.run(f"ls -a {output_data_path}", shell=True)
178178
if yaml_i["spec"]["generate_explanations"] and model not in [
179179
"automlx",
180-
# "lgbforecast",
180+
"lgbforecast",
181181
"auto-select",
182182
]:
183183
verify_explanations(

0 commit comments

Comments
 (0)