Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 19 additions & 30 deletions tests/tensor/rewriting/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,53 +178,42 @@ def test_local_useless_inc_subtensor_increment_zeros():
s = pt.zeros((2, 2))[:, :]
o_shape = inc_subtensor(s, specify_shape(y, s.shape))

mode = get_default_mode().including("local_useless_inc_subtensor")
f_shape = function([y], o_shape, mode=mode)

topo = f_shape.maker.fgraph.toposort()
assert not any(isinstance(n.op, IncSubtensor) for n in topo)
result = utt.rewrite_test([y], [o_shape])
result.assert_equivalent_computations(
[specify_shape(y, (2, 2))], strict_dtype=False
)
result.assert_numerical_close([np.ones((2, 2))])


def test_local_useless_inc_subtensor_no_opt():
r"""Make sure we don't remove `IncSubtensor`\s that involve slices with steps that skip elements and non-zero increments."""
x = matrix("x")
y = matrix("y")

# Stepped slice — can't be removed.
s = x[:, ::2]
o_shape = set_subtensor(s, specify_shape(y, s.shape))

mode = get_default_mode().including("local_useless_inc_subtensor")
f_shape = function([x, y], o_shape, mode=mode)

topo = f_shape.maker.fgraph.toposort()
assert any(isinstance(n.op, IncSubtensor) for n in topo)

out = f_shape([[2, 3, 6, 7]], [[8, 9]])
assert np.array_equal(out, np.asarray([[8, 3, 9, 7]]))
result = utt.rewrite_test([x, y], [o_shape])
result.assert_equivalent_computations([o_shape], strict_dtype=False)
result.assert_numerical_close([[[2, 3, 6, 7]], [[8, 9]]])

# Increment with a non-constant target array, full slices collapse to x + y.
s = x[:, :]
o_shape = inc_subtensor(s, specify_shape(y, s.shape))

f_shape = function([x, y], o_shape, mode=mode)

topo = f_shape.maker.fgraph.toposort()
assert not any(isinstance(n.op, IncSubtensor) for n in topo)

out = f_shape([[1, 2], [3, 4]], [[10, 20], [30, 40]])
assert np.array_equal(out, np.asarray([[11, 22], [33, 44]]))
result = utt.rewrite_test([x, y], [o_shape])
result.assert_equivalent_computations(
[x + specify_shape(y, x.shape)], strict_dtype=False
)
result.assert_numerical_close([[[1, 2], [3, 4]], [[10, 20], [30, 40]]])

# Increment with a non-zero constant target array, same collapse to x + y.
s = pt.ones((2, 2))[:, :]
o_shape = inc_subtensor(s, specify_shape(y, s.shape))

f_shape = function([y], o_shape, mode=mode)

topo = f_shape.maker.fgraph.toposort()
assert not any(isinstance(n.op, IncSubtensor) for n in topo)

out = f_shape([[10, 20], [30, 40]])
assert np.array_equal(out, np.asarray([[11, 21], [31, 41]]))
result = utt.rewrite_test([y], [o_shape])
result.assert_equivalent_computations(
[np.ones((1, 1)) + specify_shape(y, (2, 2))], strict_dtype=False
)
result.assert_numerical_close([[[10, 20], [30, 40]]])


def test_local_add_of_sparse_write():
Expand Down
153 changes: 152 additions & 1 deletion tests/unittest_tools.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
import logging
import sys
from collections.abc import Sequence
from copy import copy, deepcopy
from functools import wraps
from numbers import Number
from typing import TYPE_CHECKING

import numpy as np
import pytest

import pytensor
from pytensor.compile.debug.debugmode import str_diagnostic
from pytensor.compile.mode import Mode
from pytensor.gradient import verify_grad as orig_verify_grad
from pytensor.graph.basic import equal_computations
from pytensor.graph.basic import Variable, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.math import _allclose
from pytensor.tensor.math import add as pt_add


if TYPE_CHECKING:
from pytensor.graph.rewriting.basic import GraphRewriter


_logger = logging.getLogger("tests.unittest_tools")


Expand Down Expand Up @@ -402,3 +410,146 @@ def test_with_assert(*args, **kwargs):
return test_with_assert
else:
return f


class RewriteTester:
NO_OPT = Mode(linker="py", optimizer=None)

