Skip to content

Commit 5494a81

Browse files
committed
support for lightgbm forecasting
1 parent 06dd598 commit 5494a81

File tree

4 files changed

+53
-68
lines changed

4 files changed

+53
-68
lines changed

ads/opctl/operator/lowcode/forecast/model/ml_forecast.py

Lines changed: 44 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def set_kwargs(self):
4141
model_kwargs["uppper_quantile"] = uppper_quantile
4242
return model_kwargs
4343

44+
4445
def preprocess(self, df, series_id):
4546
pass
4647

@@ -53,54 +54,64 @@ def preprocess(self, df, series_id):
5354
err_msg="lightgbm is not installed, please install it with 'pip install lightgbm'",
5455
)
5556
def _train_model(self, data_train, data_test, model_kwargs):
57+
import lightgbm as lgb
58+
from mlforecast import MLForecast
59+
from mlforecast.lag_transforms import ExpandingMean, RollingMean
60+
from mlforecast.target_transforms import Differences
61+
62+
def set_model_config(freq):
63+
seasonal_map = {
64+
"H": 24,
65+
"D": 7,
66+
"W": 52,
67+
"M": 12,
68+
"Q": 4,
69+
}
70+
sp = seasonal_map.get(freq.upper(), 7)
71+
default_lags = [1, sp, 2 * sp]
72+
lags = model_kwargs.get("lags", default_lags)
73+
74+
default_roll = 2 * sp
75+
roll = model_kwargs.get("RollingMean", default_roll)
76+
77+
diff = model_kwargs.get("Differences", sp)
78+
79+
return {
80+
"target_transforms": [Differences([diff])],
81+
"lags": lags,
82+
"lag_transforms": {
83+
1: [ExpandingMean()],
84+
sp: [RollingMean(window_size=roll, min_samples=1)]
85+
}
86+
}
87+
5688
try:
57-
import lightgbm as lgb
58-
from mlforecast import MLForecast
59-
from mlforecast.lag_transforms import ExpandingMean, RollingMean
60-
from mlforecast.target_transforms import Differences
6189

6290
lgb_params = {
6391
"verbosity": model_kwargs.get("verbosity", -1),
6492
"num_leaves": model_kwargs.get("num_leaves", 512),
6593
}
66-
additional_data_params = {}
67-
if len(self.datasets.get_additional_data_column_names()) > 0:
68-
additional_data_params = {
69-
"target_transforms": [
70-
Differences([model_kwargs.get("Differences", 12)])
71-
],
72-
"lags": model_kwargs.get("lags", [1, 6, 12]),
73-
"lag_transforms": (
74-
{
75-
1: [ExpandingMean()],
76-
12: [
77-
RollingMean(
78-
window_size=model_kwargs.get("RollingMean", 24),
79-
min_samples=1,
80-
)
81-
],
82-
}
83-
),
84-
}
94+
95+
data_freq = pd.infer_freq(data_train[self.date_col].drop_duplicates()) \
96+
or pd.infer_freq(data_train[self.date_col].drop_duplicates()[-5:])
97+
98+
additional_data_params = set_model_config(data_freq)
8599

86100
fcst = MLForecast(
87101
models={
88102
"forecast": lgb.LGBMRegressor(**lgb_params),
89-
# "p" + str(int(model_kwargs["uppper_quantile"] * 100))
90103
"upper": lgb.LGBMRegressor(
91104
**lgb_params,
92105
objective="quantile",
93106
alpha=model_kwargs["uppper_quantile"],
94107
),
95-
# "p" + str(int(model_kwargs["lower_quantile"] * 100))
96108
"lower": lgb.LGBMRegressor(
97109
**lgb_params,
98110
objective="quantile",
99111
alpha=model_kwargs["lower_quantile"],
100112
),
101113
},
102-
freq=pd.infer_freq(data_train[self.date_col].drop_duplicates())
103-
or pd.infer_freq(data_train[self.date_col].drop_duplicates()[-5:]),
114+
freq=data_freq,
104115
**additional_data_params,
105116
)
106117

@@ -158,6 +169,7 @@ def _train_model(self, data_train, data_test, model_kwargs):
158169
self.model_parameters[s_id] = {
159170
"framework": SupportedModels.LGBForecast,
160171
**lgb_params,
172+
**fcst.models_['forecast'].get_params(),
161173
}
162174

