Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,6 +1351,8 @@ class UnaryScalarOp(ScalarOp):
preserves_zero = False
monotonic_increasing = False
monotonic_decreasing = False
strictly_monotonic_increasing = False
strictly_monotonic_decreasing = False

def c_code_contiguous(self, node, name, inputs, outputs, sub):
(x,) = inputs
Expand Down Expand Up @@ -2470,6 +2472,7 @@ def pullback(self, inputs, outputs, gout):
class Identity(UnaryScalarOp):
preserves_zero = True
monotonic_increasing = True
strictly_monotonic_increasing = True

def impl(self, input):
return input
Expand Down Expand Up @@ -2900,6 +2903,7 @@ def c_code(self, node, name, inputs, outputs, sub):
class Neg(UnaryScalarOp):
preserves_zero = True
monotonic_decreasing = True
strictly_monotonic_decreasing = True
# We can use numpy.negative here, because even if it gives unexpected
# results on Boolean arrays, it will be passed other dtypes as PyTensor
# does not have a Boolean type for tensors.
Expand Down Expand Up @@ -2976,6 +2980,7 @@ class Log(UnaryScalarOp):
"""

monotonic_increasing = True
strictly_monotonic_increasing = True
nfunc_spec = ("log", 1, 1)
amd_float32 = "amd_vrsa_logf"
amd_float64 = "amd_vrda_log"
Expand Down Expand Up @@ -3023,6 +3028,7 @@ class Log2(UnaryScalarOp):
"""

monotonic_increasing = True
strictly_monotonic_increasing = True
nfunc_spec = ("log2", 1, 1)
amd_float32 = "amd_vrsa_log2f"
amd_float64 = "amd_vrda_log2"
Expand Down Expand Up @@ -3067,6 +3073,7 @@ class Log10(UnaryScalarOp):
"""

monotonic_increasing = True
strictly_monotonic_increasing = True
nfunc_spec = ("log10", 1, 1)
amd_float32 = "amd_vrsa_log10f"
amd_float64 = "amd_vrda_log10"
Expand Down Expand Up @@ -3107,6 +3114,7 @@ def c_code(self, node, name, inputs, outputs, sub):
class Log1p(UnaryScalarOp):
preserves_zero = True
monotonic_increasing = True
strictly_monotonic_increasing = True
"""
log(1+x).

Expand Down Expand Up @@ -3149,6 +3157,7 @@ def c_code(self, node, name, inputs, outputs, sub):

class Exp(UnaryScalarOp):
monotonic_increasing = True
strictly_monotonic_increasing = True
nfunc_spec = ("exp", 1, 1)
amd_float32 = "amd_vrsa_expf"
amd_float64 = "amd_vrda_exp"
Expand Down Expand Up @@ -3188,6 +3197,7 @@ def c_code(self, node, name, inputs, outputs, sub):

class Exp2(UnaryScalarOp):
monotonic_increasing = True
strictly_monotonic_increasing = True
nfunc_spec = ("exp2", 1, 1)

def impl(self, x):
Expand Down Expand Up @@ -3226,6 +3236,7 @@ def c_code(self, node, name, inputs, outputs, sub):
class Expm1(UnaryScalarOp):
preserves_zero = True
monotonic_increasing = True
strictly_monotonic_increasing = True
nfunc_spec = ("expm1", 1, 1)

def impl(self, x):
Expand Down Expand Up @@ -3296,6 +3307,7 @@ def c_code(self, node, name, inputs, outputs, sub):
class Sqrt(UnaryScalarOp):
preserves_zero = True
monotonic_increasing = True
strictly_monotonic_increasing = True

nfunc_spec = ("sqrt", 1, 1)

Expand Down Expand Up @@ -3335,6 +3347,7 @@ def c_code(self, node, name, inputs, outputs, sub):
class Deg2Rad(UnaryScalarOp):
preserves_zero = True
monotonic_increasing = True
strictly_monotonic_increasing = True

nfunc_spec = ("deg2rad", 1, 1)

