-
Notifications
You must be signed in to change notification settings - Fork 32
Ep memory eff init #207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Ep memory eff init #207
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,4 +1,5 @@ | ||||||||||||||||||||||||||||||||
| # Copyright (c) ModelScope Contributors. All rights reserved. | ||||||||||||||||||||||||||||||||
| import os | ||||||||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||||||||
| import torch.distributed as dist | ||||||||||||||||||||||||||||||||
| from torch import nn | ||||||||||||||||||||||||||||||||
|
|
@@ -53,8 +54,11 @@ def capture_pre_ep_state_if_needed(self, model, *, enable_ep: bool) -> None: | |||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||
| if not (enable_ep and self.use_rank0_pretrained_broadcast()): | ||||||||||||||||||||||||||||||||
| return | ||||||||||||||||||||||||||||||||
| is_rank0 = dist.is_available() and dist.is_initialized() and dist.get_rank() == 0 | ||||||||||||||||||||||||||||||||
| self.set_rank0_pre_ep_full_state_dict(clone_state_dict_to_cpu(model.state_dict()) if is_rank0 else {}) | ||||||||||||||||||||||||||||||||
| local_rank = Platform.get_local_rank() | ||||||||||||||||||||||||||||||||
| if local_rank < 0: | ||||||||||||||||||||||||||||||||
| raise RuntimeError('Native FSDP node-local pre-EP state capture requires LOCAL_RANK.') | ||||||||||||||||||||||||||||||||
| is_source_rank = dist.is_available() and dist.is_initialized() and local_rank == 0 | ||||||||||||||||||||||||||||||||
| self.set_rank0_pre_ep_full_state_dict(clone_state_dict_to_cpu(model.state_dict()) if is_source_rank else {}) | ||||||||||||||||||||||||||||||||
| self._pre_ep_state_captured = True | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def prepare_adapter_config(self, config_or_dir, *, enable_ep: bool): | ||||||||||||||||||||||||||||||||
|
|
@@ -131,15 +135,19 @@ def wrap_model(self, model, optimizer=None): | |||||||||||||||||||||||||||||||
| adapter_source_sd = {} | ||||||||||||||||||||||||||||||||
| adapter_full_sd = {} | ||||||||||||||||||||||||||||||||
| if use_meta: | ||||||||||||||||||||||||||||||||
| is_rank0 = (dist.get_rank() == 0) | ||||||||||||||||||||||||||||||||
| local_rank = Platform.get_local_rank() | ||||||||||||||||||||||||||||||||
| if local_rank < 0: | ||||||||||||||||||||||||||||||||
| raise RuntimeError('Native FSDP node-local state loading requires LOCAL_RANK.') | ||||||||||||||||||||||||||||||||
| is_source_rank = local_rank == 0 | ||||||||||||||||||||||||||||||||
| if ep_enabled and self._rank0_pre_ep_full_state_dict is not None: | ||||||||||||||||||||||||||||||||
| original_sd = self._rank0_pre_ep_full_state_dict if is_rank0 else {} | ||||||||||||||||||||||||||||||||
| original_sd = self._rank0_pre_ep_full_state_dict if is_source_rank else {} | ||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||
| original_sd = model.state_dict() if is_rank0 else {} | ||||||||||||||||||||||||||||||||
| original_sd = model.state_dict() if is_source_rank else {} | ||||||||||||||||||||||||||||||||
| adapter_source_sd = _collect_adapter_source_state(model.state_dict()) | ||||||||||||||||||||||||||||||||
| adapter_full_sd = self._adapter_full_state_dict if is_rank0 and self._adapter_full_state_dict else {} | ||||||||||||||||||||||||||||||||
| saved_buffers = _get_non_persistent_buffers(model) if is_rank0 else {} | ||||||||||||||||||||||||||||||||
| if is_rank0: | ||||||||||||||||||||||||||||||||
| adapter_full_sd = ( | ||||||||||||||||||||||||||||||||
| self._adapter_full_state_dict if is_source_rank and self._adapter_full_state_dict else {}) | ||||||||||||||||||||||||||||||||
| saved_buffers = _get_non_persistent_buffers(model) if is_source_rank else {} | ||||||||||||||||||||||||||||||||
| if is_source_rank: | ||||||||||||||||||||||||||||||||
| model = model.to(torch.device('meta')) | ||||||||||||||||||||||||||||||||
| if hasattr(model, 'tie_weights'): | ||||||||||||||||||||||||||||||||
| model.tie_weights() | ||||||||||||||||||||||||||||||||
|
|
@@ -534,6 +542,25 @@ def _build_rank_to_ep_rank(ep_fsdp_device_mesh: Optional[TorchDeviceMesh]) -> Di | |||||||||||||||||||||||||||||||
| return rank_to_ep_rank | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _get_local_rank_info() -> tuple[int, int, int, List[int]]: | ||||||||||||||||||||||||||||||||
| """Return local-rank topology for node-local state-dict fanout.""" | ||||||||||||||||||||||||||||||||
| rank = dist.get_rank() | ||||||||||||||||||||||||||||||||
| world_size = dist.get_world_size() | ||||||||||||||||||||||||||||||||
| local_rank = Platform.get_local_rank() | ||||||||||||||||||||||||||||||||
| if 'LOCAL_WORLD_SIZE' not in os.environ and 'LOCAL_SIZE' not in os.environ: | ||||||||||||||||||||||||||||||||
| raise RuntimeError('Native FSDP node-local state loading requires LOCAL_WORLD_SIZE or LOCAL_SIZE.') | ||||||||||||||||||||||||||||||||
| local_world_size = Platform.get_local_world_size() | ||||||||||||||||||||||||||||||||
| if local_rank < 0 or local_world_size <= 0 or world_size % local_world_size != 0: | ||||||||||||||||||||||||||||||||
| raise RuntimeError(f'Invalid local rank topology: rank={rank}, world_size={world_size}, ' | ||||||||||||||||||||||||||||||||
| f'local_rank={local_rank}, local_world_size={local_world_size}.') | ||||||||||||||||||||||||||||||||
| node_start = rank - local_rank | ||||||||||||||||||||||||||||||||
| node_ranks = list(range(node_start, min(node_start + local_world_size, world_size))) | ||||||||||||||||||||||||||||||||
| if rank not in node_ranks or len(node_ranks) != local_world_size: | ||||||||||||||||||||||||||||||||
| raise RuntimeError(f'Invalid local rank group: rank={rank}, local_rank={local_rank}, ' | ||||||||||||||||||||||||||||||||
| f'local_world_size={local_world_size}, node_ranks={node_ranks}.') | ||||||||||||||||||||||||||||||||
| return rank, world_size, node_start, node_ranks | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _find_experts_in_layer(layer_mod: nn.Module, experts_map: Dict[str, nn.Module]) -> Optional[nn.Module]: | ||||||||||||||||||||||||||||||||
| """Find the experts module inside a decoder layer, if any.""" | ||||||||||||||||||||||||||||||||
| for module in layer_mod.modules(): | ||||||||||||||||||||||||||||||||
|
|
@@ -793,7 +820,9 @@ def _broadcast_sharded_state_dict( | |||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| meta_sharded_sd = model.state_dict() | ||||||||||||||||||||||||||||||||
| sharded_sd = {} | ||||||||||||||||||||||||||||||||
| is_rank0 = (dist.get_rank() == 0) | ||||||||||||||||||||||||||||||||
| rank, _, local_source_rank, local_ranks = _get_local_rank_info() | ||||||||||||||||||||||||||||||||
| is_rank0 = (rank == 0) | ||||||||||||||||||||||||||||||||
| is_source_rank = rank == local_source_rank | ||||||||||||||||||||||||||||||||
|
Comment on lines
+823
to
+825
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using point-to-point Instead, we can collectively create a node-local process group once and use
Suggested change
|
||||||||||||||||||||||||||||||||
| expert_shard_specs = expert_shard_specs or {} | ||||||||||||||||||||||||||||||||
| rank_to_ep_rank = rank_to_ep_rank or {} | ||||||||||||||||||||||||||||||||
| adapter_source_sd = adapter_source_sd or {} | ||||||||||||||||||||||||||||||||
|
|
@@ -819,6 +848,18 @@ def _broadcast_sharded_state_dict( | |||||||||||||||||||||||||||||||
| source_keys = metadata_holder[1] or {} | ||||||||||||||||||||||||||||||||
| adapter_metadata = metadata_holder[2] or {} | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _broadcast_from_local_source(full_tensor): | ||||||||||||||||||||||||||||||||
| if is_source_rank: | ||||||||||||||||||||||||||||||||
| if full_tensor is None: | ||||||||||||||||||||||||||||||||
| raise RuntimeError(f'Local source rank {local_source_rank} does not have full state_dict tensor.') | ||||||||||||||||||||||||||||||||
| for target_rank in local_ranks: | ||||||||||||||||||||||||||||||||
| if target_rank == rank: | ||||||||||||||||||||||||||||||||
| continue | ||||||||||||||||||||||||||||||||
| dist.send(full_tensor, dst=target_rank) | ||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||
| dist.recv(full_tensor, src=local_source_rank) | ||||||||||||||||||||||||||||||||
| return full_tensor | ||||||||||||||||||||||||||||||||
|
Comment on lines
+851
to
+861
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update def _broadcast_from_local_source(full_tensor):
if local_group is not None:
if is_source_rank and full_tensor is None:
raise RuntimeError(f'Local source rank {local_source_rank} does not have full state_dict tensor.')
dist.broadcast(full_tensor, src=local_source_rank, group=local_group)
return full_tensor |
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements): | ||||||||||||||||||||||||||||||||
| local_tensor = full_tensor | ||||||||||||||||||||||||||||||||
| for mesh_dim, placement in enumerate(placements): | ||||||||||||||||||||||||||||||||
|
|
@@ -849,36 +890,20 @@ def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements): | |||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _broadcast_adapter_source_tensor(full_tensor, sharded_param): | ||||||||||||||||||||||||||||||||
| if not isinstance(sharded_param, DTensor): | ||||||||||||||||||||||||||||||||
| dist.broadcast(full_tensor, src=0) | ||||||||||||||||||||||||||||||||
| return full_tensor | ||||||||||||||||||||||||||||||||
| mesh = sharded_param.device_mesh.mesh | ||||||||||||||||||||||||||||||||
| source_rank = int(mesh.flatten()[0].item()) | ||||||||||||||||||||||||||||||||
| dist.broadcast(full_tensor, src=source_rank, group=sharded_param.device_mesh.get_group()) | ||||||||||||||||||||||||||||||||
| return full_tensor | ||||||||||||||||||||||||||||||||
| return _broadcast_from_local_source(full_tensor) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _scatter_ep_adapter_tensor(param_name, full_tensor, sharded_param): | ||||||||||||||||||||||||||||||||
| local_shape = tuple(sharded_param.size()) | ||||||||||||||||||||||||||||||||
| _, source_dtype = adapter_metadata[param_name] | ||||||||||||||||||||||||||||||||
| local_tensor = torch.empty(local_shape, device=device_type, dtype=source_dtype) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if is_rank0: | ||||||||||||||||||||||||||||||||
| shard_dim = _ep_expert_state_dict_gather_dim(param_name) | ||||||||||||||||||||||||||||||||
| local_dim = local_shape[shard_dim] | ||||||||||||||||||||||||||||||||
| world_size = dist.get_world_size() | ||||||||||||||||||||||||||||||||
| for rank in range(world_size): | ||||||||||||||||||||||||||||||||
| if rank not in rank_to_ep_rank: | ||||||||||||||||||||||||||||||||
| raise RuntimeError(f'Missing EP rank mapping for global rank {rank}.') | ||||||||||||||||||||||||||||||||
| ep_rank = rank_to_ep_rank[rank] | ||||||||||||||||||||||||||||||||
| start = ep_rank * local_dim | ||||||||||||||||||||||||||||||||
| chunk = full_tensor.narrow(shard_dim, start, local_dim).contiguous().to(device_type) | ||||||||||||||||||||||||||||||||
| if rank == 0: | ||||||||||||||||||||||||||||||||
| local_tensor.copy_(chunk) | ||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||
| dist.send(chunk, dst=rank) | ||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||
| dist.recv(local_tensor, src=0) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| shard_dim = _ep_expert_state_dict_gather_dim(param_name) | ||||||||||||||||||||||||||||||||
| local_dim = local_shape[shard_dim] | ||||||||||||||||||||||||||||||||
| local_tensor = _scatter_ep_tensor_from_source( | ||||||||||||||||||||||||||||||||
| full_tensor, | ||||||||||||||||||||||||||||||||
| local_tensor, | ||||||||||||||||||||||||||||||||
| shard_dim=shard_dim, | ||||||||||||||||||||||||||||||||
| shard_size=local_dim, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| return local_tensor | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _get_adapter_source(param_name): | ||||||||||||||||||||||||||||||||
|
|
@@ -903,27 +928,35 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param): | |||||||||||||||||||||||||||||||
| _, source_dtype = source_metadata[param_name] | ||||||||||||||||||||||||||||||||
| local_tensor = torch.empty(local_shape, device=device_type, dtype=source_dtype) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if is_rank0: | ||||||||||||||||||||||||||||||||
| if is_source_rank: | ||||||||||||||||||||||||||||||||
| if full_tensor.size(0) != num_experts: | ||||||||||||||||||||||||||||||||
| raise RuntimeError(f"EP expert parameter '{param_name}' expects {num_experts} experts, " | ||||||||||||||||||||||||||||||||
| f'but source state has shape {tuple(full_tensor.shape)}. ' | ||||||||||||||||||||||||||||||||
| 'Rank0 must capture the full pre-EP state_dict before apply_expert_parallel().') | ||||||||||||||||||||||||||||||||
| world_size = dist.get_world_size() | ||||||||||||||||||||||||||||||||
| for rank in range(world_size): | ||||||||||||||||||||||||||||||||
| if rank not in rank_to_ep_rank: | ||||||||||||||||||||||||||||||||
| raise RuntimeError(f'Missing EP rank mapping for global rank {rank}.') | ||||||||||||||||||||||||||||||||
| ep_rank = rank_to_ep_rank[rank] | ||||||||||||||||||||||||||||||||
| start = ep_rank * experts_per_rank | ||||||||||||||||||||||||||||||||
| end = start + experts_per_rank | ||||||||||||||||||||||||||||||||
| chunk = full_tensor[start:end].contiguous() | ||||||||||||||||||||||||||||||||
| chunk_gpu = chunk.to(device_type) | ||||||||||||||||||||||||||||||||
| if rank == 0: | ||||||||||||||||||||||||||||||||
| local_tensor.copy_(chunk_gpu) | ||||||||||||||||||||||||||||||||
| local_tensor = _scatter_ep_tensor_from_source( | ||||||||||||||||||||||||||||||||
| full_tensor, | ||||||||||||||||||||||||||||||||
| local_tensor, | ||||||||||||||||||||||||||||||||
| shard_dim=0, | ||||||||||||||||||||||||||||||||
| shard_size=experts_per_rank, | ||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||
| return local_tensor | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def _scatter_ep_tensor_from_source(full_tensor, local_tensor, *, shard_dim: int, shard_size: int): | ||||||||||||||||||||||||||||||||
| if is_source_rank: | ||||||||||||||||||||||||||||||||
| if full_tensor is None: | ||||||||||||||||||||||||||||||||
| raise RuntimeError(f'Local source rank {local_source_rank} does not have full state_dict tensor.') | ||||||||||||||||||||||||||||||||
| for target_rank in local_ranks: | ||||||||||||||||||||||||||||||||
| if target_rank not in rank_to_ep_rank: | ||||||||||||||||||||||||||||||||
| raise RuntimeError(f'Missing EP rank mapping for global rank {target_rank}.') | ||||||||||||||||||||||||||||||||
| ep_rank = rank_to_ep_rank[target_rank] | ||||||||||||||||||||||||||||||||
| start = ep_rank * shard_size | ||||||||||||||||||||||||||||||||
| chunk = full_tensor.narrow(shard_dim, start, shard_size).contiguous().to(device_type) | ||||||||||||||||||||||||||||||||
| if target_rank == rank: | ||||||||||||||||||||||||||||||||
| local_tensor.copy_(chunk) | ||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||
| dist.send(chunk_gpu, dst=rank) | ||||||||||||||||||||||||||||||||
| dist.send(chunk, dst=target_rank) | ||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||
| dist.recv(local_tensor, src=0) | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| dist.recv(local_tensor, src=local_source_rank) | ||||||||||||||||||||||||||||||||
| return local_tensor | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| for param_name, sharded_param in meta_sharded_sd.items(): | ||||||||||||||||||||||||||||||||
|
|
@@ -950,7 +983,7 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param): | |||||||||||||||||||||||||||||||
| full_tensor = torch.empty(source_shape, device=device_type, dtype=source_dtype) | ||||||||||||||||||||||||||||||||
| if not is_ep_adapter_param: | ||||||||||||||||||||||||||||||||
| full_tensor = _broadcast_adapter_source_tensor(full_tensor, sharded_param) | ||||||||||||||||||||||||||||||||
| elif is_rank0: | ||||||||||||||||||||||||||||||||
| elif is_source_rank: | ||||||||||||||||||||||||||||||||
| source_key = source_keys[param_name] | ||||||||||||||||||||||||||||||||
| if source_key not in full_sd: | ||||||||||||||||||||||||||||||||
| raise KeyError( | ||||||||||||||||||||||||||||||||
|
|
@@ -979,7 +1012,7 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param): | |||||||||||||||||||||||||||||||
| raise RuntimeError(f"Parameter '{param_name}' shape mismatch before broadcast: " | ||||||||||||||||||||||||||||||||
| f'sharded logical shape={tuple(shape)}, source shape={source_shape}.') | ||||||||||||||||||||||||||||||||
| if not is_adapter_param: | ||||||||||||||||||||||||||||||||
| dist.broadcast(full_tensor, src=0) | ||||||||||||||||||||||||||||||||
| full_tensor = _broadcast_from_local_source(full_tensor) | ||||||||||||||||||||||||||||||||
| torch_util.synchronize() | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| if isinstance(sharded_param, DTensor): | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check if distributed training is initialized before attempting to access
Platform.get_local_rank(). This prevents unnecessaryRuntimeErrorexceptions when running in non-distributed or single-process environments whereLOCAL_RANKis not set.