@@ -500,12 +500,13 @@ def test_infer_shape(self):
500500
501501
502502class ApplyDefaultTestOp (Op ):
503- def __init__ (self , id ):
503+ def __init__ (self , id , n_outs = 1 ):
504504 self .default_output = id
505+ self .n_outs = n_outs
505506
506507 def make_node (self , x ):
507508 x = at .as_tensor_variable (x )
508- return Apply (self , [x ], [x .type ()])
509+ return Apply (self , [x ], [x .type () for _ in range ( self . n_outs ) ])
509510
510511 def perform (self , * args , ** kwargs ):
511512 raise NotImplementedError ()
@@ -556,16 +557,26 @@ def test_tensor_from_scalar(self):
556557 y = as_tensor_variable (aes .int8 ())
557558 assert isinstance (y .owner .op , TensorFromScalar )
558559
559- def test_multi_outputs (self ):
560- good_apply_var = ApplyDefaultTestOp (0 ).make_node (self .x )
561- as_tensor_variable (good_apply_var )
560+ def test_default_output (self ):
561+ good_apply_var = ApplyDefaultTestOp (0 , n_outs = 1 ).make_node (self .x )
562+ as_tensor_variable (good_apply_var ) is good_apply_var
562563
563- bad_apply_var = ApplyDefaultTestOp (- 1 ).make_node (self .x )
564- with pytest .raises (ValueError ):
564+ good_apply_var = ApplyDefaultTestOp (- 1 , n_outs = 1 ).make_node (self .x )
565+ as_tensor_variable (good_apply_var ) is good_apply_var
566+
567+ bad_apply_var = ApplyDefaultTestOp (1 , n_outs = 1 ).make_node (self .x )
568+ with pytest .raises (IndexError ):
565569 _ = as_tensor_variable (bad_apply_var )
566570
567- bad_apply_var = ApplyDefaultTestOp (2 ).make_node (self .x )
568- with pytest .raises (ValueError ):
571+ bad_apply_var = ApplyDefaultTestOp (2.0 , n_outs = 1 ).make_node (self .x )
572+ with pytest .raises (TypeError ):
573+ _ = as_tensor_variable (bad_apply_var )
574+
575+ good_apply_var = ApplyDefaultTestOp (1 , n_outs = 2 ).make_node (self .x )
576+ as_tensor_variable (good_apply_var ) is good_apply_var .outputs [1 ]
577+
578+ bad_apply_var = ApplyDefaultTestOp (None , n_outs = 2 ).make_node (self .x )
579+ with pytest .raises (TypeError , match = "Multi-output Op without default_output" ):
569580 _ = as_tensor_variable (bad_apply_var )
570581
571582 def test_list (self ):
@@ -578,7 +589,7 @@ def test_list(self):
578589 _ = as_tensor_variable (y )
579590
580591 bad_apply_var = ApplyDefaultTestOp ([0 , 1 ]).make_node (self .x )
581- with pytest .raises (ValueError ):
592+ with pytest .raises (TypeError ):
582593 as_tensor_variable (bad_apply_var )
583594
584595 def test_ndim_strip_leading_broadcastable (self ):
0 commit comments