Block helper and rewrites#2117
Conversation
Pure-Python helper that walks a nested-list np.block-style structure and returns nested concatenate outputs directly, validating uniform leaf depth and promoting ranks via atleast_Nd.
Push dot inside a Join, splitting the other operand by leaf widths (or heights) and emitting per-leaf dots that sum or concat to the original result. Conservative: skips when partition dims are dynamic.
Matrix-transpose distributes through Join by transposing each leaf and swapping concatenation axis when it's one of the last two.
Recognize a square 2-D nested-Join with statically-zero off-diagonals and rewrite to BlockDiagonal so the existing block-diag rewrites can fire on user-written or rewrite-induced concat patterns.
Push Split through Join: matching axis with matching sizes returns the Join inputs directly; different axis distributes the Split per input. Unblocks Block@Block and X@S@X.T leaf-level decompositions.
Shared by upcoming block-triangular solve / det rewrites: is_static_zero predicate and match_2x2_nested_join structural matcher.
ricardoV94
left a comment
There was a problem hiding this comment.
I agree with the end goal but I have doubts on the optimality of most of these individual rewrites that pave the way to it.
If they aren't great in isolation you may need to do it as an integrated graph rewrite that bails out unless the end goal is achieved.
Also for the love of good can we tackle #1528 already?
| When ``Join`` runs along the matmul-contracted axis, ``Y`` is split by symbolic per-leaf sizes and | ||
| the per-leaf products are summed. Otherwise each leaf multiplies ``Y`` directly and the results are concatenated. | ||
|
|
||
| Walks through chains of left-``expand_dims`` ``DimShuffle`` nodes between the Join and the matmul |
There was a problem hiding this comment.
we need to start pushing Dimshuffle out of the way like we do with elemwise
|
|
||
| @register_stabilize | ||
| @node_rewriter([Join]) | ||
| def local_dot_of_join(fgraph, node): |
| return [new_out] | ||
|
|
||
|
|
||
| @register_canonicalize |
There was a problem hiding this comment.
not sure about this in canonicalize, seems pretty gnarly
|
Bot analysis: Rewrites in this PRThe branch adds 4 rewrites and 1 matching helper. Below: what each does, and whether it pays off on its own or only if downstream rewrites continue to fire. 1.
|
| Rewrite | Local effect | Needs downstream simplification? |
|---|---|---|
local_dot_of_join |
Worse (k small GEMMs > 1 big GEMM) | Yes — only stabilize |
local_transpose_of_join |
Neutral (free DimShuffles) | Mostly yes |
local_nested_join_to_block_diagonal |
Small win | Mostly — main payoff is BlockDiagonal specializations |
local_split_of_join (same axis) |
Strict identity removal | No |
local_split_of_join (diff axis) |
Worse (more nodes) | Yes |
pt.block(mirroringnumpy.block) that lowers to nestedJoincalls without a wrappingOpJoin-aware rewrites that decompose block-matrix algebra without materializing the full matrix:local_dot_of_join: pushdotinsideJoin.local_transpose_of_join: push matrix-transpose insideJoin, swapping-1 <-> -2for matrix axes and leaving batch axes alone. Used as a clean-up pass to expose the Join for further rewriting.local_split_of_join: liftSplitthroughJoin— return inputs directly when partitions match, distribute when split is on a different axis.The
local_dot_of_joinis especially handy on block triangular structures, like:By decomposing the dot into 4 smaller dots, we actually only end up with 3, because the 0 part will be eliminated by other rewrites. This pattern comes up a lot in (you guessed it) statespace.