Add MultiDot op and rewrites for optimal contraction#2060
Add MultiDot op and rewrites for optimal contraction#2060jessegrabowski wants to merge 7 commits into
Conversation
There was a problem hiding this comment.
My gut doesn't love this approach.
Adding multi_dot in the IR is going to make us miss / compliate regular dot graphs.
The flattening by default may break the original associativity that may have been optimal in lack of statically known information.
My suggestion: After specialize, have a single GraphRewrite that collects nested matmuls and "re-associates" them if it can prove the new order is strictly better than the old one. It doesn't need an OpFromGraph imo.
Something like (bot generated):
class ReassociateMatmulChain(GraphRewriter):
"""Post-specialize: find matmul chains and reassociate if provably cheaper."""
def apply(self, fgraph):
visited = set()
for node in fgraph.toposort():
if node in visited or not _is_matmul_node(node):
continue
# 1. Extend chain through single-client intermediates only.
# This should ignore expand_dims / squeeze (maybe even transposes somehow?)
inputs, chain_nodes = self._extend_chain(node, fgraph, visited)
visited.update(chain_nodes)
if len(inputs) < 3:
continue
# 2. Symbolic shapes for every input. Each is a tuple of dim
# expressions (batch dims..., m_i, k_i) built from static shape
# where available and shape_of(var) otherwise.
shapes = [_symbolic_shape(x, fgraph) for x in inputs]
# 2b. Canonicalize all dim entries via a single shape-unification pass.
shapes = _unify_shapes(shapes, fgraph.shape_feature)
# 3. DP over parenthesizations.
# dp[i, j] = (cost_expr, split_k, result_shape) for chain[i..j]
n = len(inputs)
dp = {(i, i): (_zero(), None, shapes[i]) for i in range(n)}
for length in range(2, n + 1):
for i in range(n - length + 1):
j = i + length - 1
best = None
for k in range(i, j):
lc, _, ls = dp[i, k]
rc, _, rs = dp[k + 1, j]
step = _contract_cost(ls, rs)
total = lc + rc + step
result = _matmul_result_shape(ls, rs)
if best is None or _provably_less(total, best[0]):
best = (total, k, result)
dp[i, j] = best
new_cost, *_ = dp[0, n - 1]
old_cost = _current_order_cost(chain_nodes, shapes)
# 4. Only replace when provably strictly cheaper.
if not _provably_less(new_cost, old_cost):
continue
# _build_tree should return all nodes so we can add them to `seen`
new_out = _build_tree(inputs, dp, 0, n - 1) # plain matmul nodes
copy_stack_trace(chain_nodes[-1].outputs[0], new_out)
fgraph.replace(chain_nodes[-1].outputs[0], new_out,
reason="reassoc_matmul")Helpers — batch-aware shape & cost
def _matmul_result_shape(left, right):
"""left = (*bl, m, k), right = (*br, k, n) -> (*broadcast(bl, br), m, n).
Align batch dims from the right; missing dims on the shorter side are
treated as literal 1. After _unify_shapes, each aligned pair is either
(1, x), (x, 1), or (x, x) — pick the non-literal-1 side.
"""
batch = []
for da, db in zip_longest_right(left[:-2], right[:-2], fill=ONE):
if _is_literal_one(da):
batch.append(db)
elif _is_literal_one(db):
batch.append(da)
else:
assert _same_symbol(da, db) # unification guarantees this
batch.append(da)
return (*batch, left[-2], right[-1])
def _contract_cost(left, right):
"""FLOPs of (left @ right). Batch broadcast enters as a multiplier."""
result = _matmul_result_shape(left, right)
m, k, n = left[-2], left[-1], right[-1]
return _prod(result[:-2]) * m * k * n
def _unify_shapes(shapes, shape_feature):
"""Canonicalize dim entries for a matmul chain using all known equalities.
Three sources of equality feed in:
1. Contracting dims (matmul semantics): shapes[i][-1] == shapes[i+1][-2]
for every adjacent pair. Applies to *every* chain, adjacent only.
2. Batch dims required equal at runtime: for any pair (i, j), align
their batch dims from the right. Dims that are both non-literal-1
MUST be equal (broadcasting rule) — unify them for costing.
Applies to non-adjacent pairs too, transitively.
3. ShapeFeature same_shape classes: if the fgraph's ShapeFeature
already knows two shape entries are equal (from earlier rewrites
or op-level declarations), use it directly — no need to re-derive.
This is why the helper takes `shape_feature` rather than just
looking at the raw shape graphs.
Strategy: union-find over all dim entries in the chain. Add edges
from (1), (2), (3). Pick a representative per class preferring
literal ints > static-shape ints > shape_of symbols. Rewrite every
shape tuple with representatives.
TODO: ideally ShapeFeature itself carries the edges from (1) and (2)
(matmul's Op declares "my input ks are equal"; blockwise declares
"my batch dims broadcast"), so `same_shape` works everywhere and this
helper collapses to "read canonical reps from ShapeFeature." For now
we do it locally, but the long-term home is ShapeFeature.
"""Proving a < b symbolically
def _provably_less(a, b):
"""Expand a, b into sum-of-monomials in positive dim symbols.
Use the invariant: every dim symbol >= 1 (matmul dims are positive).
Return True iff we can match each monomial of `a` to a DISTINCT
monomial of `b` s.t. the a-term is dominated term-wise by its b-term,
and b has at least one unmatched monomial (strict). Otherwise False.
False means 'not provable'; it does NOT claim b <= a."""Term-wise dominance: monomial c·x1^a1·x2^a2… is dominated by
d·x1^b1·x2^b2… when c ≤ d and every ai ≤ bi, given all xi ≥ 1.
Cheap, and catches the common wins (a factor of m·k·p dominates k·p for
any positive m). Won't decide genuinely shape-dependent ties — that's
fine; bail and keep the original order.
edcb675 to
7f53de5
Compare
|
@ricardoV94 took another cut at this, lmk if it looks better to you |
does it look better to you? |
|
it's definitely thinking about the problem in a much different way. You can see from my first pass how I think about pytensor in a very Op-oriented way. Here we have a pure graph reasoning implementation. I would never have come up with the monomial stuff by myself, so it's hard for me to assess if it's doing the right thing from a design level. It's also being very cute by trying to work 100% with shapes instead of passing around variables, so you end up with stuff like the broadcast checkers. I think you can argue those are bloat, but only if I refactor it again to work with variables (so we get access to .broadcastable and whatnot) |
Could be lack of familiarity with graph level rewrites? Maybe would be a similar weirdness if you were implementing fusion rewriter from scratch. Start with an eager local thing that just tries to expand one node at a time -> Composite A -> Compasite A + 1. And then I come and suggest you analyze the graph all at once and break it into all the Composites you can see, try to use some advanced data structure to verify convexity cheaply |
|
Or is just not the right answer for this problem ... |
|
I think it's nicer, and I think you're right that I've mostly only been working with node rewriters so. I don't love that it's a +1000 line PR for a feature that won't be relevant much of the time. On the plus side though, it's contained to a single file so easy to iterate on if it ends up sucking. I want run it through the ASV benchmarks to make sure it doesn't drag on graphs overall. |
|
Curious if it fires anywhere in the CI or pymc model catalogue |
|
If it's rare but the bail out is fast, that's also ok |
| if not (0 <= x < len(shape)): | ||
| raise _BailOutError( | ||
| f"DimShuffle.new_order references index {x} outside operand shape " | ||
| f"of length {len(shape)}; lift cannot legally apply." | ||
| ) |
There was a problem hiding this comment.
yeah idk, i got a bit lazy letting all the bot slop through
| return False, False | ||
|
|
||
|
|
||
| def _is_chain_link(node: Apply) -> bool: |
There was a problem hiding this comment.
inline function (fine to still be a function) in `find_chan_top
There was a problem hiding this comment.
it's used in _decompose_operand, _find_chain_top, and ReassociateMatmulChain. I think this one is defensible as a helper.
ricardoV94
left a comment
There was a problem hiding this comment.
pretty nice, some nits. I'll check if this fires anywhere in the pymc catalogue out of curiosity
Curious what you find. If we have a SEM example it might? It's pretty unusual for a GLM to do a fat chains of matmuls |
|
Get's called in 5 models (but two are just for demoing), in none does it apply:
|
|
I imagine this will be the most common case tbh. But statespace cares! |
e435b76 to
9d1a298
Compare
|
For statespace it will help because the package always puts the dot in the same order, but they are not necessarily in the optimal from the get go. Regardless of this rewrite, do we have the best default order there (say if static shapes never trigger)? |
Yes, in principle the right answer is known and I could just go in and put parenthesis. But I want my software to magically do it for me :( Also in expressions like these it gets uglier and uglier to actually do that in code. These types of huge dot chains also appear in optimal control problems. Again the shapes of all those objects are known statically ahead of time (even if pytensor doesn't know that) so one could simply optimize it himself. |
|
I believe it's useful. I'm a bit bummed that we haven't seen it be useful yet. Do you have any STS example where it triggers (not one purposedly built now to prove it). Not a blocker regardless as long as this rewrite is cheap to bailout (which I believe it will be). |
|
Ok so a place we really, really should be getting gains is in low-rank projects, like inducing-point approximation for GP. Here's some code due to @bwengals : Kuf = self.kernel(Z, X_train) # (M, N)
Kus = self.kernel(Z, X_new) # (M, N*)
Kss_diag = self.kernel.diag(X_new)
# Sigma = Kuu + Kuf @ Kuf.T / sigma^2
Sigma = Kuu + Kuf @ Kuf.T / sigma2
Sigma_inv = pt.linalg.inv(Sigma)
mu_train = self.mean(X_train)
alpha = Sigma_inv @ Kuf @ (y_train - mu_train) / sigma2
fmean = self.mean(X_new) + Kus.T @ alpha
# fvar = Kss - Kus.T @ (Kuu^{-1} - Sigma^{-1}) @ Kus
Kuu_inv = pt.linalg.inv(Kuu)
diff_inv = Kuu_inv - Sigma_inv
fvar = Kss_diag - pt.sum(Kus * (diff_inv @ Kus), axis=0) Bot analysis: |
|
Solve / Dot seems like a hard ordering problem, you could re-canonicalize as MatrixInverse here and be sure to always reintroduce Solve when done. The BLAS thing, may be time to pull the plug, should def be a very late stage rewrite (after specialize, before fusion), it hurts us everytime we want to work with matmuls (all the time), and I don't see why it should/must be eager at all |
I've wanted this for a while. Adds a
MultiDotOpthat we can track with rewrites. We look for sequences of matrix multiplicates in the graph and fuse them into a MultiDot during canonicalization. For example:A @ B @ C -> MultiDot(A, B, C).By default, MulitDot is just an OpFromGraph that does simple left-to-right matrix multiplication. So
MultiDot(A, B, C) -> A @ B @ Cduring inlining. If all shapes of A, B, C are statically known, however, we solve the dynamic programming problem to figure out the optimal ordering of matmuls. For details see the wiki here: https://en.wikipedia.org/wiki/Matrix_chain_multiplicationWe could probably try to do something more heroic, but I think this is a good start.