Ep memory eff init#207
Conversation
Load full pretrained weights on each node's local rank0 and distribute shards only within the node, reducing global rank0 pressure for large EP/FSDP jobs.
There was a problem hiding this comment.
Code Review
This pull request removes the FSDP2 state-dict loading patch from the Accelerate strategy and implements node-local state-dict loading and scattering in the Native FSDP strategy using local rank topology. The review feedback suggests optimizing the node-local communication by creating a node-local process group once and using dist.broadcast instead of inefficient point-to-point dist.send and dist.recv calls for every parameter. Additionally, it is recommended to verify that distributed training is initialized before retrieving the local rank to prevent runtime errors in non-distributed environments.
| rank, _, local_source_rank, local_ranks = _get_local_rank_info() | ||
| is_rank0 = (rank == 0) | ||
| is_source_rank = rank == local_source_rank |
There was a problem hiding this comment.
Using point-to-point dist.send and dist.recv in a loop over all local ranks for every single parameter is highly inefficient and can cause significant performance bottlenecks during model initialization, especially for large models with many parameters and large local world sizes (e.g., 8 GPUs per node).
Instead, we can collectively create a node-local process group once and use dist.broadcast to leverage NCCL's highly optimized collective communication algorithms.
| rank, _, local_source_rank, local_ranks = _get_local_rank_info() | |
| is_rank0 = (rank == 0) | |
| is_source_rank = rank == local_source_rank | |
| rank, world_size, local_source_rank, local_ranks = _get_local_rank_info() | |
| is_rank0 = (rank == 0) | |
| is_source_rank = rank == local_source_rank | |
| local_group = None | |
| local_world_size = len(local_ranks) | |
| if local_world_size > 1: | |
| num_nodes = world_size // local_world_size | |
| for node_idx in range(num_nodes): | |
| node_ranks_i = list(range(node_idx * local_world_size, (node_idx + 1) * local_world_size)) | |
| group = dist.new_group(ranks=node_ranks_i) | |
| if rank in node_ranks_i: | |
| local_group = group |
| def _broadcast_from_local_source(full_tensor): | ||
| if is_source_rank: | ||
| if full_tensor is None: | ||
| raise RuntimeError(f'Local source rank {local_source_rank} does not have full state_dict tensor.') | ||
| for target_rank in local_ranks: | ||
| if target_rank == rank: | ||
| continue | ||
| dist.send(full_tensor, dst=target_rank) | ||
| else: | ||
| dist.recv(full_tensor, src=local_source_rank) | ||
| return full_tensor |
There was a problem hiding this comment.
Update _broadcast_from_local_source to use the collectively created local_group and dist.broadcast instead of sequential point-to-point dist.send and dist.recv calls.
def _broadcast_from_local_source(full_tensor):
if local_group is not None:
if is_source_rank and full_tensor is None:
raise RuntimeError(f'Local source rank {local_source_rank} does not have full state_dict tensor.')
dist.broadcast(full_tensor, src=local_source_rank, group=local_group)
return full_tensor| local_rank = Platform.get_local_rank() | ||
| if local_rank < 0: | ||
| raise RuntimeError('Native FSDP node-local pre-EP state capture requires LOCAL_RANK.') | ||
| is_source_rank = dist.is_available() and dist.is_initialized() and local_rank == 0 |
There was a problem hiding this comment.
Check if distributed training is initialized before attempting to access Platform.get_local_rank(). This prevents unnecessary RuntimeError exceptions when running in non-distributed or single-process environments where LOCAL_RANK is not set.
| local_rank = Platform.get_local_rank() | |
| if local_rank < 0: | |
| raise RuntimeError('Native FSDP node-local pre-EP state capture requires LOCAL_RANK.') | |
| is_source_rank = dist.is_available() and dist.is_initialized() and local_rank == 0 | |
| if not (dist.is_available() and dist.is_initialized()): | |
| return | |
| local_rank = Platform.get_local_rank() | |
| if local_rank < 0: | |
| raise RuntimeError('Native FSDP node-local pre-EP state capture requires LOCAL_RANK.') | |
| is_source_rank = local_rank == 0 |
PR type
PR information
This PR optimizes Native FSDP
memory_efficient_initweight loading for multi-node EP/FSDP jobs.Previously, global rank 0 was responsible for distributing pretrained weights and EP expert shards to all ranks. When
world_sizeis large, especially withep_fsdp_size=1, this creates heavy communication pressure on global rank 0.This change uses each node's local rank 0 as the source rank for pretrained weights. Local rank 0 loads/captures the full state on its node and distributes ordinary tensors, EP expert tensors, and EP LoRA adapter tensors only to ranks on the same node. EP shard selection still uses
rank_to_ep_rank, so each target rank receives the correct EP slice before FSDP/DTensor placement is applied.It also removes the Twinkle-side monkey patch for Accelerate FSDP2 state-dict loading. AccelerateStrategy continues to rely on native Accelerate behavior for
memory_efficient_init/cpu_ram_efficient_loading.Experiment results