diff --git a/src/twinkle/model/transformers/strategy/accelerate.py b/src/twinkle/model/transformers/strategy/accelerate.py index 7b3d414b..3ffc429f 100644 --- a/src/twinkle/model/transformers/strategy/accelerate.py +++ b/src/twinkle/model/transformers/strategy/accelerate.py @@ -7,150 +7,6 @@ from .load_context import fsdp_pretrained_load_context -def _patch_accelerate_fsdp2_load_full_state_dict(): - """Allow Accelerate FSDP2 state-dict loading to handle unsharded buffers. - - Some Transformers models keep persistent buffers in `state_dict`. FSDP2 - shards parameters as DTensors, but those buffers can remain ordinary - tensors; older Accelerate versions assume every state-dict entry has - `device_mesh` and fail on such buffers. - """ - import accelerate.utils.fsdp_utils as fsdp_utils - import torch - import torch.distributed as dist - from torch.distributed.tensor import DTensor, Partial, Replicate, Shard, distribute_tensor - - if getattr(fsdp_utils.fsdp2_load_full_state_dict, '_twinkle_patched', False): - return - - original = fsdp_utils.fsdp2_load_full_state_dict - - def patched_fsdp2_load_full_state_dict(accelerator, model, full_sd, cpu_offload=False): - meta_sharded_sd = model.state_dict() - sharded_sd = {} - - def _infer_parameter_dtype(model, param_name, empty_param): - old_param = _get_state_dict_param_for_dtype_inference(model, param_name) - is_torch_e4m3fn_available = hasattr(torch, 'float8_e4m3fn') - is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn - casting_dtype = None - if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn: - casting_dtype = old_param.dtype - return old_param is not None and old_param.is_contiguous(), casting_dtype - - def _cast_and_contiguous(tensor, to_contiguous, dtype): - if dtype is not None: - tensor = tensor.to(dtype=dtype) - if to_contiguous: - tensor = tensor.contiguous() - return tensor - - def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements): - if device_mesh.device_type == 'cuda': - return distribute_tensor(full_tensor, device_mesh, placements) - - local_tensor = full_tensor - for mesh_dim, placement in enumerate(placements): - if isinstance(placement, Shard): - # All ranks already received the full tensor via broadcast. - # Split locally to avoid distribute_tensor's scatter path, - # which is fragile on some torch_npu/HCCL versions. - local_tensor = placement._shard_tensor( - local_tensor, - device_mesh, - mesh_dim, - src_data_rank=None, - ) - elif isinstance(placement, Replicate): - continue - elif isinstance(placement, Partial): - raise NotImplementedError('FSDP2 full-state loading does not support Partial placements.') - else: - raise NotImplementedError(f'Unsupported DTensor placement: {placement}') - return DTensor.from_local( - local_tensor, - device_mesh=device_mesh, - placements=placements, - run_check=False, - shape=full_tensor.shape, - stride=full_tensor.stride(), - ) - - def _load_full_value(param_name, sharded_param): - if param_name not in full_sd: - raise KeyError( - f"Parameter '{param_name}' found in sharded model state dict but missing from full state dict. " - f'Full state dict has {len(full_sd)} keys, sharded has {len(meta_sharded_sd)} keys.') - full_value = full_sd[param_name].detach() - if isinstance(full_value, DTensor): - full_value = full_value.to_local() - device = sharded_param.device_mesh.device_type if isinstance(sharded_param, DTensor) else accelerator.device - return full_value.to(device).contiguous() - - def _tensor_debug(tensor): - if isinstance(tensor, DTensor): - return (f'type=DTensor shape={tuple(tensor.size())} dtype={tensor.dtype} ' - f'placements={tensor.placements} mesh={tensor.device_mesh}') - if hasattr(tensor, 'size') and hasattr(tensor, 'dtype'): - return f'type={type(tensor).__name__} shape={tuple(tensor.size())} dtype={tensor.dtype}' - return f'type={type(tensor).__name__}' - - for param_name, sharded_param in meta_sharded_sd.items(): - if isinstance(sharded_param, DTensor): - device_mesh = sharded_param.device_mesh - placements = sharded_param.placements - if accelerator.is_main_process: - full_param = _load_full_value(param_name, sharded_param) - else: - full_param = torch.empty( - sharded_param.size(), - device=device_mesh.device_type, - dtype=sharded_param.dtype, - ) - - dist.broadcast(full_param, src=0, group=dist.group.WORLD) - sharded_tensor = _dtensor_from_replicated_full_tensor(full_param, device_mesh, placements) - to_contiguous, casting_dtype = _infer_parameter_dtype(model, param_name, full_param) - sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype) - if cpu_offload: - sharded_tensor = sharded_tensor.to('cpu') - sharded_sd[param_name] = sharded_tensor - continue - - if accelerator.is_main_process: - full_value = _load_full_value(param_name, sharded_param) - else: - full_value = torch.empty( - sharded_param.size(), - device=accelerator.device, - dtype=sharded_param.dtype, - ) - - dist.broadcast(full_value, src=0, group=dist.group.WORLD) - to_contiguous, casting_dtype = _infer_parameter_dtype(model, param_name, full_value) - full_value = _cast_and_contiguous(full_value, to_contiguous, casting_dtype) - if cpu_offload: - full_value = full_value.to('cpu') - sharded_sd[param_name] = full_value - - model.load_state_dict(sharded_sd, assign=True) - return model - - patched_fsdp2_load_full_state_dict._twinkle_patched = True - patched_fsdp2_load_full_state_dict._twinkle_original = original - fsdp_utils.fsdp2_load_full_state_dict = patched_fsdp2_load_full_state_dict - - -def _get_state_dict_param_for_dtype_inference(model, param_name: str): - try: - return model.get_parameter_or_buffer(param_name) - except AttributeError: - if '.' in param_name: - base_param_name, param_name = param_name.rsplit('.', 1) - model = model.get_submodule(base_param_name) - return getattr(model, param_name) - - class AccelerateStrategy: """A training strategy that uses `accelerate` to wrap models. @@ -172,8 +28,6 @@ def __init__( from accelerate import Accelerator from accelerate.utils import InitProcessGroupKwargs - _patch_accelerate_fsdp2_load_full_state_dict() - self.device_mesh = device_mesh self.mixed_precision = mixed_precision self._memory_efficient_init = memory_efficient_init diff --git a/src/twinkle/model/transformers/strategy/native_fsdp.py b/src/twinkle/model/transformers/strategy/native_fsdp.py index b695f50b..ba47e96d 100644 --- a/src/twinkle/model/transformers/strategy/native_fsdp.py +++ b/src/twinkle/model/transformers/strategy/native_fsdp.py @@ -1,4 +1,5 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import os import torch import torch.distributed as dist from torch import nn @@ -53,8 +54,11 @@ def capture_pre_ep_state_if_needed(self, model, *, enable_ep: bool) -> None: return if not (enable_ep and self.use_rank0_pretrained_broadcast()): return - is_rank0 = dist.is_available() and dist.is_initialized() and dist.get_rank() == 0 - self.set_rank0_pre_ep_full_state_dict(clone_state_dict_to_cpu(model.state_dict()) if is_rank0 else {}) + local_rank = Platform.get_local_rank() + if local_rank < 0: + raise RuntimeError('Native FSDP node-local pre-EP state capture requires LOCAL_RANK.') + is_source_rank = dist.is_available() and dist.is_initialized() and local_rank == 0 + self.set_rank0_pre_ep_full_state_dict(clone_state_dict_to_cpu(model.state_dict()) if is_source_rank else {}) self._pre_ep_state_captured = True def prepare_adapter_config(self, config_or_dir, *, enable_ep: bool): @@ -131,15 +135,19 @@ def wrap_model(self, model, optimizer=None): adapter_source_sd = {} adapter_full_sd = {} if use_meta: - is_rank0 = (dist.get_rank() == 0) + local_rank = Platform.get_local_rank() + if local_rank < 0: + raise RuntimeError('Native FSDP node-local state loading requires LOCAL_RANK.') + is_source_rank = local_rank == 0 if ep_enabled and self._rank0_pre_ep_full_state_dict is not None: - original_sd = self._rank0_pre_ep_full_state_dict if is_rank0 else {} + original_sd = self._rank0_pre_ep_full_state_dict if is_source_rank else {} else: - original_sd = model.state_dict() if is_rank0 else {} + original_sd = model.state_dict() if is_source_rank else {} adapter_source_sd = _collect_adapter_source_state(model.state_dict()) - adapter_full_sd = self._adapter_full_state_dict if is_rank0 and self._adapter_full_state_dict else {} - saved_buffers = _get_non_persistent_buffers(model) if is_rank0 else {} - if is_rank0: + adapter_full_sd = ( + self._adapter_full_state_dict if is_source_rank and self._adapter_full_state_dict else {}) + saved_buffers = _get_non_persistent_buffers(model) if is_source_rank else {} + if is_source_rank: model = model.to(torch.device('meta')) if hasattr(model, 'tie_weights'): model.tie_weights() @@ -534,6 +542,25 @@ def _build_rank_to_ep_rank(ep_fsdp_device_mesh: Optional[TorchDeviceMesh]) -> Di return rank_to_ep_rank +def _get_local_rank_info() -> tuple[int, int, int, List[int]]: + """Return local-rank topology for node-local state-dict fanout.""" + rank = dist.get_rank() + world_size = dist.get_world_size() + local_rank = Platform.get_local_rank() + if 'LOCAL_WORLD_SIZE' not in os.environ and 'LOCAL_SIZE' not in os.environ: + raise RuntimeError('Native FSDP node-local state loading requires LOCAL_WORLD_SIZE or LOCAL_SIZE.') + local_world_size = Platform.get_local_world_size() + if local_rank < 0 or local_world_size <= 0 or world_size % local_world_size != 0: + raise RuntimeError(f'Invalid local rank topology: rank={rank}, world_size={world_size}, ' + f'local_rank={local_rank}, local_world_size={local_world_size}.') + node_start = rank - local_rank + node_ranks = list(range(node_start, min(node_start + local_world_size, world_size))) + if rank not in node_ranks or len(node_ranks) != local_world_size: + raise RuntimeError(f'Invalid local rank group: rank={rank}, local_rank={local_rank}, ' + f'local_world_size={local_world_size}, node_ranks={node_ranks}.') + return rank, world_size, node_start, node_ranks + + def _find_experts_in_layer(layer_mod: nn.Module, experts_map: Dict[str, nn.Module]) -> Optional[nn.Module]: """Find the experts module inside a decoder layer, if any.""" for module in layer_mod.modules(): @@ -793,7 +820,9 @@ def _broadcast_sharded_state_dict( meta_sharded_sd = model.state_dict() sharded_sd = {} - is_rank0 = (dist.get_rank() == 0) + rank, _, local_source_rank, local_ranks = _get_local_rank_info() + is_rank0 = (rank == 0) + is_source_rank = rank == local_source_rank expert_shard_specs = expert_shard_specs or {} rank_to_ep_rank = rank_to_ep_rank or {} adapter_source_sd = adapter_source_sd or {} @@ -819,6 +848,18 @@ def _broadcast_sharded_state_dict( source_keys = metadata_holder[1] or {} adapter_metadata = metadata_holder[2] or {} + def _broadcast_from_local_source(full_tensor): + if is_source_rank: + if full_tensor is None: + raise RuntimeError(f'Local source rank {local_source_rank} does not have full state_dict tensor.') + for target_rank in local_ranks: + if target_rank == rank: + continue + dist.send(full_tensor, dst=target_rank) + else: + dist.recv(full_tensor, src=local_source_rank) + return full_tensor + def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements): local_tensor = full_tensor for mesh_dim, placement in enumerate(placements): @@ -849,36 +890,20 @@ def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements): ) def _broadcast_adapter_source_tensor(full_tensor, sharded_param): - if not isinstance(sharded_param, DTensor): - dist.broadcast(full_tensor, src=0) - return full_tensor - mesh = sharded_param.device_mesh.mesh - source_rank = int(mesh.flatten()[0].item()) - dist.broadcast(full_tensor, src=source_rank, group=sharded_param.device_mesh.get_group()) - return full_tensor + return _broadcast_from_local_source(full_tensor) def _scatter_ep_adapter_tensor(param_name, full_tensor, sharded_param): local_shape = tuple(sharded_param.size()) _, source_dtype = adapter_metadata[param_name] local_tensor = torch.empty(local_shape, device=device_type, dtype=source_dtype) - - if is_rank0: - shard_dim = _ep_expert_state_dict_gather_dim(param_name) - local_dim = local_shape[shard_dim] - world_size = dist.get_world_size() - for rank in range(world_size): - if rank not in rank_to_ep_rank: - raise RuntimeError(f'Missing EP rank mapping for global rank {rank}.') - ep_rank = rank_to_ep_rank[rank] - start = ep_rank * local_dim - chunk = full_tensor.narrow(shard_dim, start, local_dim).contiguous().to(device_type) - if rank == 0: - local_tensor.copy_(chunk) - else: - dist.send(chunk, dst=rank) - else: - dist.recv(local_tensor, src=0) - + shard_dim = _ep_expert_state_dict_gather_dim(param_name) + local_dim = local_shape[shard_dim] + local_tensor = _scatter_ep_tensor_from_source( + full_tensor, + local_tensor, + shard_dim=shard_dim, + shard_size=local_dim, + ) return local_tensor def _get_adapter_source(param_name): @@ -903,27 +928,35 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param): _, source_dtype = source_metadata[param_name] local_tensor = torch.empty(local_shape, device=device_type, dtype=source_dtype) - if is_rank0: + if is_source_rank: if full_tensor.size(0) != num_experts: raise RuntimeError(f"EP expert parameter '{param_name}' expects {num_experts} experts, " f'but source state has shape {tuple(full_tensor.shape)}. ' 'Rank0 must capture the full pre-EP state_dict before apply_expert_parallel().') - world_size = dist.get_world_size() - for rank in range(world_size): - if rank not in rank_to_ep_rank: - raise RuntimeError(f'Missing EP rank mapping for global rank {rank}.') - ep_rank = rank_to_ep_rank[rank] - start = ep_rank * experts_per_rank - end = start + experts_per_rank - chunk = full_tensor[start:end].contiguous() - chunk_gpu = chunk.to(device_type) - if rank == 0: - local_tensor.copy_(chunk_gpu) + local_tensor = _scatter_ep_tensor_from_source( + full_tensor, + local_tensor, + shard_dim=0, + shard_size=experts_per_rank, + ) + return local_tensor + + def _scatter_ep_tensor_from_source(full_tensor, local_tensor, *, shard_dim: int, shard_size: int): + if is_source_rank: + if full_tensor is None: + raise RuntimeError(f'Local source rank {local_source_rank} does not have full state_dict tensor.') + for target_rank in local_ranks: + if target_rank not in rank_to_ep_rank: + raise RuntimeError(f'Missing EP rank mapping for global rank {target_rank}.') + ep_rank = rank_to_ep_rank[target_rank] + start = ep_rank * shard_size + chunk = full_tensor.narrow(shard_dim, start, shard_size).contiguous().to(device_type) + if target_rank == rank: + local_tensor.copy_(chunk) else: - dist.send(chunk_gpu, dst=rank) + dist.send(chunk, dst=target_rank) else: - dist.recv(local_tensor, src=0) - + dist.recv(local_tensor, src=local_source_rank) return local_tensor for param_name, sharded_param in meta_sharded_sd.items(): @@ -950,7 +983,7 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param): full_tensor = torch.empty(source_shape, device=device_type, dtype=source_dtype) if not is_ep_adapter_param: full_tensor = _broadcast_adapter_source_tensor(full_tensor, sharded_param) - elif is_rank0: + elif is_source_rank: source_key = source_keys[param_name] if source_key not in full_sd: raise KeyError( @@ -979,7 +1012,7 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param): raise RuntimeError(f"Parameter '{param_name}' shape mismatch before broadcast: " f'sharded logical shape={tuple(shape)}, source shape={source_shape}.') if not is_adapter_param: - dist.broadcast(full_tensor, src=0) + full_tensor = _broadcast_from_local_source(full_tensor) torch_util.synchronize() if isinstance(sharded_param, DTensor): diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index a9a80b8f..05291df8 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -209,7 +209,12 @@ def __init__( def _should_init_empty_pretrained_model_on_this_rank(self) -> bool: use_rank0_broadcast = getattr(self.strategy, 'use_rank0_pretrained_broadcast', lambda: False) - return bool(use_rank0_broadcast() and dist.is_available() and dist.is_initialized() and dist.get_rank() != 0) + if not (use_rank0_broadcast() and dist.is_available() and dist.is_initialized()): + return False + local_rank = Platform.get_local_rank() + if local_rank < 0: + raise RuntimeError('Native FSDP memory_efficient_init requires LOCAL_RANK.') + return local_rank != 0 def _init_empty_model_from_config(self, model_cls, **kwargs): from accelerate import init_empty_weights