diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index 027cca16db..30942dfb4a 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -178,11 +178,11 @@ def test_local_useless_inc_subtensor_increment_zeros(): s = pt.zeros((2, 2))[:, :] o_shape = inc_subtensor(s, specify_shape(y, s.shape)) - mode = get_default_mode().including("local_useless_inc_subtensor") - f_shape = function([y], o_shape, mode=mode) - - topo = f_shape.maker.fgraph.toposort() - assert not any(isinstance(n.op, IncSubtensor) for n in topo) + result = utt.rewrite_test([y], [o_shape]) + result.assert_equivalent_computations( + [specify_shape(y, (2, 2))], strict_dtype=False + ) + result.assert_numerical_close([np.ones((2, 2))]) def test_local_useless_inc_subtensor_no_opt(): @@ -190,41 +190,30 @@ def test_local_useless_inc_subtensor_no_opt(): x = matrix("x") y = matrix("y") + # Stepped slice — can't be removed. s = x[:, ::2] o_shape = set_subtensor(s, specify_shape(y, s.shape)) - - mode = get_default_mode().including("local_useless_inc_subtensor") - f_shape = function([x, y], o_shape, mode=mode) - - topo = f_shape.maker.fgraph.toposort() - assert any(isinstance(n.op, IncSubtensor) for n in topo) - - out = f_shape([[2, 3, 6, 7]], [[8, 9]]) - assert np.array_equal(out, np.asarray([[8, 3, 9, 7]])) + result = utt.rewrite_test([x, y], [o_shape]) + result.assert_equivalent_computations([o_shape], strict_dtype=False) + result.assert_numerical_close([[[2, 3, 6, 7]], [[8, 9]]]) # Increment with a non-constant target array, full slices collapse to x + y. s = x[:, :] o_shape = inc_subtensor(s, specify_shape(y, s.shape)) - - f_shape = function([x, y], o_shape, mode=mode) - - topo = f_shape.maker.fgraph.toposort() - assert not any(isinstance(n.op, IncSubtensor) for n in topo) - - out = f_shape([[1, 2], [3, 4]], [[10, 20], [30, 40]]) - assert np.array_equal(out, np.asarray([[11, 22], [33, 44]])) + result = utt.rewrite_test([x, y], [o_shape]) + result.assert_equivalent_computations( + [x + specify_shape(y, x.shape)], strict_dtype=False + ) + result.assert_numerical_close([[[1, 2], [3, 4]], [[10, 20], [30, 40]]]) # Increment with a non-zero constant target array, same collapse to x + y. s = pt.ones((2, 2))[:, :] o_shape = inc_subtensor(s, specify_shape(y, s.shape)) - - f_shape = function([y], o_shape, mode=mode) - - topo = f_shape.maker.fgraph.toposort() - assert not any(isinstance(n.op, IncSubtensor) for n in topo) - - out = f_shape([[10, 20], [30, 40]]) - assert np.array_equal(out, np.asarray([[11, 21], [31, 41]])) + result = utt.rewrite_test([y], [o_shape]) + result.assert_equivalent_computations( + [np.ones((1, 1)) + specify_shape(y, (2, 2))], strict_dtype=False + ) + result.assert_numerical_close([[[10, 20], [30, 40]]]) def test_local_add_of_sparse_write(): diff --git a/tests/unittest_tools.py b/tests/unittest_tools.py index c63da8eff3..ecdd275cee 100644 --- a/tests/unittest_tools.py +++ b/tests/unittest_tools.py @@ -1,20 +1,28 @@ import logging import sys +from collections.abc import Sequence from copy import copy, deepcopy from functools import wraps from numbers import Number +from typing import TYPE_CHECKING import numpy as np import pytest import pytensor from pytensor.compile.debug.debugmode import str_diagnostic +from pytensor.compile.mode import Mode from pytensor.gradient import verify_grad as orig_verify_grad -from pytensor.graph.basic import equal_computations +from pytensor.graph.basic import Variable, equal_computations +from pytensor.graph.fg import FunctionGraph from pytensor.tensor.math import _allclose from pytensor.tensor.math import add as pt_add +if TYPE_CHECKING: + from pytensor.graph.rewriting.basic import GraphRewriter + + _logger = logging.getLogger("tests.unittest_tools") @@ -402,3 +410,146 @@ def test_with_assert(*args, **kwargs): return test_with_assert else: return f + + +class RewriteTester: + NO_OPT = Mode(linker="py", optimizer=None) + + def __init__(self, orig_fg, rewr_fg, orig_inputs): + self.orig_fg = orig_fg + self.rewr_fg = rewr_fg + self._orig_inputs = orig_inputs + self._orig_fn = None + self._rewr_fn = None + + @property + def orig_fn(self): + if self._orig_fn is None: + self._orig_fn = pytensor.function( + self.orig_fg.inputs, + self.orig_fg.outputs, + mode=self.NO_OPT, + on_unused_input="ignore", + ) + return self._orig_fn + + @property + def rewr_fn(self): + if self._rewr_fn is None: + self._rewr_fn = pytensor.function( + self.rewr_fg.inputs, + self.rewr_fg.outputs, + mode=self.NO_OPT, + on_unused_input="ignore", + ) + return self._rewr_fn + + @staticmethod + def _match(predicate): + from pytensor.graph.rewriting.unify import OpPattern + + if isinstance(predicate, OpPattern): + pattern = predicate + return lambda n: pattern.match_op(n.op) + if isinstance(predicate, type): + op_type = predicate + return lambda n: isinstance(n.op, op_type) + if isinstance(predicate, tuple) and all(isinstance(t, type) for t in predicate): + op_types = predicate + return lambda n: isinstance(n.op, op_types) + return predicate + + @staticmethod + def _count(fgraph, predicate): + from pytensor.graph.op import HasInnerGraph + + count = 0 + for node in fgraph.toposort(): + if predicate(node): + count += 1 + if isinstance(node.op, HasInnerGraph): + count += RewriteTester._count(node.op.fgraph, predicate) + return count + + def count_nodes(self, predicate): + return self._count(self.rewr_fg, self._match(predicate)) + + def count_nodes_before(self, predicate): + return self._count(self.orig_fg, self._match(predicate)) + + def assert_node_count(self, predicate, count): + __tracebackhide__ = True + actual = self.count_nodes(predicate) + if actual != count: + raise AssertionError( + f"Expected {count} matching node(s) after rewrite, got {actual}\n" + f"Rewritten graph:\n" + f"{pytensor.dprint(self.rewr_fg, print_type=True, file='str')}" + ) + + def assert_node_count_before(self, predicate, count): + __tracebackhide__ = True + actual = self.count_nodes_before(predicate) + if actual != count: + raise AssertionError( + f"Expected {count} matching node(s) before rewrite, got {actual}\n" + f"Original graph:\n" + f"{pytensor.dprint(self.orig_fg, print_type=True, file='str')}" + ) + + def assert_numerical_close(self, test_values, rtol=None, atol=None): + __tracebackhide__ = True + orig_out = self.orig_fn(*test_values) + rewr_out = self.rewr_fn(*test_values) + if not isinstance(orig_out, list | tuple): + orig_out = [orig_out] + if not isinstance(rewr_out, list | tuple): + rewr_out = [rewr_out] + for i, (a, b) in enumerate(zip(orig_out, rewr_out, strict=True)): + np.testing.assert_allclose( + a, + b, + rtol=rtol or 1e-7, + atol=atol or 0, + err_msg=f"Output {i} mismatch between original and rewritten graph", + ) + + def assert_equivalent_computations(self, expected_outputs, **kwargs): + __tracebackhide__ = True + if not isinstance(expected_outputs, list): + expected_outputs = [expected_outputs] + assert_equal_computations( + self.rewr_fg.outputs, + expected_outputs, + in_xs=list(self.rewr_fg.inputs), + in_ys=self._orig_inputs, + original=self.orig_fg.outputs, + **kwargs, + ) + + +def rewrite_test( + inputs: Sequence[Variable], + outputs: Sequence[Variable], + *, + include: Sequence[str] = ("useless", "canonicalize", "stabilize", "specialize"), + custom_rewrite: "GraphRewriter | None" = None, + **kwargs, +) -> RewriteTester: + inputs = list(inputs) + outputs = list(outputs) + orig_fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=True) + rewr_fg = FunctionGraph(inputs=inputs, outputs=outputs, clone=True) + + if include: + from pytensor.compile import optdb + from pytensor.graph.rewriting.db import RewriteDatabaseQuery + + optdb.query(RewriteDatabaseQuery(include=list(include), **kwargs)).rewrite( + rewr_fg + ) + + if custom_rewrite is not None: + custom_rewrite.rewrite(rewr_fg) + + return RewriteTester(orig_fg, rewr_fg, inputs)