From 2ec0189dde4650344a556553d022fa73fc3d9786 Mon Sep 17 00:00:00 2001 From: Alessandro Gentili Date: Wed, 29 Apr 2026 23:15:57 +0200 Subject: [PATCH 1/2] rewrite log of det of inverse and log of reciprocal --- pytensor/tensor/rewriting/linalg/summary.py | 36 +++++++++- tests/tensor/rewriting/linalg/test_summary.py | 66 +++++++++++++++++++ 2 files changed, 100 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg/summary.py b/pytensor/tensor/rewriting/linalg/summary.py index ae3c6fb374..8cc969666d 100644 --- a/pytensor/tensor/rewriting/linalg/summary.py +++ b/pytensor/tensor/rewriting/linalg/summary.py @@ -5,7 +5,7 @@ copy_stack_trace, node_rewriter, ) -from pytensor.scalar.basic import Abs, Exp, Log, Sign, Sqr +from pytensor.scalar.basic import Abs, Exp, Log, Reciprocal, Sign, Sqr from pytensor.tensor.basic import AllocDiag, ones from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import Elemwise @@ -13,8 +13,9 @@ from pytensor.tensor.linalg.decomposition.lu import LU, LUFactor from pytensor.tensor.linalg.decomposition.qr import QR from pytensor.tensor.linalg.decomposition.svd import SVD +from pytensor.tensor.linalg.inverse import MatrixInverse from pytensor.tensor.linalg.summary import SLogDet, det -from pytensor.tensor.math import Prod, log, prod +from pytensor.tensor.math import Prod, log, prod, sign from pytensor.tensor.rewriting.basic import ( register_canonicalize, register_specialize, @@ -49,6 +50,24 @@ def local_log_prod_to_sum_log(fgraph, node): return [log(abs(x)).sum(axis=axis)] +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([Elemwise]) +def local_reciprocal_linalg_special_cases(fgraph, node): + """Special cases for log(reciprocal(x)) and sign(reciprocal(x)).""" + if len(node.inputs) != 1: + return None + + [p] = node.inputs + match p.owner_op_and_inputs: + case (Elemwise(Reciprocal()), x): + if isinstance(node.op.scalar_op, Log): + return [-log(x)] + if isinstance(node.op.scalar_op, Sign): + return [sign(x)] + + @register_stabilize("shape_unsafe") @register_specialize("shape_unsafe") @node_rewriter([det]) @@ -225,6 +244,19 @@ def det_of_diag(fgraph, node): return [det_val] +@register_canonicalize +@register_stabilize +@node_rewriter([det]) +def det_of_inv(fgraph, node): + """Replace det(matrix_inverse(X)) with reciprocal(det(X)). + + Since det(inv(X)) = 1/det(X), we avoid computing the inverse. + """ + match node.inputs[0].owner_op_and_inputs: + case (Blockwise(MatrixInverse()), X): + return [1 / det(X)] + + @register_specialize @node_rewriter([det]) def slogdet_specialization(fgraph, node): diff --git a/tests/tensor/rewriting/linalg/test_summary.py b/tests/tensor/rewriting/linalg/test_summary.py index 3fe35712e2..e97fe58d83 100644 --- a/tests/tensor/rewriting/linalg/test_summary.py +++ b/tests/tensor/rewriting/linalg/test_summary.py @@ -426,3 +426,69 @@ def test_det_of_factorized_matrix_special_cases(original_fn, expected_fn): expected = expected_fn(x) rewritten = rewrite_graph(out, include=["stabilize", "specialize"]) assert_equal_computations([rewritten], [expected]) + + +def test_det_of_inv(): + x = pt.tensor("x", shape=(3, 3)) + out = det(pt.linalg.inv(x)) + expected = pt.as_tensor(1.0, dtype="float64") / det(x) + rewritten = rewrite_graph(out, include=["canonicalize", "stabilize"]) + assert_equal_computations([rewritten], [expected]) + + +def test_log_reciprocal(): + x = pt.dscalar("x") + out = pt.log(pt.reciprocal(x)) + expected = -pt.log(x) + rewritten = rewrite_graph(out, include=["stabilize", "specialize"]) + assert_equal_computations([rewritten], [expected]) + + +@pytest.mark.parametrize( + "original_fn, expected_fn", + [ + pytest.param( + lambda x: pt.log(pt.reciprocal(pt.abs(x))), + lambda x: -pt.log(pt.abs(x)), + id="log_reciprocal_abs", + ), + pytest.param( + lambda x: pt.log(pt.reciprocal(pt.exp(x))), + lambda x: -x, + id="log_reciprocal_exp", + ), + ], +) +def test_log_reciprocal_composed(original_fn, expected_fn): + x = pt.dscalar("x") + out = original_fn(x) + expected = expected_fn(x) + rewritten = rewrite_graph(out, include=["stabilize", "specialize"]) + assert_equal_computations([rewritten], [expected]) + + +def test_slogdet_of_inv(): + x = pt.dmatrix("x") + # slogdet(inv(x)) -> (sign, logabsdet) + sign_inv, logabsdet_inv = pt.linalg.slogdet(pt.linalg.inv(x)) + + # expected: (sign(det(x)), -logabsdet(det(x))) + # det(inv(x)) = 1/det(x), so sign is same. + # logabsdet(inv(x)) = log(abs(1/det(x))) = -log(abs(det(x))) + sign_x, logabsdet_x = pt.linalg.slogdet(x) + expected_sign = sign_x + expected_logabsdet = -logabsdet_x + + # We need stabilize for det_of_inv and log_reciprocal + # and specialize for slogdet_specialization + rewritten_sign, rewritten_logabsdet = rewrite_graph( + [sign_inv, logabsdet_inv], include=["canonicalize", "stabilize", "specialize"] + ) + + expected_sign_opt, expected_logabsdet_opt = rewrite_graph( + [expected_sign, expected_logabsdet], + include=["canonicalize", "stabilize", "specialize"], + ) + + assert_equal_computations([rewritten_sign], [expected_sign_opt]) + assert_equal_computations([rewritten_logabsdet], [expected_logabsdet_opt]) From b9a8b91047e419f15b5784e753d5f7d5baee8ae3 Mon Sep 17 00:00:00 2001 From: Alessandro Gentili Date: Fri, 1 May 2026 15:12:41 +0200 Subject: [PATCH 2/2] refactor log of inverse rewrites and add new Op properties --- pytensor/scalar/basic.py | 21 +++++++ pytensor/scalar/math.py | 8 +++ pytensor/tensor/rewriting/linalg/summary.py | 22 +------- pytensor/tensor/rewriting/math.py | 56 +++++++++++++++++++ tests/tensor/rewriting/linalg/test_summary.py | 31 ---------- tests/tensor/rewriting/test_math.py | 53 ++++++++++++++++++ 6 files changed, 140 insertions(+), 51 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index d3a7f2f650..90c5b9477e 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -1351,6 +1351,8 @@ class UnaryScalarOp(ScalarOp): preserves_zero = False monotonic_increasing = False monotonic_decreasing = False + strictly_monotonic_increasing = False + strictly_monotonic_decreasing = False def c_code_contiguous(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -2470,6 +2472,7 @@ def pullback(self, inputs, outputs, gout): class Identity(UnaryScalarOp): preserves_zero = True monotonic_increasing = True + strictly_monotonic_increasing = True def impl(self, input): return input @@ -2900,6 +2903,7 @@ def c_code(self, node, name, inputs, outputs, sub): class Neg(UnaryScalarOp): preserves_zero = True monotonic_decreasing = True + strictly_monotonic_decreasing = True # We can use numpy.negative here, because even if it gives unexpected # results on Boolean arrays, it will be passed other dtypes as PyTensor # does not have a Boolean type for tensors. @@ -2976,6 +2980,7 @@ class Log(UnaryScalarOp): """ monotonic_increasing = True + strictly_monotonic_increasing = True nfunc_spec = ("log", 1, 1) amd_float32 = "amd_vrsa_logf" amd_float64 = "amd_vrda_log" @@ -3023,6 +3028,7 @@ class Log2(UnaryScalarOp): """ monotonic_increasing = True + strictly_monotonic_increasing = True nfunc_spec = ("log2", 1, 1) amd_float32 = "amd_vrsa_log2f" amd_float64 = "amd_vrda_log2" @@ -3067,6 +3073,7 @@ class Log10(UnaryScalarOp): """ monotonic_increasing = True + strictly_monotonic_increasing = True nfunc_spec = ("log10", 1, 1) amd_float32 = "amd_vrsa_log10f" amd_float64 = "amd_vrda_log10" @@ -3107,6 +3114,7 @@ def c_code(self, node, name, inputs, outputs, sub): class Log1p(UnaryScalarOp): preserves_zero = True monotonic_increasing = True + strictly_monotonic_increasing = True """ log(1+x). @@ -3149,6 +3157,7 @@ def c_code(self, node, name, inputs, outputs, sub): class Exp(UnaryScalarOp): monotonic_increasing = True + strictly_monotonic_increasing = True nfunc_spec = ("exp", 1, 1) amd_float32 = "amd_vrsa_expf" amd_float64 = "amd_vrda_exp" @@ -3188,6 +3197,7 @@ def c_code(self, node, name, inputs, outputs, sub): class Exp2(UnaryScalarOp): monotonic_increasing = True + strictly_monotonic_increasing = True nfunc_spec = ("exp2", 1, 1) def impl(self, x): @@ -3226,6 +3236,7 @@ def c_code(self, node, name, inputs, outputs, sub): class Expm1(UnaryScalarOp): preserves_zero = True monotonic_increasing = True + strictly_monotonic_increasing = True nfunc_spec = ("expm1", 1, 1) def impl(self, x): @@ -3296,6 +3307,7 @@ def c_code(self, node, name, inputs, outputs, sub): class Sqrt(UnaryScalarOp): preserves_zero = True monotonic_increasing = True + strictly_monotonic_increasing = True nfunc_spec = ("sqrt", 1, 1) @@ -3335,6 +3347,7 @@ def c_code(self, node, name, inputs, outputs, sub): class Deg2Rad(UnaryScalarOp): preserves_zero = True monotonic_increasing = True + strictly_monotonic_increasing = True nfunc_spec = ("deg2rad", 1, 1) @@ -3373,6 +3386,7 @@ def c_code(self, node, name, inputs, outputs, sub): class Rad2Deg(UnaryScalarOp): preserves_zero = True monotonic_increasing = True + strictly_monotonic_increasing = True nfunc_spec = ("rad2deg", 1, 1) @@ -3448,6 +3462,7 @@ def c_code(self, node, name, inputs, outputs, sub): class ArcCos(UnaryScalarOp): monotonic_decreasing = True + strictly_monotonic_decreasing = True nfunc_spec = ("arccos", 1, 1) def impl(self, x): @@ -3525,6 +3540,7 @@ def c_code(self, node, name, inputs, outputs, sub): class ArcSin(UnaryScalarOp): preserves_zero = True monotonic_increasing = True + strictly_monotonic_increasing = True nfunc_spec = ("arcsin", 1, 1) def impl(self, x): @@ -3600,6 +3616,7 @@ def c_code(self, node, name, inputs, outputs, sub): class ArcTan(UnaryScalarOp): preserves_zero = True monotonic_increasing = True + strictly_monotonic_increasing = True nfunc_spec = ("arctan", 1, 1) def impl(self, x): @@ -3762,6 +3779,7 @@ def c_code(self, node, name, inputs, outputs, sub): class Sinh(UnaryScalarOp): preserves_zero = True monotonic_increasing = True + strictly_monotonic_increasing = True """ sinh(x) = (exp(x) - exp(-x)) / 2. @@ -3805,6 +3823,7 @@ def c_code(self, node, name, inputs, outputs, sub): class ArcSinh(UnaryScalarOp): preserves_zero = True monotonic_increasing = True + strictly_monotonic_increasing = True nfunc_spec = ("arcsinh", 1, 1) def impl(self, x): @@ -3843,6 +3862,7 @@ def c_code(self, node, name, inputs, outputs, sub): class Tanh(UnaryScalarOp): preserves_zero = True monotonic_increasing = True + strictly_monotonic_increasing = True """ tanh(x) = sinh(x) / cosh(x) = (exp(2*x) - 1) / (exp(2*x) + 1). @@ -3887,6 +3907,7 @@ def c_code(self, node, name, inputs, outputs, sub): class ArcTanh(UnaryScalarOp): preserves_zero = True monotonic_increasing = True + strictly_monotonic_increasing = True nfunc_spec = ("arctanh", 1, 1) def impl(self, x): diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index bc5c221fe2..bfce2cdc15 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -51,6 +51,7 @@ class Erf(UnaryScalarOp): preserves_zero = True monotonic_increasing = True + strictly_monotonic_increasing = True nfunc_spec = ("scipy.special.erf", 1, 1) def impl(self, x): @@ -86,6 +87,7 @@ def c_code(self, node, name, inp, out, sub): class Erfc(UnaryScalarOp): monotonic_decreasing = True + strictly_monotonic_decreasing = True nfunc_spec = ("scipy.special.erfc", 1, 1) def impl(self, x): @@ -136,6 +138,7 @@ class Erfcx(UnaryScalarOp): """ monotonic_decreasing = True + strictly_monotonic_decreasing = True nfunc_spec = ("scipy.special.erfcx", 1, 1) def impl(self, x): @@ -183,6 +186,7 @@ def c_code(self, node, name, inp, out, sub): class Erfinv(UnaryScalarOp): preserves_zero = True monotonic_increasing = True + strictly_monotonic_increasing = True """ Implements the inverse error function. @@ -230,6 +234,7 @@ def c_code(self, node, name, inp, out, sub): class Erfcinv(UnaryScalarOp): monotonic_decreasing = True + strictly_monotonic_decreasing = True nfunc_spec = ("scipy.special.erfcinv", 1, 1) def impl(self, x): @@ -1188,6 +1193,7 @@ class Sigmoid(UnaryScalarOp): """ monotonic_increasing = True + strictly_monotonic_increasing = True nfunc_spec = ("scipy.special.expit", 1, 1) def impl(self, x): @@ -1243,6 +1249,7 @@ class Softplus(UnaryScalarOp): """ monotonic_increasing = True + strictly_monotonic_increasing = True def impl(self, x): # If x is an int8 or uint8, numpy.exp will compute the result in @@ -1326,6 +1333,7 @@ class Log1mexp(UnaryScalarOp): """ monotonic_decreasing = True + strictly_monotonic_decreasing = True def impl(self, x): if x < np.log(0.5): diff --git a/pytensor/tensor/rewriting/linalg/summary.py b/pytensor/tensor/rewriting/linalg/summary.py index 8cc969666d..bd5db3ee26 100644 --- a/pytensor/tensor/rewriting/linalg/summary.py +++ b/pytensor/tensor/rewriting/linalg/summary.py @@ -5,7 +5,7 @@ copy_stack_trace, node_rewriter, ) -from pytensor.scalar.basic import Abs, Exp, Log, Reciprocal, Sign, Sqr +from pytensor.scalar.basic import Abs, Exp, Log, Sign, Sqr from pytensor.tensor.basic import AllocDiag, ones from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import Elemwise @@ -15,7 +15,7 @@ from pytensor.tensor.linalg.decomposition.svd import SVD from pytensor.tensor.linalg.inverse import MatrixInverse from pytensor.tensor.linalg.summary import SLogDet, det -from pytensor.tensor.math import Prod, log, prod, sign +from pytensor.tensor.math import Prod, log, prod from pytensor.tensor.rewriting.basic import ( register_canonicalize, register_specialize, @@ -50,24 +50,6 @@ def local_log_prod_to_sum_log(fgraph, node): return [log(abs(x)).sum(axis=axis)] -@register_canonicalize -@register_stabilize -@register_specialize -@node_rewriter([Elemwise]) -def local_reciprocal_linalg_special_cases(fgraph, node): - """Special cases for log(reciprocal(x)) and sign(reciprocal(x)).""" - if len(node.inputs) != 1: - return None - - [p] = node.inputs - match p.owner_op_and_inputs: - case (Elemwise(Reciprocal()), x): - if isinstance(node.op.scalar_op, Log): - return [-log(x)] - if isinstance(node.op.scalar_op, Sign): - return [sign(x)] - - @register_stabilize("shape_unsafe") @register_specialize("shape_unsafe") @node_rewriter([det]) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 66cf21bdcc..fc190c1ab5 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -651,6 +651,62 @@ def local_exp_log_nan_switch(fgraph, node): return [new_out] +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([log]) +def local_log_reciprocal(fgraph, node): + """Rewrite log(reciprocal(x)) -> -log(x).""" + (inp,) = node.inputs + if ( + inp.owner + and isinstance(inp.owner.op, Elemwise) + and isinstance(inp.owner.op.scalar_op, ps.Reciprocal) + ): + return [neg(log(inp.owner.inputs[0]))] + + +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([sign]) +def local_sign_reciprocal(fgraph, node): + """Rewrite sign(reciprocal(x)) -> sign(x).""" + (inp,) = node.inputs + if ( + inp.owner + and isinstance(inp.owner.op, Elemwise) + and isinstance(inp.owner.op.scalar_op, ps.Reciprocal) + ): + return [sign(inp.owner.inputs[0])] + + +@register_canonicalize +@register_stabilize +@register_specialize +@node_rewriter([sign]) +def local_sign_of_monotonic(fgraph, node): + """Rewrite sign(f(x)) to sign(x) or -sign(x) based on monotonicity. + + If f is strictly monotonic increasing and preserves zero, then sign(f(x)) == sign(x). + If f is strictly monotonic decreasing and preserves zero, then sign(f(x)) == -sign(x). + """ + (inp,) = node.inputs + if not (inp.owner and isinstance(inp.owner.op, Elemwise)): + return + + scalar_op = inp.owner.op.scalar_op + + if not getattr(scalar_op, "preserves_zero", False): + return + + if getattr(scalar_op, "strictly_monotonic_increasing", False): + return [sign(inp.owner.inputs[0])] + + if getattr(scalar_op, "strictly_monotonic_decreasing", False): + return [neg(sign(inp.owner.inputs[0]))] + + @register_canonicalize @register_specialize @node_rewriter([Sum]) diff --git a/tests/tensor/rewriting/linalg/test_summary.py b/tests/tensor/rewriting/linalg/test_summary.py index e97fe58d83..1bd8e2d953 100644 --- a/tests/tensor/rewriting/linalg/test_summary.py +++ b/tests/tensor/rewriting/linalg/test_summary.py @@ -436,37 +436,6 @@ def test_det_of_inv(): assert_equal_computations([rewritten], [expected]) -def test_log_reciprocal(): - x = pt.dscalar("x") - out = pt.log(pt.reciprocal(x)) - expected = -pt.log(x) - rewritten = rewrite_graph(out, include=["stabilize", "specialize"]) - assert_equal_computations([rewritten], [expected]) - - -@pytest.mark.parametrize( - "original_fn, expected_fn", - [ - pytest.param( - lambda x: pt.log(pt.reciprocal(pt.abs(x))), - lambda x: -pt.log(pt.abs(x)), - id="log_reciprocal_abs", - ), - pytest.param( - lambda x: pt.log(pt.reciprocal(pt.exp(x))), - lambda x: -x, - id="log_reciprocal_exp", - ), - ], -) -def test_log_reciprocal_composed(original_fn, expected_fn): - x = pt.dscalar("x") - out = original_fn(x) - expected = expected_fn(x) - rewritten = rewrite_graph(out, include=["stabilize", "specialize"]) - assert_equal_computations([rewritten], [expected]) - - def test_slogdet_of_inv(): x = pt.dmatrix("x") # slogdet(inv(x)) -> (sign, logabsdet) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 6d0b0d978c..97434bc7df 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -5024,3 +5024,56 @@ def test_rewrite_does_not_apply(self): original, include=("canonicalize", "stabilize", "specialize") ) assert_equal_computations([rewritten], [original]) + + +def test_log_reciprocal(): + x = pt.dscalar("x") + out = pt.log(pt.reciprocal(x)) + expected = -pt.log(x) + rewritten = rewrite_graph(out, include=["stabilize", "specialize"]) + assert_equal_computations([rewritten], [expected]) + + +def test_sign_reciprocal(): + x = pt.dscalar("x") + out = pt.sign(pt.reciprocal(x)) + expected = pt.sign(x) + rewritten = rewrite_graph(out, include=["stabilize", "specialize"]) + assert_equal_computations([rewritten], [expected]) + + +@pytest.mark.parametrize( + "op, expected_fn", + [ + (pt.neg, lambda x: -pt.sign(x)), + (pt.tanh, lambda x: pt.sign(x)), + (pt.expm1, lambda x: pt.sign(x)), + (pt.log1p, lambda x: pt.sign(x)), + ], +) +def test_sign_of_monotonic(op, expected_fn): + x = pt.dscalar("x") + out = pt.sign(op(x)) + expected = expected_fn(x) + rewritten = rewrite_graph(out, include=["stabilize", "specialize"]) + assert_equal_computations([rewritten], [expected]) + + +def test_sign_of_non_strict_monotonic(): + # ceil is monotonic but not strictly. It should NOT be rewritten. + x = pt.dscalar("x") + out = pt.sign(pt.ceil(x)) + rewritten = rewrite_graph(out, include=["stabilize", "specialize"]) + # Should still have ceil + # Should still have ceil + from pytensor.graph.traversal import ancestors + + nodes = [v.owner for v in ancestors([rewritten]) if v.owner] + assert any( + isinstance( + getattr(node.op, "scalar_op", None), type(pt.ceil(x).owner.op.scalar_op) + ) + for node in nodes + ) + # Wait, simpler check: rewritten should be out (or equivalent) + assert_equal_computations([rewritten], [out])