def __init__(self, orig_fg, rewr_fg, orig_inputs):
self.orig_fg = orig_fg
self.rewr_fg = rewr_fg
self._orig_inputs = orig_inputs
self._orig_fn = None
self._rewr_fn = None

@property
def orig_fn(self):
if self._orig_fn is None:
self._orig_fn = pytensor.function(
self.orig_fg.inputs,
self.orig_fg.outputs,
mode=self.NO_OPT,
on_unused_input="ignore",
)
return self._orig_fn

@property
def rewr_fn(self):
if self._rewr_fn is None:
self._rewr_fn = pytensor.function(
self.rewr_fg.inputs,
self.rewr_fg.outputs,
mode=self.NO_OPT,
on_unused_input="ignore",
)
return self._rewr_fn

@staticmethod
def _match(predicate):
from pytensor.graph.rewriting.unify import OpPattern

if isinstance(predicate, OpPattern):
pattern = predicate
return lambda n: pattern.match_op(n.op)
if isinstance(predicate, type):
op_type = predicate
return lambda n: isinstance(n.op, op_type)
if isinstance(predicate, tuple) and all(isinstance(t, type) for t in predicate):
op_types = predicate
return lambda n: isinstance(n.op, op_types)
return predicate

@staticmethod
def _count(fgraph, predicate):
from pytensor.graph.op import HasInnerGraph

count = 0
for node in fgraph.toposort():
if predicate(node):
count += 1
if isinstance(node.op, HasInnerGraph):
count += RewriteTester._count(node.op.fgraph, predicate)
return count

def count_nodes(self, predicate):
return self._count(self.rewr_fg, self._match(predicate))

def count_nodes_before(self, predicate):
return self._count(self.orig_fg, self._match(predicate))

def assert_node_count(self, predicate, count):
__tracebackhide__ = True
actual = self.count_nodes(predicate)
if actual != count:
raise AssertionError(
f"Expected {count} matching node(s) after rewrite, got {actual}\n"
f"Rewritten graph:\n"
f"{pytensor.dprint(self.rewr_fg, print_type=True, file='str')}"
)

def assert_node_count_before(self, predicate, count):
__tracebackhide__ = True
actual = self.count_nodes_before(predicate)
if actual != count:
raise AssertionError(
f"Expected {count} matching node(s) before rewrite, got {actual}\n"
f"Original graph:\n"
f"{pytensor.dprint(self.orig_fg, print_type=True, file='str')}"
)

def assert_numerical_close(self, test_values, rtol=None, atol=None):
__tracebackhide__ = True
orig_out = self.orig_fn(*test_values)
rewr_out = self.rewr_fn(*test_values)
if not isinstance(orig_out, list | tuple):
orig_out = [orig_out]
if not isinstance(rewr_out, list | tuple):
rewr_out = [rewr_out]
for i, (a, b) in enumerate(zip(orig_out, rewr_out, strict=True)):
np.testing.assert_allclose(
a,
b,
rtol=rtol or 1e-7,
atol=atol or 0,
err_msg=f"Output {i} mismatch between original and rewritten graph",
)

def assert_equivalent_computations(self, expected_outputs, **kwargs):
__tracebackhide__ = True
if not isinstance(expected_outputs, list):
expected_outputs = [expected_outputs]
assert_equal_computations(
self.rewr_fg.outputs,
expected_outputs,
in_xs=list(self.rewr_fg.inputs),
in_ys=self._orig_inputs,
original=self.orig_fg.outputs,
**kwargs,
)


def rewrite_test(
inputs: Sequence[Variable],
outputs: Sequence[Variable],
*,
include: Sequence[str] = ("useless", "canonicalize", "stabilize", "specialize"),
custom_rewrite: "GraphRewriter | None" = None,
**kwargs,
) -> RewriteTester:
inputs = list(inputs)
outputs = list(outputs)
orig_fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=True)
rewr_fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=True)

if include:
from pytensor.compile import optdb
from pytensor.graph.rewriting.db import RewriteDatabaseQuery

optdb.query(RewriteDatabaseQuery(include=list(include), **kwargs)).rewrite(
rewr_fg
)

if custom_rewrite is not None:
custom_rewrite.rewrite(rewr_fg)

return RewriteTester(orig_fg, rewr_fg, inputs)
Loading