Skip to content

Ep memory eff init#207

Draft
meichangsu1 wants to merge 3 commits into
modelscope:mainfrom
meichangsu1:ep_memory_eff_init
Draft

Ep memory eff init#207
meichangsu1 wants to merge 3 commits into
modelscope:mainfrom
meichangsu1:ep_memory_eff_init

Conversation

@meichangsu1
Copy link
Copy Markdown
Collaborator

@meichangsu1 meichangsu1 commented May 29, 2026

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

This PR optimizes Native FSDP memory_efficient_init weight loading for multi-node EP/FSDP jobs.

Previously, global rank 0 was responsible for distributing pretrained weights and EP expert shards to all ranks. When world_size is large, especially with ep_fsdp_size=1, this creates heavy communication pressure on global rank 0.

This change uses each node's local rank 0 as the source rank for pretrained weights. Local rank 0 loads/captures the full state on its node and distributes ordinary tensors, EP expert tensors, and EP LoRA adapter tensors only to ranks on the same node. EP shard selection still uses rank_to_ep_rank, so each target rank receives the correct EP slice before FSDP/DTensor placement is applied.

It also removes the Twinkle-side monkey patch for Accelerate FSDP2 state-dict loading. AccelerateStrategy continues to rely on native Accelerate behavior for memory_efficient_init / cpu_ram_efficient_loading.

Experiment results

qq_30035749 added 3 commits May 29, 2026 15:47
Load full pretrained weights on each node's local rank0 and distribute shards
only within the node, reducing global rank0 pressure for large EP/FSDP jobs.
@meichangsu1 meichangsu1 marked this pull request as draft May 29, 2026 10:45
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request removes the FSDP2 state-dict loading patch from the Accelerate strategy and implements node-local state-dict loading and scattering in the Native FSDP strategy using local rank topology. The review feedback suggests optimizing the node-local communication by creating a node-local process group once and using dist.broadcast instead of inefficient point-to-point dist.send and dist.recv calls for every parameter. Additionally, it is recommended to verify that distributed training is initialized before retrieving the local rank to prevent runtime errors in non-distributed environments.

Comment on lines +823 to +825
rank, _, local_source_rank, local_ranks = _get_local_rank_info()
is_rank0 = (rank == 0)
is_source_rank = rank == local_source_rank
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using point-to-point dist.send and dist.recv in a loop over all local ranks for every single parameter is highly inefficient and can cause significant performance bottlenecks during model initialization, especially for large models with many parameters and large local world sizes (e.g., 8 GPUs per node).

Instead, we can collectively create a node-local process group once and use dist.broadcast to leverage NCCL's highly optimized collective communication algorithms.

Suggested change
rank, _, local_source_rank, local_ranks = _get_local_rank_info()
is_rank0 = (rank == 0)
is_source_rank = rank == local_source_rank
rank, world_size, local_source_rank, local_ranks = _get_local_rank_info()
is_rank0 = (rank == 0)
is_source_rank = rank == local_source_rank
local_group = None
local_world_size = len(local_ranks)
if local_world_size > 1:
num_nodes = world_size // local_world_size
for node_idx in range(num_nodes):
node_ranks_i = list(range(node_idx * local_world_size, (node_idx + 1) * local_world_size))
group = dist.new_group(ranks=node_ranks_i)
if rank in node_ranks_i:
local_group = group

Comment on lines +851 to +861
def _broadcast_from_local_source(full_tensor):
if is_source_rank:
if full_tensor is None:
raise RuntimeError(f'Local source rank {local_source_rank} does not have full state_dict tensor.')
for target_rank in local_ranks:
if target_rank == rank:
continue
dist.send(full_tensor, dst=target_rank)
else:
dist.recv(full_tensor, src=local_source_rank)
return full_tensor
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Update _broadcast_from_local_source to use the collectively created local_group and dist.broadcast instead of sequential point-to-point dist.send and dist.recv calls.

    def _broadcast_from_local_source(full_tensor):
        if local_group is not None:
            if is_source_rank and full_tensor is None:
                raise RuntimeError(f'Local source rank {local_source_rank} does not have full state_dict tensor.')
            dist.broadcast(full_tensor, src=local_source_rank, group=local_group)
        return full_tensor

Comment on lines +57 to +60
local_rank = Platform.get_local_rank()
if local_rank < 0:
raise RuntimeError('Native FSDP node-local pre-EP state capture requires LOCAL_RANK.')
is_source_rank = dist.is_available() and dist.is_initialized() and local_rank == 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Check if distributed training is initialized before attempting to access Platform.get_local_rank(). This prevents unnecessary RuntimeError exceptions when running in non-distributed or single-process environments where LOCAL_RANK is not set.

Suggested change
local_rank = Platform.get_local_rank()
if local_rank < 0:
raise RuntimeError('Native FSDP node-local pre-EP state capture requires LOCAL_RANK.')
is_source_rank = dist.is_available() and dist.is_initialized() and local_rank == 0
if not (dist.is_available() and dist.is_initialized()):
return
local_rank = Platform.get_local_rank()
if local_rank < 0:
raise RuntimeError('Native FSDP node-local pre-EP state capture requires LOCAL_RANK.')
is_source_rank = local_rank == 0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant