Rewrite solve(matrix_inverse(X), b) → X @ b#2101
Conversation
| # Get all nodes in the rewritten graph | ||
| all_nodes = io_toposort([], [rewritten_out]) | ||
|
|
||
| assert not any( | ||
| isinstance(getattr(node.op, "core_op", node.op), Solve | MatrixInverse) | ||
| for node in all_nodes | ||
| ) |
There was a problem hiding this comment.
Can replace this with assert_equal_computation
There was a problem hiding this comment.
this way should works
@pytest.mark.parametrize("b_ndim", [1, 2], ids=lambda x: f"b_ndim={x}")
def test_solve_of_inv_to_matmul(b_ndim):
X = pt.dmatrix("X")
b = pt.dvector("b") if b_ndim == 1 else pt.dmatrix("b")
out = solve(pt.linalg.inv(X), b, b_ndim=b_ndim)
# Just include the rewrite we are testing
rewriter = WalkingGraphRewriter(solve_of_inv_to_matmul)
rewritten_out = rewrite_graph(out, custom_rewrite=rewriter)
# Verify the rewrite
expected = X @ b
assert_equal_computations([rewritten_out], [expected])
# Numerical check
rng = np.random.default_rng(42)
X_val = (rng.random((4, 4)) + np.eye(4) * 4).astype(X.type.dtype)
b_val = rng.random((4,) if b_ndim == 1 else (4, 3)).astype(b.type.dtype)
f_opt = function([X, b], rewritten_out)
res_opt = f_opt(X_val, b_val)
res_expected = np.linalg.solve(np.linalg.inv(X_val), b_val)
np.testing.assert_allclose(res_opt, res_expected, rtol=1e-7)| # Graph rewrite test | ||
| # We include 'stabilize' because solve_of_inv_to_matmul is registered there. | ||
| # This avoids dependency on the global config.mode (e.g. FAST_COMPILE). | ||
| rewritten_out = rewrite_graph(out, include=["stabilize"]) |
There was a problem hiding this comment.
You can directly include just the rewrite you're testing , its a bit more clear that way
| # Numerical check | ||
| rng = np.random.default_rng(42) | ||
| X_val = (rng.random((4, 4)) + np.eye(4) * 4).astype(X.type.dtype) | ||
| b_val = rng.random((4,) if b_ndim == 1 else (4, 3)).astype(b.type.dtype) | ||
|
|
||
| f_opt = function([X, b], rewritten_out) | ||
| res_opt = f_opt(X_val, b_val) | ||
| res_expected = np.linalg.solve(np.linalg.inv(X_val), b_val) | ||
|
|
||
| np.testing.assert_allclose(res_opt, res_expected, rtol=1e-7) |
There was a problem hiding this comment.
We don't need the numerical check if the structural check passes (you're just testing BLAS at that point -- i promise you BLAS works)
| rewriter = WalkingGraphRewriter(solve_of_inv_to_matmul) | ||
| rewritten_out = rewrite_graph(out, custom_rewrite=rewriter) |
There was a problem hiding this comment.
I don't love the import + custom_rewrite. The pattern I had in mind was to just use rewrite_graph(out, include=('your_rewrite_name', )). If that doesn't work I'd rather it be reverted to what you had before for simplicity. But also fine with it staying this way if you don't want to keep going back and forth.
There was a problem hiding this comment.
Again I think this PR is a great discussion for #2103 which is still open-ended and in ask of feedback. So we standardize how we want to test this sort of rewrites and don't need to waste future time discussing it
There was a problem hiding this comment.
I have updated the test suite for solve_of_inv_to_matmul to address the feedback regarding custom rewriters and numerical checks. During the refactoring, I encountered a failure that highlights why we need to stabilize both sides of the assertion when using assert_equal_computations.
The Issue
Even when the rewrite was triggered, X @ b initially produces a Matmul op. However, during the stabilize phase, PyTensor lowers this to a more specific Dot op.
Failure Example (b_ndim=1):
- Rewritten Graph:
Squeeze(Dot(X, ExpandDims(b))) - Expected Graph (Raw):
Squeeze(Matmul(X, ExpandDims(b))) - Result:
AssertionError: equal_computations failed(Dot vs Matmul).
The Fix
I updated the test to apply the stabilize group to both the output and the expected result. This ensures we are comparing the graphs in their final, canonicalized form:
# Restore stabilization to trigger the rewrite and canonicalize the internal Ops (Matmul -> Dot)
rewritten_out = rewrite_graph(out, include=["stabilize"])
expected = rewrite_graph(X @ b, include=["stabilize"])
assert_equal_computations([rewritten_out], [expected])There was a problem hiding this comment.
Yes matmul is eagerly rewritten as dot because old rewrites handled dot directly, we still need to transition to matmul by default. But now test is good?
| b = pt.dvector("b") if b_ndim == 1 else pt.dmatrix("b") | ||
| out = solve(pt.linalg.inv(X), b, b_ndim=b_ndim) | ||
|
|
||
| # Just include the rewrite we are testing |
There was a problem hiding this comment.
Comment is not true, rewrite_graph includes canonicalize by default
| f_opt = function([X, b], rewritten_out) | ||
| res_opt = f_opt(X_val, b_val) |
There was a problem hiding this comment.
EDIT (if you saw it): NVM
f_opt, no need to compile aggressively, you already locked the rewrite in, so now a simple mode=Mode(linker="py", optimizer=None) allows the fastest numerical eval.
In this case I'm okay with not testing it numerically, up to you
Description
Related Issue
Checklist
Type of change