11import pickle
22import shutil
3- import sys
43
54import lightning .pytorch as pl
65from lightning .pytorch .callbacks import EarlyStopping
76from lightning .pytorch .loggers import TensorBoardLogger
87import numpy as np
98import pandas as pd
109import pytest
11- from test_models .conftest import make_dataloaders
12- import torch
1310
14- from pytorch_forecasting import TimeSeriesDataSet
15- from pytorch_forecasting .data .encoders import (
16- GroupNormalizer ,
17- MultiNormalizer ,
18- NaNLabelEncoder ,
19- )
20- from pytorch_forecasting .metrics import (
21- MAE ,
22- MAPE ,
23- SMAPE ,
24- CrossEntropy ,
25- MultiLoss ,
26- PoissonLoss ,
27- QuantileLoss ,
28- )
29- from pytorch_forecasting .metrics .distributions import NegativeBinomialDistributionLoss
11+ from pytorch_forecasting .data .timeseries import TimeSeriesDataSet
12+ from pytorch_forecasting .metrics import MAE , SMAPE , QuantileLoss
3013from pytorch_forecasting .models import TiDEModel
3114from pytorch_forecasting .utils ._dependencies import _get_installed_packages
3215
3316
34- def _integration (dataloader , tmp_path , loss = None , trainer_kwargs = None , ** kwargs ):
35- "Integration test for TiDEModel functionality."
36-
17+ def _integration (dataloader , tmp_path , trainer_kwargs = None , ** kwargs ):
3718 train_dataloader = dataloader ["train" ]
3819 val_dataloader = dataloader ["val" ]
3920 test_dataloader = dataloader ["test" ]
4021
41- early_stop = EarlyStopping (
42- monitor = "val_loss" ,
43- patience = 1 ,
44- verbose = False ,
45- mode = "min" ,
22+ early_stop_callback = EarlyStopping (
23+ monitor = "val_loss" , min_delta = 1e-4 , patience = 1 , verbose = False , mode = "min"
4624 )
4725
4826 logger = TensorBoardLogger (tmp_path )
49-
5027 if trainer_kwargs is None :
5128 trainer_kwargs = {}
52-
5329 trainer = pl .Trainer (
5430 max_epochs = 2 ,
5531 gradient_clip_val = 0.1 ,
56- callbacks = [early_stop ],
32+ callbacks = [early_stop_callback ],
5733 enable_checkpointing = True ,
5834 default_root_dir = tmp_path ,
5935 limit_train_batches = 2 ,
@@ -63,213 +39,118 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs)
6339 ** trainer_kwargs ,
6440 )
6541
66- kwargs .setdefault ("learning_rate" , 0.15 )
67-
68- if loss is not None :
69- pass
70- elif isinstance (train_dataloader .dataset .target_normalizer , NaNLabelEncoder ):
71- loss = CrossEntropy ()
72- elif isinstance (train_dataloader .dataset .target_normalizer , MultiNormalizer ):
73- loss = MultiLoss (
74- [
75- (
76- (
77- CrossEntropy ()
78- if isinstance (normalizer , NaNLabelEncoder )
79- else QuantileLoss ()
80- ),
81- )
82- for normalizer in train_dataloader .dataset .target_normalizer .normalizers
83- ]
84- )
85- else :
86- loss = QuantileLoss ()
42+ kwargs .setdefault ("hidden_size" , 16 )
43+ kwargs .setdefault ("temporal_decoder_hidden" , 8 )
44+ kwargs .setdefault ("temporal_width_future" , 4 )
45+ kwargs .setdefault ("dropout" , 0.1 )
46+ kwargs .setdefault ("learning_rate" , 0.01 )
8747
8848 net = TiDEModel .from_dataset (
8949 train_dataloader .dataset ,
90- hidden_size = 4 ,
91- decoder_output_dim = 4 ,
92- num_encoder_layers = 2 ,
93- num_decoder_layers = 2 ,
94- dropout = 0.2 ,
95- loss = loss ,
96- add_relative_time_idx = False ,
97- temporal_decoder_hidden = 4 ,
98- temporal_width_future = 2 ,
99- temporal_hidden_size_future = 4 ,
100- log_interval = 5 ,
101- log_val_interval = 1 ,
10250 ** kwargs ,
10351 )
104-
10552 net .size ()
106-
10753 try :
10854 trainer .fit (
10955 net ,
11056 train_dataloaders = train_dataloader ,
11157 val_dataloaders = val_dataloader ,
11258 )
113-
114- test_outputs = trainer .test (
115- net ,
116- test_dataloaders = test_dataloader ,
117- )
59+ test_outputs = trainer .test (net , dataloaders = test_dataloader )
11860 assert len (test_outputs ) > 0
119-
61+ # check loading
12062 net = TiDEModel .load_from_checkpoint (
12163 trainer .checkpoint_callback .best_model_path
12264 )
12365
124- predictions = net .predict (
66+ # check prediction
67+ net .predict (
12568 val_dataloader ,
126- return_index = True ,
127- return_x = True ,
128- return_y = True ,
12969 fast_dev_run = True ,
70+ return_index = True ,
71+ return_decoder_lengths = True ,
13072 trainer_kwargs = trainer_kwargs ,
13173 )
74+ finally :
75+ shutil .rmtree (tmp_path , ignore_errors = True )
13276
133- pred_len = len (predictions .index )
134-
135- def check (x ):
136- if isinstance (x , (tuple , list )):
137- for xi in x :
138- check (xi )
139- elif isinstance (x , dict ):
140- for xi in x .values ():
141- check (xi )
142- else :
143- assert (
144- pred_len == x .shape [0 ]
145- ), "first dimension should be prediction length"
146-
147- check (predictions .output )
148- if isinstance (predictions .output , torch .Tensor ):
149- assert (
150- predictions .output .ndim == 2
151- ), "shape of predictions should be batch_size x timesteps"
152- else :
153- assert all (
154- p .ndim == 2 for p in predictions .output
155- ), "shape of predictions should be batch_size x timesteps"
77+ predictions = net .predict (
78+ val_dataloader ,
79+ fast_dev_run = True ,
80+ return_index = True ,
81+ return_decoder_lengths = True ,
82+ )
83+ return predictions
15684
157- check (predictions .output )
15885
159- if isinstance (predictions .output , torch .Tensor ):
160- assert (
161- predictions .output .ndim == 2
162- ), "shape of predictions should be batch_size x timesteps"
163- else :
164- assert all (
165- p .ndim == 2 for p in predictions .output
166- ), "shape of predictions should be batch_size x timesteps"
167- check (predictions .x )
168- check (predictions .index )
169- finally :
170- shutil .rmtree (tmp_path , ignore_errors = True )
86+ @pytest .mark .parametrize (
87+ "kwargs" ,
88+ [
89+ {},
90+ {"loss" : SMAPE ()},
91+ {"hidden_size" : 32 , "temporal_decoder_hidden" : 16 },
92+ {"dropout" : 0.2 , "use_layer_norm" : True },
93+ ],
94+ )
95+ def test_integration (dataloaders_with_covariates , tmp_path , kwargs ):
96+ _integration (dataloaders_with_covariates , tmp_path , ** kwargs )
17197
17298
173- def test_integration (multiple_dataloaders_with_covariates , tmp_path ):
174- """Test basic integration of model with covariates."""
175- _integration (
176- multiple_dataloaders_with_covariates ,
177- tmp_path ,
178- trainer_kwargs = dict (accelerator = "cpu" ),
179- )
99+ @pytest .mark .parametrize (
100+ "kwargs" ,
101+ [
102+ {}, # Default settings for multi-target
103+ ],
104+ )
105+ def test_multi_target_integration (dataloaders_multi_target , tmp_path , kwargs ):
106+ _integration (dataloaders_multi_target , tmp_path , ** kwargs )
180107
181108
182109@pytest .fixture
183110def model (dataloaders_with_covariates ):
184- """Create a model for testing."""
185-
186111 dataset = dataloaders_with_covariates ["train" ].dataset
187-
188112 net = TiDEModel .from_dataset (
189- dataset = dataset ,
190- learning_rate = 0.15 ,
191- hidden_size = 4 ,
192- num_encoder_layers = 2 ,
193- num_decoder_layers = 2 ,
194- decoder_output_dim = 4 ,
195- dropout = 0.2 ,
196- temporal_decoder_hidden = 4 ,
197- temporal_width_future = 2 ,
198- temporal_hidden_size_future = 4 ,
199- loss = PoissonLoss (),
200- output_size = 1 ,
201- log_interval = 5 ,
202- log_val_interval = 1 ,
113+ dataset ,
114+ hidden_size = 16 ,
115+ dropout = 0.1 ,
116+ temporal_width_future = 4 ,
203117 )
204118 return net
205119
206120
207- def test_tensorboard_graph_log (dataloaders_with_covariates , model , tmp_path ):
208- """Test if tensorboard graph can be logged."""
209- d = next (iter (dataloaders_with_covariates ["train" ]))
210- logger = TensorBoardLogger ("test" , str (tmp_path ), log_graph = True )
211- logger .log_graph (model , d [0 ])
212-
213-
214121def test_pickle (model ):
215- """Test that model can be pickled and unpickled."""
216122 pkl = pickle .dumps (model )
217123 pickle .loads (pkl ) # noqa: S301
218124
219125
220- @pytest .mark .parametrize (
221- "kwargs" , [dict (mode = "dataframe" ), dict (mode = "series" ), dict (mode = "raw" )]
126+ @pytest .mark .skipif (
127+ "matplotlib" not in _get_installed_packages (),
128+ reason = "skip test if required package matplotlib not installed" ,
222129)
223- def test_predict_dependency (
224- model , dataloaders_with_covariates , data_with_covariates , kwargs
225- ):
226- """Test if predict_dependency works correctly."""
227- train_dataset = dataloaders_with_covariates ["train" ].dataset
228- data_with_covariates = data_with_covariates .copy ()
229- dataset = TimeSeriesDataSet .from_dataset (
230- train_dataset ,
231- data_with_covariates [lambda x : x .agency == data_with_covariates .agency .iloc [0 ]],
232- predict = True ,
233- )
234- model .predict_dependency (dataset , variable = "discount" , values = [0.1 , 0.0 ], ** kwargs )
235- model .predict_dependency (
236- dataset ,
237- variable = "agency" ,
238- values = data_with_covariates .agency .unique ()[:2 ],
239- ** kwargs ,
130+ def test_prediction_visualization (model , dataloaders_with_covariates ):
131+ raw_predictions = model .predict (
132+ dataloaders_with_covariates ["val" ],
133+ mode = "raw" ,
134+ return_x = True ,
135+ fast_dev_run = True ,
240136 )
137+ model .plot_prediction (raw_predictions .x , raw_predictions .output , idx = 0 )
241138
242139
243- @pytest .mark .parametrize (
244- "kwargs" ,
245- [
246- dict (mode = "raw" ),
247- dict (mode = "quantiles" ),
248- dict (return_index = True ),
249- dict (return_decoder_lengths = True ),
250- dict (return_x = True ),
251- dict (return_y = True ),
252- ],
253- )
254- def test_prediction_with_dataloader (model , dataloaders_with_covariates , kwargs ):
255- """Test prediction with dataloader."""
256- val_dataloader = dataloaders_with_covariates ["val" ]
257- model .predict (val_dataloader , fast_dev_run = True , ** kwargs )
258-
259-
260- def test_prediction_with_dataset (model , dataloaders_with_covariates ):
261- """Test prediction with dataset."""
262- val_dataloader = dataloaders_with_covariates ["val" ]
263- model .predict (val_dataloader .dataset , fast_dev_run = True )
264-
265-
266- def test_prediction_with_dataframe (model , data_with_covariates ):
267- """Test the prediction with dataframe."""
268- model .predict (data_with_covariates , fast_dev_run = True )
140+ def test_prediction_with_kwargs (model , dataloaders_with_covariates ):
141+ # Tests prediction works with different keyword arguments
142+ model .predict (
143+ dataloaders_with_covariates ["val" ], return_index = True , fast_dev_run = True
144+ )
145+ model .predict (
146+ dataloaders_with_covariates ["val" ],
147+ return_x = True ,
148+ return_y = True ,
149+ fast_dev_run = True ,
150+ )
269151
270152
271153def test_no_exogenous_variable ():
272- """Test whether model works without exogenous variables."""
273154 data = pd .DataFrame (
274155 {
275156 "target" : np .ones (1600 ),
@@ -284,8 +165,6 @@ def test_no_exogenous_variable():
284165 group_ids = ["group_id" ],
285166 max_encoder_length = 10 ,
286167 max_prediction_length = 5 ,
287- min_encoder_length = 10 ,
288- min_prediction_length = 5 ,
289168 time_varying_unknown_reals = ["target" ],
290169 time_varying_known_reals = [],
291170 )
@@ -300,7 +179,6 @@ def test_no_exogenous_variable():
300179 )
301180 forecaster = TiDEModel .from_dataset (
302181 training_dataset ,
303- log_interval = 1 ,
304182 )
305183 from lightning .pytorch import Trainer
306184
0 commit comments