Skip to content

Optimize CAReduce of Join by pushing reduction through concatenation#2130

Open
williambdean wants to merge 3 commits into
pymc-devs:mainfrom
williambdean:careduce-join-optimization
Open

Optimize CAReduce of Join by pushing reduction through concatenation#2130
williambdean wants to merge 3 commits into
pymc-devs:mainfrom
williambdean:careduce-join-optimization

Conversation

@williambdean
Copy link
Copy Markdown
Contributor

Closes #59

Pushes CAReduce (Sum, Prod, Max, Min) through Join to reduce each input separately before combining, avoiding a large concatenated intermediate. Works for axis=None and any axis that includes the join axis. For binary-only ops (Max/Min), limited to 2 inputs.

Benchmark results: sum 1.6x, prod 1.4x, max 1.2x, min 1.5x speedup.

Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good already 😍

As always my nits are mostly with tests. Tbf this is something we need to standardize. But in general we want to assert the specific expected graph, it's easy to think were optimizing or writing a good tests when we're really not.

Comment thread pytensor/tensor/rewriting/math.py Outdated
Comment thread pytensor/tensor/rewriting/math.py Outdated
Comment thread pytensor/tensor/rewriting/math.py Outdated
expected_out = add(exp(x), log(x))
assert equal_computations([rewritten_out], [expected_out])

def test_careduce_join_sum_2(self):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check the kind of approach to writing rewrite tests we're trying to settle on: #2103

assert any(isinstance(n.op, Join) for n in topo)

def test_careduce_join_not_applied_axis_excludes_join(self):
"""Sum(concat(mat_a, mat_b), axis=1) should NOT trigger (axis excludes join axis 0)"""
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for follow up we can still optimize, it's still generally better to reduce before joining even if the join is still needed. Just need to change the axis then. My comment here is to make the docstring not so authoritative that sounds like this would be a problem. Mention it as not currently supported instead


fg = FunctionGraph([vx, vy, vz], [out], clone=False)
[rewritten_out] = local_careduce_join.transform(fg, out.owner)
expected_out = add(add(pt_sum(vx[None]), pt_sum(vy[None])), pt_sum(vz[None]))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are missing CAReduce(Dimshuffle(x)) -> CAReduce(x), when the DimShuffle has no effect due to reduction (or DimShuffle may still be needed but only a subset of its behavior). In this case it isn't needed.

Does not need to be a blocker for this PR, but we should open an issue. I thought there was one already

Comment thread tests/tensor/rewriting/test_math.py Outdated
@ricardoV94
Copy link
Copy Markdown
Member

The numba failure should be incidentally fixed by #1961 which sidesteps the numba bug. Can you rebase and check?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Optimize Sums of MakeVectors and Joins

2 participants