File tree Expand file tree Collapse file tree 2 files changed +28
-2
lines changed Expand file tree Collapse file tree 2 files changed +28
-2
lines changed Original file line number Diff line number Diff line change @@ -387,7 +387,7 @@ def on_predict_epoch_end(
387387 if self .return_decoder_lengths :
388388 output ["decoder_lengths" ] = torch .cat (self ._decode_lengths , dim = 0 )
389389 if self .return_y :
390- y = _torch_cat_na ([yi [0 ] for yi in self ._y ])
390+ y = concat_sequences ([yi [0 ] for yi in self ._y ])
391391 if self ._y [- 1 ][1 ] is None :
392392 weight = None
393393 else :
Original file line number Diff line number Diff line change @@ -272,7 +272,33 @@ def concat_sequences(
272272 if isinstance (sequences [0 ], rnn .PackedSequence ):
273273 return rnn .pack_sequence (sequences , enforce_sorted = False )
274274 elif isinstance (sequences [0 ], torch .Tensor ):
275- return torch .cat (sequences , dim = 1 )
275+ if sequences [0 ].ndim > 1 :
276+ first_lens = [xi .shape [1 ] for xi in sequences ]
277+ max_first_len = max (first_lens )
278+ if max_first_len > min (first_lens ):
279+ sequences = [
280+ (
281+ xi
282+ if xi .shape [1 ] == max_first_len
283+ else torch .cat (
284+ [
285+ xi ,
286+ torch .full (
287+ (
288+ xi .shape [0 ],
289+ max_first_len - xi .shape [1 ],
290+ * xi .shape [2 :],
291+ ),
292+ float ("nan" ),
293+ device = xi .device ,
294+ ),
295+ ],
296+ dim = 1 ,
297+ )
298+ )
299+ for xi in sequences
300+ ]
301+ return torch .cat (sequences , dim = 0 )
276302 elif isinstance (sequences [0 ], (tuple , list )):
277303 return tuple (
278304 concat_sequences ([sequences [ii ][i ] for ii in range (len (sequences ))])
You can’t perform that action at this time.
0 commit comments