Optimize CAReduce of Join by pushing reduction through concatenation#2130
Optimize CAReduce of Join by pushing reduction through concatenation#2130williambdean wants to merge 3 commits into
Conversation
…r join_axis in reduce_axis
ricardoV94
left a comment
There was a problem hiding this comment.
Looks pretty good already 😍
As always my nits are mostly with tests. Tbf this is something we need to standardize. But in general we want to assert the specific expected graph, it's easy to think were optimizing or writing a good tests when we're really not.
| expected_out = add(exp(x), log(x)) | ||
| assert equal_computations([rewritten_out], [expected_out]) | ||
|
|
||
| def test_careduce_join_sum_2(self): |
There was a problem hiding this comment.
check the kind of approach to writing rewrite tests we're trying to settle on: #2103
| assert any(isinstance(n.op, Join) for n in topo) | ||
|
|
||
| def test_careduce_join_not_applied_axis_excludes_join(self): | ||
| """Sum(concat(mat_a, mat_b), axis=1) should NOT trigger (axis excludes join axis 0)""" |
There was a problem hiding this comment.
for follow up we can still optimize, it's still generally better to reduce before joining even if the join is still needed. Just need to change the axis then. My comment here is to make the docstring not so authoritative that sounds like this would be a problem. Mention it as not currently supported instead
|
|
||
| fg = FunctionGraph([vx, vy, vz], [out], clone=False) | ||
| [rewritten_out] = local_careduce_join.transform(fg, out.owner) | ||
| expected_out = add(add(pt_sum(vx[None]), pt_sum(vy[None])), pt_sum(vz[None])) |
There was a problem hiding this comment.
We are missing CAReduce(Dimshuffle(x)) -> CAReduce(x), when the DimShuffle has no effect due to reduction (or DimShuffle may still be needed but only a subset of its behavior). In this case it isn't needed.
Does not need to be a blocker for this PR, but we should open an issue. I thought there was one already
|
The numba failure should be incidentally fixed by #1961 which sidesteps the numba bug. Can you rebase and check? |
Closes #59
Pushes
CAReduce(Sum, Prod, Max, Min) throughJointo reduce each input separately before combining, avoiding a large concatenated intermediate. Works foraxis=Noneand any axis that includes the join axis. For binary-only ops (Max/Min), limited to 2 inputs.Benchmark results: sum 1.6x, prod 1.4x, max 1.2x, min 1.5x speedup.