Skip to content

[DeepCompile] fix gather params in dynamo skipped frames for ZeRO3#8059

Open
XAheli wants to merge 3 commits into
deepspeedai:masterfrom
XAheli:fix/deepcompile-skipped-frame-gather-7942
Open

[DeepCompile] fix gather params in dynamo skipped frames for ZeRO3#8059
XAheli wants to merge 3 commits into
deepspeedai:masterfrom
XAheli:fix/deepcompile-skipped-frame-gather-7942

Conversation

@XAheli

@XAheli XAheli commented Jun 11, 2026

Copy link
Copy Markdown

Fixes #7942

Root cause: When init_z3() initializes DeepCompile it removes all three parameter-gathering mechanisms (ZeROOrderedDict, module hooks, engine forward hooks) and relies entirely on compiled FX graph ops for allgather/release. but torch._dynamo may skip entire frames when it detects graph breaks in for/while loops. Skipped frames execute eagerly with no gathering mechanism, so parameters stay partitioned at shape [0].

Testing

Validated on 2× H200 with ZeRO3 + DeepCompile:

Test Result
Qwen2 MoE (actual failing model from #7942) PASS — 5 training steps, no crash
LLaMA (regression — already worked) PASS
Tied embeddings (shared param across frame types) PASS
Gradient correctness (loss decreases on fixed input) PASS — 12.09 → 10.28
Guard stability (no recompilation loops) PASS — 22 compilations with fix = 22 without
Existing test_compile.py (100 steps) PASS

Test plan

  • pre-commit run passes on all changed files
  • Existing tests/torch_compile/test_compile.py passes (2 GPU, ZeRO-3)
  • New regression test test_deepcompile_skipped_frame.py passes
  • Real Qwen2 MoE model trains without crash
  • Real LLaMA model trains without regression
  • Tied embedding model trains without crash
  • Compilation count unchanged vs upstream (no guard instability)

cc @tohtana @eternalNight

Signed-off-by: ahpoddar <ahpoddar@redhat.com>

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 04bf85696a

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +2390 to +2392
for p in self.module.parameters():
if hasattr(p, "ds_status") and p.ds_status == ZeroParamStatus.AVAILABLE and not p.ds_persist:
p.partition()

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Don't free eager-gathered weights before backward

When a Dynamo-skipped frame executes eagerly and touches a ZeRO-3 parameter whose backward needs the weight (for example a Linear after an embedding, where grad-input must be computed), this loop immediately calls p.partition() after forward. In DeepCompile the ZeRO module backward hooks have been removed, and deepcompile_backward_prologue() only starts the compiled runtime, so there is no eager fallback to re-gather that saved weight before the eager autograd node runs; backward will see the released [0] parameter/storage or compute from invalid state. The fallback-gathered params need to stay available until their eager backward use has completed, or get a matching backward-time gather/release path.

Useful? React with 👍 / 👎.

if _dc_z3_eager_fallback:
for p in self.module.parameters():
if hasattr(p, "ds_status") and p.ds_status == ZeroParamStatus.AVAILABLE and not p.ds_persist:
p.partition()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this!

One question: I was wondering why we need to walk through all parameters and free those that are still gathered at this point? Does it mean the parameters gathered outside the compiled graphs are all alive till this point? If so, it can increase the peak GPU memory usage, which can hurt training efficiency in some cases.

@tohtana

tohtana commented Jun 13, 2026

Copy link
Copy Markdown
Collaborator

Thanks @XAheli for digging into this! The root-cause analysis is correct, and keeping ZeROOrderedDict as an eager fallback guarded by torch.compiler.is_compiling() is the right direction. But I think we still need some rework.

One issue is the post-forward release loop just moves the failure into backward (this confirms the earlier P1 review comment). With DeepCompile actually enabled and the parameters not persistent (stage3_param_persistence_threshold: 0), training still crashes in engine.backward().

RuntimeError: The size of tensor a (0) must match the size of tensor b (384) at non-singleton dimension 1

Autograd saves leaf parameters by reference and reads param.data at backward time; ZeRO-3's free_param swaps .data to an empty tensor. In eager ZeRO-3 the pre-backward hooks swap the data back, but DeepCompile removed those hooks, so partitioning every AVAILABLE non-persistent param right after forward frees exactly the weights the skipped frame's eager autograd nodes still need.

Another issue is that this PR releases gathered parameters based on global ds_status state that other components also manage. For example, selective_gather keeps chosen parameters gathered via the C++-side registry without updating Python's ds_persist, so the sweep's not p.ds_persist guard cannot exclude them. This kind of interference is reproducible: with DeepCompile actually enabled, the regression test fails on this branch with KeyError: wait_allgather_ds_param__arg0_1_0 (master fails with the expected 'weight' must be 2-D), and the failure disappears once the sweep is replaced by tracked-set release.

Your tests didn't catch these issues because test_deepcompile_skipped_frame.py never activates DeepCompile (ds_config_z3.json doesn't set "compile": {"deepcompile": true}). Also, all its parameters are below the default persistence threshold (100000 elements). So the release loop excludes them even when DeepCompile is on. The test passes on master without your fix.

To make this concrete, I've opened a PR against your branch implementing the rework: XAheli#1. In summary it:

  1. records the params the fallback actually gathers and releases only that set after their eager backward use completes (the all-parameter post-forward sweep is removed);
  2. skips register_external_parameter in the fallback when DeepCompile is active (its consumers are removed with the module hooks, so the registration is dead state — eager-mode behavior is unchanged);
  3. moves the pre/post logic out of DeepSpeedEngine.forward into DeepCompile-owned preprocess/postprocess in deepspeed/compile/;
  4. makes the regression test real: DeepCompile enabled, persistence defeated, asserts the frame was actually skipped, and fails on master without the fix.

On @eternalNight's question: This still increases the peak memory, though the cost is now bounded to the fallback-gathered set. But I think it would avoid errors and keep correctness.

Can you take a look at the rework, adjust it as needed, and merge it into your branch if it looks reasonable? (if you'd prefer to address these issues in your own way, that works just as well) Either way, once the backward-safe release Thanks again for working on this!

@XAheli

XAheli commented Jun 13, 2026

Copy link
Copy Markdown
Author

@tohtana @eternalNight thanks a lot for the detailed review :) I'll take a deeper look and push the changes soon!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG][DeepCompile] DeepCompile fails on Qwen1.5-MoE-A2.7B-Chat with RuntimeError: 'weight' must be 2-D (LLaMA works in the same environment)

3 participants