Skip to content

Commit c04ebf3

Browse files
cngmidPranavBhatP
authored andcommitted
[BUG] fix incorrect concatenation dimension in concat_sequences (sktime#1827)
### Description This PR fixes [1823](sktime#1823)
1 parent 15ea3c3 commit c04ebf3

File tree

4 files changed

+52
-6
lines changed

4 files changed

+52
-6
lines changed

pytorch_forecasting/data/timeseries/_timeseries.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2348,7 +2348,7 @@ def __getitem__(self, idx: int) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
23482348

23492349
@staticmethod
23502350
def _collate_fn(
2351-
batches: list[tuple[dict[str, torch.Tensor], torch.Tensor]]
2351+
batches: list[tuple[dict[str, torch.Tensor], torch.Tensor]],
23522352
) -> tuple[dict[str, torch.Tensor], torch.Tensor]:
23532353
"""
23542354
Collate function to combine items into mini-batch for dataloader.

pytorch_forecasting/models/base/_base_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def _concatenate_output(
133133
str,
134134
List[Union[List[torch.Tensor], torch.Tensor, bool, int, str, np.ndarray]],
135135
]
136-
]
136+
],
137137
) -> Dict[
138138
str, Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, int, bool, str]]]
139139
]:

pytorch_forecasting/utils/_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def autocorrelation(input, dim=0):
233233

234234

235235
def unpack_sequence(
236-
sequence: Union[torch.Tensor, rnn.PackedSequence]
236+
sequence: Union[torch.Tensor, rnn.PackedSequence],
237237
) -> Tuple[torch.Tensor, torch.Tensor]:
238238
"""
239239
Unpack RNN sequence.
@@ -257,7 +257,7 @@ def unpack_sequence(
257257

258258

259259
def concat_sequences(
260-
sequences: Union[List[torch.Tensor], List[rnn.PackedSequence]]
260+
sequences: Union[List[torch.Tensor], List[rnn.PackedSequence]],
261261
) -> Union[torch.Tensor, rnn.PackedSequence]:
262262
"""
263263
Concatenate RNN sequences.
@@ -272,7 +272,7 @@ def concat_sequences(
272272
if isinstance(sequences[0], rnn.PackedSequence):
273273
return rnn.pack_sequence(sequences, enforce_sorted=False)
274274
elif isinstance(sequences[0], torch.Tensor):
275-
return torch.cat(sequences, dim=1)
275+
return torch.cat(sequences, dim=0)
276276
elif isinstance(sequences[0], (tuple, list)):
277277
return tuple(
278278
concat_sequences([sequences[ii][i] for ii in range(len(sequences))])

tests/test_models/test_temporal_fusion_transformer.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
import pytest
1111
import torch
1212

13-
from pytorch_forecasting import TimeSeriesDataSet
13+
from pytorch_forecasting import Baseline, TimeSeriesDataSet
1414
from pytorch_forecasting.data import NaNLabelEncoder
1515
from pytorch_forecasting.data.encoders import GroupNormalizer, MultiNormalizer
16+
from pytorch_forecasting.data.examples import generate_ar_data
1617
from pytorch_forecasting.metrics import (
1718
CrossEntropy,
1819
MQF2DistributionLoss,
@@ -521,3 +522,48 @@ def test_no_exogenous_variable():
521522
return_y=True,
522523
return_index=True,
523524
)
525+
526+
527+
def test_correct_prediction_concatenation():
528+
data = generate_ar_data(seasonality=10.0, timesteps=100, n_series=2, seed=42)
529+
data["static"] = 2
530+
data["date"] = pd.Timestamp("2020-01-01") + pd.to_timedelta(data.time_idx, "D")
531+
data.head()
532+
533+
# create dataset and dataloaders
534+
max_encoder_length = 20
535+
max_prediction_length = 5
536+
537+
training_cutoff = data["time_idx"].max() - max_prediction_length
538+
539+
context_length = max_encoder_length
540+
prediction_length = max_prediction_length
541+
542+
training = TimeSeriesDataSet(
543+
data[lambda x: x.time_idx <= training_cutoff],
544+
time_idx="time_idx",
545+
target="value",
546+
categorical_encoders={"series": NaNLabelEncoder().fit(data.series)},
547+
group_ids=["series"],
548+
# only unknown variable is "value"
549+
# and N-Beats can also not take any additional variables
550+
time_varying_unknown_reals=["value"],
551+
max_encoder_length=context_length,
552+
max_prediction_length=prediction_length,
553+
)
554+
555+
batch_size = 71
556+
train_dataloader = training.to_dataloader(
557+
train=True, batch_size=batch_size, num_workers=0
558+
)
559+
560+
baseline_model = Baseline()
561+
predictions = baseline_model.predict(
562+
train_dataloader,
563+
return_x=True,
564+
return_y=True,
565+
trainer_kwargs=dict(logger=None, accelerator="cpu"),
566+
)
567+
568+
# The predicted output and the target should have the same size.
569+
assert predictions.output.size() == predictions.y[0].size()

0 commit comments

Comments
 (0)