Skip to content

ZeRO 1/2: wait on all IPG-bucket producer streams in average_tensor (#8061)#8080

Open
arunshar wants to merge 1 commit into
deepspeedai:masterfrom
arunshar:fix/zero-overlap-comm-multistream
Open

ZeRO 1/2: wait on all IPG-bucket producer streams in average_tensor (#8061)#8080
arunshar wants to merge 1 commit into
deepspeedai:masterfrom
arunshar:fix/zero-overlap-comm-multistream

Conversation

@arunshar

Copy link
Copy Markdown

What

Fixes #8061. In ZeRO stage 1/2 with overlap_comm, average_tensor waits the reduction stream on only the current stream before reducing the contiguous IPG gradient bucket:

if self.overlap_comm:
    stream = self.reduction_stream
    if not get_accelerator().resolves_data_dependency():
        stream.wait_stream(get_accelerator().current_stream())   # only one stream

But the per-parameter gradient copies that fill the bucket (reduce_independent_p_g_buckets_and_remove_grads, the new_grad_tensor.copy_(...) into bucket.buffer[bucket.index]) can be issued on multiple streams. That is exactly the scenario #8061 reports under torch.compile, where gradient hooks run on different autograd streams and several device streams write slices into the same IPG buffer. Waiting on only the current stream lets the all-reduce read the bucket before the other producers finish, corrupting gradients (NaN loss from step 1).

Change

Implements the fix direction proposed in #8061 (record the streams used for the IPG bucket copies, then make the reduction stream wait on all of them):

  • IPGBucket gains a copy_streams set, cleared in clear().
  • After each copy into bucket.buffer[bucket.index], record current_stream() (overlap path only).
  • average_tensor waits the reduction stream on every recorded producer stream, falling back to current_stream() when the set is empty (e.g. the extra-large-param path that reduces without a bucket copy).

The single-stream case is unchanged: when all copies are on one stream, copy_streams == {current_stream}, so the wait is identical to before, and there is no behavior change for the common path.

Tests / validation

  • New CPU unit tests in tests/unit/v1/zero/test_overlap_comm_record_stream.py (fakes + monkeypatch, no GPU): the reduction stream waits on all producer streams, the empty-set fallback to current_stream, and IPGBucket.copy_streams reset. The two pre-existing tests in that file still pass (5 passed).
  • pre-commit run --files is green on both changed files (yapf, flake8, check-torchdist, check-license, codespell).
  • No-regression A/B on 2x A40 (seeded, identical data, ZeRO-2 + contiguous_gradients, overlap_comm on vs off, eager and torch.compile): before and after this change, suspect == baseline on every repeat with byte-identical final-param hashes, and the baseline-vs-baseline determinism gate passes. So the fix does not change results when grad-bucket copies stay on a single stream.

Honest scope on reproduction

I could not deterministically trigger the multi-stream NaN on the available hardware (A40, small MLP): neither the torch.compile A/B nor a synthetic two-stream microbenchmark surfaced it (PyTorch's caching allocator inserts implicit cross-stream syncs that mask the race in a microbenchmark, and a plain torch.compile(model) on this model kept the grad-bucket copies on one stream). This PR is therefore offered as the minimal correct synchronization for the clearly-missing producer-stream wait identified in average_tensor, validated for no-regression and with unit coverage; a reviewer with the original torch.compile multi-stream repro from #8061 can confirm the NaN is resolved.

Opened as a draft.

…eepspeedai#8061)

With overlap_comm, the per-parameter gradient copies into the contiguous IPG
bucket can be issued on multiple streams (e.g. under torch.compile, gradient
hooks run on different autograd streams). average_tensor waited the reduction
stream on only the current stream before reducing the bucket, so the reduction
could read the bucket before another producer finished, corrupting gradients
(NaN loss). Track the set of producer streams per IPG bucket and wait on all of
them. The single-stream path is unchanged (the set is just {current_stream}), so
there is no behavior change when overlap_comm copies stay on one stream.

Adds CPU unit tests in tests/unit/v1/zero/test_overlap_comm_record_stream.py for
the producer-stream wait, the empty fallback to current_stream, and the
IPGBucket.copy_streams reset.

Fixes deepspeedai#8061.

Signed-off-by: Arun Sharma <sharm485@umn.edu>
@arunshar arunshar force-pushed the fix/zero-overlap-comm-multistream branch from a2d31ca to 35b4262 Compare June 19, 2026 23:34
@arunshar arunshar marked this pull request as ready for review June 19, 2026 23:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] ZeRO stage 1/2 overlap_comm only waits current stream, but contiguous gradient bucket copies may come from multiple streams

1 participant