-
Notifications
You must be signed in to change notification settings - Fork 187
Rewrite det(inv(X)) → 1/det(X) #2102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |||
| from pytensor.tensor.linalg.decomposition.lu import LU, LUFactor | ||||
| from pytensor.tensor.linalg.decomposition.qr import QR | ||||
| from pytensor.tensor.linalg.decomposition.svd import SVD | ||||
| from pytensor.tensor.linalg.inverse import MatrixInverse | ||||
| from pytensor.tensor.linalg.summary import SLogDet, det | ||||
| from pytensor.tensor.math import Prod, log, prod | ||||
| from pytensor.tensor.rewriting.basic import ( | ||||
|
|
@@ -225,6 +226,19 @@ def det_of_diag(fgraph, node): | |||
| return [det_val] | ||||
|
|
||||
|
|
||||
| @register_canonicalize | ||||
| @register_stabilize | ||||
| @node_rewriter([det]) | ||||
| def det_of_inv(fgraph, node): | ||||
| """Replace det(matrix_inverse(X)) with reciprocal(det(X)). | ||||
|
|
||||
| Since det(inv(X)) = 1/det(X), we avoid computing the inverse. | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
| """ | ||||
| match node.inputs[0].owner_op_and_inputs: | ||||
| case (Blockwise(MatrixInverse()), X): | ||||
| return [1 / det(X)] | ||||
|
|
||||
|
|
||||
| @register_specialize | ||||
| @node_rewriter([det]) | ||||
| def slogdet_specialization(fgraph, node): | ||||
|
|
||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -651,6 +651,62 @@ def local_exp_log_nan_switch(fgraph, node): | |
| return [new_out] | ||
|
|
||
|
|
||
| @register_canonicalize | ||
| @register_stabilize | ||
| @register_specialize | ||
| @node_rewriter([log]) | ||
| def local_log_reciprocal(fgraph, node): | ||
| """Rewrite log(reciprocal(x)) -> -log(x).""" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should do the more general as well (reciprocal is fine): log(a/b), where a or b is a non-negative constant -> log(a) - log(b) (the constant constant-folded already). |
||
| (inp,) = node.inputs | ||
| if ( | ||
| inp.owner | ||
| and isinstance(inp.owner.op, Elemwise) | ||
| and isinstance(inp.owner.op.scalar_op, ps.Reciprocal) | ||
| ): | ||
| return [neg(log(inp.owner.inputs[0]))] | ||
|
|
||
|
|
||
| @register_canonicalize | ||
| @register_stabilize | ||
| @register_specialize | ||
| @node_rewriter([sign]) | ||
| def local_sign_reciprocal(fgraph, node): | ||
| """Rewrite sign(reciprocal(x)) -> sign(x).""" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here sign a/b, where one is a positive constant -> sign of the other term. If the constant is negative, 1-sign of the other. If it's mixed, can't do anything |
||
| (inp,) = node.inputs | ||
| if ( | ||
| inp.owner | ||
| and isinstance(inp.owner.op, Elemwise) | ||
| and isinstance(inp.owner.op.scalar_op, ps.Reciprocal) | ||
| ): | ||
| return [sign(inp.owner.inputs[0])] | ||
|
|
||
|
|
||
| @register_canonicalize | ||
| @register_stabilize | ||
| @register_specialize | ||
| @node_rewriter([sign]) | ||
| def local_sign_of_monotonic(fgraph, node): | ||
| """Rewrite sign(f(x)) to sign(x) or -sign(x) based on monotonicity. | ||
|
|
||
| If f is strictly monotonic increasing and preserves zero, then sign(f(x)) == sign(x). | ||
| If f is strictly monotonic decreasing and preserves zero, then sign(f(x)) == -sign(x). | ||
| """ | ||
| (inp,) = node.inputs | ||
| if not (inp.owner and isinstance(inp.owner.op, Elemwise)): | ||
| return | ||
|
|
||
| scalar_op = inp.owner.op.scalar_op | ||
|
|
||
| if not getattr(scalar_op, "preserves_zero", False): | ||
| return | ||
|
|
||
| if getattr(scalar_op, "strictly_monotonic_increasing", False): | ||
| return [sign(inp.owner.inputs[0])] | ||
|
|
||
| if getattr(scalar_op, "strictly_monotonic_decreasing", False): | ||
| return [neg(sign(inp.owner.inputs[0]))] | ||
|
|
||
|
|
||
| @register_canonicalize | ||
| @register_specialize | ||
| @node_rewriter([Sum]) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't love this, what cases disagree right now between the two?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ceilandfloor, for exampleThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't mark those instead?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I don't get it, any discrete input version to these ops is also not strictly monotonic.nvm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not against that, but then we should use the
strictly_language everywhere (drop the shorter one) to be clear what is going onUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Circling back to the sign thing, if that's the motivation, I don't you can apply it based on monotonicity, strict or not. sign(exp(x)) is obviously not sign(x).
We are adding these properties for specific uses, not for mathematical idealism, so they need not be verbose nor geberalized besides the problems we want to solve with them. sctrict vs non strict is more a question of invertible 1-1 map not the direction. wouldn't a combination ot those 2 poperties + zero preserving be a better way to achieve the goal?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exp isn't zero preserving, so the rule doesn't apply. Both things are important. It has to be strictly monotonic increasing and zero preserving. I think taking strict monotonicity as our canonical form is nice, because who cares about ceil/floor anyway. But I also think it's important to be clear in language, otherwise someone can come along in a few years and say "Well technicaly BitwiseInverse is monotonic_increasing, why isn't it marked" and the answer is "because we define monotonicity as strict monotonicity but it isn't written anywhere". We lose nothing by just writing it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I know sign thing doesn't apply to zero, did I say so?
Otherwise okay, we can go with verbose, don't love it but it's strictly more precise.
Can we stop there and not at strictly_monotonic_increasing_over_defined_domain?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was reacting to this: sign(exp(x)) is obviously not sign(x).
Maybe i misunderstood your point.
I'm not being dogmatic about the name change, if we want to just define monotonic to mean "strictly monotonic" and put it in the docs somewhere, I have no objection. I just want it to be written down, and I like self-documenting code. Agreed that there is a limit.