Skip to content

Rewrite solve(matrix_inverse(X), b) → X @ b#2101

Closed
alessandrogentili001 wants to merge 3 commits into
pymc-devs:mainfrom
alessandrogentili001:rewrite-solve-matrix-inverse-as-mutmul
Closed

Rewrite solve(matrix_inverse(X), b) → X @ b#2101
alessandrogentili001 wants to merge 3 commits into
pymc-devs:mainfrom
alessandrogentili001:rewrite-solve-matrix-inverse-as-mutmul

Conversation

@alessandrogentili001
Copy link
Copy Markdown
Contributor

Description

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

Comment on lines +462 to +468
# Get all nodes in the rewritten graph
all_nodes = io_toposort([], [rewritten_out])

assert not any(
isinstance(getattr(node.op, "core_op", node.op), Solve | MatrixInverse)
for node in all_nodes
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can replace this with assert_equal_computation

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 May 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #2103 to try and formalize a bit better

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this way should works

@pytest.mark.parametrize("b_ndim", [1, 2], ids=lambda x: f"b_ndim={x}")
def test_solve_of_inv_to_matmul(b_ndim):
    X = pt.dmatrix("X")
    b = pt.dvector("b") if b_ndim == 1 else pt.dmatrix("b")
    out = solve(pt.linalg.inv(X), b, b_ndim=b_ndim)

    # Just include the rewrite we are testing
    rewriter = WalkingGraphRewriter(solve_of_inv_to_matmul)
    rewritten_out = rewrite_graph(out, custom_rewrite=rewriter)

    # Verify the rewrite
    expected = X @ b
    assert_equal_computations([rewritten_out], [expected])

    # Numerical check
    rng = np.random.default_rng(42)
    X_val = (rng.random((4, 4)) + np.eye(4) * 4).astype(X.type.dtype)
    b_val = rng.random((4,) if b_ndim == 1 else (4, 3)).astype(b.type.dtype)

    f_opt = function([X, b], rewritten_out)
    res_opt = f_opt(X_val, b_val)
    res_expected = np.linalg.solve(np.linalg.inv(X_val), b_val)

    np.testing.assert_allclose(res_opt, res_expected, rtol=1e-7)

# Graph rewrite test
# We include 'stabilize' because solve_of_inv_to_matmul is registered there.
# This avoids dependency on the global config.mode (e.g. FAST_COMPILE).
rewritten_out = rewrite_graph(out, include=["stabilize"])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can directly include just the rewrite you're testing , its a bit more clear that way

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

Comment on lines +462 to +471
# Numerical check
rng = np.random.default_rng(42)
X_val = (rng.random((4, 4)) + np.eye(4) * 4).astype(X.type.dtype)
b_val = rng.random((4,) if b_ndim == 1 else (4, 3)).astype(b.type.dtype)

f_opt = function([X, b], rewritten_out)
res_opt = f_opt(X_val, b_val)
res_expected = np.linalg.solve(np.linalg.inv(X_val), b_val)

np.testing.assert_allclose(res_opt, res_expected, rtol=1e-7)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need the numerical check if the structural check passes (you're just testing BLAS at that point -- i promise you BLAS works)

Comment on lines +455 to +456
rewriter = WalkingGraphRewriter(solve_of_inv_to_matmul)
rewritten_out = rewrite_graph(out, custom_rewrite=rewriter)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love the import + custom_rewrite. The pattern I had in mind was to just use rewrite_graph(out, include=('your_rewrite_name', )). If that doesn't work I'd rather it be reverted to what you had before for simplicity. But also fine with it staying this way if you don't want to keep going back and forth.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again I think this PR is a great discussion for #2103 which is still open-ended and in ask of feedback. So we standardize how we want to test this sort of rewrites and don't need to waste future time discussing it

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated the test suite for solve_of_inv_to_matmul to address the feedback regarding custom rewriters and numerical checks. During the refactoring, I encountered a failure that highlights why we need to stabilize both sides of the assertion when using assert_equal_computations.

The Issue

Even when the rewrite was triggered, X @ b initially produces a Matmul op. However, during the stabilize phase, PyTensor lowers this to a more specific Dot op.

Failure Example (b_ndim=1):

  • Rewritten Graph: Squeeze(Dot(X, ExpandDims(b)))
  • Expected Graph (Raw): Squeeze(Matmul(X, ExpandDims(b)))
  • Result: AssertionError: equal_computations failed (Dot vs Matmul).

The Fix

I updated the test to apply the stabilize group to both the output and the expected result. This ensures we are comparing the graphs in their final, canonicalized form:

# Restore stabilization to trigger the rewrite and canonicalize the internal Ops (Matmul -> Dot)
rewritten_out = rewrite_graph(out, include=["stabilize"])
expected = rewrite_graph(X @ b, include=["stabilize"])

assert_equal_computations([rewritten_out], [expected])

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes matmul is eagerly rewritten as dot because old rewrites handled dot directly, we still need to transition to matmul by default. But now test is good?

b = pt.dvector("b") if b_ndim == 1 else pt.dmatrix("b")
out = solve(pt.linalg.inv(X), b, b_ndim=b_ndim)

# Just include the rewrite we are testing
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment is not true, rewrite_graph includes canonicalize by default

Comment on lines +467 to +468
f_opt = function([X, b], rewritten_out)
res_opt = f_opt(X_val, b_val)
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 May 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EDIT (if you saw it): NVM

f_opt, no need to compile aggressively, you already locked the rewrite in, so now a simple mode=Mode(linker="py", optimizer=None) allows the fastest numerical eval.

In this case I'm okay with not testing it numerically, up to you

@alessandrogentili001 alessandrogentili001 deleted the rewrite-solve-matrix-inverse-as-mutmul branch May 4, 2026 21:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Rewrite solve(matrix_inverse(X), b) → X @ b

3 participants