Expand Down Expand Up @@ -3373,6 +3386,7 @@ def c_code(self, node, name, inputs, outputs, sub):
class Rad2Deg(UnaryScalarOp):
preserves_zero = True
monotonic_increasing = True
strictly_monotonic_increasing = True

nfunc_spec = ("rad2deg", 1, 1)

Expand Down Expand Up @@ -3448,6 +3462,7 @@ def c_code(self, node, name, inputs, outputs, sub):

class ArcCos(UnaryScalarOp):
monotonic_decreasing = True
strictly_monotonic_decreasing = True
nfunc_spec = ("arccos", 1, 1)

def impl(self, x):
Expand Down Expand Up @@ -3525,6 +3540,7 @@ def c_code(self, node, name, inputs, outputs, sub):
class ArcSin(UnaryScalarOp):
preserves_zero = True
monotonic_increasing = True
strictly_monotonic_increasing = True
nfunc_spec = ("arcsin", 1, 1)

def impl(self, x):
Expand Down Expand Up @@ -3600,6 +3616,7 @@ def c_code(self, node, name, inputs, outputs, sub):
class ArcTan(UnaryScalarOp):
preserves_zero = True
monotonic_increasing = True
strictly_monotonic_increasing = True
nfunc_spec = ("arctan", 1, 1)

def impl(self, x):
Expand Down Expand Up @@ -3762,6 +3779,7 @@ def c_code(self, node, name, inputs, outputs, sub):
class Sinh(UnaryScalarOp):
preserves_zero = True
monotonic_increasing = True
strictly_monotonic_increasing = True
"""
sinh(x) = (exp(x) - exp(-x)) / 2.

Expand Down Expand Up @@ -3805,6 +3823,7 @@ def c_code(self, node, name, inputs, outputs, sub):
class ArcSinh(UnaryScalarOp):
preserves_zero = True
monotonic_increasing = True
strictly_monotonic_increasing = True
nfunc_spec = ("arcsinh", 1, 1)

def impl(self, x):
Expand Down Expand Up @@ -3843,6 +3862,7 @@ def c_code(self, node, name, inputs, outputs, sub):
class Tanh(UnaryScalarOp):
preserves_zero = True
monotonic_increasing = True
strictly_monotonic_increasing = True
"""
tanh(x) = sinh(x) / cosh(x)
= (exp(2*x) - 1) / (exp(2*x) + 1).
Expand Down Expand Up @@ -3887,6 +3907,7 @@ def c_code(self, node, name, inputs, outputs, sub):
class ArcTanh(UnaryScalarOp):
preserves_zero = True
monotonic_increasing = True
strictly_monotonic_increasing = True
nfunc_spec = ("arctanh", 1, 1)

def impl(self, x):
Expand Down
8 changes: 8 additions & 0 deletions pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
class Erf(UnaryScalarOp):
preserves_zero = True
monotonic_increasing = True
strictly_monotonic_increasing = True
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't love this, what cases disagree right now between the two?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ceil and floor, for example

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't mark those instead?

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 May 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I don't get it, any discrete input version to these ops is also not strictly monotonic.

nvm

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't mark those instead?

I'm not against that, but then we should use the strictly_ language everywhere (drop the shorter one) to be clear what is going on

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 May 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Circling back to the sign thing, if that's the motivation, I don't you can apply it based on monotonicity, strict or not. sign(exp(x)) is obviously not sign(x).

We are adding these properties for specific uses, not for mathematical idealism, so they need not be verbose nor geberalized besides the problems we want to solve with them. sctrict vs non strict is more a question of invertible 1-1 map not the direction. wouldn't a combination ot those 2 poperties + zero preserving be a better way to achieve the goal?

Copy link
Copy Markdown
Member

@jessegrabowski jessegrabowski May 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exp isn't zero preserving, so the rule doesn't apply. Both things are important. It has to be strictly monotonic increasing and zero preserving. I think taking strict monotonicity as our canonical form is nice, because who cares about ceil/floor anyway. But I also think it's important to be clear in language, otherwise someone can come along in a few years and say "Well technicaly BitwiseInverse is monotonic_increasing, why isn't it marked" and the answer is "because we define monotonicity as strict monotonicity but it isn't written anywhere". We lose nothing by just writing it.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I know sign thing doesn't apply to zero, did I say so?

