Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 0 additions & 146 deletions src/twinkle/model/transformers/strategy/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,150 +7,6 @@
from .load_context import fsdp_pretrained_load_context


def _patch_accelerate_fsdp2_load_full_state_dict():
"""Allow Accelerate FSDP2 state-dict loading to handle unsharded buffers.

Some Transformers models keep persistent buffers in `state_dict`. FSDP2
shards parameters as DTensors, but those buffers can remain ordinary
tensors; older Accelerate versions assume every state-dict entry has
`device_mesh` and fail on such buffers.
"""
import accelerate.utils.fsdp_utils as fsdp_utils
import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor, Partial, Replicate, Shard, distribute_tensor

if getattr(fsdp_utils.fsdp2_load_full_state_dict, '_twinkle_patched', False):
return

original = fsdp_utils.fsdp2_load_full_state_dict

def patched_fsdp2_load_full_state_dict(accelerator, model, full_sd, cpu_offload=False):
meta_sharded_sd = model.state_dict()
sharded_sd = {}

def _infer_parameter_dtype(model, param_name, empty_param):
old_param = _get_state_dict_param_for_dtype_inference(model, param_name)
is_torch_e4m3fn_available = hasattr(torch, 'float8_e4m3fn')
is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
casting_dtype = None
if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
casting_dtype = old_param.dtype
return old_param is not None and old_param.is_contiguous(), casting_dtype

def _cast_and_contiguous(tensor, to_contiguous, dtype):
if dtype is not None:
tensor = tensor.to(dtype=dtype)
if to_contiguous:
tensor = tensor.contiguous()
return tensor

def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements):
if device_mesh.device_type == 'cuda':
return distribute_tensor(full_tensor, device_mesh, placements)

local_tensor = full_tensor
for mesh_dim, placement in enumerate(placements):
if isinstance(placement, Shard):
# All ranks already received the full tensor via broadcast.
# Split locally to avoid distribute_tensor's scatter path,
# which is fragile on some torch_npu/HCCL versions.
local_tensor = placement._shard_tensor(
local_tensor,
device_mesh,
mesh_dim,
src_data_rank=None,
)
elif isinstance(placement, Replicate):
continue
elif isinstance(placement, Partial):
raise NotImplementedError('FSDP2 full-state loading does not support Partial placements.')
else:
raise NotImplementedError(f'Unsupported DTensor placement: {placement}')
return DTensor.from_local(
local_tensor,
device_mesh=device_mesh,
placements=placements,
run_check=False,
shape=full_tensor.shape,
stride=full_tensor.stride(),
)

def _load_full_value(param_name, sharded_param):
if param_name not in full_sd:
raise KeyError(
f"Parameter '{param_name}' found in sharded model state dict but missing from full state dict. "
f'Full state dict has {len(full_sd)} keys, sharded has {len(meta_sharded_sd)} keys.')
full_value = full_sd[param_name].detach()
if isinstance(full_value, DTensor):
full_value = full_value.to_local()
device = sharded_param.device_mesh.device_type if isinstance(sharded_param, DTensor) else accelerator.device
return full_value.to(device).contiguous()

def _tensor_debug(tensor):
if isinstance(tensor, DTensor):
return (f'type=DTensor shape={tuple(tensor.size())} dtype={tensor.dtype} '
f'placements={tensor.placements} mesh={tensor.device_mesh}')
if hasattr(tensor, 'size') and hasattr(tensor, 'dtype'):
return f'type={type(tensor).__name__} shape={tuple(tensor.size())} dtype={tensor.dtype}'
return f'type={type(tensor).__name__}'

for param_name, sharded_param in meta_sharded_sd.items():
if isinstance(sharded_param, DTensor):
device_mesh = sharded_param.device_mesh
placements = sharded_param.placements
if accelerator.is_main_process:
full_param = _load_full_value(param_name, sharded_param)
else:
full_param = torch.empty(
sharded_param.size(),
device=device_mesh.device_type,
dtype=sharded_param.dtype,
)

