Skip to content

Commit f632e32

Browse files
committed
try scenarios
1 parent 1c8d4b5 commit f632e32

File tree

3 files changed

+265
-3
lines changed

3 files changed

+265
-3
lines changed

pytorch_forecasting/tests/_conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def gpus():
1717
return 0
1818

1919

20-
@pytest.fixture(scope="package")
20+
@pytest.fixture(scope="session")
2121
def data_with_covariates():
2222
data = get_stallion_data()
2323
data["month"] = data.date.dt.month.astype(str)
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
import numpy as np
2+
import pytest
3+
import torch
4+
5+
from pytorch_forecasting import TimeSeriesDataSet
6+
from pytorch_forecasting.data import EncoderNormalizer, GroupNormalizer, NaNLabelEncoder
7+
from pytorch_forecasting.data.examples import generate_ar_data, get_stallion_data
8+
9+
torch.manual_seed(23)
10+
11+
12+
@pytest.fixture(scope="session")
13+
def gpus():
14+
if torch.cuda.is_available():
15+
return [0]
16+
else:
17+
return 0
18+
19+
20+
def data_with_covariates():
21+
data = get_stallion_data()
22+
data["month"] = data.date.dt.month.astype(str)
23+
data["log_volume"] = np.log1p(data.volume)
24+
data["weight"] = 1 + np.sqrt(data.volume)
25+
26+
data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
27+
data["time_idx"] -= data["time_idx"].min()
28+
29+
# convert special days into strings
30+
special_days = [
31+
"easter_day",
32+
"good_friday",
33+
"new_year",
34+
"christmas",
35+
"labor_day",
36+
"independence_day",
37+
"revolution_day_memorial",
38+
"regional_games",
39+
"fifa_u_17_world_cup",
40+
"football_gold_cup",
41+
"beer_capital",
42+
"music_fest",
43+
]
44+
data[special_days] = (
45+
data[special_days].apply(lambda x: x.map({0: "", 1: x.name})).astype("category")
46+
)
47+
data = data.astype(dict(industry_volume=float))
48+
49+
# select data subset
50+
data = data[lambda x: x.sku.isin(data.sku.unique()[:2])][
51+
lambda x: x.agency.isin(data.agency.unique()[:2])
52+
]
53+
54+
# default target
55+
data["target"] = data["volume"].clip(1e-3, 1.0)
56+
57+
return data
58+
59+
60+
def make_dataloaders(data_with_covariates, **kwargs):
61+
training_cutoff = "2016-09-01"
62+
max_encoder_length = 4
63+
max_prediction_length = 3
64+
65+
kwargs.setdefault("target", "volume")
66+
kwargs.setdefault("group_ids", ["agency", "sku"])
67+
kwargs.setdefault("add_relative_time_idx", True)
68+
kwargs.setdefault("time_varying_unknown_reals", ["volume"])
69+
70+
training = TimeSeriesDataSet(
71+
data_with_covariates[lambda x: x.date < training_cutoff].copy(),
72+
time_idx="time_idx",
73+
max_encoder_length=max_encoder_length,
74+
max_prediction_length=max_prediction_length,
75+
**kwargs, # fixture parametrization
76+
)
77+
78+
validation = TimeSeriesDataSet.from_dataset(
79+
training,
80+
data_with_covariates.copy(),
81+
min_prediction_idx=training.index.time.max() + 1,
82+
)
83+
train_dataloader = training.to_dataloader(train=True, batch_size=2, num_workers=0)
84+
val_dataloader = validation.to_dataloader(train=False, batch_size=2, num_workers=0)
85+
test_dataloader = validation.to_dataloader(train=False, batch_size=1, num_workers=0)
86+
87+
return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader)
88+
89+
90+
@pytest.fixture(
91+
params=[
92+
dict(),
93+
dict(
94+
static_categoricals=["agency", "sku"],
95+
static_reals=["avg_population_2017", "avg_yearly_household_income_2017"],
96+
time_varying_known_categoricals=["special_days", "month"],
97+
variable_groups=dict(
98+
special_days=[
99+
"easter_day",
100+
"good_friday",
101+
"new_year",
102+
"christmas",
103+
"labor_day",
104+
"independence_day",
105+
"revolution_day_memorial",
106+
"regional_games",
107+
"fifa_u_17_world_cup",
108+
"football_gold_cup",
109+
"beer_capital",
110+
"music_fest",
111+
]
112+
),
113+
time_varying_known_reals=[
114+
"time_idx",
115+
"price_regular",
116+
"price_actual",
117+
"discount",
118+
"discount_in_percent",
119+
],
120+
time_varying_unknown_categoricals=[],
121+
time_varying_unknown_reals=[
122+
"volume",
123+
"log_volume",
124+
"industry_volume",
125+
"soda_volume",
126+
"avg_max_temp",
127+
],
128+
constant_fill_strategy={"volume": 0},
129+
categorical_encoders={"sku": NaNLabelEncoder(add_nan=True)},
130+
),
131+
dict(static_categoricals=["agency", "sku"]),
132+
dict(randomize_length=True, min_encoder_length=2),
133+
dict(target_normalizer=EncoderNormalizer(), min_encoder_length=2),
134+
dict(target_normalizer=GroupNormalizer(transformation="log1p")),
135+
dict(
136+
target_normalizer=GroupNormalizer(
137+
groups=["agency", "sku"], transformation="softplus", center=False
138+
)
139+
),
140+
dict(target="agency"),
141+
# test multiple targets
142+
dict(target=["industry_volume", "volume"]),
143+
dict(target=["agency", "volume"]),
144+
dict(
145+
target=["agency", "volume"], min_encoder_length=1, min_prediction_length=1
146+
),
147+
dict(target=["agency", "volume"], weight="volume"),
148+
# test weights
149+
dict(target="volume", weight="volume"),
150+
],
151+
scope="session",
152+
)
153+
def multiple_dataloaders_with_covariates(data_with_covariates, request):
154+
return make_dataloaders(data_with_covariates, **request.param)
155+
156+
157+
@pytest.fixture(scope="session")
158+
def dataloaders_with_different_encoder_decoder_length(data_with_covariates):
159+
return make_dataloaders(
160+
data_with_covariates.copy(),
161+
target="target",
162+
time_varying_known_categoricals=["special_days", "month"],
163+
variable_groups=dict(
164+
special_days=[
165+
"easter_day",
166+
"good_friday",
167+
"new_year",
168+
"christmas",
169+
"labor_day",
170+
"independence_day",
171+
"revolution_day_memorial",
172+
"regional_games",
173+
"fifa_u_17_world_cup",
174+
"football_gold_cup",
175+
"beer_capital",
176+
"music_fest",
177+
]
178+
),
179+
time_varying_known_reals=[
180+
"time_idx",
181+
"price_regular",
182+
"price_actual",
183+
"discount",
184+
"discount_in_percent",
185+
],
186+
time_varying_unknown_categoricals=[],
187+
time_varying_unknown_reals=[
188+
"target",
189+
"volume",
190+
"log_volume",
191+
"industry_volume",
192+
"soda_volume",
193+
"avg_max_temp",
194+
],
195+
static_categoricals=["agency"],
196+
add_relative_time_idx=False,
197+
target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False),
198+
)
199+
200+
201+
@pytest.fixture(scope="session")
202+
def dataloaders_with_covariates(data_with_covariates):
203+
return make_dataloaders(
204+
data_with_covariates.copy(),
205+
target="target",
206+
time_varying_known_reals=["discount"],
207+
time_varying_unknown_reals=["target"],
208+
static_categoricals=["agency"],
209+
add_relative_time_idx=False,
210+
target_normalizer=GroupNormalizer(groups=["agency", "sku"], center=False),
211+
)
212+
213+
214+
@pytest.fixture(scope="session")
215+
def dataloaders_multi_target(data_with_covariates):
216+
return make_dataloaders(
217+
data_with_covariates.copy(),
218+
time_varying_unknown_reals=["target", "discount"],
219+
target=["target", "discount"],
220+
add_relative_time_idx=False,
221+
)
222+
223+
224+
@pytest.fixture(scope="session")
225+
def dataloaders_fixed_window_without_covariates():
226+
data = generate_ar_data(seasonality=10.0, timesteps=50, n_series=2)
227+
validation = data.series.iloc[:2]
228+
229+
max_encoder_length = 30
230+
max_prediction_length = 10
231+
232+
training = TimeSeriesDataSet(
233+
data[lambda x: ~x.series.isin(validation)],
234+
time_idx="time_idx",
235+
target="value",
236+
categorical_encoders={"series": NaNLabelEncoder().fit(data.series)},
237+
group_ids=["series"],
238+
static_categoricals=[],
239+
max_encoder_length=max_encoder_length,
240+
max_prediction_length=max_prediction_length,
241+
time_varying_unknown_reals=["value"],
242+
target_normalizer=EncoderNormalizer(),
243+
)
244+
245+
validation = TimeSeriesDataSet.from_dataset(
246+
training,
247+
data[lambda x: x.series.isin(validation)],
248+
stop_randomization=True,
249+
)
250+
batch_size = 2
251+
train_dataloader = training.to_dataloader(
252+
train=True, batch_size=batch_size, num_workers=0
253+
)
254+
val_dataloader = validation.to_dataloader(
255+
train=False, batch_size=batch_size, num_workers=0
256+
)
257+
test_dataloader = validation.to_dataloader(
258+
train=False, batch_size=batch_size, num_workers=0
259+
)
260+
261+
return dict(train=train_dataloader, val=val_dataloader, test=test_dataloader)

pytorch_forecasting/tests/test_all_estimators.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import lightning.pytorch as pl
77
from lightning.pytorch.callbacks import EarlyStopping
88
from lightning.pytorch.loggers import TensorBoardLogger
9-
import pytest
109
from skbase.testing import BaseFixtureGenerator as _BaseFixtureGenerator
1110

1211
from pytorch_forecasting._registry import all_objects
@@ -276,11 +275,13 @@ def test_integration(
276275
self,
277276
object_metadata,
278277
trainer_kwargs,
279-
data_with_covariates,
280278
tmp_path,
281279
):
282280
"""Fails for certain, for testing."""
283281
from pytorch_forecasting.metrics import NegativeBinomialDistributionLoss
282+
from pytorch_forecasting.tests._data_scenarios import data_with_covariates
283+
284+
data_with_covariates = data_with_covariates()
284285

285286
object_class = object_metadata.get_model_cls()
286287

0 commit comments

Comments
 (0)