163175
logger.debug("===========Done===========")
@@ -191,48 +203,21 @@ def _generate_report(self):
191203
Generates the report for the model
192204
"""
193205
import report_creator as rc
194-
from utilsforecast.plotting import plot_series
195206

196207
logging.getLogger("report_creator").setLevel(logging.WARNING)
197208

198-
# Section 1: Forecast Overview
199-
sec1_text = rc.Block(
200-
rc.Heading("Forecast Overview", level=2),
201-
rc.Text(
202-
"These plots show your forecast in the context of historical data."
203-
),
204-
)
205-
sec_1 = _select_plot_list(
206-
lambda s_id: plot_series(
207-
self.datasets.get_all_data_long(include_horizon=False),
208-
pd.concat(
209-
[self.fitted_values, self.outputs], axis=0, ignore_index=True
210-
),
211-
id_col=ForecastOutputColumns.SERIES,
212-
time_col=self.spec.datetime_column.name,
213-
target_col=self.original_target_column,
214-
seed=42,
215-
ids=[s_id],
216-
),
217-
self.datasets.list_series_ids(),
218-
)
219-
220209
# Section 2: LGBForecast Model Parameters
221210
sec2_text = rc.Block(
222211
rc.Heading("LGBForecast Model Parameters", level=2),
223212
rc.Text("These are the parameters used for the LGBForecast model."),
224213
)
225214

226-
blocks = [
227-
rc.Html(
228-
str(s_id[1]),
229-
label=s_id[0],
230-
)
231-
for _, s_id in enumerate(self.model_parameters.items())
232-
]
233-
sec_2 = rc.Select(blocks=blocks)
215+
k, v = next(iter(self.model_parameters.items()))
216+
sec_2 = rc.Html(
217+
pd.DataFrame(list(v.items())).to_html(index=False, header=False),
218+
)
234219

235-
all_sections = [sec1_text, sec_1, sec2_text, sec_2]
220+
all_sections = [sec2_text, sec_2]
236221
model_description = rc.Text(
237222
"LGBForecast uses mlforecast framework to perform time series forecasting using machine learning models"
238223
"with the option to scale to massive amounts of data using remote clusters."

ads/opctl/operator/lowcode/forecast/schema.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ spec:
455455
- prophet
456456
- arima
457457
- neuralprophet
458-
# - lgbforecast
458+
- lgbforecast
459459
- automlx
460460
- autots
461461
- auto-select

tests/operators/forecast/test_datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"prophet",
3333
"neuralprophet",
3434
"autots",
35-
# "lgbforecast",
35+
"lgbforecast",
3636
"auto-select",
3737
"auto-select-series",
3838
]

tests/operators/forecast/test_errors.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@
141141
"arima",
142142
"automlx",
143143
"prophet",
144-
"neuralprophet",
145-
"autots",
146-
# "lgbforecast",
144+
# "neuralprophet",
145+
# "autots",
146+
"lgbforecast",
147147
]
148148

149149
TEMPLATE_YAML = {
@@ -415,8 +415,8 @@ def test_0_series(operator_setup, model):
415415
"local_explanation.csv",
416416
"global_explanation.csv",
417417
]
418-
if model == "autots":
419-
# explanations are not supported for autots
418+
if model in ["autots", "lgbforecast"]:
419+
# explanations are not supported for autots or lgbforecast
420420
output_files.remove("local_explanation.csv")
421421
output_files.remove("global_explanation.csv")
422422
for file in output_files:
@@ -709,7 +709,7 @@ def test_arima_automlx_errors(operator_setup, model):
709709
in error_content["13"]["model_fitting"]["error"]
710710
), f"Error message mismatch: {error_content}"
711711

712-
if model not in ["autots", "automlx"]: # , "lgbforecast"
712+
if model not in ["autots", "automlx", "lgbforecast"]:
713713
if yaml_i["spec"].get("explanations_accuracy_mode") != "AUTOMLX":
714714
global_fn = f"{tmpdirname}/results/global_explanation.csv"
715715
assert os.path.exists(
@@ -816,7 +816,7 @@ def test_date_format(operator_setup, model):
816816
@pytest.mark.parametrize("model", MODELS)
817817
def test_what_if_analysis(operator_setup, model):
818818
os.environ["TEST_MODE"] = "True"
819-
if model == "auto-select":
819+
if model in ["auto-select", "lgbforecast"]:
820820
pytest.skip("Skipping what-if scenario for auto-select")
821821
tmpdirname = operator_setup
822822
historical_data_path, additional_data_path = setup_small_rossman()

0 commit comments

Comments
 (0)