Skip to content

Commit 9124c5d

Browse files
Changing back to old function with modifications
1 parent 8f310b9 commit 9124c5d

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

pytorch_forecasting/models/base/_base_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def on_predict_epoch_end(
387387
if self.return_decoder_lengths:
388388
output["decoder_lengths"] = torch.cat(self._decode_lengths, dim=0)
389389
if self.return_y:
390-
y = _torch_cat_na([yi[0] for yi in self._y])
390+
y = concat_sequences([yi[0] for yi in self._y])
391391
if self._y[-1][1] is None:
392392
weight = None
393393
else:

pytorch_forecasting/utils/_utils.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,33 @@ def concat_sequences(
272272
if isinstance(sequences[0], rnn.PackedSequence):
273273
return rnn.pack_sequence(sequences, enforce_sorted=False)
274274
elif isinstance(sequences[0], torch.Tensor):
275-
return torch.cat(sequences, dim=1)
275+
if sequences[0].ndim > 1:
276+
first_lens = [xi.shape[1] for xi in sequences]
277+
max_first_len = max(first_lens)
278+
if max_first_len > min(first_lens):
279+
sequences = [
280+
(
281+
xi
282+
if xi.shape[1] == max_first_len
283+
else torch.cat(
284+
[
285+
xi,
286+
torch.full(
287+
(
288+
xi.shape[0],
289+
max_first_len - xi.shape[1],
290+
*xi.shape[2:],
291+
),
292+
float("nan"),
293+
device=xi.device,
294+
),
295+
],
296+
dim=1,
297+
)
298+
)
299+
for xi in sequences
300+
]
301+
return torch.cat(sequences, dim=0)
276302
elif isinstance(sequences[0], (tuple, list)):
277303
return tuple(
278304
concat_sequences([sequences[ii][i] for ii in range(len(sequences))])

0 commit comments

Comments
 (0)