dist.broadcast(full_param, src=0, group=dist.group.WORLD)
sharded_tensor = _dtensor_from_replicated_full_tensor(full_param, device_mesh, placements)
to_contiguous, casting_dtype = _infer_parameter_dtype(model, param_name, full_param)
sharded_tensor = _cast_and_contiguous(sharded_tensor, to_contiguous, casting_dtype)
if cpu_offload:
sharded_tensor = sharded_tensor.to('cpu')
sharded_sd[param_name] = sharded_tensor
continue

if accelerator.is_main_process:
full_value = _load_full_value(param_name, sharded_param)
else:
full_value = torch.empty(
sharded_param.size(),
device=accelerator.device,
dtype=sharded_param.dtype,
)

dist.broadcast(full_value, src=0, group=dist.group.WORLD)
to_contiguous, casting_dtype = _infer_parameter_dtype(model, param_name, full_value)
full_value = _cast_and_contiguous(full_value, to_contiguous, casting_dtype)
if cpu_offload:
full_value = full_value.to('cpu')
sharded_sd[param_name] = full_value

model.load_state_dict(sharded_sd, assign=True)
return model

patched_fsdp2_load_full_state_dict._twinkle_patched = True
patched_fsdp2_load_full_state_dict._twinkle_original = original
fsdp_utils.fsdp2_load_full_state_dict = patched_fsdp2_load_full_state_dict


def _get_state_dict_param_for_dtype_inference(model, param_name: str):
try:
return model.get_parameter_or_buffer(param_name)
except AttributeError:
if '.' in param_name:
base_param_name, param_name = param_name.rsplit('.', 1)
model = model.get_submodule(base_param_name)
return getattr(model, param_name)


