Skip to content

Commit f606d7a

Browse files
Minor changes
1 parent a9f1188 commit f606d7a

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tests/test_models/test_temporal_fusion_transformer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,15 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs)
208208
)
209209
pred_len = len(predictions.index)
210210
if isinstance(predictions.output, torch.Tensor):
211-
assert predictions.output.shape == predictions.y[0].shape, "shape of predictions should match shape of targets"
211+
assert (
212+
predictions.output.shape == predictions.y[0].shape
213+
), "shape of predictions should match shape of targets"
212214
else:
213215
for i in range(len(predictions.output)):
214-
assert predictions.output[i].shape == predictions.y[0][i].shape, "shape of predictions should match shape of targets"
216+
assert (
217+
predictions.output[i].shape == predictions.y[0][i].shape
218+
), "shape of predictions should match shape of targets"
219+
215220
# check that output is of correct shape
216221
def check(x):
217222
if isinstance(x, (tuple, list)):

0 commit comments

Comments
 (0)