1515from pytensor .scan .op import Scan
1616from pytensor .scan .rewriting import ScanInplaceOptimizer , ScanMerge
1717from pytensor .scan .utils import until
18+ from pytensor .tensor import stack
1819from pytensor .tensor .blas import Dot22
1920from pytensor .tensor .elemwise import Elemwise
2021from pytensor .tensor .math import Dot , dot , sigmoid
@@ -796,7 +797,13 @@ def inner_fct(seq1, seq2, seq3, previous_output):
796797
797798
798799class TestScanMerge :
799- mode = get_default_mode ().including ("scan" )
800+ mode = get_default_mode ().including ("scan" ).excluding ("scan_pushout_seqs_ops" )
801+
802+ @staticmethod
803+ def count_scans (fn ):
804+ nodes = fn .maker .fgraph .apply_nodes
805+ scans = [node for node in nodes if isinstance (node .op , Scan )]
806+ return len (scans )
800807
801808 def test_basic (self ):
802809 x = vector ()
@@ -808,56 +815,38 @@ def sum(s):
808815 sx , upx = scan (sum , sequences = [x ])
809816 sy , upy = scan (sum , sequences = [y ])
810817
811- f = function (
812- [x , y ], [sx , sy ], mode = self .mode .excluding ("scan_pushout_seqs_ops" )
813- )
814- topo = f .maker .fgraph .toposort ()
815- scans = [n for n in topo if isinstance (n .op , Scan )]
816- assert len (scans ) == 2
818+ f = function ([x , y ], [sx , sy ], mode = self .mode )
819+ assert self .count_scans (f ) == 2
817820
818821 sx , upx = scan (sum , sequences = [x ], n_steps = 2 )
819822 sy , upy = scan (sum , sequences = [y ], n_steps = 3 )
820823
821- f = function (
822- [x , y ], [sx , sy ], mode = self .mode .excluding ("scan_pushout_seqs_ops" )
823- )
824- topo = f .maker .fgraph .toposort ()
825- scans = [n for n in topo if isinstance (n .op , Scan )]
826- assert len (scans ) == 2
824+ f = function ([x , y ], [sx , sy ], mode = self .mode )
825+ assert self .count_scans (f ) == 2
827826
828827 sx , upx = scan (sum , sequences = [x ], n_steps = 4 )
829828 sy , upy = scan (sum , sequences = [y ], n_steps = 4 )
830829
831- f = function (
832- [x , y ], [sx , sy ], mode = self .mode .excluding ("scan_pushout_seqs_ops" )
833- )
834- topo = f .maker .fgraph .toposort ()
835- scans = [n for n in topo if isinstance (n .op , Scan )]
836- assert len (scans ) == 1
830+ f = function ([x , y ], [sx , sy ], mode = self .mode )
831+ assert self .count_scans (f ) == 1
837832
838833 sx , upx = scan (sum , sequences = [x ])
839834 sy , upy = scan (sum , sequences = [x ])
840835
841- f = function ([x ], [sx , sy ], mode = self .mode .excluding ("scan_pushout_seqs_ops" ))
842- topo = f .maker .fgraph .toposort ()
843- scans = [n for n in topo if isinstance (n .op , Scan )]
844- assert len (scans ) == 1
836+ f = function ([x ], [sx , sy ], mode = self .mode )
837+ assert self .count_scans (f ) == 1
845838
846839 sx , upx = scan (sum , sequences = [x ])
847840 sy , upy = scan (sum , sequences = [x ], mode = "FAST_COMPILE" )
848841
849- f = function ([x ], [sx , sy ], mode = self .mode .excluding ("scan_pushout_seqs_ops" ))
850- topo = f .maker .fgraph .toposort ()
851- scans = [n for n in topo if isinstance (n .op , Scan )]
852- assert len (scans ) == 1
842+ f = function ([x ], [sx , sy ], mode = self .mode )
843+ assert self .count_scans (f ) == 1
853844
854845 sx , upx = scan (sum , sequences = [x ])
855846 sy , upy = scan (sum , sequences = [x ], truncate_gradient = 1 )
856847
857- f = function ([x ], [sx , sy ], mode = self .mode .excluding ("scan_pushout_seqs_ops" ))
858- topo = f .maker .fgraph .toposort ()
859- scans = [n for n in topo if isinstance (n .op , Scan )]
860- assert len (scans ) == 2
848+ f = function ([x ], [sx , sy ], mode = self .mode )
849+ assert self .count_scans (f ) == 2
861850
862851 def test_three_scans (self ):
863852 r"""
@@ -877,12 +866,8 @@ def sum(s):
877866 sy , upy = scan (sum , sequences = [2 * y + 2 ], n_steps = 4 , name = "Y" )
878867 sz , upz = scan (sum , sequences = [sx ], n_steps = 4 , name = "Z" )
879868
880- f = function (
881- [x , y ], [sy , sz ], mode = self .mode .excluding ("scan_pushout_seqs_ops" )
882- )
883- topo = f .maker .fgraph .toposort ()
884- scans = [n for n in topo if isinstance (n .op , Scan )]
885- assert len (scans ) == 2
869+ f = function ([x , y ], [sy , sz ], mode = self .mode )
870+ assert self .count_scans (f ) == 2
886871
887872 rng = np .random .default_rng (utt .fetch_seed ())
888873 x_val = rng .uniform (size = (4 ,)).astype (config .floatX )
@@ -913,6 +898,112 @@ def test_belongs_to_set(self):
913898 assert not opt_obj .belongs_to_set (scan_node1 , [scan_node2 ])
914899 assert not opt_obj .belongs_to_set (scan_node2 , [scan_node1 ])
915900
901+ @config .change_flags (cxx = "" ) # Just for faster compilation
902+ def test_while_scan (self ):
903+ x = vector ("x" )
904+ y = vector ("y" )
905+
906+ def add (s ):
907+ return s + 1 , until (s > 5 )
908+
909+ def sub (s ):
910+ return s - 1 , until (s > 5 )
911+
912+ def sub_alt (s ):
913+ return s - 1 , until (s > 4 )
914+
915+ sx , upx = scan (add , sequences = [x ])
916+ sy , upy = scan (sub , sequences = [y ])
917+
918+ f = function ([x , y ], [sx , sy ], mode = self .mode )
919+ assert self .count_scans (f ) == 2
920+
921+ sx , upx = scan (add , sequences = [x ])
922+ sy , upy = scan (sub , sequences = [x ])
923+
924+ f = function ([x ], [sx , sy ], mode = self .mode )
925+ assert self .count_scans (f ) == 1
926+
927+ sx , upx = scan (add , sequences = [x ])
928+ sy , upy = scan (sub_alt , sequences = [x ])
929+
930+ f = function ([x ], [sx , sy ], mode = self .mode )
931+ assert self .count_scans (f ) == 2
932+
933+ @config .change_flags (cxx = "" ) # Just for faster compilation
934+ def test_while_scan_nominal_dependency (self ):
935+ """Test case where condition depends on nominal variables.
936+
937+ This is a regression test for #509
938+ """
939+ c1 = scalar ("c1" )
940+ c2 = scalar ("c2" )
941+ x = vector ("x" , shape = (5 ,))
942+ y = vector ("y" , shape = (5 ,))
943+ z = vector ("z" , shape = (5 ,))
944+
945+ def add (s1 , s2 , const ):
946+ return s1 + 1 , until (s2 > const )
947+
948+ def sub (s1 , s2 , const ):
949+ return s1 - 1 , until (s2 > const )
950+
951+ sx , _ = scan (add , sequences = [x , z ], non_sequences = [c1 ])
952+ sy , _ = scan (sub , sequences = [y , - z ], non_sequences = [c1 ])
953+
954+ f = pytensor .function (inputs = [x , y , z , c1 ], outputs = [sx , sy ], mode = self .mode )
955+ assert self .count_scans (f ) == 2
956+ res_sx , res_sy = f (
957+ x = [0 , 0 , 0 , 0 , 0 ],
958+ y = [0 , 0 , 0 , 0 , 0 ],
959+ z = [0 , 1 , 2 , 3 , 4 ],
960+ c1 = 0 ,
961+ )
962+ np .testing .assert_array_equal (res_sx , [1 , 1 ])
963+ np .testing .assert_array_equal (res_sy , [- 1 , - 1 , - 1 , - 1 , - 1 ])
964+
965+ sx , _ = scan (add , sequences = [x , z ], non_sequences = [c1 ])
966+ sy , _ = scan (sub , sequences = [y , z ], non_sequences = [c2 ])
967+
968+ f = pytensor .function (
969+ inputs = [x , y , z , c1 , c2 ], outputs = [sx , sy ], mode = self .mode
970+ )
971+ assert self .count_scans (f ) == 2
972+ res_sx , res_sy = f (
973+ x = [0 , 0 , 0 , 0 , 0 ],
974+ y = [0 , 0 , 0 , 0 , 0 ],
975+ z = [0 , 1 , 2 , 3 , 4 ],
976+ c1 = 3 ,
977+ c2 = 1 ,
978+ )
979+ np .testing .assert_array_equal (res_sx , [1 , 1 , 1 , 1 , 1 ])
980+ np .testing .assert_array_equal (res_sy , [- 1 , - 1 , - 1 ])
981+
982+ sx , _ = scan (add , sequences = [x , z ], non_sequences = [c1 ])
983+ sy , _ = scan (sub , sequences = [y , z ], non_sequences = [c1 ])
984+
985+ f = pytensor .function (inputs = [x , y , z , c1 ], outputs = [sx , sy ], mode = self .mode )
986+ assert self .count_scans (f ) == 1
987+
988+ def nested_scan (c , x , z ):
989+ sx , _ = scan (add , sequences = [x , z ], non_sequences = [c ])
990+ sy , _ = scan (sub , sequences = [x , z ], non_sequences = [c ])
991+ return sx .sum () + sy .sum ()
992+
993+ sz , _ = scan (
994+ nested_scan ,
995+ sequences = [stack ([c1 , c2 ])],
996+ non_sequences = [x , z ],
997+ mode = self .mode ,
998+ )
999+
1000+ f = pytensor .function (inputs = [x , z , c1 , c2 ], outputs = sz , mode = mode )
1001+ [scan_node ] = [
1002+ node for node in f .maker .fgraph .apply_nodes if isinstance (node .op , Scan )
1003+ ]
1004+ inner_f = scan_node .op .fn
1005+ assert self .count_scans (inner_f ) == 1
1006+
9161007
9171008class TestScanInplaceOptimizer :
9181009 mode = get_default_mode ().including ("scan_make_inplace" , "inplace" )
0 commit comments