Skip to content

Commit 3445305

Browse files
author
PranavBhatP
committed
[ENH] initial tests for tide - errors prevalent
1 parent e87230b commit 3445305

File tree

1 file changed

+324
-0
lines changed

1 file changed

+324
-0
lines changed

tests/test_models/test_tide.py

Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
import pickle
2+
import shutil
3+
import sys
4+
5+
import lightning.pytorch as pl
6+
from lightning.pytorch.callbacks import EarlyStopping
7+
from lightning.pytorch.loggers import TensorBoardLogger
8+
import numpy as np
9+
import pandas as pd
10+
import pytest
11+
from test_models.conftest import make_dataloaders
12+
import torch
13+
14+
from pytorch_forecasting import TimeSeriesDataSet
15+
from pytorch_forecasting.data.encoders import (
16+
GroupNormalizer,
17+
MultiNormalizer,
18+
NaNLabelEncoder,
19+
)
20+
from pytorch_forecasting.metrics import (
21+
MAE,
22+
MAPE,
23+
SMAPE,
24+
CrossEntropy,
25+
MultiLoss,
26+
PoissonLoss,
27+
QuantileLoss,
28+
)
29+
from pytorch_forecasting.metrics.distributions import NegativeBinomialDistributionLoss
30+
from pytorch_forecasting.models import TiDEModel
31+
from pytorch_forecasting.utils._dependencies import _get_installed_packages
32+
33+
34+
def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs):
35+
"Integration test for TiDEModel functionality."
36+
37+
train_dataloader = dataloader["train"]
38+
val_dataloader = dataloader["val"]
39+
test_dataloader = dataloader["test"]
40+
41+
early_stop = EarlyStopping(
42+
monitor="val_loss",
43+
patience=1,
44+
verbose=False,
45+
mode="min",
46+
)
47+
48+
logger = TensorBoardLogger(tmp_path)
49+
50+
if trainer_kwargs is None:
51+
trainer_kwargs = {}
52+
53+
trainer = pl.Trainer(
54+
max_epochs=2,
55+
gradient_clip_val=0.1,
56+
callbacks=[early_stop],
57+
enable_checkpointing=True,
58+
default_root_dir=tmp_path,
59+
limit_train_batches=2,
60+
limit_val_batches=2,
61+
limit_test_batches=2,
62+
logger=logger,
63+
**trainer_kwargs,
64+
)
65+
66+
kwargs.setdefault("learning_rate", 0.15)
67+
68+
if loss is not None:
69+
pass
70+
elif isinstance(train_dataloader.dataset.target_normalizer, NaNLabelEncoder):
71+
loss = CrossEntropy()
72+
elif isinstance(train_dataloader.dataset.target_normalizer, MultiNormalizer):
73+
loss = MultiLoss(
74+
[
75+
(
76+
(
77+
CrossEntropy()
78+
if isinstance(normalizer, NaNLabelEncoder)
79+
else QuantileLoss()
80+
),
81+
)
82+
for normalizer in train_dataloader.dataset.target_normalizer.normalizers
83+
]
84+
)
85+
else:
86+
loss = QuantileLoss()
87+
88+
net = TiDEModel.from_dataset(
89+
train_dataloader.dataset,
90+
hidden_size=4,
91+
decoder_output_dim=4,
92+
num_encoder_layers=2,
93+
num_decoder_layers=2,
94+
dropout=0.2,
95+
loss=loss,
96+
add_relative_time_idx=False,
97+
temporal_decoder_hidden=4,
98+
temporal_width_future=2,
99+
temporal_hidden_size_future=4,
100+
log_interval=5,
101+
log_val_interval=1,
102+
**kwargs,
103+
)
104+
105+
net.size()
106+
107+
try:
108+
trainer.fit(
109+
net,
110+
train_dataloaders=train_dataloader,
111+
val_dataloaders=val_dataloader,
112+
)
113+
114+
test_outputs = trainer.test(
115+
net,
116+
test_dataloaders=test_dataloader,
117+
)
118+
assert len(test_outputs) > 0
119+
120+
net = TiDEModel.load_from_checkpoint(
121+
trainer.checkpoint_callback.best_model_path
122+
)
123+
124+
predictions = net.predict(
125+
val_dataloader,
126+
return_index=True,
127+
return_x=True,
128+
return_y=True,
129+
fast_dev_run=True,
130+
trainer_kwargs=trainer_kwargs,
131+
)
132+
133+
pred_len = len(predictions.index)
134+
135+
def check(x):
136+
if isinstance(x, (tuple, list)):
137+
for xi in x:
138+
check(xi)
139+
elif isinstance(x, dict):
140+
for xi in x.values():
141+
check(xi)
142+
else:
143+
assert (
144+
pred_len == x.shape[0]
145+
), "first dimension should be prediction length"
146+
147+
check(predictions.output)
148+
if isinstance(predictions.output, torch.Tensor):
149+
assert (
150+
predictions.output.ndim == 2
151+
), "shape of predictions should be batch_size x timesteps"
152+
else:
153+
assert all(
154+
p.ndim == 2 for p in predictions.output
155+
), "shape of predictions should be batch_size x timesteps"
156+
157+
check(predictions.output)
158+
159+
if isinstance(predictions.output, torch.Tensor):
160+
assert (
161+
predictions.output.ndim == 2
162+
), "shape of predictions should be batch_size x timesteps"
163+
else:
164+
assert all(
165+
p.ndim == 2 for p in predictions.output
166+
), "shape of predictions should be batch_size x timesteps"
167+
check(predictions.x)
168+
check(predictions.index)
169+
finally:
170+
shutil.rmtree(tmp_path, ignore_errors=True)
171+
172+
173+
def test_integration(multiple_dataloaders_with_covariates, tmp_path):
174+
"""Test basic integration of model with covariates."""
175+
_integration(
176+
multiple_dataloaders_with_covariates,
177+
tmp_path,
178+
trainer_kwargs=dict(accelerator="cpu"),
179+
)
180+
181+
182+
@pytest.fixture
183+
def model(dataloaders_with_covariates):
184+
"""Create a model for testing."""
185+
186+
dataset = dataloaders_with_covariates["train"].dataset
187+
188+
net = TiDEModel.from_dataset(
189+
dataset=dataset,
190+
learning_rate=0.15,
191+
hidden_size=4,
192+
num_encoder_layers=2,
193+
num_decoder_layers=2,
194+
decoder_output_dim=4,
195+
dropout=0.2,
196+
temporal_decoder_hidden=4,
197+
temporal_width_future=2,
198+
temporal_hidden_size_future=4,
199+
loss=PoissonLoss(),
200+
output_size=1,
201+
log_interval=5,
202+
log_val_interval=1,
203+
)
204+
return net
205+
206+
207+
def test_tensorboard_graph_log(dataloaders_with_covariates, model, tmp_path):
208+
"""Test if tensorboard graph can be logged."""
209+
d = next(iter(dataloaders_with_covariates["train"]))
210+
logger = TensorBoardLogger("test", str(tmp_path), log_graph=True)
211+
logger.log_graph(model, d[0])
212+
213+
214+
def test_pickle(model):
215+
"""Test that model can be pickled and unpickled."""
216+
pkl = pickle.dumps(model)
217+
pickle.loads(pkl) # noqa: S301
218+
219+
220+
@pytest.mark.parametrize(
221+
"kwargs", [dict(mode="dataframe"), dict(mode="series"), dict(mode="raw")]
222+
)
223+
def test_predict_dependency(
224+
model, dataloaders_with_covariates, data_with_covariates, kwargs
225+
):
226+
"""Test if predict_dependency works correctly."""
227+
train_dataset = dataloaders_with_covariates["train"].dataset
228+
data_with_covariates = data_with_covariates.copy()
229+
dataset = TimeSeriesDataSet.from_dataset(
230+
train_dataset,
231+
data_with_covariates[lambda x: x.agency == data_with_covariates.agency.iloc[0]],
232+
predict=True,
233+
)
234+
model.predict_dependency(dataset, variable="discount", values=[0.1, 0.0], **kwargs)
235+
model.predict_dependency(
236+
dataset,
237+
variable="agency",
238+
values=data_with_covariates.agency.unique()[:2],
239+
**kwargs,
240+
)
241+
242+
243+
@pytest.mark.parametrize(
244+
"kwargs",
245+
[
246+
dict(mode="raw"),
247+
dict(mode="quantiles"),
248+
dict(return_index=True),
249+
dict(return_decoder_lengths=True),
250+
dict(return_x=True),
251+
dict(return_y=True),
252+
],
253+
)
254+
def test_prediction_with_dataloader(model, dataloaders_with_covariates, kwargs):
255+
"""Test prediction with dataloader."""
256+
val_dataloader = dataloaders_with_covariates["val"]
257+
model.predict(val_dataloader, fast_dev_run=True, **kwargs)
258+
259+
260+
def test_prediction_with_dataset(model, dataloaders_with_covariates):
261+
"""Test prediction with dataset."""
262+
val_dataloader = dataloaders_with_covariates["val"]
263+
model.predict(val_dataloader.dataset, fast_dev_run=True)
264+
265+
266+
def test_prediction_with_dataframe(model, data_with_covariates):
267+
"""Test the prediction with dataframe."""
268+
model.predict(data_with_covariates, fast_dev_run=True)
269+
270+
271+
def test_no_exogenous_variable():
272+
"""Test whether model works without exogenous variables."""
273+
data = pd.DataFrame(
274+
{
275+
"target": np.ones(1600),
276+
"group_id": np.repeat(np.arange(16), 100),
277+
"time_idx": np.tile(np.arange(100), 16),
278+
}
279+
)
280+
training_dataset = TimeSeriesDataSet(
281+
data=data,
282+
time_idx="time_idx",
283+
target="target",
284+
group_ids=["group_id"],
285+
max_encoder_length=10,
286+
max_prediction_length=5,
287+
min_encoder_length=10,
288+
min_prediction_length=5,
289+
time_varying_unknown_reals=["target"],
290+
time_varying_known_reals=[],
291+
)
292+
validation_dataset = TimeSeriesDataSet.from_dataset(
293+
training_dataset, data, stop_randomization=True, predict=True
294+
)
295+
training_data_loader = training_dataset.to_dataloader(
296+
train=True, batch_size=8, num_workers=0
297+
)
298+
validation_data_loader = validation_dataset.to_dataloader(
299+
train=False, batch_size=8, num_workers=0
300+
)
301+
forecaster = TiDEModel.from_dataset(
302+
training_dataset,
303+
log_interval=1,
304+
)
305+
from lightning.pytorch import Trainer
306+
307+
trainer = Trainer(
308+
max_epochs=2,
309+
limit_train_batches=8,
310+
limit_val_batches=8,
311+
)
312+
trainer.fit(
313+
forecaster,
314+
train_dataloaders=training_data_loader,
315+
val_dataloaders=validation_data_loader,
316+
)
317+
best_model_path = trainer.checkpoint_callback.best_model_path
318+
best_model = TiDEModel.load_from_checkpoint(best_model_path)
319+
best_model.predict(
320+
validation_data_loader,
321+
return_x=True,
322+
return_y=True,
323+
return_index=True,
324+
)

0 commit comments

Comments
 (0)