Skip to content

Commit b27b7dc

Browse files
author
PranavBhatP
committed
[ENH] update tests for tide - tests passing
1 parent 3445305 commit b27b7dc

File tree

1 file changed

+70
-192
lines changed

1 file changed

+70
-192
lines changed

tests/test_models/test_tide.py

Lines changed: 70 additions & 192 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,35 @@
11
import pickle
22
import shutil
3-
import sys
43

54
import lightning.pytorch as pl
65
from lightning.pytorch.callbacks import EarlyStopping
76
from lightning.pytorch.loggers import TensorBoardLogger
87
import numpy as np
98
import pandas as pd
109
import pytest
11-
from test_models.conftest import make_dataloaders
12-
import torch
1310

14-
from pytorch_forecasting import TimeSeriesDataSet
15-
from pytorch_forecasting.data.encoders import (
16-
GroupNormalizer,
17-
MultiNormalizer,
18-
NaNLabelEncoder,
19-
)
20-
from pytorch_forecasting.metrics import (
21-
MAE,
22-
MAPE,
23-
SMAPE,
24-
CrossEntropy,
25-
MultiLoss,
26-
PoissonLoss,
27-
QuantileLoss,
28-
)
29-
from pytorch_forecasting.metrics.distributions import NegativeBinomialDistributionLoss
11+
from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
12+
from pytorch_forecasting.metrics import MAE, SMAPE, QuantileLoss
3013
from pytorch_forecasting.models import TiDEModel
3114
from pytorch_forecasting.utils._dependencies import _get_installed_packages
3215

3316

34-
def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs):
35-
"Integration test for TiDEModel functionality."
36-
17+
def _integration(dataloader, tmp_path, trainer_kwargs=None, **kwargs):
3718
train_dataloader = dataloader["train"]
3819
val_dataloader = dataloader["val"]
3920
test_dataloader = dataloader["test"]
4021

41-
early_stop = EarlyStopping(
42-
monitor="val_loss",
43-
patience=1,
44-
verbose=False,
45-
mode="min",
22+
early_stop_callback = EarlyStopping(
23+
monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min"
4624
)
4725

4826
logger = TensorBoardLogger(tmp_path)
49-
5027
if trainer_kwargs is None:
5128
trainer_kwargs = {}
52-
5329
trainer = pl.Trainer(
5430
max_epochs=2,
5531
gradient_clip_val=0.1,
56-
callbacks=[early_stop],
32+
callbacks=[early_stop_callback],
5733
enable_checkpointing=True,
5834
default_root_dir=tmp_path,
5935
limit_train_batches=2,
@@ -63,213 +39,118 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs)
6339
**trainer_kwargs,
6440
)
6541

66-
kwargs.setdefault("learning_rate", 0.15)
67-
68-
if loss is not None:
69-
pass
70-
elif isinstance(train_dataloader.dataset.target_normalizer, NaNLabelEncoder):
71-
loss = CrossEntropy()
72-
elif isinstance(train_dataloader.dataset.target_normalizer, MultiNormalizer):
73-
loss = MultiLoss(
74-
[
75-
(
76-
(
77-
CrossEntropy()
78-
if isinstance(normalizer, NaNLabelEncoder)
79-
else QuantileLoss()
80-
),
81-
)
82-
for normalizer in train_dataloader.dataset.target_normalizer.normalizers
83-
]
84-
)
85-
else:
86-
loss = QuantileLoss()
42+
kwargs.setdefault("hidden_size", 16)
43+
kwargs.setdefault("temporal_decoder_hidden", 8)
44+
kwargs.setdefault("temporal_width_future", 4)
45+
kwargs.setdefault("dropout", 0.1)
46+
kwargs.setdefault("learning_rate", 0.01)
8747

8848
net = TiDEModel.from_dataset(
8949
train_dataloader.dataset,
90-
hidden_size=4,
91-
decoder_output_dim=4,
92-
num_encoder_layers=2,
93-
num_decoder_layers=2,
94-
dropout=0.2,
95-
loss=loss,
96-
add_relative_time_idx=False,
97-
temporal_decoder_hidden=4,
98-
temporal_width_future=2,
99-
temporal_hidden_size_future=4,
100-
log_interval=5,
101-
log_val_interval=1,
10250
**kwargs,
10351
)
104-
10552
net.size()
106-
10753
try:
10854
trainer.fit(
10955
net,
11056
train_dataloaders=train_dataloader,
11157
val_dataloaders=val_dataloader,
11258
)
113-
114-
test_outputs = trainer.test(
115-
net,
116-
test_dataloaders=test_dataloader,
117-
)
59+
test_outputs = trainer.test(net, dataloaders=test_dataloader)
11860
assert len(test_outputs) > 0
119-
61+
# check loading
12062
net = TiDEModel.load_from_checkpoint(
12163
trainer.checkpoint_callback.best_model_path
12264
)
12365