class AccelerateStrategy:
"""A training strategy that uses `accelerate` to wrap models.

Expand All @@ -172,8 +28,6 @@ def __init__(
from accelerate import Accelerator
from accelerate.utils import InitProcessGroupKwargs

_patch_accelerate_fsdp2_load_full_state_dict()

self.device_mesh = device_mesh
self.mixed_precision = mixed_precision
self._memory_efficient_init = memory_efficient_init
Expand Down
135 changes: 84 additions & 51 deletions src/twinkle/model/transformers/strategy/native_fsdp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import os
import torch
import torch.distributed as dist
from torch import nn
Expand Down Expand Up @@ -53,8 +54,11 @@ def capture_pre_ep_state_if_needed(self, model, *, enable_ep: bool) -> None:
return
if not (enable_ep and self.use_rank0_pretrained_broadcast()):
return
is_rank0 = dist.is_available() and dist.is_initialized() and dist.get_rank() == 0
self.set_rank0_pre_ep_full_state_dict(clone_state_dict_to_cpu(model.state_dict()) if is_rank0 else {})
local_rank = Platform.get_local_rank()
if local_rank < 0:
raise RuntimeError('Native FSDP node-local pre-EP state capture requires LOCAL_RANK.')
is_source_rank = dist.is_available() and dist.is_initialized() and local_rank == 0
Comment on lines +57 to +60
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Check if distributed training is initialized before attempting to access Platform.get_local_rank(). This prevents unnecessary RuntimeError exceptions when running in non-distributed or single-process environments where LOCAL_RANK is not set.

Suggested change
local_rank = Platform.get_local_rank()
if local_rank < 0:
raise RuntimeError('Native FSDP node-local pre-EP state capture requires LOCAL_RANK.')
is_source_rank = dist.is_available() and dist.is_initialized() and local_rank == 0
if not (dist.is_available() and dist.is_initialized()):
return
local_rank = Platform.get_local_rank()
if local_rank < 0:
raise RuntimeError('Native FSDP node-local pre-EP state capture requires LOCAL_RANK.')
is_source_rank = local_rank == 0

self.set_rank0_pre_ep_full_state_dict(clone_state_dict_to_cpu(model.state_dict()) if is_source_rank else {})
self._pre_ep_state_captured = True

def prepare_adapter_config(self, config_or_dir, *, enable_ep: bool):
Expand Down Expand Up @@ -131,15 +135,19 @@ def wrap_model(self, model, optimizer=None):
adapter_source_sd = {}
adapter_full_sd = {}
if use_meta:
is_rank0 = (dist.get_rank() == 0)
local_rank = Platform.get_local_rank()
if local_rank < 0:
raise RuntimeError('Native FSDP node-local state loading requires LOCAL_RANK.')
is_source_rank = local_rank == 0
if ep_enabled and self._rank0_pre_ep_full_state_dict is not None:
original_sd = self._rank0_pre_ep_full_state_dict if is_rank0 else {}
original_sd = self._rank0_pre_ep_full_state_dict if is_source_rank else {}
else:
original_sd = model.state_dict() if is_rank0 else {}
original_sd = model.state_dict() if is_source_rank else {}
adapter_source_sd = _collect_adapter_source_state(model.state_dict())
adapter_full_sd = self._adapter_full_state_dict if is_rank0 and self._adapter_full_state_dict else {}
saved_buffers = _get_non_persistent_buffers(model) if is_rank0 else {}
if is_rank0:
adapter_full_sd = (
self._adapter_full_state_dict if is_source_rank and self._adapter_full_state_dict else {})
saved_buffers = _get_non_persistent_buffers(model) if is_source_rank else {}
if is_source_rank:
model = model.to(torch.device('meta'))
if hasattr(model, 'tie_weights'):
model.tie_weights()
Expand Down Expand Up @@ -534,6 +542,25 @@ def _build_rank_to_ep_rank(ep_fsdp_device_mesh: Optional[TorchDeviceMesh]) -> Di
return rank_to_ep_rank


def _get_local_rank_info() -> tuple[int, int, int, List[int]]:
"""Return local-rank topology for node-local state-dict fanout."""
rank = dist.get_rank()
world_size = dist.get_world_size()
local_rank = Platform.get_local_rank()
if 'LOCAL_WORLD_SIZE' not in os.environ and 'LOCAL_SIZE' not in os.environ:
raise RuntimeError('Native FSDP node-local state loading requires LOCAL_WORLD_SIZE or LOCAL_SIZE.')
local_world_size = Platform.get_local_world_size()
if local_rank < 0 or local_world_size <= 0 or world_size % local_world_size != 0:
raise RuntimeError(f'Invalid local rank topology: rank={rank}, world_size={world_size}, '
f'local_rank={local_rank}, local_world_size={local_world_size}.')
node_start = rank - local_rank
node_ranks = list(range(node_start, min(node_start + local_world_size, world_size)))
if rank not in node_ranks or len(node_ranks) != local_world_size:
raise RuntimeError(f'Invalid local rank group: rank={rank}, local_rank={local_rank}, '
f'local_world_size={local_world_size}, node_ranks={node_ranks}.')
return rank, world_size, node_start, node_ranks


def _find_experts_in_layer(layer_mod: nn.Module, experts_map: Dict[str, nn.Module]) -> Optional[nn.Module]:
"""Find the experts module inside a decoder layer, if any."""
for module in layer_mod.modules():
Expand Down Expand Up @@ -793,7 +820,9 @@ def _broadcast_sharded_state_dict(

meta_sharded_sd = model.state_dict()
sharded_sd = {}
is_rank0 = (dist.get_rank() == 0)
rank, _, local_source_rank, local_ranks = _get_local_rank_info()
is_rank0 = (rank == 0)
is_source_rank = rank == local_source_rank
Comment on lines +823 to +825
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using point-to-point dist.send and dist.recv in a loop over all local ranks for every single parameter is highly inefficient and can cause significant performance bottlenecks during model initialization, especially for large models with many parameters and large local world sizes (e.g., 8 GPUs per node).

Instead, we can collectively create a node-local process group once and use dist.broadcast to leverage NCCL's highly optimized collective communication algorithms.

Suggested change
rank, _, local_source_rank, local_ranks = _get_local_rank_info()
is_rank0 = (rank == 0)
is_source_rank = rank == local_source_rank
rank, world_size, local_source_rank, local_ranks = _get_local_rank_info()
is_rank0 = (rank == 0)
is_source_rank = rank == local_source_rank
local_group = None
local_world_size = len(local_ranks)
if local_world_size > 1:
num_nodes = world_size // local_world_size
for node_idx in range(num_nodes):
node_ranks_i = list(range(node_idx * local_world_size, (node_idx + 1) * local_world_size))
group = dist.new_group(ranks=node_ranks_i)
if rank in node_ranks_i:
local_group = group

expert_shard_specs = expert_shard_specs or {}
rank_to_ep_rank = rank_to_ep_rank or {}
adapter_source_sd = adapter_source_sd or {}
Expand All @@ -819,6 +848,18 @@ def _broadcast_sharded_state_dict(
source_keys = metadata_holder[1] or {}
adapter_metadata = metadata_holder[2] or {}

def _broadcast_from_local_source(full_tensor):
if is_source_rank:
if full_tensor is None:
raise RuntimeError(f'Local source rank {local_source_rank} does not have full state_dict tensor.')
for target_rank in local_ranks:
if target_rank == rank:
continue
dist.send(full_tensor, dst=target_rank)
else:
dist.recv(full_tensor, src=local_source_rank)
return full_tensor
Comment on lines +851 to +861
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Update _broadcast_from_local_source to use the collectively created local_group and dist.broadcast instead of sequential point-to-point dist.send and dist.recv calls.

    def _broadcast_from_local_source(full_tensor):
        if local_group is not None:
            if is_source_rank and full_tensor is None:
                raise RuntimeError(f'Local source rank {local_source_rank} does not have full state_dict tensor.')
            dist.broadcast(full_tensor, src=local_source_rank, group=local_group)
        return full_tensor


def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements):
local_tensor = full_tensor
for mesh_dim, placement in enumerate(placements):
Expand Down Expand Up @@ -849,36 +890,20 @@ def _dtensor_from_replicated_full_tensor(full_tensor, device_mesh, placements):
)

def _broadcast_adapter_source_tensor(full_tensor, sharded_param):
if not isinstance(sharded_param, DTensor):
dist.broadcast(full_tensor, src=0)
return full_tensor
mesh = sharded_param.device_mesh.mesh
source_rank = int(mesh.flatten()[0].item())
dist.broadcast(full_tensor, src=source_rank, group=sharded_param.device_mesh.get_group())
return full_tensor
return _broadcast_from_local_source(full_tensor)

def _scatter_ep_adapter_tensor(param_name, full_tensor, sharded_param):
local_shape = tuple(sharded_param.size())
_, source_dtype = adapter_metadata[param_name]
local_tensor = torch.empty(local_shape, device=device_type, dtype=source_dtype)

if is_rank0:
shard_dim = _ep_expert_state_dict_gather_dim(param_name)
local_dim = local_shape[shard_dim]
world_size = dist.get_world_size()
for rank in range(world_size):
if rank not in rank_to_ep_rank:
raise RuntimeError(f'Missing EP rank mapping for global rank {rank}.')
ep_rank = rank_to_ep_rank[rank]
start = ep_rank * local_dim
chunk = full_tensor.narrow(shard_dim, start, local_dim).contiguous().to(device_type)
if rank == 0:
local_tensor.copy_(chunk)
else:
dist.send(chunk, dst=rank)
else:
dist.recv(local_tensor, src=0)

shard_dim = _ep_expert_state_dict_gather_dim(param_name)
local_dim = local_shape[shard_dim]
local_tensor = _scatter_ep_tensor_from_source(
full_tensor,
local_tensor,
shard_dim=shard_dim,
shard_size=local_dim,
)
return local_tensor

def _get_adapter_source(param_name):
Expand All @@ -903,27 +928,35 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
_, source_dtype = source_metadata[param_name]
local_tensor = torch.empty(local_shape, device=device_type, dtype=source_dtype)

if is_rank0:
if is_source_rank:
if full_tensor.size(0) != num_experts:
raise RuntimeError(f"EP expert parameter '{param_name}' expects {num_experts} experts, "
f'but source state has shape {tuple(full_tensor.shape)}. '
'Rank0 must capture the full pre-EP state_dict before apply_expert_parallel().')
world_size = dist.get_world_size()
for rank in range(world_size):
if rank not in rank_to_ep_rank:
raise RuntimeError(f'Missing EP rank mapping for global rank {rank}.')
ep_rank = rank_to_ep_rank[rank]
start = ep_rank * experts_per_rank
end = start + experts_per_rank
chunk = full_tensor[start:end].contiguous()
chunk_gpu = chunk.to(device_type)
if rank == 0:
local_tensor.copy_(chunk_gpu)
local_tensor = _scatter_ep_tensor_from_source(
full_tensor,
local_tensor,
shard_dim=0,
shard_size=experts_per_rank,
)
return local_tensor

def _scatter_ep_tensor_from_source(full_tensor, local_tensor, *, shard_dim: int, shard_size: int):
if is_source_rank:
if full_tensor is None:
raise RuntimeError(f'Local source rank {local_source_rank} does not have full state_dict tensor.')
for target_rank in local_ranks:
if target_rank not in rank_to_ep_rank:
raise RuntimeError(f'Missing EP rank mapping for global rank {target_rank}.')
ep_rank = rank_to_ep_rank[target_rank]
start = ep_rank * shard_size
chunk = full_tensor.narrow(shard_dim, start, shard_size).contiguous().to(device_type)
if target_rank == rank:
local_tensor.copy_(chunk)
else:
dist.send(chunk_gpu, dst=rank)
dist.send(chunk, dst=target_rank)
else:
dist.recv(local_tensor, src=0)

dist.recv(local_tensor, src=local_source_rank)
return local_tensor

for param_name, sharded_param in meta_sharded_sd.items():
Expand All @@ -950,7 +983,7 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
full_tensor = torch.empty(source_shape, device=device_type, dtype=source_dtype)
if not is_ep_adapter_param:
full_tensor = _broadcast_adapter_source_tensor(full_tensor, sharded_param)
elif is_rank0:
elif is_source_rank:
source_key = source_keys[param_name]
if source_key not in full_sd:
raise KeyError(
Expand Down Expand Up @@ -979,7 +1012,7 @@ def _scatter_ep_expert_tensor(param_name, full_tensor, sharded_param):
raise RuntimeError(f"Parameter '{param_name}' shape mismatch before broadcast: "
f'sharded logical shape={tuple(shape)}, source shape={source_shape}.')
if not is_adapter_param:
dist.broadcast(full_tensor, src=0)
full_tensor = _broadcast_from_local_source(full_tensor)
torch_util.synchronize()

if isinstance(sharded_param, DTensor):
Expand Down
7 changes: 6 additions & 1 deletion src/twinkle/model/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,12 @@ def __init__(

def _should_init_empty_pretrained_model_on_this_rank(self) -> bool:
use_rank0_broadcast = getattr(self.strategy, 'use_rank0_pretrained_broadcast', lambda: False)
return bool(use_rank0_broadcast() and dist.is_available() and dist.is_initialized() and dist.get_rank() != 0)
if not (use_rank0_broadcast() and dist.is_available() and dist.is_initialized()):
return False
local_rank = Platform.get_local_rank()
if local_rank < 0:
raise RuntimeError('Native FSDP memory_efficient_init requires LOCAL_RANK.')
return local_rank != 0

def _init_empty_model_from_config(self, model_cls, **kwargs):
from accelerate import init_empty_weights
Expand Down
Loading