Otherwise okay, we can go with verbose, don't love it but it's strictly more precise.

Can we stop there and not at strictly_monotonic_increasing_over_defined_domain?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was reacting to this: sign(exp(x)) is obviously not sign(x).

Maybe i misunderstood your point.

I'm not being dogmatic about the name change, if we want to just define monotonic to mean "strictly monotonic" and put it in the docs somewhere, I have no objection. I just want it to be written down, and I like self-documenting code. Agreed that there is a limit.

nfunc_spec = ("scipy.special.erf", 1, 1)

def impl(self, x):
Expand Down Expand Up @@ -86,6 +87,7 @@ def c_code(self, node, name, inp, out, sub):

class Erfc(UnaryScalarOp):
monotonic_decreasing = True
strictly_monotonic_decreasing = True
nfunc_spec = ("scipy.special.erfc", 1, 1)

def impl(self, x):
Expand Down Expand Up @@ -136,6 +138,7 @@ class Erfcx(UnaryScalarOp):
"""

monotonic_decreasing = True
strictly_monotonic_decreasing = True
nfunc_spec = ("scipy.special.erfcx", 1, 1)

def impl(self, x):
Expand Down Expand Up @@ -183,6 +186,7 @@ def c_code(self, node, name, inp, out, sub):
class Erfinv(UnaryScalarOp):
preserves_zero = True
monotonic_increasing = True
strictly_monotonic_increasing = True
"""
Implements the inverse error function.

Expand Down Expand Up @@ -230,6 +234,7 @@ def c_code(self, node, name, inp, out, sub):

class Erfcinv(UnaryScalarOp):
monotonic_decreasing = True
strictly_monotonic_decreasing = True
nfunc_spec = ("scipy.special.erfcinv", 1, 1)

def impl(self, x):
Expand Down Expand Up @@ -1188,6 +1193,7 @@ class Sigmoid(UnaryScalarOp):
"""

monotonic_increasing = True
strictly_monotonic_increasing = True
nfunc_spec = ("scipy.special.expit", 1, 1)

def impl(self, x):
Expand Down Expand Up @@ -1243,6 +1249,7 @@ class Softplus(UnaryScalarOp):
"""

monotonic_increasing = True
strictly_monotonic_increasing = True

def impl(self, x):
# If x is an int8 or uint8, numpy.exp will compute the result in
Expand Down Expand Up @@ -1326,6 +1333,7 @@ class Log1mexp(UnaryScalarOp):
"""

monotonic_decreasing = True
strictly_monotonic_decreasing = True

def impl(self, x):
if x < np.log(0.5):
Expand Down
14 changes: 14 additions & 0 deletions pytensor/tensor/rewriting/linalg/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pytensor.tensor.linalg.decomposition.lu import LU, LUFactor
from pytensor.tensor.linalg.decomposition.qr import QR
from pytensor.tensor.linalg.decomposition.svd import SVD
from pytensor.tensor.linalg.inverse import MatrixInverse
from pytensor.tensor.linalg.summary import SLogDet, det
from pytensor.tensor.math import Prod, log, prod
from pytensor.tensor.rewriting.basic import (
Expand Down Expand Up @@ -225,6 +226,19 @@ def det_of_diag(fgraph, node):
return [det_val]


@register_canonicalize
@register_stabilize
@node_rewriter([det])
def det_of_inv(fgraph, node):
"""Replace det(matrix_inverse(X)) with reciprocal(det(X)).

Since det(inv(X)) = 1/det(X), we avoid computing the inverse.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Since det(inv(X)) = 1/det(X), we avoid computing the inverse.

