@@ -3552,16 +3552,10 @@ class TestDiag:
35523552 """
35533553 Test that linalg.diag has the same behavior as numpy.diag.
35543554 numpy.diag has two behaviors:
3555- (1) when given a vector, it returns a matrix with that vector as the
3556- diagonal.
3557- (2) when given a matrix, returns a vector which is the diagonal of the
3558- matrix.
3555+ (1) when given a vector, it returns a matrix with that vector as the diagonal.
3556+ (2) when given a matrix, returns a vector which is the diagonal of the matrix.
35593557
3560- (1) and (2) are tested by test_alloc_diag and test_extract_diag
3561- respectively.
3562-
3563- test_diag test makes sure that linalg.diag instantiates
3564- the right op based on the dimension of the input.
3558+ (1) and (2) are further tested by TestAllocDiag and TestExtractDiag, respectively.
35653559 """
35663560
35673561 def setup_method (self ):
@@ -3571,6 +3565,7 @@ def setup_method(self):
35713565 self .type = TensorType
35723566
35733567 def test_diag (self ):
3568+ """Makes sure that diag instantiates the right op based on the dimension of the input."""
35743569 rng = np .random .default_rng (utt .fetch_seed ())
35753570
35763571 # test vector input
@@ -3609,38 +3604,55 @@ def test_diag(self):
36093604 f = function ([], g )
36103605 assert np .array_equal (f (), np .diag (xx ))
36113606
3612- def test_infer_shape (self ):
3607+
3608+ class TestExtractDiag :
3609+ @pytest .mark .parametrize ("axis1, axis2" , [(0 , 1 ), (1 , 0 )])
3610+ @pytest .mark .parametrize ("offset" , (- 1 , 0 , 2 ))
3611+ def test_infer_shape (self , offset , axis1 , axis2 ):
36133612 rng = np .random .default_rng (utt .fetch_seed ())
36143613
3615- x = vector ()
3616- g = diag (x )
3617- f = pytensor .function ([x ], g .shape )
3618- topo = f .maker .fgraph .toposort ()
3619- if config .mode != "FAST_COMPILE" :
3620- assert sum (isinstance (node .op , AllocDiag ) for node in topo ) == 0
3621- for shp in [5 , 0 , 1 ]:
3622- m = rng .random (shp ).astype (self .floatX )
3623- assert (f (m ) == np .diag (m ).shape ).all ()
3624-
3625- x = matrix ()
3626- g = diag (x )
3614+ x = matrix ("x" )
3615+ g = ExtractDiag (offset = offset , axis1 = axis1 , axis2 = axis2 )(x )
36273616 f = pytensor .function ([x ], g .shape )
36283617 topo = f .maker .fgraph .toposort ()
36293618 if config .mode != "FAST_COMPILE" :
36303619 assert sum (isinstance (node .op , ExtractDiag ) for node in topo ) == 0
36313620 for shp in [(5 , 3 ), (3 , 5 ), (5 , 1 ), (1 , 5 ), (5 , 0 ), (0 , 5 ), (1 , 0 ), (0 , 1 )]:
3632- m = rng .random (shp ).astype (self .floatX )
3633- assert (f (m ) == np .diag (m ).shape ).all ()
3621+ m = rng .random (shp ).astype (config .floatX )
3622+ assert (
3623+ f (m ) == np .diagonal (m , offset = offset , axis1 = axis1 , axis2 = axis2 ).shape
3624+ ).all ()
36343625
3635- def test_diag_grad (self ):
3626+ @pytest .mark .parametrize ("axis1, axis2" , [(0 , 1 ), (1 , 0 )])
3627+ @pytest .mark .parametrize ("offset" , (0 , 1 , - 1 ))
3628+ def test_grad_2d (self , offset , axis1 , axis2 ):
3629+ diag_fn = ExtractDiag (offset = offset , axis1 = axis1 , axis2 = axis2 )
36363630 rng = np .random .default_rng (utt .fetch_seed ())
3637- x = rng .random (5 )
3638- utt .verify_grad (diag , [x ], rng = rng )
36393631 x = rng .random ((5 , 3 ))
3640- utt .verify_grad (diag , [x ], rng = rng )
3632+ utt .verify_grad (diag_fn , [x ], rng = rng )
3633+
3634+ @pytest .mark .parametrize (
3635+ "axis1, axis2" ,
3636+ [
3637+ (0 , 1 ),
3638+ (1 , 0 ),
3639+ (1 , 2 ),
3640+ (2 , 1 ),
3641+ (0 , 2 ),
3642+ (2 , 0 ),
3643+ ],
3644+ )
3645+ @pytest .mark .parametrize ("offset" , (0 , 1 , - 1 ))
3646+ def test_grad_3d (self , offset , axis1 , axis2 ):
3647+ diag_fn = ExtractDiag (offset = offset , axis1 = axis1 , axis2 = axis2 )
3648+ rng = np .random .default_rng (utt .fetch_seed ())
3649+ x = rng .random ((5 , 4 , 3 ))
3650+ utt .verify_grad (diag_fn , [x ], rng = rng )
36413651
36423652
36433653class TestAllocDiag :
3654+ # TODO: Separate perform, grad and infer_shape tests
3655+
36443656 def setup_method (self ):
36453657 self .alloc_diag = AllocDiag
36463658 self .mode = pytensor .compile .mode .get_default_mode ()
@@ -3674,7 +3686,7 @@ def test_alloc_diag_values(self):
36743686 (- 2 , 0 , 1 ),
36753687 (- 1 , 1 , 2 ),
36763688 ]:
3677- # Test AllocDiag values
3689+ # Test perform
36783690 if np .maximum (axis1 , axis2 ) > len (test_val .shape ):
36793691 continue
36803692 adiag_op = self .alloc_diag (offset = offset , axis1 = axis1 , axis2 = axis2 )
@@ -3688,7 +3700,6 @@ def test_alloc_diag_values(self):
36883700 # Test infer_shape
36893701 f_shape = pytensor .function ([x ], adiag_op (x ).shape , mode = "FAST_RUN" )
36903702
3691- # pytensor.printing.debugprint(f_shape.maker.fgraph.outputs[0])
36923703 output_shape = f_shape (test_val )
36933704 assert not any (
36943705 isinstance (node .op , self .alloc_diag )
@@ -3699,6 +3710,7 @@ def test_alloc_diag_values(self):
36993710 ).shape
37003711 assert np .all (rediag_shape == test_val .shape )
37013712
3713+ # Test grad
37023714 diag_x = adiag_op (x )
37033715 sum_diag_x = at_sum (diag_x )
37043716 grad_x = pytensor .grad (sum_diag_x , x )
@@ -3710,7 +3722,6 @@ def test_alloc_diag_values(self):
37103722 true_grad_input = np .diagonal (
37113723 grad_diag_input , offset = offset , axis1 = axis1 , axis2 = axis2
37123724 )
3713-
37143725 assert np .all (true_grad_input == grad_input )
37153726
37163727
0 commit comments