Add AutoEP + AutoTP parallel folding#8064
Conversation
Allow tensor parallelism (AutoTP) for the dense/attention path to coexist with expert parallelism (AutoEP) for routed experts on the same rank set, without requiring EP to be a subset of DP. - Treat dense and MoE as independent partitionings: dense view tp*dp, expert view ep*etp*edp, with dp/edp derived so tp*dp == ep*etp*edp == stage_size. expert_tensor_parallel_size is reserved (must currently be 1). - Express folding via the existing tensor_parallel/expert_parallel config sections, with divisibility, TP/sequence-parallel exclusivity, and preset_model consistency validation. - Add the route-full / partition-dispatch MoE path and AutoTP skipping of AutoEP subtrees; derive folded process groups via the generalized expert/data-parallel group creation. - Reduce TP-replicated router/gate gradients mode-aware (sum when tokens are partitioned, average when replicated); record per-parameter-family ZeRO checkpoint metadata and handle folded ZeRO-1/2 optimizer state. - Add folding unit tests (config, groups, dispatch, runtime, gradient parity, checkpoint), including multi-rank GPU-gated cases. Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 278c919489
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| chunks = torch.split(grad_output, ctx.counts, dim=0) | ||
| grad_padded = grad_output.new_zeros((ctx.max_rows, *grad_output.shape[1:])) | ||
| if local_count: | ||
| grad_padded[:local_count].copy_(chunks[ctx.group_rank]) | ||
| return grad_padded[:local_count].contiguous(), None, None, None |
There was a problem hiding this comment.
Sum gathered-row gradients across TP lanes
When folded MoE output is consumed differently on each TP lane (for example by a row-parallel/lm-head layer that slices the hidden dimension), every gathered row participates in the loss on every lane. This backward path only returns chunks[ctx.group_rank] from the local rank's grad_output, so contributions from peer lanes to this rank's local expert outputs and routing weights are dropped; the padded local gradient needs to be accumulated across ctx.group before returning.
Useful? React with 👍 / 👎.
| grad_reduc = self.get_gradient_for_reduction(param) | ||
| self._maybe_reduce_autoep_folding_tp_gradient(param, grad_reduc) |
There was a problem hiding this comment.
Honor ds_grad_is_ready before TP reduction
In ZeRO-2 folded runs, parameters with ds_grad_is_ready=False are intentionally skipped until their transient/tiled gradient is complete, as the guard immediately below documents. Calling the new TP reduction before that guard mutates and all-reduces incomplete gradients for those parameters, which can corrupt the final accumulated gradient once the ready shard is eventually reduced.
Useful? React with 👍 / 👎.
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
| "is planned as follow-up work.") | ||
|
|
||
| expert_width = spec.ep_size * spec.etp_size | ||
| if spec.tp_size > 1 and expert_width > spec.dp_size: |
There was a problem hiding this comment.
What will happen if expert_width % spec.dp_size != 0?
| if not param.requires_grad or param.grad is None: | ||
| continue | ||
| if is_moe_param(param) or is_model_parallel_parameter(param): | ||
| continue |
There was a problem hiding this comment.
This filter cannot distingush router grads from laynorm grads. The router grads needs SUM because of dispatch of token to experts, but laynorm does not need SUM, so they need different reduce strategy.
There was a problem hiding this comment.
Here is the test, I verified on a CPU system that this test case cannot pass.
# Copyright (c) DeepSpeed Team.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
"""Engine-path (zero_stage=0) parity: does folded layernorm/gate match the non-folded
DP baseline in the FULL flow? (The author only tested ZeRO-2 parity; the engine path
that runs at zero_stage=0/1 is untested.) CPU/Gloo, world=8.
"""
import deepspeed
from unit.v1.moe.autoep_test_utils import make_autoep_config, run_cpu_gloo_test, seed_everything
from unit.v1.moe.test_autoep_autotp_grad_parity import (
_router_grad_model,
_run_router_grad_boundary,
_full_grad_by_suffix,
)
GATE_BASELINE = "model.layers.0.mlp.gate.weight"
GATE_FOLDED = "model.layers.0.mlp.router.gate.weight"
LN = "model.layers.0.input_layernorm.weight"
def _baseline_cfg():
c = {k: v for k, v in make_autoep_config(zero_stage=0, ep_size=1, mixed_precision=False).items()
if k != "expert_parallel"}
c["gradient_accumulation_steps"] = 2
c["gradient_clipping"] = 0.0
c["communication_data_type"] = "fp32"
c["optimizer"]["params"]["torch_adam"] = True
return c
def _folded_cfg():
c = make_autoep_config(zero_stage=0, ep_size=4, mixed_precision=False)
c["gradient_accumulation_steps"] = 2
c["gradient_clipping"] = 0.0
c["communication_data_type"] = "fp32"
c["optimizer"]["params"]["torch_adam"] = True
c["expert_parallel"]["autoep_size"] = 4
c["tensor_parallel"] = {"autotp_size": 2, "partition_config": {
"use_default_specs": False, "layer_specs": [{"patterns": [r".*\.weight$"], "partition_type": "skip"}]}}
return c
def _worker(rank, world_size, tmpdir):
seed = 1234
tp_size = 2
logical_dp_world_size = world_size // tp_size
logical_dp_rank = rank // tp_size
seed_everything(seed)
reference_state = _router_grad_model().state_dict()
baseline_model = _router_grad_model()
baseline_model.load_state_dict(reference_state)
baseline_engine, *_ = deepspeed.initialize(model=baseline_model, config=_baseline_cfg())
_run_router_grad_boundary(baseline_engine,
logical_dp_world_size=logical_dp_world_size,
logical_dp_rank=logical_dp_rank,
seed=seed)
base_gate = _full_grad_by_suffix(baseline_engine, GATE_BASELINE)
base_ln = _full_grad_by_suffix(baseline_engine, LN)
folded_model = _router_grad_model()
folded_model.load_state_dict(reference_state)
folded_engine, *_ = deepspeed.initialize(model=folded_model, config=_folded_cfg())
_run_router_grad_boundary(folded_engine,
logical_dp_world_size=logical_dp_world_size,
logical_dp_rank=logical_dp_rank,
seed=seed)
folded_gate = _full_grad_by_suffix(folded_engine, GATE_FOLDED)
folded_ln = _full_grad_by_suffix(folded_engine, LN)
gate_ratio = (folded_gate.norm() / base_gate.norm()).item()
ln_ratio = (folded_ln.norm() / base_ln.norm()).item()
print(f"[rank {rank}] ENGINE(zero0) gate_ratio={gate_ratio:.4f} ln_ratio={ln_ratio:.4f}")
if rank == 0:
assert abs(gate_ratio - 1.0) <= 5e-3, f"gate parity: {gate_ratio}"
assert abs(ln_ratio - 1.0) <= 5e-3, f"ln parity: {ln_ratio}"
def test_b1_engine_path_parity(tmpdir):
run_cpu_gloo_test(_worker, tmpdir, world_size=8)
This PR adds parallel folding for AutoEP: tensor parallelism (AutoTP) for the dense/attention path can now coexist with expert parallelism (AutoEP) for the routed-expert path on the same set of ranks, without forcing EP to be a subset of DP.
(This PR should be adjusted for ZeRO3 support after #8060 is merged)
Design
Attention/dense and MoE are treated as two independent partitionings of the same rank set, parameterized per parameter family:
stage_size = tp * dpstage_size = ep * etp * edpdpandedpare always derived, never user-configured, so the invarianttp * dp == ep * etp * edp == stage_sizecannot be broken from config.Configuration
No new config section. Folding is expressed by the coexistence of the existing
tensor_parallelandexpert_parallelsections:{ "tensor_parallel": { "autotp_size": 2 }, "expert_parallel": { "enabled": true, "autoep_size": 4, "expert_tensor_parallel_size": 1 } }expert_tensor_parallel_sizeis carried as a config field but currently must be1(expert-internal TP is reserved as follow-up and rejected fail-fast). Validation enforces divisibility, TP/sequence-parallel exclusivity, andpreset_modelconsistency between the two sections.What's included
mp_modeTP-strided vs SP-consecutive ordering).deepspeed/moe/ep_tp_dispatch.py), with AutoTP skipping AutoEP subtrees.Correctness & validation
aws-torch-latest-full) on H100 GPUs.Scope / follow-ups
expert_tensor_parallel_size > 1) is reserved for a follow-up.