Skip to content

Block helper and rewrites#2117

Open
jessegrabowski wants to merge 6 commits into
pymc-devs:mainfrom
jessegrabowski:block-helper-and-rewrites
Open

Block helper and rewrites#2117
jessegrabowski wants to merge 6 commits into
pymc-devs:mainfrom
jessegrabowski:block-helper-and-rewrites

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

@jessegrabowski jessegrabowski commented May 6, 2026

  • Add pt.block (mirroring numpy.block) that lowers to nested Join calls without a wrapping Op
  • Add a stack of Join-aware rewrites that decompose block-matrix algebra without materializing the full matrix:
    • local_dot_of_join: push dot inside Join.
    • local_transpose_of_join: push matrix-transpose inside Join, swapping -1 <-> -2 for matrix axes and leaving batch axes alone. Used as a clean-up pass to expose the Join for further rewriting.
    • local_split_of_join: lift Split through Join — return inputs directly when partitions match, distribute when split is on a different axis.

The local_dot_of_join is especially handy on block triangular structures, like:

x = block([[A, 0], [B, C]])
y = dvector()
z = x @ y

By decomposing the dot into 4 smaller dots, we actually only end up with 3, because the 0 part will be eliminated by other rewrites. This pattern comes up a lot in (you guessed it) statespace.

Pure-Python helper that walks a nested-list np.block-style structure
and returns nested concatenate outputs directly, validating uniform
leaf depth and promoting ranks via atleast_Nd.
Push dot inside a Join, splitting the other operand by leaf widths
(or heights) and emitting per-leaf dots that sum or concat to the
original result. Conservative: skips when partition dims are dynamic.
Matrix-transpose distributes through Join by transposing each leaf and
swapping concatenation axis when it's one of the last two.
Recognize a square 2-D nested-Join with statically-zero off-diagonals
and rewrite to BlockDiagonal so the existing block-diag rewrites can
fire on user-written or rewrite-induced concat patterns.
Push Split through Join: matching axis with matching sizes returns
the Join inputs directly; different axis distributes the Split per
input. Unblocks Block@Block and X@S@X.T leaf-level decompositions.
Shared by upcoming block-triangular solve / det rewrites: is_static_zero
predicate and match_2x2_nested_join structural matcher.
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the end goal but I have doubts on the optimality of most of these individual rewrites that pave the way to it.

If they aren't great in isolation you may need to do it as an integrated graph rewrite that bails out unless the end goal is achieved.

Also for the love of good can we tackle #1528 already?

When ``Join`` runs along the matmul-contracted axis, ``Y`` is split by symbolic per-leaf sizes and
the per-leaf products are summed. Otherwise each leaf multiplies ``Y`` directly and the results are concatenated.

Walks through chains of left-``expand_dims`` ``DimShuffle`` nodes between the Join and the matmul
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to start pushing Dimshuffle out of the way like we do with elemwise


@register_stabilize
@node_rewriter([Join])
def local_dot_of_join(fgraph, node):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this rewrite dominant?

return [new_out]


@register_canonicalize
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure about this in canonicalize, seems pretty gnarly

@ricardoV94
Copy link
Copy Markdown
Member

Bot analysis:

Rewrites in this PR

The branch adds 4 rewrites and 1 matching helper. Below: what each does, and whether it pays off on its own or only if downstream rewrites continue to fire.

1. local_dot_of_joinpytensor/tensor/rewriting/math.py

Push dot inside a Join. If the join axis is the contracted axis, split the other operand by per-leaf sizes and sum per-leaf products; otherwise multiply each leaf separately and concat. Walks through left-expand_dims DimShuffles between the Join and the matmul.

Locally optimal? No — and the author registered it only as stabilize for that reason. One large GEMM almost always beats k smaller GEMMs of the same total FLOPs (better BLAS/cache utilization), and in the off-axis case it adds the same FLOPs plus extra wrapping ops. It only wins when leaves simplify (Eye/zero/triangular/BlockDiagonal blocks) or when the Join has no other consumers and disappears.

2. local_transpose_of_joinpytensor/tensor/rewriting/math.py

Join(axis, *xs).mT → Join(swapped_axis, *[x.mT for x in xs]), where axis ∈ {-1, -2} swap and batch axes pass through.

Locally optimal? Roughly neutral. mT is a free DimShuffle (strided view), so trading 1 outer mT for k inner ones costs nothing at runtime but adds nodes. The point is to expose each leaf's transpose so it can fold ((A.mT).mT → A, solve_triangular/cholesky patterns, etc.). Pure pipeline enabler. canonicalize + stabilize is fine because the trade is essentially free.

3. local_nested_join_to_block_diagonalpytensor/tensor/rewriting/math.py

Recognize a square Join(-2, *Join(-1, ...)) whose off-diagonals are statically zero and rewrite to block_diag(diag_blocks...).

Locally optimal? Mild local win — collapses many Join nodes and drops the zero-leaf broadcast/materialization into the dedicated BlockDiagonal Op. The bigger payoff is unlocking the existing BlockDiagonal-aware specializations (det, diag, trace, block_diag @ x, solve(block_diag, ...)). So mostly pipeline-driven, with a small local cleanup as a bonus.

4. local_split_of_joinpytensor/tensor/rewriting/math.py

Two cases:

  • Same axis, matching static sizes: Split(Join(a, X_i), [|X_i|_a], a) → [X_i] — strict identity removal.
  • Different axis: distribute the split through the join (Split(Join(a, X_i), s, b) → [Join(a, Split(X_i, s, b)[k] for i)]).

Locally optimal? The same-axis case is unconditionally good (pure no-op elimination). The different-axis case is pipeline-driven: it multiplies the number of Split nodes by k and adds new Joins; it only pays off when the assembled intermediate would be re-cut or cancel out, which is exactly the cascade local_dot_of_join produces.

Summary table

Rewrite Local effect Needs downstream simplification?
local_dot_of_join Worse (k small GEMMs > 1 big GEMM) Yes — only stabilize
local_transpose_of_join Neutral (free DimShuffles) Mostly yes
local_nested_join_to_block_diagonal Small win Mostly — main payoff is BlockDiagonal specializations
local_split_of_join (same axis) Strict identity removal No
local_split_of_join (diff axis) Worse (more nodes) Yes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants