Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion tests/scan/rewriting/test_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
import pytensor
import pytensor.tensor as pt
from pytensor import function, scan, shared
from pytensor.compile.mode import get_default_mode
from pytensor.compile.mode import Mode, get_default_mode
from pytensor.configdefaults import config
from pytensor.gradient import grad
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import dfs_rewriter
from pytensor.link.basic import JITLinker
from pytensor.scan.op import Scan
from pytensor.scan.rewriting.trace import scan_reduce_trace_prealloc
from pytensor.scan.utils import until
from pytensor.tensor.math import dot
from pytensor.tensor.shape import reshape
Expand Down Expand Up @@ -598,6 +601,31 @@ def test_symbolic_stop_not_dropped(self):
np.testing.assert_allclose(fn(0, 8), ws_ref[3:8])
np.testing.assert_allclose(fn(0, 10), ws_ref[3:10])

def test_idempotent(self):
"""Applying scan_reduce_trace twice must not corrupt negative-step slices."""
x0 = scalar("x0")
acc = scan(
fn=lambda prev: prev + 1,
outputs_info=[x0],
n_steps=10,
return_updates=False,
mode=Mode(optimizer=None),
)
out = acc[7:2:-1]

rewrite = dfs_rewriter(scan_reduce_trace_prealloc, ignore_newtrees=True)
fgraph = FunctionGraph([x0], [out])
rewrite.apply(fgraph)
rewrite.apply(fgraph)

f = function(
fgraph.inputs,
fgraph.outputs[0],
accept_inplace=True,
mode=Mode(linker="py", optimizer=None),
)
np.testing.assert_allclose(f(0.0), [8, 7, 6, 5, 4])


def test_scan_sit_sot_to_untraced():
"""Test sit_sot to untraced_sit_sot conversion.
Expand Down
Loading