"""
match node.inputs[0].owner_op_and_inputs:
case (Blockwise(MatrixInverse()), X):
return [1 / det(X)]


@register_specialize
@node_rewriter([det])
def slogdet_specialization(fgraph, node):
Expand Down
56 changes: 56 additions & 0 deletions pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,62 @@ def local_exp_log_nan_switch(fgraph, node):
return [new_out]


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([log])
def local_log_reciprocal(fgraph, node):
"""Rewrite log(reciprocal(x)) -> -log(x)."""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should do the more general as well (reciprocal is fine): log(a/b), where a or b is a non-negative constant -> log(a) - log(b) (the constant constant-folded already).

(inp,) = node.inputs
if (
inp.owner
and isinstance(inp.owner.op, Elemwise)
and isinstance(inp.owner.op.scalar_op, ps.Reciprocal)
):
return [neg(log(inp.owner.inputs[0]))]


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([sign])
def local_sign_reciprocal(fgraph, node):
"""Rewrite sign(reciprocal(x)) -> sign(x)."""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here sign a/b, where one is a positive constant -> sign of the other term. If the constant is negative, 1-sign of the other. If it's mixed, can't do anything

(inp,) = node.inputs
if (
inp.owner
and isinstance(inp.owner.op, Elemwise)
and isinstance(inp.owner.op.scalar_op, ps.Reciprocal)
):
return [sign(inp.owner.inputs[0])]


@register_canonicalize
@register_stabilize
@register_specialize
@node_rewriter([sign])
def local_sign_of_monotonic(fgraph, node):
"""Rewrite sign(f(x)) to sign(x) or -sign(x) based on monotonicity.

If f is strictly monotonic increasing and preserves zero, then sign(f(x)) == sign(x).
If f is strictly monotonic decreasing and preserves zero, then sign(f(x)) == -sign(x).
"""
(inp,) = node.inputs
if not (inp.owner and isinstance(inp.owner.op, Elemwise)):
return

scalar_op = inp.owner.op.scalar_op

if not getattr(scalar_op, "preserves_zero", False):
return

if getattr(scalar_op, "strictly_monotonic_increasing", False):
return [sign(inp.owner.inputs[0])]

if getattr(scalar_op, "strictly_monotonic_decreasing", False):
return [neg(sign(inp.owner.inputs[0]))]


@register_canonicalize
@register_specialize
@node_rewriter([Sum])
Expand Down
35 changes: 35 additions & 0 deletions tests/tensor/rewriting/linalg/test_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,38 @@ def test_det_of_factorized_matrix_special_cases(original_fn, expected_fn):
expected = expected_fn(x)
rewritten = rewrite_graph(out, include=["stabilize", "specialize"])
assert_equal_computations([rewritten], [expected])


def test_det_of_inv():
x = pt.tensor("x", shape=(3, 3))
out = det(pt.linalg.inv(x))
expected = pt.as_tensor(1.0, dtype="float64") / det(x)
rewritten = rewrite_graph(out, include=["canonicalize", "stabilize"])
assert_equal_computations([rewritten], [expected])


def test_slogdet_of_inv():
x = pt.dmatrix("x")
# slogdet(inv(x)) -> (sign, logabsdet)
sign_inv, logabsdet_inv = pt.linalg.slogdet(pt.linalg.inv(x))

# expected: (sign(det(x)), -logabsdet(det(x)))
# det(inv(x)) = 1/det(x), so sign is same.
# logabsdet(inv(x)) = log(abs(1/det(x))) = -log(abs(det(x)))
sign_x, logabsdet_x = pt.linalg.slogdet(x)
expected_sign = sign_x
expected_logabsdet = -logabsdet_x

# We need stabilize for det_of_inv and log_reciprocal
# and specialize for slogdet_specialization
rewritten_sign, rewritten_logabsdet = rewrite_graph(
[sign_inv, logabsdet_inv], include=["canonicalize", "stabilize", "specialize"]
)

expected_sign_opt, expected_logabsdet_opt = rewrite_graph(
[expected_sign, expected_logabsdet],
include=["canonicalize", "stabilize", "specialize"],
)

assert_equal_computations([rewritten_sign], [expected_sign_opt])
assert_equal_computations([rewritten_logabsdet], [expected_logabsdet_opt])
Loading
Loading