Skip to content

Commit a9f1188

Browse files
Modifying the tests
1 parent 9124c5d commit a9f1188

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

tests/test_models/test_temporal_fusion_transformer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,15 @@ def _integration(dataloader, tmp_path, loss=None, trainer_kwargs=None, **kwargs)
203203
return_index=True,
204204
return_x=True,
205205
return_y=True,
206-
fast_dev_run=True,
206+
fast_dev_run=2,
207207
trainer_kwargs=trainer_kwargs,
208208
)
209209
pred_len = len(predictions.index)
210-
210+
if isinstance(predictions.output, torch.Tensor):
211+
assert predictions.output.shape == predictions.y[0].shape, "shape of predictions should match shape of targets"
212+
else:
213+
for i in range(len(predictions.output)):
214+
assert predictions.output[i].shape == predictions.y[0][i].shape, "shape of predictions should match shape of targets"
211215
# check that output is of correct shape
212216
def check(x):
213217
if isinstance(x, (tuple, list)):

0 commit comments

Comments
 (0)