124-
predictions = net.predict(
66+
# check prediction
67+
net.predict(
12568
val_dataloader,
126-
return_index=True,
127-
return_x=True,
128-
return_y=True,
12969
fast_dev_run=True,
70+
return_index=True,
71+
return_decoder_lengths=True,
13072
trainer_kwargs=trainer_kwargs,
13173
)
74+
finally:
75+
shutil.rmtree(tmp_path, ignore_errors=True)
13276

133-
pred_len = len(predictions.index)
134-
135-
def check(x):
136-
if isinstance(x, (tuple, list)):
137-
for xi in x:
138-
check(xi)
139-
elif isinstance(x, dict):
140-
for xi in x.values():
141-
check(xi)
142-
else:
143-
assert (
144-
pred_len == x.shape[0]
145-
), "first dimension should be prediction length"
146-
147-
check(predictions.output)
148-
if isinstance(predictions.output, torch.Tensor):
149-
assert (
150-
predictions.output.ndim == 2
151-
), "shape of predictions should be batch_size x timesteps"
152-
else:
153-
assert all(
154-
p.ndim == 2 for p in predictions.output
155-
), "shape of predictions should be batch_size x timesteps"
77+
predictions = net.predict(
78+
val_dataloader,
79+
fast_dev_run=True,
80+
return_index=True,
81+
return_decoder_lengths=True,
82+
)
83+
return predictions
15684

157-
check(predictions.output)
15885

159-
if isinstance(predictions.output, torch.Tensor):
160-
assert (
161-
predictions.output.ndim == 2
162-
), "shape of predictions should be batch_size x timesteps"
163-
else:
164-
assert all(
165-
p.ndim == 2 for p in predictions.output
166-
), "shape of predictions should be batch_size x timesteps"
167-
check(predictions.x)
168-
check(predictions.index)
169-
finally:
170-
shutil.rmtree(tmp_path, ignore_errors=True)
86+
@pytest.mark.parametrize(
87+
"kwargs",
88+
[
89+
{},
90+
{"loss": SMAPE()},
91+
{"hidden_size": 32, "temporal_decoder_hidden": 16},
92+
{"dropout": 0.2, "use_layer_norm": True},
93+
],
94+
)
95+
def test_integration(dataloaders_with_covariates, tmp_path, kwargs):
96+
_integration(dataloaders_with_covariates, tmp_path, **kwargs)
17197

17298

173-
def test_integration(multiple_dataloaders_with_covariates, tmp_path):
174-
"""Test basic integration of model with covariates."""
175-
_integration(
176-
multiple_dataloaders_with_covariates,
177-
tmp_path,
178-
trainer_kwargs=dict(accelerator="cpu"),
179-
)
99+
@pytest.mark.parametrize(
100+
"kwargs",
101+
[
102+
{}, # Default settings for multi-target
103+
],
104+
)
105+
def test_multi_target_integration(dataloaders_multi_target, tmp_path, kwargs):
106+
_integration(dataloaders_multi_target, tmp_path, **kwargs)
180107

181108

182109
@pytest.fixture
183110
def model(dataloaders_with_covariates):
184-
"""Create a model for testing."""
185-
186111
dataset = dataloaders_with_covariates["train"].dataset
187-
188112
net = TiDEModel.from_dataset(
189-
dataset=dataset,
190-
learning_rate=0.15,
191-
hidden_size=4,
192-
num_encoder_layers=2,
193-
num_decoder_layers=2,
194-
decoder_output_dim=4,
195-
dropout=0.2,
196-
temporal_decoder_hidden=4,
197-
temporal_width_future=2,
198-
temporal_hidden_size_future=4,
199-
loss=PoissonLoss(),
200-
output_size=1,
201-
log_interval=5,
202-
log_val_interval=1,
113+
dataset,
114+
hidden_size=16,
115+
dropout=0.1,
116+
temporal_width_future=4,
203117
)
204118
return net
205119

206120

207-
def test_tensorboard_graph_log(dataloaders_with_covariates, model, tmp_path):
208-
"""Test if tensorboard graph can be logged."""
209-
d = next(iter(dataloaders_with_covariates["train"]))
210-
logger = TensorBoardLogger("test", str(tmp_path), log_graph=True)
211-
logger.log_graph(model, d[0])
212-
213-
214121
def test_pickle(model):
215-
"""Test that model can be pickled and unpickled."""
216122
pkl = pickle.dumps(model)
217123
pickle.loads(pkl) # noqa: S301
218124

219125

220-
@pytest.mark.parametrize(
221-
"kwargs", [dict(mode="dataframe"), dict(mode="series"), dict(mode="raw")]
126+
@pytest.mark.skipif(
127+
"matplotlib" not in _get_installed_packages(),
128+
reason="skip test if required package matplotlib not installed",
222129
)
223-
def test_predict_dependency(
224-
model, dataloaders_with_covariates, data_with_covariates, kwargs
225-
):
226-
"""Test if predict_dependency works correctly."""
227-
train_dataset = dataloaders_with_covariates["train"].dataset
228-
data_with_covariates = data_with_covariates.copy()
229-
dataset = TimeSeriesDataSet.from_dataset(
230-
train_dataset,
231-
data_with_covariates[lambda x: x.agency == data_with_covariates.agency.iloc[0]],
232-
predict=True,
233-
)
234-
model.predict_dependency(dataset, variable="discount", values=[0.1, 0.0], **kwargs)
235-
model.predict_dependency(
236-
dataset,
237-
variable="agency",
238-
values=data_with_covariates.agency.unique()[:2],
239-
**kwargs,
130+
def test_prediction_visualization(model, dataloaders_with_covariates):
131+
raw_predictions = model.predict(
132+
dataloaders_with_covariates["val"],
133+
mode="raw",
134+
return_x=True,
135+
fast_dev_run=True,
240136
)
137+
model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=0)
241138

242139

243-
@pytest.mark.parametrize(
244-
"kwargs",
245-
[
246-
dict(mode="raw"),
247-
dict(mode="quantiles"),
248-
dict(return_index=True),
249-
dict(return_decoder_lengths=True),
250-
dict(return_x=True),
251-
dict(return_y=True),
252-
],
253-
)
254-
def test_prediction_with_dataloader(model, dataloaders_with_covariates, kwargs):
255-
"""Test prediction with dataloader."""
256-
val_dataloader = dataloaders_with_covariates["val"]
257-
model.predict(val_dataloader, fast_dev_run=True, **kwargs)
258-
259-
260-
def test_prediction_with_dataset(model, dataloaders_with_covariates):
261-
"""Test prediction with dataset."""
262-
val_dataloader = dataloaders_with_covariates["val"]
263-
model.predict(val_dataloader.dataset, fast_dev_run=True)
264-
265-
266-
def test_prediction_with_dataframe(model, data_with_covariates):
267-
"""Test the prediction with dataframe."""
268-
model.predict(data_with_covariates, fast_dev_run=True)
140+
def test_prediction_with_kwargs(model, dataloaders_with_covariates):
141+
# Tests prediction works with different keyword arguments
142+
model.predict(
143+
dataloaders_with_covariates["val"], return_index=True, fast_dev_run=True
144+
)
145+
model.predict(
146+
dataloaders_with_covariates["val"],
147+
return_x=True,
148+
return_y=True,
149+
fast_dev_run=True,
150+
)
269151

270152

271153
def test_no_exogenous_variable():
272-
"""Test whether model works without exogenous variables."""
273154
data = pd.DataFrame(
274155
{
275156
"target": np.ones(1600),
@@ -284,8 +165,6 @@ def test_no_exogenous_variable():
284165
group_ids=["group_id"],
285166
max_encoder_length=10,
286167
max_prediction_length=5,
287-
min_encoder_length=10,
288-
min_prediction_length=5,
289168
time_varying_unknown_reals=["target"],
290169
time_varying_known_reals=[],
291170
)
@@ -300,7 +179,6 @@ def test_no_exogenous_variable():
300179
)
301180
forecaster = TiDEModel.from_dataset(
302181
training_dataset,
303-
log_interval=1,
304182
)
305183
from lightning.pytorch import Trainer
306184

0 commit comments

Comments
 (0)