diff --git a/tests/scan/rewriting/test_trace.py b/tests/scan/rewriting/test_trace.py index 3cec9df6ae..6b923afe2a 100644 --- a/tests/scan/rewriting/test_trace.py +++ b/tests/scan/rewriting/test_trace.py @@ -4,12 +4,15 @@ import pytensor import pytensor.tensor as pt from pytensor import function, scan, shared -from pytensor.compile.mode import get_default_mode +from pytensor.compile.mode import Mode, get_default_mode from pytensor.configdefaults import config from pytensor.gradient import grad from pytensor.graph.basic import Constant +from pytensor.graph.fg import FunctionGraph +from pytensor.graph.rewriting.basic import dfs_rewriter from pytensor.link.basic import JITLinker from pytensor.scan.op import Scan +from pytensor.scan.rewriting.trace import scan_reduce_trace_prealloc from pytensor.scan.utils import until from pytensor.tensor.math import dot from pytensor.tensor.shape import reshape @@ -598,6 +601,31 @@ def test_symbolic_stop_not_dropped(self): np.testing.assert_allclose(fn(0, 8), ws_ref[3:8]) np.testing.assert_allclose(fn(0, 10), ws_ref[3:10]) + def test_idempotent(self): + """Applying scan_reduce_trace twice must not corrupt negative-step slices.""" + x0 = scalar("x0") + acc = scan( + fn=lambda prev: prev + 1, + outputs_info=[x0], + n_steps=10, + return_updates=False, + mode=Mode(optimizer=None), + ) + out = acc[7:2:-1] + + rewrite = dfs_rewriter(scan_reduce_trace_prealloc, ignore_newtrees=True) + fgraph = FunctionGraph([x0], [out]) + rewrite.apply(fgraph) + rewrite.apply(fgraph) + + f = function( + fgraph.inputs, + fgraph.outputs[0], + accept_inplace=True, + mode=Mode(linker="py", optimizer=None), + ) + np.testing.assert_allclose(f(0.0), [8, 7, 6, 5, 4]) + def test_scan_sit_sot_to_untraced(): """Test sit_sot to untraced_sit_sot conversion.