diff --git a/cookbook/transformers/sp_fsdp_dense.py b/cookbook/transformers/sp_fsdp_dense.py index 868b61c0..8725481e 100644 --- a/cookbook/transformers/sp_fsdp_dense.py +++ b/cookbook/transformers/sp_fsdp_dense.py @@ -19,9 +19,10 @@ device_type=Platform.get_platform().device_prefix(), )] -# FSDP + SP validation over 4 GPUs: dp=2, fsdp=2 (SP only affects input slicing) +# FSDP + sequence-parallel validation over 4 GPUs: dp=2, fsdp=2. +# In Transformers route, ulysses_size is the total sequence-parallel degree. device_mesh = DeviceMesh( - device_type='cuda', + device_type=Platform.get_platform().device_prefix(), mesh=np.arange(4).reshape(2, 2), mesh_dim_names=('dp', 'fsdp'), ulysses_size=2, diff --git a/cookbook/transformers/sp_fsdp_dense.sh b/cookbook/transformers/sp_fsdp_dense.sh index dd04a2b0..2a8bcf08 100644 --- a/cookbook/transformers/sp_fsdp_dense.sh +++ b/cookbook/transformers/sp_fsdp_dense.sh @@ -1,5 +1,6 @@ #!/bin/bash -# To enabele sequence parallelism, please set ulysses_size > 1 +# To enable Transformers sequence parallelism, please set ulysses_size > 1. +# ulysses_size is interpreted as the total sequence-parallel degree. # device_mesh = DeviceMesh( # device_type="cuda", # mesh=np.arange(4).reshape(2, 2), diff --git a/src/twinkle/metric/loss.py b/src/twinkle/metric/loss.py index 8f4ad0c9..5466e537 100644 --- a/src/twinkle/metric/loss.py +++ b/src/twinkle/metric/loss.py @@ -25,6 +25,7 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M return loss = outputs['loss'] loss_reduction = kwargs.get('loss_reduction', 'mean') + ulysses_size = getattr(self.device_mesh, 'ulysses_size', None) or 1 if loss_reduction == 'sum': if not isinstance(inputs, list): inputs = [inputs] @@ -32,6 +33,11 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M # `Transformers` models may use reduction=sum, to average grads before step labels = input['labels'] self.num_tokens += (labels >= 0).sum().item() + # Sequence-parallel gathered loss is replicated on each ulysses rank, while + # local labels still count only the shard-local tokens. Normalize the loss + # contribution here so metric-side averaging matches the non-SP path. + if ulysses_size > 1: + loss = loss / float(ulysses_size) grad_norm = kwargs.get('grad_norm') if grad_norm is not None: self.grad_norm = grad_norm diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel.py b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py similarity index 55% rename from src/twinkle/model/transformers/strategy/sequence_parallel.py rename to src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py index 64ea34f3..3c6f1a4e 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py @@ -1,331 +1,31 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import math import torch import torch.distributed as dist +from copy import copy from dataclasses import asdict, dataclass, is_dataclass from functools import partial from transformers import PreTrainedTokenizer -from typing import Any, Dict, Optional, Tuple, Union +from types import MethodType, SimpleNamespace +from typing import Any, Dict, List, Optional, Tuple, Union +from twinkle.patch import apply_patch from twinkle.utils import DeviceMesh from twinkle.utils.transformers_utils import get_llm_model +from .linear_attention_sp import Qwen3_5GatedDeltaNetUlyssesPatch +from .utils import (DistributedAttention, GatherLoss, _derive_sequence_parallel_sizes, _get_seq_groups_from_device_mesh, + _get_ulysses_size, _SeqAllToAll, get_config_attr, get_cu_seqlens_from_position_ids, is_hccl_backend, + is_moe_config, post_all2all) -def get_config_attr(config, key, default=None): - return getattr(config, key, default) - - -def get_cu_seqlens_from_position_ids(position_ids: torch.LongTensor): - position_ids = position_ids[0] - seq_start_indices = torch.where(position_ids == 0)[0] - seq_end_indices = torch.cat([seq_start_indices[1:], torch.tensor([len(position_ids)], device=position_ids.device)]) - seq_lengths = seq_end_indices - seq_start_indices - cu_seqlens = torch.cumsum(torch.cat([torch.tensor([0], device=position_ids.device), seq_lengths]), dim=0) - return cu_seqlens - - -def _get_raw_data_world_size(device_mesh: DeviceMesh) -> int: - dp_world_size = device_mesh.dp_world_size or 1 - fsdp_world_size = device_mesh.fsdp_world_size or 1 - if dp_world_size <= 0: - dp_world_size = 1 - if fsdp_world_size <= 0: - fsdp_world_size = 1 - return dp_world_size * fsdp_world_size - - -def _get_raw_data_rank(device_mesh: DeviceMesh, rank: int) -> Optional[int]: - coord = device_mesh._get_coord_for_rank(rank) - if coord is None: - return None - - dp_rank = None - fsdp_rank = None - if device_mesh.has_dim('dp'): - dp_rank = coord[device_mesh._get_dim_index('dp')] - if device_mesh.has_dim('fsdp'): - fsdp_rank = coord[device_mesh._get_dim_index('fsdp')] - - fsdp_world_size = device_mesh.fsdp_world_size - data_rank = dp_rank if dp_rank is not None else None - if fsdp_world_size is not None and fsdp_world_size > 1: - if dp_rank is not None and fsdp_rank is not None: - data_rank = dp_rank * fsdp_world_size + fsdp_rank - elif fsdp_rank is not None: - data_rank = fsdp_rank - - if data_rank is None: - data_rank = 0 - return int(data_rank) - - -def _get_sp_group_from_device_mesh( - device_mesh: Optional[DeviceMesh], - sp_size: int, -) -> Optional[dist.ProcessGroup]: - """Return the SP (sequence-parallel) process group for the current rank. - - If the mesh defines an explicit "sp" dimension, use it directly. Otherwise, - derive SP groups by chunking data-parallel ranks (dp/fsdp) while keeping - all other mesh dimensions (tp/pp/ep/etc.) fixed. - - Example (no explicit "sp" dim, sp_size=2): - mesh_dim_names = ("dp", "fsdp", "tp") - mesh = np.arange(8).reshape(2, 2, 2) - # coords are (dp, fsdp, tp). dp/fsdp are "data" dims; tp is "non-data". - # raw_data_rank = dp * fsdp_world_size + fsdp, so ranges [0..3]. - # group_id = raw_data_rank // sp_size partitions data ranks into 2 groups. - # - # For tp=0: - # data ranks 0,1 -> group_id=0 => ranks at coords: - # (dp=0,fsdp=0,tp=0) -> rank 0 - # (dp=0,fsdp=1,tp=0) -> rank 2 - # data ranks 2,3 -> group_id=1 => ranks at coords: - # (dp=1,fsdp=0,tp=0) -> rank 4 - # (dp=1,fsdp=1,tp=0) -> rank 6 - # - # For tp=1: - # data ranks 0,1 -> group_id=0 => ranks at coords: - # (dp=0,fsdp=0,tp=1) -> rank 1 - # (dp=0,fsdp=1,tp=1) -> rank 3 - # data ranks 2,3 -> group_id=1 => ranks at coords: - # (dp=1,fsdp=0,tp=1) -> rank 5 - # (dp=1,fsdp=1,tp=1) -> rank 7 - # - # Final SP groups (keyed by (group_id, non_data_key)): - # (0, (tp=0)) -> [0, 2] - # (1, (tp=0)) -> [4, 6] - # (0, (tp=1)) -> [1, 3] - # (1, (tp=1)) -> [5, 7] - # - # Each SP group has size=2 and never crosses tp. - """ - if device_mesh is None or sp_size <= 1: - return None - if device_mesh.has_dim('sp'): - return device_mesh.create_process_group(['sp']) - if not dist.is_available() or not dist.is_initialized(): - return None - - raw_data_world_size = _get_raw_data_world_size(device_mesh) - if raw_data_world_size % sp_size != 0: - raise ValueError(f'data_world_size ({raw_data_world_size}) must be divisible by sp_size ({sp_size}).') - - rank = dist.get_rank() - ref_coord = device_mesh._get_coord_for_rank(rank) - if ref_coord is None: - return None - - non_data_indices = [] - if device_mesh.mesh_dim_names is not None: - for i, name in enumerate(device_mesh.mesh_dim_names): - if name in ('dp', 'fsdp'): - continue - non_data_indices.append(i) - - # Group ranks by (data-parallel chunk, non-data mesh coordinates). - groups: Dict[Tuple[int, Tuple[int, ...]], list[int]] = {} - for r in device_mesh.mesh.flatten().tolist(): - r = int(r) - coord = device_mesh._get_coord_for_rank(r) - if coord is None: - continue - raw_rank = _get_raw_data_rank(device_mesh, r) - if raw_rank is None: - continue - group_id = raw_rank // sp_size - non_data_key = tuple(coord[i] for i in non_data_indices) - key = (group_id, non_data_key) - groups.setdefault(key, []).append(r) - - group_list = [] - for key, ranks in groups.items(): - ranks = sorted(ranks) - if len(ranks) != sp_size: - raise ValueError(f'SP group size mismatch for key={key}: expected {sp_size}, got {len(ranks)}') - group_list.append((key, ranks)) - - group_list.sort(key=lambda item: item[0]) - - sp_group = None - for _, ranks in group_list: - pg = dist.new_group(ranks=ranks) - if rank in ranks: - sp_group = pg - return sp_group - - -class GatherLoss(torch.autograd.Function): - """Gather loss from sequence group.""" - - @staticmethod - def forward(ctx, loss, labels, gather_idx=None, position_ids=None): - """ - Args: - loss: loss tensor after splitting - labels: labels tensor after splitting - gather_idx: gather the tensors on this dim - """ - ctx.scatter_shape = loss.shape[gather_idx or 0] - ctx.gather_idx = gather_idx or 0 - if position_ids is not None: - position_ids = sequence_parallel.pad(position_ids, padding_value=-1, position_ids=position_ids) - ctx.position_ids = position_ids - # Gather split losses/labels to compute aux losses on full sequence length. - output = sequence_parallel.gather(loss, dim=ctx.gather_idx, position_ids=position_ids) - if labels is not None: - labels_output = sequence_parallel.gather(labels, dim=ctx.gather_idx, position_ids=position_ids) - else: - labels_output = None - return output, labels_output - - @staticmethod - def backward(ctx, *grad_output): - # Split grads back to local sequence chunk. - _grad = grad_output[0] - if sequence_parallel.world_size > 1 and sequence_parallel._sp_group is not None: - # Gather replicates the sequence dimension across SP ranks. Scale once here - # so downstream FSDP avg does not shrink this path by an extra SP factor. - _grad = _grad * sequence_parallel.world_size - _grad = sequence_parallel.split(_grad, dim=ctx.gather_idx, position_ids=ctx.position_ids).contiguous() - return _grad, None, None, None - - -# Code borrowed from deepspeed, here is why: -# 1. Reduce the dependency -# 2. The original code is complex -def _generate_layout_params(scatter_idx, seq_world_size, input): - if scatter_idx < 2: - bs, global_seq_len, num_local_head, head_dim = input.shape - pre_all2all_inp_shape = [bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim] - pre_all2all_permute_idx = (1, 0, 2, 3, 4) - - post_all2all_permute_idx = (1, 2, 0, 3, 4) - post_all2all_res_shape = [bs, global_seq_len // seq_world_size, seq_world_size * num_local_head, head_dim] - else: - bs, local_seq_len, num_total_head, head_dim = input.shape - assert num_total_head % seq_world_size == 0, (f'Number of heads ({num_total_head}) must be divisible ' - f'by the sequence parallel size ({seq_world_size})!') - pre_all2all_inp_shape = [bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim] - pre_all2all_permute_idx = (2, 0, 1, 3, 4) - - post_all2all_permute_idx = (1, 0, 2, 3, 4) - post_all2all_res_shape = [bs, seq_world_size * local_seq_len, num_total_head // seq_world_size, head_dim] - - return pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape - - -def post_all2all(permute_idx, res_shape): - """ - Post-processing function for `all2all` communication. - """ - - def post_func(input): - if permute_idx is not None: - input = input.permute(permute_idx).contiguous() - output = input.reshape(res_shape).contiguous() - - return output - - return post_func - - -def pre_all2all_fun(permute_idx, inp_shape, input): - """ - Pre-processing function for `all2all` communication. - """ - input_t = input.reshape(inp_shape).contiguous() - if permute_idx is not None: - input_t = input_t.permute(permute_idx).contiguous() - return input_t - - -def single_all_to_all(input, scatter_idx, gather_idx, group, **kwargs): - seq_world_size = dist.get_world_size(group) - num_heads = input.shape[2] - if num_heads % seq_world_size != 0 and not scatter_idx < 2: - raise NotImplementedError(f'num_heads {num_heads} cannot be split by sp world size {seq_world_size}') - pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape = ( - _generate_layout_params(scatter_idx, seq_world_size, input)) - - input_t = pre_all2all_fun(pre_all2all_permute_idx, pre_all2all_inp_shape, input) - - post_all2all_fun = post_all2all(post_all2all_permute_idx, post_all2all_res_shape) - output = torch.empty_like(input_t) - dist.all_to_all_single(output, input_t, group=group) +def is_qwen3_vl(model): + mt = getattr(getattr(model, 'config', None), 'model_type', '') + return 'qwen3_vl' in mt - res = post_all2all_fun(output) - return res - -class _SeqAllToAll(torch.autograd.Function): - - @staticmethod - def forward( - ctx: Any, - group: dist.ProcessGroup, - input: torch.Tensor, - scatter_idx: int, - gather_idx: int, - ) -> torch.Tensor: - ctx.group = group - ctx.scatter_idx = scatter_idx - ctx.gather_idx = gather_idx - res = single_all_to_all(input, scatter_idx, gather_idx, group) - return res - - @staticmethod - def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, None, None]: - # Reverse scatter/gather in backward to match forward layout transform. - return None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None - - -class DistributedAttention(torch.nn.Module): - - def __init__( - self, - local_attention, - sequence_parallel, - scatter_idx: int = 2, - gather_idx: int = 1, - ) -> None: - super().__init__() - self.local_attn = local_attention - self.sequence_parallel = sequence_parallel - self.scatter_idx = scatter_idx - self.gather_idx = gather_idx - - def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor, *args: - Any, **kwargs) -> torch.Tensor: - if self.sequence_parallel.world_size == 1: - return self.local_attn(query, key, value, attention_mask, *args, **kwargs) - - # All-to-all to assemble full sequence for attention, then split back after. - if self.sequence_parallel.sp_world_size > 1: - query_layer = _SeqAllToAll.apply(self.sequence_parallel._sp_group, query, self.scatter_idx, self.gather_idx) - key_layer = _SeqAllToAll.apply(self.sequence_parallel._sp_group, key, self.scatter_idx, self.gather_idx) - value_layer = _SeqAllToAll.apply(self.sequence_parallel._sp_group, value, self.scatter_idx, self.gather_idx) - else: - query_layer, key_layer, value_layer = query, key, value - - position_ids = kwargs.pop('position_ids') - if position_ids is not None: - shape0 = position_ids.shape[0] - position_ids_output = torch.empty((shape0 * self.sequence_parallel.sp_world_size, position_ids.shape[1]), - dtype=position_ids.dtype, - device=position_ids.device) - dist.all_gather_into_tensor(position_ids_output, position_ids, group=self.sequence_parallel._sp_group) - position_ids = torch.cat(position_ids_output.split(shape0, dim=0), dim=1) - - context_layer = self.local_attn( - query_layer, key_layer, value_layer, attention_mask, *args, position_ids=position_ids, **kwargs) - - if self.sequence_parallel.sp_world_size > 1: - output = _SeqAllToAll.apply(self.sequence_parallel._sp_group, context_layer, self.gather_idx, - self.scatter_idx) - else: - output = context_layer - - return output +def is_qwen3_omni(model): + mt = getattr(getattr(model, 'config', None), 'model_type', '') + return 'qwen3_omni' in mt # main content copied from ms-swift @@ -334,13 +34,20 @@ class SequenceParallel: _global_inited: bool = False def __init__(self): + self.seq_world_size = None self.sp_world_size = None + self.rp_world_size = None self.dp_world_size = None self.world_size = None + self.attn_implementation = None self.model_dtype = None self.tokenizer = None self.device_mesh = None self._sp_group = None + self._rp_group = None + self._data_rank_group = None + self._sp_rank = 0 + self._rp_rank = 0 self.num_heads = None self.causal_mask_func = None self.extra_kwargs = {} @@ -350,6 +57,38 @@ def real_position_ids(self) -> torch.Tensor: """The real position ids, this is different from the position_ids in mrope""" return self.extra_kwargs.get('position_ids') + @staticmethod + def _extract_real_position_ids(position_ids: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if position_ids is None or not torch.is_tensor(position_ids): + return position_ids + if position_ids.dim() == 3: + return position_ids[0] + return position_ids + + def _update_packed_varlen_metadata(self, real_position_ids: Optional[torch.Tensor]) -> None: + self.extra_kwargs.pop('cu_seq_lens_q', None) + if real_position_ids is None or not self._is_packed_position_ids(real_position_ids): + return + position_ids = self._extract_real_position_ids(real_position_ids) + if position_ids is None or not torch.is_tensor(position_ids): + return + if position_ids.dim() == 1: + position_ids = position_ids.unsqueeze(0) + if position_ids.shape[0] != 1: + raise ValueError('Packed sequence-parallel inputs require batch_size == 1 when deriving cu_seq_lens_q from ' + 'position_ids. Please populate cu_seq_lens_q explicitly for batched packed inputs.') + safe_position_ids = position_ids.clone() + safe_position_ids[safe_position_ids < 0] = 0 + self.extra_kwargs['cu_seq_lens_q'] = get_cu_seqlens_from_position_ids(safe_position_ids).to(torch.int32) + + @property + def sp_rank(self) -> int: + return self._sp_rank + + @property + def rp_rank(self) -> int: + return self._rp_rank + def _prepare_flash_attn(self, base_model: torch.nn.Module): try: from transformers import masking_utils @@ -382,7 +121,6 @@ def sdpa_mask(batch_size, cache_position, kv_length, *args, **kwargs): cache_position, kv_length, *args, **kwargs) - # Rebuild cache positions from real (full) position ids. device = cache_position.device cache_position = self.real_position_ids[0] cache_position = self.pad(cache_position, padding_value=-1, position_ids=self.real_position_ids, dim=0) @@ -455,42 +193,56 @@ def _attention(query, key, value, *args, **kwargs): query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) - # Packed batches (produced by PackingDataset + padding_free collate) require FA2 varlen - # semantics to avoid cross-subsequence attention. We derive cu_seqlens from position_ids - # resets (0,1,...) and pass cu_seq_lens_* to FA2. - if self.extra_kwargs.get('is_packed', False): - position_ids = kwargs.get('position_ids') - if position_ids is None: - position_ids = self.real_position_ids - # Treat SP-alignment padding (-1) as separate 1-token sequences by mapping -1 -> 0. - pos = position_ids - if pos.dim() == 1: - pos = pos.unsqueeze(0) - pos = pos.clone() - pos[pos < 0] = 0 - - cu_seqlens = get_cu_seqlens_from_position_ids(pos).to(torch.int32) - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - assert query.shape[2] == cu_seqlens[-1] - kwargs['cu_seq_lens_q'] = cu_seqlens - kwargs['cu_seq_lens_k'] = cu_seqlens - kwargs['max_length_q'] = max_seqlen - kwargs['max_length_k'] = max_seqlen - # Do not use attention_mask-based unpadding when using explicit cu_seqlens. - if len(args) > 0: - args = (None, *args[1:]) - elif 'cu_seq_lens_q' in kwargs: + if self.rp_world_size > 1: + from .zigzag_ring_attn import zigzag_ring_flash_attn_varlen_func + position_ids = kwargs.get('position_ids') if position_ids is None: position_ids = self.real_position_ids - position_ids = self.pad(position_ids, padding_value=-1, position_ids=position_ids) + position_ids = self._extract_real_position_ids(position_ids) cu_seqlens = get_cu_seqlens_from_position_ids(position_ids).to(torch.int32) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - assert query.shape[2] == cu_seqlens[-1] + position_ids = self._split_packed(position_ids, cu_seqlens, dim=-1) + mask = position_ids != -1 + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + query, key, value = self._mask_qkv(query, key, value, mask) + return zigzag_ring_flash_attn_varlen_func( + query, + key, + value, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + dropout_p=kwargs.get('dropout', 0.0), + softmax_scale=kwargs.get('scaling'), + causal=module.is_causal, + window_size=kwargs.get('sliding_window') or (-1, -1), + group=self._rp_group, + ) + elif self.extra_kwargs.get('is_packed', False) or 'cu_seq_lens_q' in kwargs: + cu_seqlens = kwargs.get('cu_seq_lens_q') + if cu_seqlens is None: + position_ids = kwargs.get('position_ids') + if position_ids is None: + position_ids = self.real_position_ids + position_ids = self._extract_real_position_ids(position_ids) + position_ids = self.pad(position_ids, padding_value=-1, position_ids=position_ids) + cu_seqlens = get_cu_seqlens_from_position_ids(position_ids).to(torch.int32) + else: + cu_seqlens = cu_seqlens.to(dtype=torch.int32, device=query.device) + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + total_tokens = int(cu_seqlens[-1].item()) + if query.shape[2] != total_tokens: + raise ValueError('Packed/varlen flash_attention_2 expects query sequence length to match ' + f'cu_seqlens total tokens, got query_seq_len={query.shape[2]} ' + f'and cu_seqlens_total={total_tokens}.') kwargs['cu_seq_lens_q'] = cu_seqlens kwargs['cu_seq_lens_k'] = cu_seqlens kwargs['max_length_q'] = max_seqlen kwargs['max_length_k'] = max_seqlen + if self.extra_kwargs.get('is_packed', False) and len(args) > 0: + args = (None, *args[1:]) return ALL_ATTENTION_FUNCTIONS['flash_attention_2_origin'](module, query, key, value, *args, **kwargs)[0] @@ -519,6 +271,8 @@ def _attention(query, key, value, *args, **kwargs): query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) + if self.rp_world_size > 1: + raise NotImplementedError('SDPA does not support derived ring attention.') return ALL_ATTENTION_FUNCTIONS['sdpa_origin'](module, query, key, value, *args, **kwargs)[0] dist_attn.local_attn = _attention @@ -541,7 +295,9 @@ def pre_forward_split_hook(_self, args, kwargs): input_ids = kwargs.get('input_ids', None) inputs_embeds = kwargs.get('inputs_embeds', None) position_ids = kwargs['position_ids'] + real_position_ids = self._extract_real_position_ids(position_ids) attention_mask = kwargs.get('attention_mask', None) + cache_position = kwargs.get('cache_position', None) if hasattr(_self, 'language_model'): embed_tokens = getattr(_self.language_model, 'embed_tokens', None) else: @@ -554,7 +310,8 @@ def pre_forward_split_hook(_self, args, kwargs): attention_mask, None, embed_tokens=embed_tokens, - real_position_ids=self.real_position_ids) + real_position_ids=real_position_ids, + cache_position=cache_position) kwargs['input_ids'] = input_ids kwargs['inputs_embeds'] = inputs_embeds kwargs['position_ids'] = position_ids @@ -563,6 +320,71 @@ def pre_forward_split_hook(_self, args, kwargs): base_model.register_forward_pre_hook(pre_forward_split_hook, with_kwargs=True) + def _prepare_multimodal_deepstack(self, base_model: torch.nn.Module): + if not is_qwen3_vl(base_model): + return + + def _patch_deepstack_process(module: torch.nn.Module) -> bool: + origin = getattr(module, '_deepstack_process', None) + if not callable(origin): + return False + if getattr(module, '_twinkle_sp_mm_patched', False): + return False + + def _deepstack_process(_self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, + visual_embeds: torch.Tensor): + world_size = sequence_parallel.world_size + if world_size and world_size > 1 and visual_pos_masks is not None: + visual_pos_masks, visual_embeds = sequence_parallel.pad_and_split_mm_tokens( + visual_pos_masks, visual_embeds) + if visual_pos_masks is None: + return hidden_states + visual_embeds.mean() * 0 + visual_pos_masks = visual_pos_masks.to(hidden_states.device) + visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + if hidden_states.ndim == 3 and visual_pos_masks.ndim == 3: + visual_pos_masks = visual_pos_masks[..., 0] + local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds + hidden_states[visual_pos_masks, :] = local_this + return hidden_states + + module._deepstack_process = MethodType(_deepstack_process, module) + module._twinkle_sp_mm_patched = True + return True + + for submodule in base_model.modules(): + _patch_deepstack_process(submodule) + _patch_deepstack_process(base_model) + + @staticmethod + def _is_qwen35_model(model: torch.nn.Module) -> bool: + config = getattr(model, 'config', None) + model_type = str(getattr(config, 'model_type', '') or '') + if model_type == 'qwen3_5': + return True + + architectures = getattr(config, 'architectures', None) or [] + if any('Qwen3_5' in str(arch) for arch in architectures): + return True + + model_module = getattr(model.__class__, '__module__', '') or '' + return 'transformers.models.qwen3_5' in model_module + + def _prepare_qwen35_linear_attention(self, model: torch.nn.Module): + has_qwen35_linear_attention = self._is_qwen35_model(model) + if not has_qwen35_linear_attention: + try: + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5GatedDeltaNet + except Exception: + return + has_qwen35_linear_attention = any(isinstance(module, Qwen3_5GatedDeltaNet) for module in model.modules()) + if not has_qwen35_linear_attention: + return + if int(self.rp_world_size or 1) > 1: + raise NotImplementedError( + 'SequenceParallel: Qwen3.5 linear attention sequence parallel does not support rp_world_size > 1 ' + '(derived ring attention).') + apply_patch(None, Qwen3_5GatedDeltaNetUlyssesPatch, sequence_parallel=self) + def _prepare_moe_aux_loss(self, base_model: torch.nn.Module): def moe_aux_loss_hook(module, args, kwargs, output): @@ -596,15 +418,6 @@ def moe_aux_loss_hook(module, args, kwargs, output): base_model.register_forward_hook(moe_aux_loss_hook, with_kwargs=True) - @staticmethod - def _is_moe_model(config) -> bool: - if 'Moe' in config.__class__.__name__: - return True - for key in ['num_experts', 'num_experts_per_tok', 'moe_intermediate_size']: - if get_config_attr(config, key): - return True - return False - def prepare( self, sp_size: int, @@ -612,16 +425,36 @@ def prepare( tokenizer: PreTrainedTokenizer, device_mesh: Optional[DeviceMesh] = None, ): - self.num_heads = get_config_attr(model.config, 'num_key_value_heads') - if self.num_heads is None: - self.num_heads = get_config_attr(model.config, 'num_attention_heads') - assert self.num_heads is not None, 'Cannot find num_heads config in config.json' - if sp_size > 1 and self.num_heads % sp_size != 0: - raise ValueError( - f'sp_size ({sp_size}) must divide num_heads ({self.num_heads}) for ulysses sequence parallel.') - self.world_size = sp_size - llm_model = get_llm_model(model) + config_candidates = [getattr(model, 'config', None)] + llm_config = getattr(llm_model, 'config', None) + if llm_config is not None and llm_config not in config_candidates: + config_candidates.append(llm_config) + text_config = getattr(getattr(model, 'config', None), 'text_config', None) + if text_config is not None and text_config not in config_candidates: + config_candidates.append(text_config) + + self.num_heads = None + for config in config_candidates: + if config is None: + continue + self.num_heads = get_config_attr(config, 'num_key_value_heads') + if self.num_heads is None: + self.num_heads = get_config_attr(config, 'num_attention_heads') + if self.num_heads is not None: + break + assert self.num_heads is not None, 'Cannot find num_attention_heads/num_key_value_heads in model config' + self.seq_world_size = sp_size + self.sp_world_size, self.rp_world_size = _derive_sequence_parallel_sizes(self.num_heads, self.seq_world_size) + self.world_size = self.seq_world_size + + self.attn_implementation = None + for config in config_candidates: + if config is None: + continue + self.attn_implementation = getattr(config, '_attn_implementation', None) + if self.attn_implementation is not None: + break if hasattr(llm_model, 'language_model'): if hasattr(llm_model.language_model, '_update_causal_mask'): @@ -630,71 +463,202 @@ def prepare( if hasattr(llm_model, '_update_causal_mask'): self.causal_mask_func = llm_model._update_causal_mask + self._init_device_mesh(device_mesh) if not SequenceParallel._global_inited: # these operations are global initializations and patches - self._init_device_mesh(device_mesh) self._prepare_flash_attn(llm_model) SequenceParallel._global_inited = True + self._prepare_qwen35_linear_attention(llm_model) self._prepare_forward_hook(llm_model) + self._prepare_multimodal_deepstack(llm_model) - if SequenceParallel._is_moe_model(getattr(model, 'config', None)): + if is_moe_config(getattr(model, 'config', None)): self._prepare_moe_aux_loss(llm_model) self.model_dtype = next(model.parameters()).dtype self.tokenizer = tokenizer + if self.rp_world_size > 1: + attn_impl = getattr(model.config, '_attn_implementation', None) + if attn_impl != 'flash_attention_2': + raise NotImplementedError('Derived ring attention only supports flash_attention_2 backend.') + + def _mask_qkv(self, query, key, value, mask): + mask = mask.unsqueeze(2).unsqueeze(3) + query = query * mask + value = value * mask + key = key + ((~mask) * -1e5).to(key.dtype) + return query, key, value def pad(self, tensor, padding_value, position_ids=None, dim=1): - """Pad tensor for sequence parallel""" - world_size = self.world_size + """Pad tensor for sequence parallel.""" + if tensor is None: + return None + if self.rp_world_size and self.rp_world_size > 1: + world_size = self.world_size * 2 + else: + world_size = self.world_size + + dim = dim if dim >= 0 else tensor.dim() + dim def _do_pad(tensor): - # Ensure seq length is divisible by SP size to allow even split. length = tensor.shape[dim] pad_num = world_size - (length % world_size) if pad_num == 0 or pad_num == world_size: return tensor if not isinstance(padding_value, torch.Tensor): - # ids - pad_shape = ((*tensor.shape[:dim], pad_num, *tensor.shape[dim + 1:]) if dim != -1 else - (*tensor.shape[:dim], pad_num)) + pad_shape = (*tensor.shape[:dim], pad_num, *tensor.shape[dim + 1:]) pad = torch.full(pad_shape, padding_value, dtype=tensor.dtype, device=tensor.device) - tensor = torch.cat([tensor, pad], dim=dim) - else: - # For embeddings - tensor = torch.cat([tensor, padding_value.unsqueeze(0).repeat(tensor.shape[0], pad_num, 1)], dim=dim) - return tensor - + return torch.cat([tensor, pad], dim=dim) + pad = padding_value.unsqueeze(0).repeat(tensor.shape[0], pad_num, 1) + return torch.cat([tensor, pad], dim=dim) + + if position_ids is not None and self.rp_world_size > 1: + cu_seqlens = get_cu_seqlens_from_position_ids(position_ids) + padded = [] + for i in range(len(cu_seqlens) - 1): + start, end = int(cu_seqlens[i].item()), int(cu_seqlens[i + 1].item()) + slices = [slice(None)] * tensor.dim() + slices[dim] = slice(start, end) + padded.append(_do_pad(tensor[tuple(slices)])) + return torch.cat(padded, dim=dim) return _do_pad(tensor) + def pad_and_split_mm_tokens(self, visual_mask, mm_embeds): + input_ids = self.extra_kwargs['input_ids'] + empty_embeds = torch.empty( + (input_ids.shape[0], input_ids.shape[1], mm_embeds.shape[-1])).to(mm_embeds.device).to(mm_embeds.dtype) + empty_embeds[visual_mask] = mm_embeds + + embeds = SimpleNamespace(weight=mm_embeds) + + _, split_input_embeds, _, _, _, _, extra_values = self.pad_and_split_inputs( + None, + empty_embeds, + None, + None, + None, + None, + embeds, + self.real_position_ids, + extra_split_values=[(visual_mask, 0, -1)]) + visual_mask = extra_values[0] + return visual_mask, split_input_embeds[visual_mask] + def gather(self, local_output, dim: int, position_ids=None): - """Gather tensor for sequence parallel - reverse of split""" + """Gather tensor for sequence parallel - reverse of split.""" if self.world_size == 1: return local_output - # Gather local chunks from each SP rank and concatenate along sequence dim. - gathered_sp = torch.empty( - [local_output.shape[0] * self.sp_world_size] + list(local_output.shape[1:]), - dtype=local_output.dtype, - device=local_output.device) - dist.all_gather_into_tensor(gathered_sp, local_output, group=self._sp_group) - gathered_sp = torch.cat(gathered_sp.split(local_output.shape[0], dim=0), dim=dim) - return gathered_sp.contiguous() + dim = dim if dim >= 0 else local_output.dim() + dim + + def _slice(value, start, end): + slices = [slice(None)] * value.dim() + slices[dim] = slice(start, end) + return value[tuple(slices)] + + def _assign(dst, start, end, src): + slices = [slice(None)] * dst.dim() + slices[dim] = slice(start, end) + dst[tuple(slices)] = src + + if self.rp_world_size > 1: + if position_ids is None: + raise ValueError('position_ids are required to gather derived ring outputs.') + position_ids = self.pad(position_ids, padding_value=-1, position_ids=position_ids) + + if self.sp_world_size > 1: + gathered_sp = [torch.zeros_like(local_output) for _ in range(self.sp_world_size)] + dist.all_gather(gathered_sp, local_output.contiguous(), group=self._sp_group) + rp_chunk = torch.cat(gathered_sp, dim=dim) + else: + rp_chunk = local_output.contiguous() + + gathered_rp = [torch.zeros_like(rp_chunk) for _ in range(self.rp_world_size)] + dist.all_gather(gathered_rp, rp_chunk, group=self._rp_group) + + cu_seqlens = get_cu_seqlens_from_position_ids(position_ids) + padded_lengths = [] + for i in range(len(cu_seqlens) - 1): + length = int((cu_seqlens[i + 1] - cu_seqlens[i]).item()) + padded_length = math.ceil(length / (self.world_size * 2)) * (self.world_size * 2) + padded_lengths.append(padded_length) + + full_shape = list(rp_chunk.shape) + full_shape[dim] = sum(padded_lengths) + full_output = torch.zeros(full_shape, dtype=local_output.dtype, device=local_output.device) + for idx_rp, rp_tensor in enumerate(gathered_rp): + accumulated_local_length = 0 + for padded_length in padded_lengths: + local_length = padded_length // self.rp_world_size + local_tensor = _slice(rp_tensor, accumulated_local_length, accumulated_local_length + local_length) + chunk_size = local_length // 2 + full_start = accumulated_local_length * self.rp_world_size + idx_rp * chunk_size + _assign(full_output, full_start, full_start + chunk_size, _slice(local_tensor, 0, chunk_size)) + full_start = accumulated_local_length * self.rp_world_size + (2 * self.rp_world_size - idx_rp + - 1) * chunk_size + _assign( + full_output, + full_start, + full_start + chunk_size, + _slice(local_tensor, chunk_size, local_length), + ) + accumulated_local_length += local_length + return full_output.contiguous() + + if self.sp_world_size > 1: + if is_hccl_backend(self._sp_group): + gathered_sp_chunks = [torch.zeros_like(local_output) for _ in range(self.sp_world_size)] + dist.all_gather(gathered_sp_chunks, local_output.contiguous(), group=self._sp_group) + gathered_sp = torch.cat(gathered_sp_chunks, dim=dim) + else: + gathered_sp = torch.empty( + [local_output.shape[0] * self.sp_world_size] + list(local_output.shape[1:]), + dtype=local_output.dtype, + device=local_output.device) + dist.all_gather_into_tensor(gathered_sp, local_output, group=self._sp_group) + gathered_sp = torch.cat(gathered_sp.split(local_output.shape[0], dim=0), dim=dim) + return gathered_sp.contiguous() + return local_output + + def _split_packed(self, value, cu_seqlens, dim=1): + dim = dim if dim >= 0 else value.dim() + dim + local_values = [] + for i in range(len(cu_seqlens) - 1): + start, end = int(cu_seqlens[i].item()), int(cu_seqlens[i + 1].item()) + slices = [slice(None)] * value.dim() + slices[dim] = slice(start, end) + sub_value = value[tuple(slices)] + local_value = sub_value.chunk(2 * self.rp_world_size, dim=dim) + local_values.extend([ + local_value[self.rp_rank], + local_value[2 * self.rp_world_size - 1 - self.rp_rank], + ]) + return torch.cat(local_values, dim=dim).contiguous() def split(self, input, dim: int, position_ids=None): - """Split tensor for sequence parallel""" + """Split tensor for sequence parallel.""" if self.world_size == 1: return input - # Split along sequence dimension; each rank keeps its local slice. - rank = dist.get_rank(self._sp_group) if self._sp_group is not None else 0 + dim = dim if dim >= 0 else input.dim() + dim + if self.rp_world_size > 1: + if position_ids is None: + raise ValueError('position_ids are required to split derived ring inputs.') + cu_seqlens = get_cu_seqlens_from_position_ids(position_ids) + if not torch.all(cu_seqlens % (2 * self.rp_world_size) == 0): + raise ValueError( + f'Each packed sequence length must be divisible by {2 * self.rp_world_size} after padding.') + value_chunks = self._split_packed(input, cu_seqlens, dim=dim) + if self.sp_world_size > 1: + return value_chunks.chunk(self.sp_world_size, dim=dim)[self.sp_rank].contiguous() + return value_chunks.contiguous() + dim_size = input.size(dim) assert dim_size % self.sp_world_size == 0, (f'The dimension to split ({dim_size}) is not a multiple of ' f'world size ({self.sp_world_size}), cannot split tensor evenly') - tensor_list = torch.split(input, dim_size // self.sp_world_size, dim=dim) - output = tensor_list[rank].contiguous() - return output + return tensor_list[self.sp_rank].contiguous() def pad_and_split_inputs(self, input_ids, @@ -705,6 +669,7 @@ def pad_and_split_inputs(self, loss_scale, embed_tokens=None, real_position_ids=None, + cache_position=None, extra_split_values=None): """Common implementation for padding and splitting inputs @@ -725,6 +690,7 @@ def pad_and_split_inputs(self, real_position_ids = real_position_ids if real_position_ids is not None else position_ids # Track packed batches to drive attention backend behavior (packed => require flash_attention_2 varlen). self.extra_kwargs['is_packed'] = self._is_packed_position_ids(real_position_ids) + self._update_packed_varlen_metadata(real_position_ids) extra_values = [] batch_size = input_ids.shape[ 0] if input_ids is not None else input_embeds.shape[0] if input_embeds is not None else None @@ -740,6 +706,9 @@ def pad_and_split_inputs(self, input_embeds = self.pad(input_embeds, padding_value=pad_emb, position_ids=real_position_ids) batch_size = input_ids.shape[ 0] if input_ids is not None else input_embeds.shape[0] if input_embeds is not None else 1 + if self.rp_world_size > 1 and batch_size > 1: + raise NotImplementedError( + 'Derived ring attention only supports padding-free / packed batches with batch_size == 1.') if position_ids is not None: position_ids = self.pad(position_ids, padding_value=-1, position_ids=real_position_ids, dim=-1) if labels is not None: @@ -763,11 +732,13 @@ def pad_and_split_inputs(self, # no need position_ids here, because padding_free does not need attention_mask, # so this is not ring-attention attention_mask = self.pad(attention_mask, padding_value=0) - cache_position = torch.arange(0, attn_shape, device=inputs.device) - # pad attention mask to 4d to avoid calculation errors - if hasattr(self, 'causal_mask_func') and self.causal_mask_func is not None: - attention_mask = self.causal_mask_func(attention_mask, inputs.to(self.model_dtype), cache_position, - None, None) + local_cache_position = torch.arange(0, attn_shape, device=inputs.device) + # FlashAttention2 expects a 2D padding mask (or None). Converting it to a 4D causal mask here breaks + # the later per-rank sequence split and changes the attention contract relative to the baseline path. + if (cache_position is None and hasattr(self, 'causal_mask_func') and self.causal_mask_func is not None + and self.attn_implementation != 'flash_attention_2'): + attention_mask = self.causal_mask_func(attention_mask, inputs.to(self.model_dtype), + local_cache_position, None, None) if extra_split_values is not None: for (tensor, pad_value, split_dim) in extra_split_values: extra_values.append( @@ -777,23 +748,6 @@ def pad_and_split_inputs(self, if input_embeds is not None: input_embeds = self.split(input_embeds, dim=1, position_ids=real_position_ids) if labels is not None: - if self.extra_kwargs.get('is_packed', False) and real_position_ids is not None: - # PackingDataset + padding_free collate concatenates multiple sequences into a single token stream. - # `position_ids` resets to 0 at each boundary, but our labels are already next-token aligned by - # Template._roll_labels(). Therefore the cross-subsequence supervision term lives at the *previous* - # token index (the token right before a boundary start). - # - # Example (boundary at index b where position_ids[b] == 0): - # - Bad term is: token[b-1] predicting token[b] - # - In next-token-aligned labels, this appears at labels[b-1] - boundary_starts = (real_position_ids == 0) - prev = torch.zeros_like(boundary_starts, dtype=torch.bool) - # Mask token b-1 when boundary starts at b. - prev[..., :-1] = boundary_starts[..., 1:] - labels = labels.clone() - labels[prev] = -100 - # Also avoid any potential wrap-around supervision at the end of the concatenated stream. - labels[..., -1] = -100 labels = self.split(labels, dim=-1, position_ids=real_position_ids) if loss_scale is not None: loss_scale = torch.roll(loss_scale, shifts=-1, dims=-1) @@ -801,6 +755,8 @@ def pad_and_split_inputs(self, if position_ids is not None: position_ids = self.split(position_ids, dim=-1, position_ids=real_position_ids) + # if attention_mask is not None and torch.is_tensor(attention_mask) and attention_mask.dim() == 2: + # attention_mask = self.split(attention_mask, dim=1, position_ids=real_position_ids) if extra_split_values is not None: for i in range(len(extra_values)): extra_values[i] = self.split( @@ -813,11 +769,10 @@ def _init_device_mesh(self, device_mesh: Optional[DeviceMesh] = None): raise RuntimeError('SequenceParallel requires a twinkle DeviceMesh for initialization.') self.device_mesh = device_mesh - self.sp_world_size = self.world_size self.dp_world_size = device_mesh.data_world_size or 1 - self._sp_group = _get_sp_group_from_device_mesh(device_mesh, self.sp_world_size) - if self._sp_group is None and self.sp_world_size > 1: - raise RuntimeError('Failed to create sequence-parallel group from DeviceMesh.') + (self._sp_group, self._rp_group, self._data_rank_group, self._sp_rank, + self._rp_rank) = _get_seq_groups_from_device_mesh(device_mesh, self.seq_world_size, self.sp_world_size, + self.rp_world_size) @staticmethod def _is_packed_position_ids(position_ids: Optional[torch.Tensor]) -> bool: @@ -844,20 +799,29 @@ def prepare_inputs(self, inputs): """Prepare inputs 1. set extra_kwargs['position_ids'] - 2. split labels + 2. cache packed/varlen metadata + 3. split labels """ - position_ids = None input_ids = inputs.get('input_ids') position_ids = inputs.get('position_ids') - if position_ids is not None and input_ids is not None and position_ids.shape[0] == input_ids.shape[0]: - self.extra_kwargs['position_ids'] = position_ids.clone() - self.extra_kwargs['is_packed'] = self._is_packed_position_ids(position_ids) + real_position_ids = self._extract_real_position_ids(position_ids) + if real_position_ids is not None and input_ids is not None and real_position_ids.shape[0] == input_ids.shape[0]: + self.extra_kwargs['position_ids'] = real_position_ids.clone() + self.extra_kwargs['is_packed'] = self._is_packed_position_ids(real_position_ids) + self._update_packed_varlen_metadata(real_position_ids) if input_ids is not None: self.extra_kwargs['input_ids'] = input_ids.clone() if 'labels' in inputs: - labels = inputs['labels'] + labels = inputs.get('labels') _, _, labels, _, _, _, _ = self.pad_and_split_inputs( - None, None, labels, None, None, None, real_position_ids=position_ids) + None, + None, + labels, + None, + None, + None, + real_position_ids=real_position_ids, + ) inputs['labels'] = labels return inputs @@ -870,20 +834,6 @@ class SequenceParallelConfig: enabled: bool = True ulysses_size: Optional[int] = None gather_logits: bool = True - loss_reduction: str = 'mean' - compensate_fsdp_avg: bool = False - - -def _get_ulysses_size(device_mesh, sp_config: Optional[Dict[str, Any]] = None) -> int: - if sp_config: - cfg_size = sp_config.get('ulysses_size') - if cfg_size is not None: - return int(cfg_size) - if device_mesh is None: - return 1 - if getattr(device_mesh, 'ulysses_size', None) is not None: - return int(device_mesh.ulysses_size) - return 1 class SequenceParallelStrategy: @@ -969,69 +919,29 @@ def postprocess_outputs(self, outputs: Any) -> Any: outputs['logits'] = gathered return outputs - def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore_index: int = -100) -> torch.Tensor: + def gather_loss_tensors( + self, + inputs: Dict[str, Any], + outputs: Dict[str, Any], + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + if inputs is None or outputs is None: + return inputs, outputs if not self.enabled or self.ulysses_size <= 1: - return loss - if labels is None or sequence_parallel._sp_group is None: - return loss - # Compute global loss via autograd-aware all-reduce. - reduction = str(self.sp_config.get('loss_reduction', 'mean')).lower() - if reduction == 'none': - raise ValueError("SequenceParallelStrategy.reduce_loss only supports reduction='sum' or 'mean'. " - 'Please aggregate per-token losses before calling reduce_loss.') - compensate_fsdp_avg = bool(self.sp_config.get('compensate_fsdp_avg', False)) - compensate_factor = float(self.ulysses_size if compensate_fsdp_avg else 1.0) - sum_metric_scale = float(self.ulysses_size) - - class _ReduceSequenceParallelLoss(torch.autograd.Function): - - @staticmethod - def forward(ctx, local_mean: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor: - local_tokens = num_valid_tokens.detach().clone() - local_sum = local_mean * local_tokens - if local_tokens.item() == 0: - local_sum = torch.nan_to_num(local_sum) - global_sum = local_sum.detach().clone() - dist.all_reduce(global_sum, group=sequence_parallel._sp_group) - global_tokens = num_valid_tokens.detach().clone() - dist.all_reduce(global_tokens, group=sequence_parallel._sp_group) - ctx.save_for_backward(local_tokens, global_tokens) - if global_tokens.item() == 0: - return local_sum - return global_sum / global_tokens - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - local_tokens, global_tokens = ctx.saved_tensors - if global_tokens.item() == 0: - return torch.zeros_like(grad_output), None - # d(global_mean)/d(local_mean) = local_tokens / global_tokens. - grad_local_mean = grad_output * (local_tokens / global_tokens) * compensate_factor - return grad_local_mean, None - - class _ReduceSequenceParallelSum(torch.autograd.Function): - - @staticmethod - def forward(ctx, local_sum: torch.Tensor) -> torch.Tensor: - ctx.sum_metric_scale = sum_metric_scale - global_sum = local_sum.detach().clone() - dist.all_reduce(global_sum, group=sequence_parallel._sp_group) - # Keep logging/metric value aligned with non-SP sum semantics under - # outer collect='mean' by removing one SP replication factor. - return global_sum / ctx.sum_metric_scale - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - # Keep training gradient scale unchanged; forward-side scaling is for - # logging/metric alignment under outer collect='mean'. - return grad_output - - if reduction == 'sum': - return _ReduceSequenceParallelSum.apply(loss) - - # Default to mean reduction: `loss` is local mean. - num_valid_tokens = (labels != ignore_index).sum().to(loss.device) - return _ReduceSequenceParallelLoss.apply(loss, num_valid_tokens) + return inputs, outputs + labels = inputs.get('labels') + logps = outputs.get('logps') + if labels is None or logps is None: + return inputs, outputs + if not torch.is_tensor(logps) or logps.dim() < 2: + raise TypeError('SequenceParallelStrategy.gather_loss_inputs expects outputs[\"logps\"] to be a ' + f'sequence tensor, got type={type(logps)} shape={getattr(logps, "shape", None)}') + inputs = copy(inputs) + outputs = copy(outputs) + real_position_ids = sequence_parallel.real_position_ids + gathered_logps, gathered_labels = GatherLoss.apply(logps, labels, 1, real_position_ids) + outputs['logps'] = gathered_logps + inputs['labels'] = gathered_labels + return inputs, outputs def wrap_model(self, model, optimizer=None): self.initialize() diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py new file mode 100644 index 00000000..1656b670 --- /dev/null +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/linear_attention_sp.py @@ -0,0 +1,283 @@ +import os +import torch +import torch.distributed as dist +import torch.nn.functional as F +from transformers.utils.import_utils import is_flash_linear_attention_available +from typing import Any, Optional, Tuple + +from twinkle.model.transformers.strategy.sequence_parallel.utils import head_to_seq_shard, seq_to_head_shard +from twinkle.patch import Patch + +if is_flash_linear_attention_available(): + from fla.modules.convolution import causal_conv1d as _FLA_CAUSAL_CONV1D_FN + from fla.ops.gated_delta_rule import chunk_gated_delta_rule as _FLA_CHUNK_GATED_DELTA_RULE +else: + _FLA_CAUSAL_CONV1D_FN = None + _FLA_CHUNK_GATED_DELTA_RULE = None + + +def _sp_is_enabled(sequence_parallel_context) -> bool: + return bool(sequence_parallel_context is not None and getattr(sequence_parallel_context, 'world_size', 1) > 1) + + +def _get_sp_rank(sequence_parallel_context) -> int: + if not _sp_is_enabled(sequence_parallel_context): + return 0 + if getattr(sequence_parallel_context, '_sp_group', None) is None: + return 0 + return dist.get_rank(group=sequence_parallel_context._sp_group) + + +def _get_local_padding_mask( + attention_mask: torch.Tensor, + local_seq_len: int, + sequence_parallel_context, +) -> torch.Tensor: + if attention_mask.shape[-1] == local_seq_len or not _sp_is_enabled(sequence_parallel_context): + return attention_mask + return sequence_parallel_context.split( + attention_mask, + dim=1, + position_ids=sequence_parallel_context.real_position_ids, + ) + + +def _ensure_linear_attention_kernels(mod: torch.nn.Module): + mod.causal_conv1d_fn = getattr(mod, 'causal_conv1d_fn', None) or _FLA_CAUSAL_CONV1D_FN + mod.chunk_gated_delta_rule = getattr(mod, 'chunk_gated_delta_rule', None) or _FLA_CHUNK_GATED_DELTA_RULE + if mod.chunk_gated_delta_rule is None: + raise ImportError('Qwen3.5 linear attention sequence parallel requires chunk gated delta rule implementations.') + if mod.causal_conv1d_fn is None: + raise ImportError( + 'Qwen3.5 linear attention sequence parallel requires fla.modules.convolution.causal_conv1d for ' + 'training/prefill.') + + +def _get_local_conv_weights( + mod: torch.nn.Module, + *, + sp_rank: int, + local_num_k_heads: int, + local_num_v_heads: int, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + local_key_dim = local_num_k_heads * mod.head_k_dim + local_value_dim = local_num_v_heads * mod.head_v_dim + conv_weight = mod.conv1d.weight.squeeze(1) + if conv_weight.shape[0] != (2 * mod.key_dim + mod.value_dim): + raise ValueError( + f'Unexpected conv weight dim {conv_weight.shape[0]}, expected {2 * mod.key_dim + mod.value_dim}.') + key_offset = sp_rank * local_key_dim + value_offset = sp_rank * local_value_dim + local_q_weight = conv_weight[key_offset:key_offset + local_key_dim] + local_k_weight = conv_weight[mod.key_dim + key_offset:mod.key_dim + key_offset + local_key_dim] + local_v_weight = conv_weight[2 * mod.key_dim + value_offset:2 * mod.key_dim + value_offset + local_value_dim] + local_conv_weight = torch.cat([local_q_weight, local_k_weight, local_v_weight], dim=0) + + conv_bias = getattr(mod.conv1d, 'bias', None) + if conv_bias is None: + return local_conv_weight, None + local_q_bias = conv_bias[key_offset:key_offset + local_key_dim] + local_k_bias = conv_bias[mod.key_dim + key_offset:mod.key_dim + key_offset + local_key_dim] + local_v_bias = conv_bias[2 * mod.key_dim + value_offset:2 * mod.key_dim + value_offset + local_value_dim] + return local_conv_weight, torch.cat([local_q_bias, local_k_bias, local_v_bias], dim=0) + + +class Qwen3_5GatedDeltaNetUlyssesPatch(Patch): + + @staticmethod + def _run_forward( + mod: torch.nn.Module, + hidden_states: torch.Tensor, + *, + cache_params=None, + cache_position=None, + attention_mask: Optional[torch.Tensor] = None, + cu_seq_lens_q: Optional[torch.Tensor] = None, + sequence_parallel_context=None, + ) -> torch.Tensor: + _ensure_linear_attention_kernels(mod) + from transformers.models.qwen3_5.modeling_qwen3_5 import apply_mask_to_padding_states + + local_attention_mask = attention_mask + if torch.is_tensor(attention_mask) and attention_mask.dim() == 2: + local_attention_mask = _get_local_padding_mask( + attention_mask, + hidden_states.shape[1], + sequence_parallel_context, + ) + hidden_states = apply_mask_to_padding_states(hidden_states, local_attention_mask) + batch_size, seq_len, _ = hidden_states.shape + + has_previous_state = bool(cache_params is not None and getattr(cache_params, 'has_previous_state', False)) + use_precomputed_states = has_previous_state and seq_len == 1 and cache_position is not None + if use_precomputed_states: + raise NotImplementedError( + 'Qwen3.5 linear attention sequence parallel only supports training/prefill paths; decode with ' + 'cached states is not supported.') + + mixed_qkv = mod.in_proj_qkv(hidden_states) + z = mod.in_proj_z(hidden_states).reshape(batch_size, seq_len, mod.num_v_heads, mod.head_v_dim) + b = mod.in_proj_b(hidden_states) + a = mod.in_proj_a(hidden_states) + + sp_enabled = _sp_is_enabled(sequence_parallel_context) + if sp_enabled: + sp_world_size = int(sequence_parallel_context.sp_world_size) + if mod.num_k_heads % sp_world_size != 0 or mod.num_v_heads % sp_world_size != 0: + raise RuntimeError( + 'Qwen3.5 linear attention sequence parallel requires sp_world_size to divide both ' + f'linear_num_key_heads ({mod.num_k_heads}) and linear_num_value_heads ({mod.num_v_heads}).') + local_num_k_heads = mod.num_k_heads // sp_world_size + local_num_v_heads = mod.num_v_heads // sp_world_size + q_proj, k_proj, v_proj = torch.split(mixed_qkv, [mod.key_dim, mod.key_dim, mod.value_dim], dim=-1) + q_proj = q_proj.reshape(batch_size, seq_len, mod.num_k_heads, mod.head_k_dim) + k_proj = k_proj.reshape(batch_size, seq_len, mod.num_k_heads, mod.head_k_dim) + v_proj = v_proj.reshape(batch_size, seq_len, mod.num_v_heads, mod.head_v_dim) + q_proj = seq_to_head_shard(q_proj, sequence_parallel_context) + k_proj = seq_to_head_shard(k_proj, sequence_parallel_context) + v_proj = seq_to_head_shard(v_proj, sequence_parallel_context) + b = seq_to_head_shard(b.reshape(batch_size, seq_len, mod.num_v_heads, 1), + sequence_parallel_context).squeeze(-1) + a = seq_to_head_shard(a.reshape(batch_size, seq_len, mod.num_v_heads, 1), + sequence_parallel_context).squeeze(-1) + seq_after_shard = q_proj.shape[1] + mixed_qkv = torch.cat( + ( + q_proj.reshape(batch_size, seq_after_shard, local_num_k_heads * mod.head_k_dim), + k_proj.reshape(batch_size, seq_after_shard, local_num_k_heads * mod.head_k_dim), + v_proj.reshape(batch_size, seq_after_shard, local_num_v_heads * mod.head_v_dim), + ), + dim=-1, + ) + sp_rank = _get_sp_rank(sequence_parallel_context) + conv_weight, conv_bias = _get_local_conv_weights( + mod, sp_rank=sp_rank, local_num_k_heads=local_num_k_heads, local_num_v_heads=local_num_v_heads) + else: + local_num_k_heads = mod.num_k_heads + local_num_v_heads = mod.num_v_heads + sp_rank = 0 + b = b.reshape(batch_size, seq_len, mod.num_v_heads) + a = a.reshape(batch_size, seq_len, mod.num_v_heads) + conv_weight = mod.conv1d.weight.squeeze(1) + conv_bias = getattr(mod.conv1d, 'bias', None) + + packed_cu_seqlens = None + if cu_seq_lens_q is not None: + packed_cu_seqlens = cu_seq_lens_q.to(dtype=torch.int32, device=mixed_qkv.device) + elif sequence_parallel_context is not None: + packed_cu_seqlens = getattr(sequence_parallel_context, 'extra_kwargs', {}).get('cu_seq_lens_q') + if packed_cu_seqlens is not None: + packed_cu_seqlens = packed_cu_seqlens.to(dtype=torch.int32, device=mixed_qkv.device) + if bool(getattr(sequence_parallel_context, 'extra_kwargs', {}).get('is_packed', + False)) and packed_cu_seqlens is None: + raise ValueError( + 'Packed Qwen3.5 linear attention sequence parallel requires cu_seq_lens_q to be populated by ' + 'sequence parallel input preparation.') + + if cache_params is not None: + cache_params.conv_states[mod.layer_idx] = F.pad( + mixed_qkv.transpose(1, 2).contiguous(), (mod.conv_kernel_size - mixed_qkv.shape[1], 0)) + mixed_qkv, _ = mod.causal_conv1d_fn( + x=mixed_qkv, + weight=conv_weight, + bias=conv_bias, + activation=mod.activation, + seq_idx=None, + backend='triton', + cu_seqlens=packed_cu_seqlens, + ) + if mixed_qkv.dim() == 2: + mixed_qkv = mixed_qkv.unsqueeze(0) + if mixed_qkv.dim() != 3: + raise ValueError(f'Unexpected conv output dims: {tuple(mixed_qkv.shape)}') + + local_key_dim = local_num_k_heads * mod.head_k_dim + local_value_dim = local_num_v_heads * mod.head_v_dim + query, key, value = torch.split(mixed_qkv, [local_key_dim, local_key_dim, local_value_dim], dim=-1) + query = query.reshape(batch_size, query.shape[1], local_num_k_heads, mod.head_k_dim) + key = key.reshape(batch_size, key.shape[1], local_num_k_heads, mod.head_k_dim) + value = value.reshape(batch_size, value.shape[1], local_num_v_heads, mod.head_v_dim) + + beta = b.sigmoid() + head_slice = slice(sp_rank * local_num_v_heads, + (sp_rank + 1) * local_num_v_heads) if sp_enabled else slice(None) + g = -mod.A_log[head_slice].float().exp() * F.softplus(a.float() + mod.dt_bias[head_slice]) + + if local_num_v_heads // local_num_k_heads > 1: + repeat = local_num_v_heads // local_num_k_heads + query = query.repeat_interleave(repeat, dim=2) + key = key.repeat_interleave(repeat, dim=2) + + chunk_kwargs = { + 'g': g, + 'beta': beta, + 'initial_state': None, + 'output_final_state': cache_params is not None, + 'use_qk_l2norm_in_kernel': True, + } + if packed_cu_seqlens is not None: + chunk_kwargs['cu_seqlens'] = packed_cu_seqlens + core_attn_out, last_recurrent_state = mod.chunk_gated_delta_rule(query, key, value, **chunk_kwargs) + + if cache_params is not None: + cache_params.recurrent_states[mod.layer_idx] = last_recurrent_state + + if sp_enabled: + core_attn_out = head_to_seq_shard(core_attn_out, sequence_parallel_context) + core_attn_out = mod.norm(core_attn_out.reshape(-1, mod.head_v_dim), z.reshape(-1, mod.head_v_dim)) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, local_value_dim if not sp_enabled else mod.value_dim) + return mod.out_proj(core_attn_out) + + def __call__(self, module, *args, **kwargs): + del module, args + sequence_parallel = kwargs.get('sequence_parallel', None) + if sequence_parallel is None: + return + if int(getattr(sequence_parallel, 'rp_world_size', 1) or 1) > 1: + raise NotImplementedError('Qwen3.5 linear attention sequence parallel does not support rp_world_size > 1 ' + '(derived ring attention).') + if os.environ.get('QWEN35_SP_LINEAR_HEAD_PARALLEL', '1') != '1': + return + + try: + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5GatedDeltaNet + except Exception: + return + + if getattr(Qwen3_5GatedDeltaNet, '_twinkle_sp_linear_patched', False): + return + + origin_forward = Qwen3_5GatedDeltaNet.forward + + def sp_linear_forward( + mod, + hidden_states: torch.Tensor, + cache_params=None, + cache_position=None, + attention_mask: Optional[torch.Tensor] = None, + **extra_kwargs, + ): + sequence_parallel_context = extra_kwargs.pop('sequence_parallel_context', sequence_parallel) + cu_seq_lens_q = extra_kwargs.pop('cu_seq_lens_q', None) + if cu_seq_lens_q is None and sequence_parallel_context is not None: + cu_seq_lens_q = getattr(sequence_parallel_context, 'extra_kwargs', {}).get('cu_seq_lens_q') + if not _sp_is_enabled(sequence_parallel_context): + return origin_forward( + mod, + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + ) + return Qwen3_5GatedDeltaNetUlyssesPatch._run_forward( + mod, + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + cu_seq_lens_q=cu_seq_lens_q, + sequence_parallel_context=sequence_parallel_context, + ) + + Qwen3_5GatedDeltaNet.forward = sp_linear_forward + Qwen3_5GatedDeltaNet._twinkle_sp_linear_patched = True diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/utils.py b/src/twinkle/model/transformers/strategy/sequence_parallel/utils.py new file mode 100644 index 00000000..adbb945d --- /dev/null +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/utils.py @@ -0,0 +1,383 @@ +import math +import torch +import torch.distributed as dist +from typing import Any, Dict, List, Optional, Tuple + +from twinkle.utils import DeviceMesh + + +def get_config_attr(config, key, default=None): + return getattr(config, key, default) + + +def is_hccl_backend(group=None) -> bool: + return dist.get_backend(group) == 'hccl' + + +def is_moe_config(config) -> bool: + if config is None: + return False + if 'Moe' in config.__class__.__name__: + return True + for key in ['num_experts', 'num_experts_per_tok', 'moe_intermediate_size']: + if get_config_attr(config, key): + return True + return False + + +def get_cu_seqlens_from_position_ids(position_ids: torch.LongTensor): + position_ids = position_ids[0] + seq_start_indices = torch.where(position_ids == 0)[0] + seq_end_indices = torch.cat([seq_start_indices[1:], torch.tensor([len(position_ids)], device=position_ids.device)]) + seq_lengths = seq_end_indices - seq_start_indices + cu_seqlens = torch.cumsum(torch.cat([torch.tensor([0], device=position_ids.device), seq_lengths]), dim=0) + return cu_seqlens + + +def _get_raw_data_world_size(device_mesh: DeviceMesh) -> int: + dp_world_size = device_mesh.dp_world_size or 1 + fsdp_world_size = device_mesh.fsdp_world_size or 1 + if dp_world_size <= 0: + dp_world_size = 1 + if fsdp_world_size <= 0: + fsdp_world_size = 1 + return dp_world_size * fsdp_world_size + + +def _get_raw_data_rank(device_mesh: DeviceMesh, rank: int) -> Optional[int]: + coord = device_mesh._get_coord_for_rank(rank) + if coord is None: + return None + + dp_rank = None + fsdp_rank = None + if device_mesh.has_dim('dp'): + dp_rank = coord[device_mesh._get_dim_index('dp')] + if device_mesh.has_dim('fsdp'): + fsdp_rank = coord[device_mesh._get_dim_index('fsdp')] + + fsdp_world_size = device_mesh.fsdp_world_size + data_rank = dp_rank if dp_rank is not None else None + if fsdp_world_size is not None and fsdp_world_size > 1: + if dp_rank is not None and fsdp_rank is not None: + data_rank = dp_rank * fsdp_world_size + fsdp_rank + elif fsdp_rank is not None: + data_rank = fsdp_rank + + if data_rank is None: + data_rank = 0 + return int(data_rank) + + +def _derive_sequence_parallel_sizes(num_heads: int, seq_world_size: int) -> Tuple[int, int]: + if seq_world_size <= 1: + return 1, 1 + sp_world_size = math.gcd(int(num_heads), int(seq_world_size)) + sp_world_size = max(1, sp_world_size) + if seq_world_size % sp_world_size != 0: + raise ValueError( + f'seq_world_size ({seq_world_size}) must be divisible by derived sp_world_size ({sp_world_size}).') + rp_world_size = seq_world_size // sp_world_size + return sp_world_size, rp_world_size + + +def _get_sequence_group_specs( + device_mesh: Optional[DeviceMesh], + seq_world_size: int, + sp_world_size: int, + rp_world_size: int, +) -> List[Dict[str, Any]]: + if device_mesh is None or seq_world_size <= 1: + return [] + + if seq_world_size != sp_world_size * rp_world_size: + raise ValueError(f'seq_world_size ({seq_world_size}) must equal sp_world_size ({sp_world_size}) * ' + f'rp_world_size ({rp_world_size}).') + + raw_data_world_size = _get_raw_data_world_size(device_mesh) + if raw_data_world_size % seq_world_size != 0: + raise ValueError( + f'data_world_size ({raw_data_world_size}) must be divisible by seq_world_size ({seq_world_size}).') + + non_data_indices = [] + if device_mesh.mesh_dim_names is not None: + for i, name in enumerate(device_mesh.mesh_dim_names): + if name in ('dp', 'fsdp'): + continue + non_data_indices.append(i) + + groups: Dict[Tuple[int, Tuple[int, ...]], List[Tuple[int, int]]] = {} + for r in device_mesh.mesh.flatten().tolist(): + rank = int(r) + coord = device_mesh._get_coord_for_rank(rank) + if coord is None: + continue + raw_rank = _get_raw_data_rank(device_mesh, rank) + if raw_rank is None: + continue + group_id = raw_rank // seq_world_size + seq_local_rank = raw_rank % seq_world_size + non_data_key = tuple(coord[i] for i in non_data_indices) + key = (group_id, non_data_key) + groups.setdefault(key, []).append((seq_local_rank, rank)) + + group_specs = [] + for key, items in groups.items(): + items = sorted(items, key=lambda item: item[0]) + local_ranks = [local_rank for local_rank, _ in items] + if local_ranks != list(range(seq_world_size)): + raise ValueError(f'Invalid sequence-parallel rank layout for key={key}: {local_ranks}') + seq_ranks = [rank for _, rank in items] + if len(seq_ranks) != seq_world_size: + raise ValueError( + f'Sequence-parallel group size mismatch for key={key}: expected {seq_world_size}, got {len(seq_ranks)}') + sp_groups = [seq_ranks[i * sp_world_size:(i + 1) * sp_world_size] for i in range(rp_world_size)] + rp_groups = [[sp_groups[rp_idx][sp_idx] for rp_idx in range(rp_world_size)] for sp_idx in range(sp_world_size)] + group_specs.append({ + 'key': key, + 'seq_ranks': seq_ranks, + 'sp_groups': sp_groups, + 'rp_groups': rp_groups, + }) + + group_specs.sort(key=lambda item: item['key']) + return group_specs + + +def _get_seq_groups_from_device_mesh( + device_mesh: Optional[DeviceMesh], + seq_world_size: int, + sp_world_size: int, + rp_world_size: int, +) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup], Optional[dist.ProcessGroup], int, int]: + if device_mesh is None or seq_world_size <= 1: + return None, None, None, 0, 0 + if not dist.is_available() or not dist.is_initialized(): + return None, None, None, 0, 0 + + rank = dist.get_rank() + sp_group = None + rp_group = None + data_rank_group = None + sp_rank = 0 + rp_rank = 0 + group_specs = _get_sequence_group_specs(device_mesh, seq_world_size, sp_world_size, rp_world_size) + + for spec in group_specs: + seq_pg = dist.new_group(ranks=spec['seq_ranks']) + if rank in spec['seq_ranks']: + data_rank_group = seq_pg + + if sp_world_size > 1: + for ranks in spec['sp_groups']: + pg = dist.new_group(ranks=ranks) + if rank in ranks: + sp_group = pg + sp_rank = ranks.index(rank) + + if rp_world_size > 1: + for ranks in spec['rp_groups']: + pg = dist.new_group(ranks=ranks) + if rank in ranks: + rp_group = pg + rp_rank = ranks.index(rank) + + if data_rank_group is None: + raise RuntimeError('Failed to create sequence-parallel data group from DeviceMesh.') + if sp_world_size > 1 and sp_group is None: + raise RuntimeError('Failed to create sequence-parallel SP group from DeviceMesh.') + if rp_world_size > 1 and rp_group is None: + raise RuntimeError('Failed to create sequence-parallel ring group from DeviceMesh.') + + return sp_group, rp_group, data_rank_group, sp_rank, rp_rank + + +def _get_ulysses_size(device_mesh, sp_config: Optional[Dict[str, Any]] = None) -> int: + if sp_config: + cfg_size = sp_config.get('ulysses_size') + if cfg_size is not None: + return int(cfg_size) + if device_mesh is None: + return 1 + if getattr(device_mesh, 'ulysses_size', None) is not None: + return int(device_mesh.ulysses_size) + return 1 + + +def seq_to_head_shard(tensor: torch.Tensor, sequence_parallel) -> torch.Tensor: + if getattr(sequence_parallel, 'sp_world_size', 1) <= 1: + return tensor + # [B, local_S, H, D] -> [B, global_S, local_H, D] + return _SeqAllToAll.apply(sequence_parallel._sp_group, tensor, 2, 1) + + +def head_to_seq_shard(tensor: torch.Tensor, sequence_parallel) -> torch.Tensor: + if getattr(sequence_parallel, 'sp_world_size', 1) <= 1: + return tensor + # [B, global_S, local_H, D] -> [B, local_S, H, D] + return _SeqAllToAll.apply(sequence_parallel._sp_group, tensor, 1, 2) + + +class GatherLoss(torch.autograd.Function): + """Gather loss from sequence group.""" + + @staticmethod + def forward(ctx, loss, labels, gather_idx=None, position_ids=None): + from . import sequence_parallel + ctx.scatter_shape = loss.shape[gather_idx or 0] + ctx.gather_idx = gather_idx or 0 + if position_ids is not None: + position_ids = sequence_parallel.pad(position_ids, padding_value=-1, position_ids=position_ids) + ctx.position_ids = position_ids + output = sequence_parallel.gather(loss, dim=ctx.gather_idx, position_ids=position_ids) + if labels is not None: + labels_output = sequence_parallel.gather(labels, dim=ctx.gather_idx, position_ids=position_ids) + else: + labels_output = None + return output, labels_output + + @staticmethod + def backward(ctx, *grad_output): + from . import sequence_parallel + _grad = grad_output[0] * sequence_parallel.world_size + if sequence_parallel.rp_world_size > 1: + _grad = sequence_parallel.split(_grad, dim=ctx.gather_idx, position_ids=ctx.position_ids).contiguous() + else: + _grad = _grad.split( + ctx.scatter_shape, dim=ctx.gather_idx)[dist.get_rank(group=sequence_parallel._sp_group)].contiguous() + return _grad, None, None, None + + +def _generate_layout_params(scatter_idx, seq_world_size, input): + if scatter_idx < 2: + bs, global_seq_len, num_local_head, head_dim = input.shape + pre_all2all_inp_shape = [bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim] + pre_all2all_permute_idx = (1, 0, 2, 3, 4) + + post_all2all_permute_idx = (1, 2, 0, 3, 4) + post_all2all_res_shape = [bs, global_seq_len // seq_world_size, seq_world_size * num_local_head, head_dim] + else: + bs, local_seq_len, num_total_head, head_dim = input.shape + assert num_total_head % seq_world_size == 0, (f'Number of heads ({num_total_head}) must be divisible ' + f'by the sequence parallel size ({seq_world_size})!') + pre_all2all_inp_shape = [bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim] + pre_all2all_permute_idx = (2, 0, 1, 3, 4) + + post_all2all_permute_idx = (1, 0, 2, 3, 4) + post_all2all_res_shape = [bs, seq_world_size * local_seq_len, num_total_head // seq_world_size, head_dim] + + return pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape + + +def post_all2all(permute_idx, res_shape): + + def post_func(input): + if permute_idx is not None: + input = input.permute(permute_idx).contiguous() + output = input.reshape(res_shape).contiguous() + return output + + return post_func + + +def pre_all2all_fun(permute_idx, inp_shape, input): + input_t = input.reshape(inp_shape).contiguous() + if permute_idx is not None: + input_t = input_t.permute(permute_idx).contiguous() + return input_t + + +def single_all_to_all(input, scatter_idx, gather_idx, group, **kwargs): + seq_world_size = dist.get_world_size(group) + num_heads = input.shape[2] + if num_heads % seq_world_size != 0 and not scatter_idx < 2: + raise NotImplementedError(f'num_heads {num_heads} cannot be split by sp world size {seq_world_size}') + pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape = ( + _generate_layout_params(scatter_idx, seq_world_size, input)) + + input_t = pre_all2all_fun(pre_all2all_permute_idx, pre_all2all_inp_shape, input) + post_all2all_fun = post_all2all(post_all2all_permute_idx, post_all2all_res_shape) + output = torch.empty_like(input_t) + dist.all_to_all_single(output, input_t, group=group) + res = post_all2all_fun(output) + return res + + +class _SeqAllToAll(torch.autograd.Function): + + @staticmethod + def forward( + ctx: Any, + group: dist.ProcessGroup, + input: torch.Tensor, + scatter_idx: int, + gather_idx: int, + ) -> torch.Tensor: + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + res = single_all_to_all(input, scatter_idx, gather_idx, group) + return res + + @staticmethod + def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, None, None]: + return None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None + + +class DistributedAttention(torch.nn.Module): + + def __init__( + self, + local_attention, + sequence_parallel, + scatter_idx: int = 2, + gather_idx: int = 1, + ) -> None: + super().__init__() + self.local_attn = local_attention + self.sequence_parallel = sequence_parallel + self.scatter_idx = scatter_idx + self.gather_idx = gather_idx + + def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: torch.Tensor, *args: + Any, **kwargs) -> torch.Tensor: + if self.sequence_parallel.world_size == 1: + return self.local_attn(query, key, value, attention_mask, *args, **kwargs) + if self.sequence_parallel.rp_world_size > 1 and attention_mask is not None: + if torch.is_tensor(attention_mask) and not attention_mask.all(): + raise NotImplementedError( + 'Derived ring attention only supports padding-free / packed inputs without masked padding.') + + if self.sequence_parallel.sp_world_size > 1: + query_layer = _SeqAllToAll.apply(self.sequence_parallel._sp_group, query, self.scatter_idx, self.gather_idx) + key_layer = _SeqAllToAll.apply(self.sequence_parallel._sp_group, key, self.scatter_idx, self.gather_idx) + value_layer = _SeqAllToAll.apply(self.sequence_parallel._sp_group, value, self.scatter_idx, self.gather_idx) + else: + query_layer, key_layer, value_layer = query, key, value + + if (self.sequence_parallel.sp_world_size > 1 and torch.is_tensor(attention_mask) and attention_mask.dim() == 4): + if attention_mask.shape[-1] != key_layer.shape[1]: + attention_mask = self.sequence_parallel.gather(attention_mask, dim=-1, position_ids=None) + if attention_mask.shape[-2] != query_layer.shape[1]: + attention_mask = self.sequence_parallel.gather(attention_mask, dim=-2, position_ids=None) + + if self.sequence_parallel.rp_world_size > 1: + kwargs.pop('position_ids', None) + position_ids = self.sequence_parallel.real_position_ids + position_ids = self.sequence_parallel.pad(position_ids, padding_value=-1, position_ids=position_ids) + else: + position_ids = kwargs.pop('position_ids') + if position_ids is not None and self.sequence_parallel.sp_world_size > 1: + # Reuse the generic gather path to support both 2D and 3D position_ids (e.g. mrope). + position_ids = self.sequence_parallel.gather(position_ids.contiguous(), dim=-1, position_ids=None) + + context_layer = self.local_attn( + query_layer, key_layer, value_layer, attention_mask, *args, position_ids=position_ids, **kwargs) + + if self.sequence_parallel.sp_world_size > 1: + output = _SeqAllToAll.apply(self.sequence_parallel._sp_group, context_layer, self.gather_idx, + self.scatter_idx) + else: + output = context_layer + + return output diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/zigzag_ring_attn.py b/src/twinkle/model/transformers/strategy/sequence_parallel/zigzag_ring_attn.py new file mode 100644 index 00000000..0d4c79a6 --- /dev/null +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/zigzag_ring_attn.py @@ -0,0 +1,642 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import inspect +import torch +import torch.distributed as dist +import torch.nn.functional as F +from functools import cache +from typing import Optional, Tuple + + +class RingComm: + + def __init__(self, process_group: dist.ProcessGroup): + self._process_group = process_group + self._ops = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = None + + self.send_rank = (self.rank + 1) % self.world_size + self.recv_rank = (self.rank - 1) % self.world_size + + if process_group is not None: + self.send_rank = dist.get_global_rank(self._process_group, self.send_rank) + self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank) + + def send_recv(self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor: + if recv_tensor is None: + res = torch.empty_like(to_send) + else: + res = recv_tensor + + send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group) + recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + self._ops.append(send_op) + self._ops.append(recv_op) + return res + + def commit(self): + if self._reqs is not None: + raise RuntimeError('commit called twice') + self._reqs = dist.batch_isend_irecv(self._ops) + + def wait(self): + if self._reqs is None: + raise RuntimeError('wait called before commit') + for req in self._reqs: + req.wait() + self._reqs = None + self._ops = [] + + def send_recv_kv( + self, + k: torch.Tensor, + v: torch.Tensor, + k_buffer: Optional[torch.Tensor] = None, + v_buffer: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + next_k, next_v = self.send_recv(k, k_buffer), self.send_recv(v, v_buffer) + self.commit() + return next_k, next_v + + +def get_half_index(cu_seqlens, *, front: bool): + if len(cu_seqlens) == 2: + if front: + return slice(None, cu_seqlens[-1] // 2) + return slice(cu_seqlens[-1] // 2, None) + + index = torch.zeros((cu_seqlens[-1].item(), ), dtype=torch.bool) + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + if front: + end = (start + end) // 2 + else: + start = (start + end) // 2 + index[start:end] = True + return index + + +@torch.jit.script +def get_half_lse(lse, cu_seqlens, *, front: bool): + new_lse = torch.empty( + (lse.shape[0], lse.shape[1] // 2), + dtype=lse.dtype, + device=lse.device, + ) + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i].item(), cu_seqlens[i + 1].item() + new_start, new_end = start // 2, end // 2 + if front: + end -= (end - start) // 2 + else: + start += (end - start) // 2 + new_lse[:, new_start:new_end] = lse[:, start:end] + return new_lse + + +def update_out_and_lse(out, lse, block_out, block_lse): + if out is None: + out = block_out.to(torch.float32) + lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + sig_diff = None + else: + block_out = block_out.to(torch.float32) + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + diff = block_lse - lse + sig_diff = torch.sigmoid(diff) + + out = out - sig_diff * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + return out, lse, sig_diff + + +@cache +def _get_default_args(func): + spec = inspect.getfullargspec(func) + defaults = spec.defaults if spec.defaults is not None else () + padded_defaults = (None, ) * (len(spec.args) - len(defaults)) + defaults + args = dict(zip(spec.args, padded_defaults)) + if 'softcap' in args: + args['softcap'] = 0.0 + return args + + +def get_default_args(func): + if inspect.isfunction(func): + return _get_default_args(func) + return _get_default_args(func._init_fn) + + +def squeeze_batch(*tensors): + squeezed = [] + for sub in tensors: + if sub.shape[0] == 1: + squeezed.append(sub.squeeze(0)) + else: + squeezed.append(sub) + return tuple(squeezed) + + +def padding(tensor, cu_seqlens, padding_value, front): + if len(cu_seqlens) == 2: + if front: + return torch.cat((tensor, torch.full_like(tensor, padding_value).to(tensor.dtype).to(tensor.device)), dim=0) + return torch.cat((torch.full_like(tensor, padding_value).to(tensor.dtype).to(tensor.device), tensor), dim=0) + + output = [] + acc = 0 + for i in range(len(cu_seqlens) - 1): + start, end = cu_seqlens[i], cu_seqlens[i + 1] + half_len = (end - start) // 2 + acc += half_len + half_start = start // 2 + local_tensor = tensor[half_start:half_start + half_len] + if front: + output.append(local_tensor) + output.append(torch.full_like(local_tensor, padding_value).to(local_tensor.dtype).to(local_tensor.device)) + else: + output.append(torch.full_like(local_tensor, padding_value).to(local_tensor.dtype).to(local_tensor.device)) + output.append(local_tensor) + assert acc == tensor.shape[0] + return torch.cat(output) + + +def forward( + q, + k, + v, + causal, + cu_seqlens, + max_seqlen, + block_seq_len, + dropout_p, + softmax_scale, + alibi_slopes, + window_size, +): + seqlen_q = q.shape[0] + seqlen_kv = k.shape[0] + half_cu_seqlens = cu_seqlens // 2 + half_max_seqlen = max_seqlen // 2 + cu_seqlens_q = half_cu_seqlens if seqlen_q == block_seq_len else cu_seqlens + max_seqlen_q = half_max_seqlen if seqlen_q == block_seq_len else max_seqlen + cu_seqlens_kv = half_cu_seqlens if seqlen_kv == block_seq_len else cu_seqlens + max_seqlen_kv = half_max_seqlen if seqlen_kv == block_seq_len else max_seqlen + from flash_attn.flash_attn_interface import _flash_attn_varlen_forward + + params = get_default_args(_flash_attn_varlen_forward).copy() + params.update({ + 'q': q, + 'k': k, + 'v': v, + 'cu_seqlens_q': cu_seqlens_q, + 'cu_seqlens_k': cu_seqlens_kv, + 'max_seqlen_q': max_seqlen_q, + 'max_seqlen_k': max_seqlen_kv, + 'dropout_p': dropout_p, + 'softmax_scale': softmax_scale, + 'causal': causal, + 'alibi_slopes': alibi_slopes, + 'return_softmax': True and dropout_p > 0, + }) + if 'window_size' in params: + params.update({'window_size': window_size}) + else: + params.update({ + 'window_size_left': window_size[0], + 'window_size_right': window_size[1], + }) + assert k.shape[-0] == cu_seqlens_kv[-1] + assert q.shape[-0] == cu_seqlens_q[-1] + assert max_seqlen_q == (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + assert max_seqlen_kv == (cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]).max().item() + outputs = _flash_attn_varlen_forward(**params) + if len(outputs) == 8: + block_out, _, _, _, _, block_lse, _, _ = outputs + else: + assert len(outputs) == 4 + block_out, block_lse, _, _ = outputs + return block_out, block_lse + + +def backward( + dout, + q, + k, + v, + out, + softmax_lse, + causal, + cu_seqlens, + max_seqlen, + block_seq_len, + dq_buffer, + dk_buffer, + dv_buffer, + dropout_p, + softmax_scale, + alibi_slopes, + deterministic, + window_size, +): + seqlen_q = q.shape[0] + seqlen_kv = k.shape[0] + + half_cu_seqlens = cu_seqlens // 2 + half_max_seqlen = max_seqlen // 2 + cu_seqlens_q = half_cu_seqlens if seqlen_q == block_seq_len else cu_seqlens + max_seqlen_q = half_max_seqlen if seqlen_q == block_seq_len else max_seqlen + cu_seqlens_kv = half_cu_seqlens if seqlen_kv == block_seq_len else cu_seqlens + max_seqlen_kv = half_max_seqlen if seqlen_kv == block_seq_len else max_seqlen + from flash_attn.flash_attn_interface import _flash_attn_varlen_backward + + params = get_default_args(_flash_attn_varlen_backward).copy() + params.update({ + 'dout': dout, + 'q': q, + 'k': k, + 'v': v, + 'out': out, + 'softmax_lse': softmax_lse, + 'dq': dq_buffer[:seqlen_q], + 'dk': dk_buffer[:seqlen_kv], + 'dv': dv_buffer[:seqlen_kv], + 'cu_seqlens_q': cu_seqlens_q, + 'cu_seqlens_k': cu_seqlens_kv, + 'max_seqlen_q': max_seqlen_q, + 'max_seqlen_k': max_seqlen_kv, + 'dropout_p': dropout_p, + 'softmax_scale': softmax_scale, + 'causal': causal, + 'alibi_slopes': alibi_slopes, + 'deterministic': deterministic, + }) + assert dout.shape[0] == q.shape[0] + assert dout.shape[0] == out.shape[0] + assert softmax_lse.shape[1] == q.shape[0] + assert k.shape[0] == cu_seqlens_kv[-1] + assert q.shape[0] == cu_seqlens_q[-1] + assert max_seqlen_q == (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + assert max_seqlen_kv == (cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]).max().item() + if 'window_size' in params: + params.update({'window_size': window_size}) + else: + params.update({ + 'window_size_left': window_size[0], + 'window_size_right': window_size[1], + }) + _flash_attn_varlen_backward(**params) + + +def lse_grad(out, lse, block_out, block_lse, sig, grad_out, grad_lse): + grad_out_input = grad_out * (1 - sig) + grad_block_out = grad_out * sig + d_new_out_d_lse = (out - block_out) * (sig * (1 - sig)) + grad_lse_input = (grad_out * d_new_out_d_lse).sum(dim=-1, keepdim=True) + grad_lse_input_final = grad_lse_input + grad_lse * torch.sigmoid(lse - block_lse) + grad_block_lse = -grad_lse_input_final + grad_lse + return grad_out_input, grad_lse_input_final, grad_block_out, grad_block_lse + + +def zigzag_ring_flash_attn_varlen_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens, + max_seqlen, + half_index0, + half_index1, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal, 'zigzag ring is meaningless for causal=False' + comm = RingComm(process_group) + q, k, v = squeeze_batch(q, k, v) + q1 = q[half_index1] + cu_seqlens = cu_seqlens // comm.world_size + max_seqlen = max_seqlen // comm.world_size + block_seq_len = q.shape[0] // 2 + out = None + lse = None + next_k, next_v = None, None + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k, next_v = comm.send_recv_kv(k, v) + if step == 0: + block_out, block_lse = forward(q, k, v, True, cu_seqlens, max_seqlen, block_seq_len, dropout_p, + softmax_scale, alibi_slopes, window_size) + out, lse, _ = update_out_and_lse(out, lse, block_out, block_lse) + elif step <= comm.rank: + k0 = k[half_index0] + v0 = v[half_index0] + block_out, block_lse = forward(q, k0, v0, False, cu_seqlens, max_seqlen, block_seq_len, dropout_p, + softmax_scale, alibi_slopes, window_size) + out, lse, _ = update_out_and_lse(out, lse, block_out, block_lse) + else: + block_out, block_lse = forward(q1, k, v, False, cu_seqlens, max_seqlen, block_seq_len, dropout_p, + softmax_scale, alibi_slopes, window_size) + out[half_index1], lse[half_index1], _ = update_out_and_lse(out[half_index1], lse[half_index1], block_out, + block_lse) + + if step + 1 != comm.world_size: + comm.wait() + k, v = next_k, next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(0, 1) + return out.unsqueeze(0), lse.unsqueeze(0) + + +def zigzag_ring_flash_attn_varlen_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens, + max_seqlen, + half_index0, + half_index1, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal, 'zigzag ring is meaningless for causal=False' + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dk_comm_buffer = dv_comm_buffer = None + dq = dk = dv = None + next_dk = next_dv = None + next_k = next_v = None + + dout, q, k, v, out, softmax_lse = squeeze_batch(dout, q, k, v, out, softmax_lse) + q1 = q[half_index1] + cu_seqlens = cu_seqlens // kv_comm.world_size + max_seqlen = max_seqlen // kv_comm.world_size + block_seq_len = q.shape[0] // 2 + + dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + origin_q, origin_k, origin_v = q, k, v + + out_lse = [] + fout = None + flse = None + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k, next_v = kv_comm.send_recv_kv(k, v) + + if step == 0: + block_out, block_lse = forward(q, k, v, True, cu_seqlens, max_seqlen, block_seq_len, dropout_p, + softmax_scale, alibi_slopes, window_size) + fout, flse, sig_diff = update_out_and_lse(fout, flse, block_out, block_lse) + elif step <= kv_comm.rank: + k0 = k[half_index0] + v0 = v[half_index0] + block_out, block_lse = forward(q, k0, v0, False, cu_seqlens, max_seqlen, block_seq_len, dropout_p, + softmax_scale, alibi_slopes, window_size) + fout, flse, sig_diff = update_out_and_lse(fout, flse, block_out, block_lse) + else: + block_out, block_lse = forward(q1, k, v, False, cu_seqlens, max_seqlen, block_seq_len, dropout_p, + softmax_scale, alibi_slopes, window_size) + fout[half_index1], flse[half_index1], sig_diff = update_out_and_lse(fout[half_index1], flse[half_index1], + block_out, block_lse) + + block_lse = block_lse.transpose(0, 1).unsqueeze(-1) + if step > kv_comm.rank: + block_out = padding(block_out, cu_seqlens, 0, front=False) + block_lse = padding(block_lse, cu_seqlens, -1e5, front=False) + sig_diff = padding(sig_diff, cu_seqlens, 0, front=False) + + out_lse.append((fout, flse, block_out, block_lse, sig_diff)) + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v = next_k, next_v + + current_dout = dout + current_dlse = torch.zeros_like(softmax_lse.transpose(0, 1).unsqueeze(-1)) + block_gradients = {} + + for i in reversed(range(len(out_lse))): + if i == 0: + continue + stored_out, stored_lse, stored_block_out, stored_block_lse, stored_sig = out_lse[i] + grad_out_input, grad_lse_input, grad_block_out, grad_block_lse = lse_grad( + stored_out, + stored_lse, + stored_block_out, + stored_block_lse, + stored_sig, + current_dout, + current_dlse, + ) + current_dout = grad_out_input + current_dlse = grad_lse_input + block_gradients[i] = {'grad_block_out': grad_block_out, 'grad_block_lse': grad_block_lse} + + q, k, v = origin_q, origin_k, origin_v + + for step in range(kv_comm.world_size): + _, _, block_out, block_lse, _ = out_lse[step] + if block_out.isnan().any() or block_lse.isnan().any(): + raise RuntimeError('NaN detected in ring attention backward recompute.') + block_lse = block_lse.transpose(0, 1).squeeze(2) + + if step + 1 != kv_comm.world_size: + next_k, next_v = kv_comm.send_recv_kv(k, v) + + if step == 0: + block_dout = current_dout + else: + block_dout = block_gradients[step]['grad_block_out'] + + if block_dout.isnan().any(): + raise RuntimeError('NaN detected in ring attention dout.') + + if step == 0: + backward( + block_dout.to(dout.dtype), q, k, v, block_out, block_lse, True, cu_seqlens, max_seqlen, block_seq_len, + dq_buffer, dk_buffer, dv_buffer, dropout_p, softmax_scale, alibi_slopes, deterministic, window_size) + dq = dq_buffer.to(torch.float32) + dk = dk_buffer.to(torch.float32) + dv = dv_buffer.to(torch.float32) + if dq.isnan().any() or dk.isnan().any() or dv.isnan().any(): + raise RuntimeError('NaN detected in ring attention gradients.') + else: + if step <= kv_comm.rank: + k0 = k[half_index0] + v0 = v[half_index0] + backward( + block_dout.to(dout.dtype), q, k0, v0, block_out, block_lse, False, cu_seqlens, max_seqlen, + block_seq_len, dq_buffer, dk_buffer, dv_buffer, dropout_p, softmax_scale, alibi_slopes, + deterministic, window_size) + dq += dq_buffer + else: + backward(block_dout[half_index1].to(dout.dtype), q1, k, v, block_out[half_index1], + get_half_lse(block_lse, cu_seqlens, + front=False), False, cu_seqlens, max_seqlen, block_seq_len, dq_buffer, dk_buffer, + dv_buffer, dropout_p, softmax_scale, alibi_slopes, deterministic, window_size) + dq[half_index1] += dq_buffer[:block_seq_len] + + d_kv_comm.wait() + dk_comm_buffer = torch.empty_like(dk) + dv_comm_buffer = torch.empty_like(dv) + dk_comm_buffer.copy_(dk) + dv_comm_buffer.copy_(dv) + dk, dv = next_dk, next_dv + + if step <= kv_comm.rank: + dk[half_index0] += dk_buffer[:block_seq_len] + dv[half_index0] += dv_buffer[:block_seq_len] + else: + dk += dk_buffer + dv += dv_buffer + if dq.isnan().any() or dk.isnan().any() or dv.isnan().any(): + raise RuntimeError('NaN detected in accumulated ring attention gradients.') + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k, v = next_k, next_v + + next_dk, next_dv = d_kv_comm.send_recv_kv(dk, dv, dk_comm_buffer, dv_comm_buffer) + + d_kv_comm.wait() + return dq.to(q.dtype).unsqueeze(0), next_dk.to(q.dtype).unsqueeze(0), next_dv.to(q.dtype).unsqueeze(0) + + +class ZigZagRingFlashAttnVarlenFunc(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1]**(-0.5) + + assert alibi_slopes is None + k = k.contiguous() + v = v.contiguous() + rp_world_size = dist.get_world_size(group) + half_index0 = get_half_index(cu_seqlens // rp_world_size, front=True) + half_index1 = get_half_index(cu_seqlens // rp_world_size, front=False) + out, softmax_lse = zigzag_ring_flash_attn_varlen_forward( + group, + q, + k, + v, + cu_seqlens, + max_seqlen, + half_index0, + half_index1, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + is_half_index_tensor = isinstance(half_index0, torch.Tensor) + ctx.is_half_index_tensor = is_half_index_tensor + if is_half_index_tensor: + ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens, half_index0, half_index1) + else: + ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens) + ctx.half_index0 = half_index0 + ctx.half_index1 = half_index1 + ctx.max_seqlen = max_seqlen + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): + if ctx.is_half_index_tensor: + q, k, v, out, softmax_lse, cu_seqlens, half_index0, half_index1 = ctx.saved_tensors + else: + q, k, v, out, softmax_lse, cu_seqlens = ctx.saved_tensors + half_index0 = ctx.half_index0 + half_index1 = ctx.half_index1 + dq, dk, dv = zigzag_ring_flash_attn_varlen_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens, + ctx.max_seqlen, + half_index0, + half_index1, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None + + +def zigzag_ring_flash_attn_varlen_func( + q, + k, + v, + cu_seqlens, + max_seqlen, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 6097ffe2..3aa384f1 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -42,6 +42,18 @@ from twinkle.utils.grad_clip import normalize_and_clip_grad_norm +def _get_raw_dp_fsdp_world_size(device_mesh: Optional[DeviceMesh]) -> int: + if device_mesh is None: + return 1 + dp_world_size = device_mesh.dp_world_size or 1 + fsdp_world_size = device_mesh.fsdp_world_size or 1 + if dp_world_size <= 0: + dp_world_size = 1 + if fsdp_world_size <= 0: + fsdp_world_size = 1 + return dp_world_size * fsdp_world_size + + @dataclass class OptimizerGroup(BaseOptimizerGroup): """Optimizer group for Transformers training.""" @@ -74,12 +86,13 @@ def _build_metrics(self): def _ensure_dp_group(self): if self._dp_group is not None or self._device_mesh is None: return - if self._device_mesh.data_world_size <= 1: + raw_world_size = _get_raw_dp_fsdp_world_size(self._device_mesh) + if raw_world_size <= 1: return if not dist.is_available() or not dist.is_initialized(): return - if dist.get_world_size() < self._device_mesh.data_world_size: - # World size is smaller than the requested dp group; skip to avoid crash. + if dist.get_world_size() < raw_world_size: + # World size is smaller than the requested dp/fsdp group; skip to avoid crash. return dims = [dim for dim in ('dp', 'fsdp') if self._device_mesh.has_dim(dim)] if not dims: @@ -230,15 +243,9 @@ def _ensure_sp_strategy(self) -> None: return from .strategy.sequence_parallel import SequenceParallelStrategy - sp_config = {} - # When data-parallel gradient averaging runs across SP shards (native FSDP or - # accelerate DDP/FSDP paths), compensate SP loss backward to keep gradient scale. - if isinstance(self.strategy, (NativeFSDPStrategy, AccelerateStrategy)) and self.device_mesh is not None: - if (self.device_mesh.ulysses_size or 1) > 1 and (self.device_mesh.data_world_size or 1) > 1: - sp_config['compensate_fsdp_avg'] = True self.sp_strategy = SequenceParallelStrategy( self.device_mesh, - sp_config, + {}, model=self.model, tokenizer_id=self.tokenizer_id, ) @@ -360,9 +367,7 @@ def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajec inputs = optimizer_config.template.batch_encode(inputs) # noqa processor: InputProcessor = optimizer_config.processor assert isinstance(processor, InputProcessor), 'Set a correct `InputProcessor` before forwarding' - inputs: Dict[str, Any] = processor(inputs) - if self.sp_strategy is not None: - inputs = self.sp_strategy.preprocess_inputs(inputs) + inputs: Dict[str, Any] = processor(inputs, sp_strategy=self.sp_strategy) labels: torch.Tensor = inputs.pop('labels', None) optimizer_config.accumulate_metrics(True) outputs = self.model(**inputs) @@ -418,9 +423,7 @@ def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], List[T with torch.no_grad(): processor: InputProcessor = optimizer_config.processor assert isinstance(processor, InputProcessor), 'Set InputProcessor correctly before forwarding' - inputs: Dict[str, Any] = processor(inputs) - if self.sp_strategy is not None: - inputs = self.sp_strategy.preprocess_inputs(inputs) + inputs: Dict[str, Any] = processor(inputs, sp_strategy=self.sp_strategy) labels = inputs.pop('labels', None) optimizer_config.accumulate_metrics(False) unwrapped_model = self.strategy.unwrap_model(self.model) @@ -467,7 +470,11 @@ def calculate_loss(self, **kwargs): inputs = optimizer_config.train_status.inputs outputs = optimizer_config.train_status.outputs assert inputs is not None and outputs is not None, 'Cannot calculate loss of empty inputs and outputs' - result = loss_instance(inputs, outputs, **kwargs) + loss_inputs = inputs + loss_outputs = outputs + if self.sp_strategy is not None: + loss_inputs, loss_outputs = self.sp_strategy.gather_loss_tensors(inputs, outputs) + result = loss_instance(loss_inputs, loss_outputs, **kwargs) loss_value = result['loss'] counts = result['num_tokens'] if not counts: @@ -488,11 +495,6 @@ def calculate_loss(self, **kwargs): counts = counts / self.device_mesh.data_world_size optimizer_config = self.optimizer_group[adapter_name] optimizer_config.train_status.num_tokens += counts.item() - if self.sp_strategy is not None and 'labels' in inputs: - reduction = getattr(loss_instance, 'reduction', None) - if reduction is not None: - self.sp_strategy.sp_config['loss_reduction'] = str(reduction) - loss_value = self.sp_strategy.reduce_loss(loss_value, inputs['labels']) optimizer_config.train_status.loss_value += loss_value outputs['loss'] = optimizer_config.train_status.loss_value return optimizer_config.train_status.loss_value.item() diff --git a/src/twinkle/processor/base.py b/src/twinkle/processor/base.py index 576db8cd..d704f8df 100644 --- a/src/twinkle/processor/base.py +++ b/src/twinkle/processor/base.py @@ -23,6 +23,7 @@ class PackedSeqParams: class InputProcessor: padding_map = { 'input_ids': 0, + 'mm_token_type_ids': 0, 'inputs_embeds': 0.0, 'attention_mask': 0, 'labels': -100, @@ -66,6 +67,7 @@ def __init__(self, self.to_transformers_dict, self.add_extra_padding_free_args, self.split_cp, + self.apply_transformers_sp, self.prepare_outputs, ] @@ -108,10 +110,18 @@ def to_tensor(_input): return [to_tensor(_input) for _input in inputs] + def apply_transformers_sp(self, inputs: List[InputFeature], **kwargs) -> List[InputFeature]: + sp_strategy = kwargs.get('sp_strategy') + if self.framework != 'transformers' or sp_strategy is None: + return inputs + return [InputFeature(**sp_strategy.preprocess_inputs(dict(_input))) for _input in inputs] + def pad_cp(self, inputs: List[InputFeature], **kwargs) -> List[InputFeature]: if self.device_mesh is None: return inputs + if self.framework == 'transformers': + return inputs def _pad_cp(_input: InputFeature) -> InputFeature: # Pad sequence for parallel compatibility @@ -172,6 +182,8 @@ def split_cp(self, inputs: List[Dict[str, Any]], **kwargs) -> List[Dict[str, Any if self.device_mesh is None: return inputs + if self.framework == 'transformers': + return inputs def _split_cp(inputs: Dict[str, Any]) -> Dict[str, Any]: diff --git a/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py b/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py deleted file mode 100644 index b035db3c..00000000 --- a/tests/moe/test_expert_parallel_qwen3_fsdp_sp.py +++ /dev/null @@ -1,664 +0,0 @@ -# Copyright (c) ModelScope Contributors. All rights reserved. -import copy -import json -import numpy as np -import os -import socket -import sys -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -import torch.nn.functional as F -import unittest -from datetime import timedelta -from pathlib import Path -from torch import nn -from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig -from typing import Dict, List, Optional, Tuple - -from twinkle.model.transformers.moe import apply_expert_parallel -from twinkle.model.transformers.strategy import NativeFSDPStrategy -from twinkle.model.transformers.strategy.sequence_parallel import (SequenceParallelStrategy, - _get_sp_group_from_device_mesh, sequence_parallel) -from twinkle.utils import DeviceMesh - -# QWEN3_MOE_MODEL_ID=/path/to/Qwen3-MoE \ -# QWEN3_MOE_LOCAL_ONLY=1 \ -# pytest -q tests/moe/test_expert_parallel_qwen3_fsdp_sp.py -rs - - -def _find_free_port() -> int: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(('127.0.0.1', 0)) - return sock.getsockname()[1] - - -def _enable_strict_determinism(seed: int) -> None: - """Best-effort deterministic knobs (still not guaranteed bitwise with NCCL collectives).""" - # These should be set before CUDA context is initialized for best effect. - os.environ.setdefault('PYTHONHASHSEED', str(seed)) - os.environ.setdefault('CUBLAS_WORKSPACE_CONFIG', ':16:8') - os.environ.setdefault('NCCL_DETERMINISTIC', '1') - os.environ.setdefault('FLASH_ATTENTION_DETERMINISTIC', '1') - os.environ.setdefault('NCCL_ASYNC_ERROR_HANDLING', '1') - - torch.backends.cuda.matmul.allow_tf32 = False - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - torch.backends.cudnn.enabled = False - # Disable reduced-precision bf16 reductions when possible. - if hasattr(torch.backends.cuda.matmul, 'allow_bf16_reduced_precision_reduction'): - torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False - - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.use_deterministic_algorithms(True, warn_only=True) - - -def _find_moe_blocks(model: nn.Module) -> List[nn.Module]: - blocks = [] - for module in model.modules(): - experts = getattr(module, 'experts', None) - if experts is None: - continue - if not isinstance(experts, nn.ModuleList): - if not (hasattr(experts, 'gate_up_proj') and hasattr(experts, 'down_proj')): - continue - gate = getattr(module, 'gate', None) or getattr(module, 'router', None) - if gate is None: - continue - blocks.append(module) - return blocks - - -def _get_top_k(block: nn.Module) -> int: - if hasattr(block, 'num_experts_per_tok') and getattr(block, 'num_experts_per_tok') is not None: - return int(getattr(block, 'num_experts_per_tok')) - if hasattr(block, 'top_k') and getattr(block, 'top_k') is not None: - return int(getattr(block, 'top_k')) - gate = getattr(block, 'gate', None) or getattr(block, 'router', None) - if gate is not None and hasattr(gate, 'top_k') and getattr(gate, 'top_k') is not None: - return int(getattr(gate, 'top_k')) - raise RuntimeError('Cannot infer top_k for MoE block.') - - -def _capture_router_state(model: nn.Module): - # Return a list aligned with _find_moe_blocks order. - states: List[Dict[str, torch.Tensor]] = [] - handles = [] - for block in _find_moe_blocks(model): - gate = getattr(block, 'gate', None) or getattr(block, 'router', None) - if gate is None: - continue - top_k = _get_top_k(block) - norm_topk_prob = getattr(block, 'norm_topk_prob', False) - - def _hook(module, inputs, output, *, _top_k=top_k, _norm=norm_topk_prob): - if isinstance(output, tuple): - router_logits, routing_weights, selected_experts = output[:3] - else: - router_logits = output - routing_weights = torch.softmax(router_logits, dim=-1, dtype=torch.float32) - routing_weights, selected_experts = torch.topk(routing_weights, _top_k, dim=-1) - if _norm: - routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) - states.append({ - 'selected_experts': selected_experts.detach().cpu(), - 'routing_weights': routing_weights.detach().cpu(), - }) - - handles.append(gate.register_forward_hook(_hook)) - return states, handles - - -def _load_qwen3_moe_config(model_id: str, local_files_only: bool): - try: - return AutoConfig.from_pretrained( - model_id, - trust_remote_code=True, - local_files_only=local_files_only, - ) - except Exception as exc: # noqa: BLE001 - config_path = Path(model_id) / 'config.json' - if not config_path.exists(): - raise exc - with config_path.open('r', encoding='utf-8') as handle: - data = json.load(handle) - if 'model_type' not in data: - data['model_type'] = 'qwen3_moe' - if 'architectures' not in data: - data['architectures'] = ['Qwen3MoeForCausalLM'] - try: - return AutoConfig.from_dict(data) - except Exception as exc: # noqa: BLE001 - print(f'AutoConfig.from_dict fallback to PretrainedConfig for {model_id}: {exc}') - return PretrainedConfig.from_dict(data) - - -def _load_qwen3_moe_pretrained(model_id: str, local_files_only: bool, device: torch.device) -> nn.Module: - config = _load_qwen3_moe_config(model_id, local_files_only) - if hasattr(config, 'num_hidden_layers'): - config.num_hidden_layers = 1 - if hasattr(config, 'use_cache'): - config.use_cache = False - if hasattr(config, '_experts_implementation'): - config._experts_implementation = 'eager' - model = AutoModelForCausalLM.from_pretrained( - model_id, - config=config, - torch_dtype=torch.bfloat16, - low_cpu_mem_usage=True, - trust_remote_code=True, - local_files_only=local_files_only, - ) - model.to(device) - model.eval() - return model - - -def _ensure_embed_tokens(model, embed) -> None: - # SequenceParallel's forward hook looks for `_self.language_model.embed_tokens` or `_self.embed_tokens` - # where `_self` is the top-level model passed to `sequence_parallel.prepare(...)`. - # - # HF models vary: some expose `.language_model`, others expose `.model` (decoder), etc. - targets = [model] - for attr in ('language_model', 'model'): - if hasattr(model, attr): - targets.append(getattr(model, attr)) - for t in targets: - if t is not None and getattr(t, 'embed_tokens', None) is None: - t.embed_tokens = embed - - -def _per_token_ce_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - # [B,S,V] + [B,S] -> [B,S] (sum/avg applied by caller) - loss_1d = F.cross_entropy( - logits.view(-1, logits.size(-1)), - labels.view(-1), - ignore_index=-100, - reduction='none', - ) - return loss_1d.view(labels.shape) - - -def _sp_slice_range_for_seq_len( - seq_len: int, - *, - sp_group: Optional[dist.ProcessGroup], - sp_size: int, -) -> Tuple[int, int]: - if sp_group is None or sp_size <= 1: - return 0, seq_len - sp_rank = dist.get_rank(sp_group) - if seq_len % sp_size != 0: - raise ValueError(f'seq_len ({seq_len}) must be divisible by sp_size ({sp_size}) in this test.') - local = seq_len // sp_size - start = sp_rank * local - end = start + local - return start, end - - -def _gather_full_seq_grad_from_sp(local_grad: torch.Tensor, *, sp_group: Optional[dist.ProcessGroup]) -> torch.Tensor: - """Gather per-rank local sequence gradients into a full-sequence gradient on every rank.""" - if sp_group is None or dist.get_world_size(sp_group) <= 1: - return local_grad.contiguous() - world = dist.get_world_size(sp_group) - chunks = [torch.empty_like(local_grad) for _ in range(world)] - dist.all_gather(chunks, local_grad.contiguous(), group=sp_group) - return torch.cat(chunks, dim=1).contiguous() - - -def _collect_active_local_expert_grad_tensors( - block: nn.Module, - active_global_experts: torch.Tensor, -) -> Dict[str, torch.Tensor]: - """Return a {f\"expert{global}.{param_name}\": grad_tensor_cpu} dict for active local experts only.""" - active = {int(x) for x in active_global_experts.reshape(-1).tolist()} - grads: Dict[str, torch.Tensor] = {} - if isinstance(block.experts, nn.ModuleList): - for local_idx, expert in enumerate(block.experts): - global_idx = int(block._ep_local_start + local_idx) - if global_idx not in active: - continue - for name, param in expert.named_parameters(): - if param.grad is None: - continue - grads[f'expert{global_idx}.{name}'] = param.grad.detach().cpu() - return grads - - # Tensor experts: gradients are indexed by local expert id. - gate_up = block.experts.gate_up_proj - down = block.experts.down_proj - gate_up_grad = gate_up.grad - down_grad = down.grad - for local_idx in range(gate_up.shape[0]): - global_idx = int(block._ep_local_start + local_idx) - if global_idx not in active: - continue - if gate_up_grad is not None: - grads[f'expert{global_idx}.gate_up_proj'] = gate_up_grad[local_idx].detach().cpu() - if down_grad is not None: - grads[f'expert{global_idx}.down_proj'] = down_grad[local_idx].detach().cpu() - return grads - - -def _compare_grad_dicts( - *, - rank: int, - baseline: Dict[str, torch.Tensor], - sp: Dict[str, torch.Tensor], - rel_tol: float, -) -> None: - keys = sorted(set(baseline.keys()) | set(sp.keys())) - for k in keys: - a = baseline.get(k) - b = sp.get(k) - if a is None or b is None: - raise AssertionError(f'[rank{rank}] Missing grad key={k} baseline={a is not None} sp={b is not None}') - a32 = a.to(dtype=torch.float32) - b32 = b.to(dtype=torch.float32) - diff = b32 - a32 - rel = diff.norm() / (a32.norm() + 1e-12) - assert rel.item() <= rel_tol - - -def _run_worker_ep_fsdp_sp_align( - rank: int, - world_size: int, - port: int, - model_id: str, - local_files_only: bool, -): - os.environ['RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - # Some utilities (e.g. Platform.get_local_device()) rely on LOCAL_RANK. - os.environ['LOCAL_RANK'] = str(rank) - os.environ['LOCAL_WORLD_SIZE'] = str(world_size) - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = str(port) - - strict = os.environ.get('TWINKLE_STRICT_ALIGN', '0') == '1' - if strict: - _enable_strict_determinism(seed=1234) - - if not torch.cuda.is_available(): - raise RuntimeError('This test requires CUDA (4 GPUs).') - device = torch.device(f'cuda:{rank}') - torch.cuda.set_device(device) - - dist.init_process_group( - backend='nccl', - rank=rank, - world_size=world_size, - init_method=f'tcp://127.0.0.1:{port}', - device_id=device, - timeout=timedelta(minutes=15), - ) - dist.barrier() - - try: - torch.manual_seed(1234) - torch.cuda.manual_seed_all(1234) - - # 4 GPUs: (fsdp=2, ep=2); SP is derived with ulysses_size=2 over raw data ranks (fsdp). - device_mesh = DeviceMesh( - device_type='cuda', - mesh=np.arange(world_size).reshape(2, 2), - mesh_dim_names=('fsdp', 'ep'), - ulysses_size=2, - ) - sp_size = 2 - sp_group = _get_sp_group_from_device_mesh(device_mesh, sp_size) - - # Shared input (same across ranks) + per-rank slice loss (matches SP slice ownership). - # Keep seq_len divisible by sp_size to avoid padding complexity here. - batch_size = 2 - seq_len = 8 - - # --- Baseline: EP+FSDP (no SP) --- - model_base = _load_qwen3_moe_pretrained(model_id, local_files_only, device) - vocab_size = int(model_base.config.vocab_size) - input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device=device) - dist.broadcast(input_ids, src=0) - position_ids = torch.arange(seq_len, device=device).unsqueeze(0).repeat(batch_size, 1) - attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device) - - # Prepare labels for causal LM: set first token ignore so roll won't create wrap-around target. - labels_raw = input_ids.clone() - labels_raw[:, 0] = -100 - labels_shifted = torch.roll(labels_raw, shifts=-1, dims=1) - - embed_base = model_base.get_input_embeddings() - _ensure_embed_tokens(model_base, embed_base) - base_embeds = embed_base(input_ids).detach() - - apply_expert_parallel( - getattr(model_base, 'model', model_base), - device_mesh, - config={ - 'enabled': True, - 'router_dtype': 'fp32', - 'all_to_all': 'torch', - 'keep_router_logits': False, - }, - ) - fsdp_strategy = NativeFSDPStrategy(device_mesh=device_mesh, mixed_precision='bf16', fsdp_config={}) - model_base, _ = fsdp_strategy.wrap_model(model_base, optimizer=None) - - base_states, base_state_handles = _capture_router_state(getattr(model_base, 'model', model_base)) - base_out = model_base( - inputs_embeds=base_embeds, - position_ids=position_ids, - attention_mask=attention_mask, - use_cache=False, - ) - for h in base_state_handles: - h.remove() - base_logits = base_out.logits.detach() - - start, end = _sp_slice_range_for_seq_len(seq_len, sp_group=sp_group, sp_size=sp_size) - base_token_loss = _per_token_ce_loss(base_out.logits, labels_shifted) - base_loss_sum = base_token_loss[:, start:end].sum() - base_loss_sum.backward() - - # Collect active experts (slice-only) and corresponding local expert grads. - base_blocks = _find_moe_blocks(getattr(model_base, 'model', model_base)) - if not base_blocks: - raise RuntimeError('No MoE blocks found in Qwen3 MoE model.') - assert len(base_states) == len(base_blocks) - base_active_grads: Dict[str, torch.Tensor] = {} - for block, state in zip(base_blocks, base_states): - sel = state['selected_experts'] # [tokens, top_k] (flattened) - # Router hook captures all tokens; reshape to [B,S,top_k] and slice same seq range. - top_k = sel.shape[-1] - sel = sel.view(batch_size, seq_len, top_k)[:, start:end, :].reshape(-1, top_k) - active = torch.unique(sel) - base_active_grads.update(_collect_active_local_expert_grad_tensors(block, active)) - - # --- SP variant: EP+FSDP+SP --- - # Note: SP does global patching; keep it after baseline in this process. - model_sp = _load_qwen3_moe_pretrained(model_id, local_files_only, device) - embed_sp = model_sp.get_input_embeddings() - _ensure_embed_tokens(model_sp, embed_sp) - sp_embeds = embed_sp(input_ids).detach() - - apply_expert_parallel( - getattr(model_sp, 'model', model_sp), - device_mesh, - config={ - 'enabled': True, - 'router_dtype': 'fp32', - 'all_to_all': 'torch', - 'keep_router_logits': False, - }, - ) - sp_strategy = SequenceParallelStrategy( - device_mesh=device_mesh, - sp_config={ - 'enabled': True, - 'ulysses_size': sp_size, - 'gather_logits': True - }, - model=model_sp, - tokenizer_id=model_id, - ) - sp_strategy.initialize() - model_sp, _ = fsdp_strategy.wrap_model(model_sp, optimizer=None) - - # Preprocess labels through SP strategy so they are shifted + split consistently. - # Keep label semantics consistent with the baseline path: next-token aligned labels. - sp_label_inputs = {'labels': labels_shifted, 'position_ids': position_ids} - sp_label_inputs = sp_strategy.preprocess_inputs(sp_label_inputs) - sp_local_labels = sp_label_inputs['labels'] - - sequence_parallel.extra_kwargs['position_ids'] = position_ids.clone() - sp_states, sp_state_handles = _capture_router_state(getattr(model_sp, 'model', model_sp)) - sp_out = model_sp( - inputs_embeds=sp_embeds, - position_ids=position_ids, - attention_mask=attention_mask, - use_cache=False, - ) - for h in sp_state_handles: - h.remove() - sp_local_logits = sp_out.logits - sp_out = sp_strategy.postprocess_outputs(sp_out) - sp_logits = sp_out.logits.detach() - - # Forward alignment (full-seq logits reconstructed by SP gather). - assert torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4) - - # Router alignment on this rank's slice: compare selected experts exactly. - # SP captures only local tokens; baseline captures full tokens (we slice it). - sp_blocks = _find_moe_blocks(getattr(model_sp, 'model', model_sp)) - assert len(sp_states) == len(sp_blocks) == len(base_blocks) - for idx, (base_state, sp_state) in enumerate(zip(base_states, sp_states)): - base_sel = base_state['selected_experts'].view(batch_size, seq_len, -1)[:, start:end, :].contiguous() - # sp_sel is already local-seq; shape should match [B, local_seq, top_k] or [tokens, top_k] - sp_sel = sp_state['selected_experts'] - if sp_sel.dim() == 2: - sp_sel = sp_sel.view(batch_size, end - start, -1) - assert torch.equal(base_sel, sp_sel) - - # Backward alignment (expert grads on active local experts for this slice). - sp_loss_sum = F.cross_entropy( - sp_local_logits.view(-1, sp_local_logits.size(-1)), - sp_local_labels.view(-1), - ignore_index=-100, - reduction='sum', - ) - sp_loss_sum.backward() - - sp_active_grads: Dict[str, torch.Tensor] = {} - for block, state in zip(sp_blocks, sp_states): - active = torch.unique(state['selected_experts']) - sp_active_grads.update(_collect_active_local_expert_grad_tensors(block, active)) - - # Mixed precision + extra collectives => allow a bit more slack on gradients than logits. - grad_rel_tol = float(os.environ.get('TWINKLE_EXPERT_GRAD_REL_TOL', '1e-3')) - _compare_grad_dicts(rank=rank, baseline=base_active_grads, sp=sp_active_grads, rel_tol=grad_rel_tol) - finally: - dist.destroy_process_group() - - -class TestExpertParallelFSDPSequenceParallelPretrained(unittest.TestCase): - - def test_qwen3_moe_pretrained_ep_fsdp_sp_alignment(self): - if not dist.is_available(): - self.skipTest('torch.distributed is not available') - if not torch.cuda.is_available(): - self.skipTest('CUDA is required for this test.') - world_size = 4 - if torch.cuda.device_count() < world_size: - self.skipTest('Requires at least 4 GPUs for EP+FSDP+SP alignment test.') - model_id = os.environ.get('QWEN3_MOE_MODEL_ID', 'Qwen/Qwen3.5-4B') - local_files_only = os.environ.get('QWEN3_MOE_LOCAL_ONLY', '1') != '0' - try: - _load_qwen3_moe_config(model_id, local_files_only) - except Exception as exc: # noqa: BLE001 - self.skipTest(f'Qwen3 MoE model not available locally: {exc}') - port = _find_free_port() - mp.spawn( - _run_worker_ep_fsdp_sp_align, - args=(world_size, port, model_id, local_files_only), - nprocs=world_size, - join=True, - ) - - -def _run_worker_fsdp_sp_align( - rank: int, - world_size: int, - port: int, - model_id: str, - local_files_only: bool, -): - """Compare FSDP-only vs FSDP+SP for a Qwen3 MoE pretrained model.""" - os.environ['RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['LOCAL_RANK'] = str(rank) - os.environ['LOCAL_WORLD_SIZE'] = str(world_size) - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = str(port) - - strict = os.environ.get('TWINKLE_STRICT_ALIGN', '0') == '1' - if strict: - _enable_strict_determinism(seed=1234) - - if not torch.cuda.is_available(): - raise RuntimeError('This test requires CUDA (4 GPUs).') - device = torch.device(f'cuda:{rank}') - torch.cuda.set_device(device) - - dist.init_process_group( - backend='nccl', - rank=rank, - world_size=world_size, - init_method=f'tcp://127.0.0.1:{port}', - device_id=device, - timeout=timedelta(minutes=15), - ) - dist.barrier() - - try: - torch.manual_seed(1234) - torch.cuda.manual_seed_all(1234) - - # 4 GPUs: fsdp=4, dp=1; SP is derived via ulysses_size=2 over raw data ranks (fsdp). - device_mesh = DeviceMesh.from_sizes( - fsdp_size=world_size, - dp_size=1, - ulysses_size=2, - device_type='cuda', - ) - sp_size = 2 - sp_group = _get_sp_group_from_device_mesh(device_mesh, sp_size) - - batch_size = 2 - seq_len = 16 - - # Loading the pretrained checkpoint twice per-rank is very slow and can look "hung". - # Load once, then deepcopy to get a second identical model for the SP variant. - model_fsdp = _load_qwen3_moe_pretrained(model_id, local_files_only, device) - model_sp = copy.deepcopy(model_fsdp) - vocab_size = int(model_fsdp.config.vocab_size) - - input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device=device) - dist.broadcast(input_ids, src=0) - position_ids = torch.arange(seq_len, device=device).unsqueeze(0).repeat(batch_size, 1) - attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device) - - labels_raw = input_ids.clone() - labels_raw[:, 0] = -100 - labels_shifted = torch.roll(labels_raw, shifts=-1, dims=1) - - fsdp_strategy = NativeFSDPStrategy(device_mesh=device_mesh, mixed_precision='bf16', fsdp_config={}) - - # --- Baseline: FSDP only (no SP). Use full-sequence loss (sum over all tokens). - embed_fsdp = model_fsdp.get_input_embeddings() - _ensure_embed_tokens(model_fsdp, embed_fsdp) - base_embeds = embed_fsdp(input_ids).detach().requires_grad_(True) - model_fsdp, _ = fsdp_strategy.wrap_model(model_fsdp, optimizer=None) - - base_out = model_fsdp( - inputs_embeds=base_embeds, - position_ids=position_ids, - attention_mask=attention_mask, - use_cache=False, - ) - base_logits = base_out.logits.detach() - base_loss_sum = F.cross_entropy( - base_out.logits.view(-1, base_out.logits.size(-1)), - labels_shifted.view(-1), - ignore_index=-100, - reduction='sum', - ) - base_loss_sum.backward() - base_embed_grad = base_embeds.grad.detach().cpu() - model_fsdp.zero_grad(set_to_none=True) - - # --- Variant: FSDP + SP. - sp_strategy = SequenceParallelStrategy( - device_mesh=device_mesh, - sp_config={ - 'enabled': True, - 'ulysses_size': sp_size, - 'gather_logits': True - }, - model=model_sp, - tokenizer_id=model_id, - ) - sp_strategy.initialize() - - # Compute inputs_embeds before DTensor wrapping to avoid mixed Tensor/DTensor embedding op. - embed_sp = model_sp.get_input_embeddings() - _ensure_embed_tokens(model_sp, embed_sp) - sp_embeds = embed_sp(input_ids).detach().requires_grad_(True) - model_sp, _ = fsdp_strategy.wrap_model(model_sp, optimizer=None) - - # Keep label semantics consistent with the baseline path: next-token aligned labels. - sp_label_inputs = {'labels': labels_shifted, 'position_ids': position_ids} - sp_label_inputs = sp_strategy.preprocess_inputs(sp_label_inputs) - sp_local_labels = sp_label_inputs['labels'] - - sequence_parallel.extra_kwargs['position_ids'] = position_ids.clone() - sp_out = model_sp( - inputs_embeds=sp_embeds, - position_ids=position_ids, - attention_mask=attention_mask, - use_cache=False, - ) - sp_local_logits = sp_out.logits - sp_out = sp_strategy.postprocess_outputs(sp_out) - sp_logits = sp_out.logits.detach() - - # Forward alignment (full-seq logits reconstructed by SP gather). - assert torch.allclose(sp_logits, base_logits, rtol=1e-3, atol=1e-4) - - # Backward alignment: local CE(sum) on SP, compare gathered full-seq inputs_embeds grads. - sp_loss_sum = F.cross_entropy( - sp_local_logits.view(-1, sp_local_logits.size(-1)), - sp_local_labels.view(-1), - ignore_index=-100, - reduction='sum', - ) - sp_loss_sum.backward() - sp_embed_grad = sp_embeds.grad.detach().cpu() - - # Backward alignment: gather SP local-seq grads into a full-seq grad and compare. - start, end = _sp_slice_range_for_seq_len(seq_len, sp_group=sp_group, sp_size=sp_size) - sp_local = sp_embed_grad.to(device=device, dtype=torch.float32)[:, start:end].contiguous() - sp_full = _gather_full_seq_grad_from_sp(sp_local, sp_group=sp_group) - base_full = base_embed_grad.to(device=device, dtype=torch.float32)[:, :seq_len].contiguous() - diff = sp_full - base_full - rel = diff.norm() / (base_full.norm() + 1e-12) - grad_rel_tol = float(os.environ.get('TWINKLE_INPUT_GRAD_REL_TOL', '1e-2')) - assert rel.item() <= grad_rel_tol - finally: - dist.destroy_process_group() - - -class TestFSDPSequenceParallelQwen3MoePretrained(unittest.TestCase): - - def test_qwen3_pretrained_fsdp_sp_alignment(self): - if not dist.is_available(): - self.skipTest('torch.distributed is not available') - if not torch.cuda.is_available(): - self.skipTest('CUDA is required for this test.') - world_size = 4 - if torch.cuda.device_count() < world_size: - self.skipTest('Requires at least 4 GPUs for FSDP+SP alignment test.') - model_id = os.environ.get('QWEN3_MOE_MODEL_ID', 'Qwen/Qwen3-0.6B') - local_files_only = os.environ.get('QWEN3_MOE_LOCAL_ONLY', '1') != '0' - try: - _load_qwen3_moe_config(model_id, local_files_only) - except Exception as exc: # noqa: BLE001 - self.skipTest(f'Qwen3 MoE model not available locally: {exc}') - port = _find_free_port() - mp.spawn( - _run_worker_fsdp_sp_align, - args=(world_size, port, model_id, local_files_only), - nprocs=world_size, - join=True, - ) diff --git a/tests/transformers/test_qwen35_linear_attention_sp.py b/tests/transformers/test_qwen35_linear_attention_sp.py new file mode 100644 index 00000000..f8f50185 --- /dev/null +++ b/tests/transformers/test_qwen35_linear_attention_sp.py @@ -0,0 +1,300 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import copy +import os +import socket +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn.functional as F +import unittest +from datetime import timedelta +from transformers.modeling_flash_attention_utils import is_flash_attn_available +from transformers.utils.import_utils import is_flash_linear_attention_available +from types import SimpleNamespace + +from twinkle.loss import CrossEntropyLoss +from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallelStrategy, sequence_parallel +from twinkle.utils import DeviceMesh, selective_log_softmax + +try: + from transformers import Qwen3_5ForCausalLM, Qwen3_5TextConfig + from transformers.models.qwen3_5 import modeling_qwen3_5 as hf_qwen35 + + _HAS_QWEN35 = True +except Exception: + Qwen3_5ForCausalLM = None + Qwen3_5TextConfig = None + hf_qwen35 = None + _HAS_QWEN35 = False + +WORLD_SIZE = 2 +LOGITS_RTOL = 5e-3 +LOGITS_ATOL = 5e-3 +LOSS_ATOL = 5e-3 +GRAD_RTOL = 5e-3 +GRAD_ATOL = 2e-3 +_HAS_FLA_PREFILL = bool( + _HAS_QWEN35 and (getattr(hf_qwen35, 'causal_conv1d_fn', None) is not None or is_flash_linear_attention_available())) + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(('127.0.0.1', 0)) + return sock.getsockname()[1] + + +def _init_dist(rank: int, world_size: int, port: int) -> torch.device: + os.environ['RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['LOCAL_WORLD_SIZE'] = str(world_size) + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(port) + + device = torch.device(f'cuda:{rank}') + torch.cuda.set_device(device) + dist.init_process_group( + backend='nccl', + rank=rank, + world_size=world_size, + init_method=f'tcp://127.0.0.1:{port}', + device_id=device, + timeout=timedelta(minutes=15), + ) + return device + + +def _set_determinism(seed: int) -> None: + os.environ.setdefault('PYTHONHASHSEED', str(seed)) + os.environ.setdefault('CUBLAS_WORKSPACE_CONFIG', ':16:8') + os.environ.setdefault('NCCL_DETERMINISTIC', '1') + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def _model_dtype() -> torch.dtype: + return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + + +def _build_tiny_qwen35(device: torch.device, + *, + attn_implementation: str = 'sdpa', + layer_types: list[str] | None = None) -> Qwen3_5ForCausalLM: + if layer_types is None: + layer_types = ['linear_attention', 'linear_attention'] + config = Qwen3_5TextConfig( + vocab_size=128, + hidden_size=64, + intermediate_size=256, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + head_dim=16, + linear_conv_kernel_dim=4, + linear_key_head_dim=16, + linear_value_head_dim=16, + linear_num_key_heads=2, + linear_num_value_heads=4, + layer_types=layer_types, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + attention_dropout=0.0, + use_cache=False, + ) + config._attn_implementation = attn_implementation + model = Qwen3_5ForCausalLM(config) + model.to(device=device, dtype=_model_dtype()) + model.eval() + return model + + +def _make_strategy(model: Qwen3_5ForCausalLM, world_size: int) -> SequenceParallelStrategy: + strategy = SequenceParallelStrategy( + device_mesh=DeviceMesh.from_sizes( + world_size=world_size, + fsdp_size=world_size, + dp_size=1, + ulysses_size=world_size, + device_type='cuda', + ), + sp_config={ + 'enabled': True, + 'ulysses_size': world_size, + 'gather_logits': True, + }, + model=model, + tokenizer_id=None, + ) + strategy._tokenizer = SimpleNamespace(pad_token_id=0) + strategy.initialize() + return strategy + + +def _make_shift_labels(input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + labels = torch.full_like(input_ids, -100) + labels[..., :-1] = input_ids[..., 1:] + labels = labels.clone() + labels[attention_mask == 0] = -100 + labels[..., -1] = -100 + return labels + + +def _make_train_batch(device: torch.device): + input_ids = torch.tensor([ + [0, 0, 11, 12, 13, 14, 15, 16], + [21, 22, 23, 24, 25, 26, 27, 28], + ], + device=device, + dtype=torch.long) + attention_mask = torch.tensor([ + [0, 0, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1, 1], + ], + device=device, + dtype=torch.long) + position_ids = torch.arange(input_ids.shape[1], device=device, dtype=torch.long).unsqueeze(0).expand_as(input_ids) + labels = _make_shift_labels(input_ids, attention_mask) + return input_ids, attention_mask, position_ids, labels + + +def _get_qkv_weight(model: Qwen3_5ForCausalLM) -> torch.nn.Parameter: + for layer in model.model.layers: + linear_attn = getattr(layer, 'linear_attn', None) + if linear_attn is not None: + return linear_attn.in_proj_qkv.weight + raise AssertionError('No linear attention layer found in Qwen3.5 test model.') + + +def _allreduce_sp_grad(grad: torch.Tensor) -> torch.Tensor: + reduced = grad.detach().float().contiguous() + if sequence_parallel.world_size is not None and sequence_parallel.world_size > 1: + dist.all_reduce(reduced, op=dist.ReduceOp.SUM, group=sequence_parallel._sp_group) + return reduced + + +def _run_prefill_alignment_worker(rank: int, + world_size: int, + port: int, + attn_implementation: str = 'sdpa', + layer_types: list[str] | None = None): + device = _init_dist(rank, world_size, port) + try: + _set_determinism(1234) + os.environ['QWEN35_SP_LINEAR_HEAD_PARALLEL'] = '1' + + baseline_model = _build_tiny_qwen35(device, attn_implementation=attn_implementation, layer_types=layer_types) + sp_model = copy.deepcopy(baseline_model) + input_ids, attention_mask, position_ids, labels = _make_train_batch(device) + + baseline_outputs = baseline_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + ) + baseline_logits = baseline_outputs.logits.float() + baseline_loss = F.cross_entropy( + baseline_logits.reshape(-1, baseline_logits.size(-1)), + labels.reshape(-1), + ignore_index=-100, + reduction='mean', + ) + baseline_loss.backward() + baseline_qkv_grad = _get_qkv_weight(baseline_model).grad.detach().float().cpu() + + strategy = _make_strategy(sp_model, world_size) + processed_inputs = strategy.preprocess_inputs({ + 'input_ids': input_ids, + 'position_ids': position_ids, + 'labels': labels, + }) + local_labels = processed_inputs['labels'] + sp_outputs = sp_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False, + ) + gathered_outputs = strategy.postprocess_outputs(copy.copy(sp_outputs)) + gathered_logits = gathered_outputs.logits.float() + if not torch.allclose(gathered_logits, baseline_logits, rtol=LOGITS_RTOL, atol=LOGITS_ATOL): + max_diff = (gathered_logits - baseline_logits).abs().max().item() + raise AssertionError(f'prefill logits mismatch on rank {rank}: max_diff={max_diff}') + + loss_instance = CrossEntropyLoss(reduction='mean') + local_logits = sp_outputs.logits + masked_local_labels = local_labels.masked_fill(local_labels == -100, 0) + local_logps = selective_log_softmax(local_logits, masked_local_labels) + loss_inputs = {'labels': local_labels} + loss_outputs = {'logits': local_logits, 'logps': local_logps} + loss_inputs, loss_outputs = strategy.gather_loss_tensors(loss_inputs, loss_outputs) + result = loss_instance(loss_inputs, loss_outputs) + sp_loss = result['loss'] + if not torch.allclose(sp_loss.detach(), baseline_loss.detach(), atol=LOSS_ATOL, rtol=0): + raise AssertionError( + f'prefill loss mismatch on rank {rank}: baseline={baseline_loss.item()} sp={sp_loss.item()}') + sp_loss.backward() + + sp_qkv_grad = _allreduce_sp_grad(_get_qkv_weight(sp_model).grad).cpu() + if not torch.allclose(sp_qkv_grad, baseline_qkv_grad, rtol=GRAD_RTOL, atol=GRAD_ATOL): + max_diff = (sp_qkv_grad - baseline_qkv_grad).abs().max().item() + raise AssertionError(f'qkv grad mismatch on rank {rank}: max_diff={max_diff}') + finally: + if dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + + +@unittest.skipUnless(_HAS_QWEN35, 'transformers Qwen3.5 is not available in this environment') +@unittest.skipUnless(torch.cuda.is_available() and torch.cuda.device_count() >= WORLD_SIZE, 'requires 2 CUDA devices') +@unittest.skipUnless( + _HAS_FLA_PREFILL, + 'requires either transformers qwen3.5 causal_conv1d_fn or flash-linear-attention kernels for Qwen3.5 SP patch') +class TestQwen35LinearAttentionSP(unittest.TestCase): + + def test_qwen35_linear_attention_prefill_logits_and_qkv_grad_alignment(self): + port = _find_free_port() + mp.spawn( + _run_prefill_alignment_worker, + args=(WORLD_SIZE, port, 'sdpa', ['linear_attention', 'linear_attention']), + nprocs=WORLD_SIZE, + join=True, + ) + + def test_qwen35_mixed_attention_prefill_logits_and_qkv_grad_alignment(self): + port = _find_free_port() + mp.spawn( + _run_prefill_alignment_worker, + args=(WORLD_SIZE, port, 'sdpa', ['full_attention', 'linear_attention']), + nprocs=WORLD_SIZE, + join=True, + ) + + @unittest.skipUnless(is_flash_attn_available(), 'requires flash_attention_2 support in transformers') + def test_qwen35_linear_attention_prefill_logits_and_qkv_grad_alignment_fa2(self): + port = _find_free_port() + mp.spawn( + _run_prefill_alignment_worker, + args=(WORLD_SIZE, port, 'flash_attention_2', ['linear_attention', 'linear_attention']), + nprocs=WORLD_SIZE, + join=True, + ) + + @unittest.skipUnless(is_flash_attn_available(), 'requires flash_attention_2 support in transformers') + def test_qwen35_mixed_attention_prefill_logits_and_qkv_grad_alignment_fa2(self): + port = _find_free_port() + mp.spawn( + _run_prefill_alignment_worker, + args=(WORLD_SIZE, port, 'flash_attention_2', ['full_attention', 'linear_attention']), + nprocs=WORLD_SIZE, + join=True, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/transformers/test_sequence_parallel_and_cp.py b/tests/transformers/test_sequence_parallel_and_cp.py new file mode 100644 index 00000000..513cbb11 --- /dev/null +++ b/tests/transformers/test_sequence_parallel_and_cp.py @@ -0,0 +1,667 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import copy +import math +import os +import socket +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import unittest +from datetime import timedelta +from transformers import LlamaConfig, LlamaForCausalLM +from transformers.modeling_flash_attention_utils import is_flash_attn_available +from types import SimpleNamespace + +from twinkle.loss import CrossEntropyLoss +from twinkle.model.transformers.strategy.sequence_parallel import SequenceParallelStrategy, sequence_parallel +from twinkle.utils import DeviceMesh, Platform, ensure_hccl_socket_env, selective_log_softmax, torch_util + +LOGITS_RTOL = 1e-2 +LOGITS_ATOL = 5e-3 +LOSS_ATOL = 5e-3 +GRAD_RTOL = 2e-2 +GRAD_ATOL = 1e-2 + + +def _find_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(('127.0.0.1', 0)) + return sock.getsockname()[1] + + +def _make_labels(input_ids: torch.Tensor) -> torch.Tensor: + labels = torch.full_like(input_ids, -100) + labels[..., :-1] = input_ids[..., 1:] + return labels + + +def _make_case(case_name: str) -> dict: + cases = { + 'sp_only': { + 'expected_mode': 'sp_only', + 'world_size': 2, + 'ulysses_size': 2, + 'num_attention_heads': 8, + 'hidden_size': 128, + 'seq_len': 769, + }, + 'cp_only': { + 'expected_mode': 'cp_only', + 'world_size': 2, + 'ulysses_size': 2, + 'num_attention_heads': 1, + 'hidden_size': 64, + 'seq_len': 769, + }, + 'cp_sp': { + 'expected_mode': 'cp_sp', + 'world_size': 4, + 'ulysses_size': 4, + 'num_attention_heads': 6, + 'hidden_size': 96, + 'seq_len': 769, + }, + 'sp_only_memory': { + 'expected_mode': 'sp_only', + 'world_size': 4, + 'ulysses_size': 2, + 'num_attention_heads': 8, + 'hidden_size': 128, + 'num_hidden_layers': 2, + 'seq_lens': [255, 511, 1023], + 'batch_sizes': [1, 2, 4], + }, + 'cp_only_memory': { + 'expected_mode': 'cp_only', + 'world_size': 2, + 'ulysses_size': 2, + 'num_attention_heads': 1, + 'hidden_size': 128, + 'num_hidden_layers': 4, + 'seq_lens': [511, 1023, 2047], + 'batch_sizes': [1], + }, + 'cp_sp_memory': { + 'expected_mode': 'cp_sp', + 'world_size': 4, + 'ulysses_size': 4, + # gcd(6, 4) = 2 -> sp=2, rp=2 + 'num_attention_heads': 6, + 'hidden_size': 192, + 'num_hidden_layers': 4, + 'seq_lens': [511, 1023, 2047], + 'batch_sizes': [1], + }, + } + return copy.deepcopy(cases[case_name]) + + +def _validate_case_config(case: dict) -> tuple[str, int, int]: + hidden_size = int(case['hidden_size']) + num_heads = int(case['num_attention_heads']) + ulysses_size = int(case['ulysses_size']) + expected_mode = case.get('expected_mode') + + if hidden_size % num_heads != 0: + raise ValueError(f'Invalid test case config: hidden_size ({hidden_size}) must be divisible by ' + f'num_attention_heads ({num_heads}).') + + head_dim = hidden_size // num_heads + if head_dim % 2 != 0: + raise ValueError(f'Invalid test case config: head_dim ({head_dim}) must be even for RoPE. ' + f'Got hidden_size={hidden_size}, num_attention_heads={num_heads}.') + + sp_world_size = math.gcd(num_heads, ulysses_size) + rp_world_size = ulysses_size // sp_world_size + mode = 'sp_only' + if rp_world_size > 1 and sp_world_size == 1: + mode = 'cp_only' + elif rp_world_size > 1 and sp_world_size > 1: + mode = 'cp_sp' + + if expected_mode is not None and mode != expected_mode: + raise ValueError(f'Invalid test case config: expected {expected_mode}, but derived {mode}. ' + f'Got ulysses_size={ulysses_size}, num_attention_heads={num_heads}, ' + f'sp_world_size={sp_world_size}, rp_world_size={rp_world_size}.') + return mode, sp_world_size, rp_world_size + + +def _get_runtime_backend() -> dict | None: + if torch.cuda.is_available(): + return { + 'device_type': 'cuda', + 'dist_backend': 'nccl', + 'device_count': int(torch.cuda.device_count()), + 'label': 'CUDA', + } + if torch_util.is_npu_available(): + return { + 'device_type': 'npu', + 'dist_backend': 'hccl', + 'device_count': int(torch.npu.device_count()), + 'label': 'NPU', + } + return None + + +def _get_device_module(device_type: str): + if device_type == 'cuda': + return torch.cuda + if device_type == 'npu': + return torch.npu + raise ValueError(f'Unsupported device_type for derived ring tests: {device_type}') + + +def _supports_peak_memory_stats(device_type: str) -> bool: + device_module = _get_device_module(device_type) + required_apis = ('empty_cache', 'reset_peak_memory_stats', 'synchronize', 'max_memory_allocated') + return all(hasattr(device_module, name) for name in required_apis) + + +def _get_model_dtype(device_type: str) -> torch.dtype: + if device_type == 'cuda': + return torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + if device_type == 'npu': + is_bf16_supported = getattr(torch.npu, 'is_bf16_supported', None) + if callable(is_bf16_supported): + try: + if is_bf16_supported(): + return torch.bfloat16 + except Exception: + pass + return torch.float16 + raise ValueError(f'Unsupported device_type for derived ring tests: {device_type}') + + +def _seed_backend(seed: int, device_type: str) -> None: + torch.manual_seed(seed) + if device_type == 'cuda': + torch.cuda.manual_seed_all(seed) + elif device_type == 'npu': + torch.npu.manual_seed_all(seed) + + +def _build_tiny_llama(case: dict, + device: torch.device, + device_type: str, + attn_implementation: str = 'flash_attention_2') -> LlamaForCausalLM: + _validate_case_config(case) + hidden_size = int(case['hidden_size']) + num_heads = int(case['num_attention_heads']) + num_hidden_layers = int(case.get('num_hidden_layers', 1)) + dtype = _get_model_dtype(device_type) + max_seq_len = int(case.get('seq_len', max(case.get('seq_lens', [1024])))) + 32 + config = LlamaConfig( + vocab_size=256, + hidden_size=hidden_size, + intermediate_size=hidden_size * 4, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_heads, + num_key_value_heads=num_heads, + max_position_embeddings=max_seq_len, + attention_dropout=0.0, + rms_norm_eps=1e-5, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + use_cache=False, + ) + config._attn_implementation = attn_implementation + model = LlamaForCausalLM(config) + model.to(device=device, dtype=dtype) + model.eval() + return model + + +def _make_strategy(model: LlamaForCausalLM, device_mesh: DeviceMesh, ulysses_size: int) -> SequenceParallelStrategy: + strategy = SequenceParallelStrategy( + device_mesh=device_mesh, + sp_config={ + 'enabled': True, + 'ulysses_size': ulysses_size, + 'gather_logits': True, + }, + model=model, + tokenizer_id=None, + ) + strategy._tokenizer = SimpleNamespace(pad_token_id=0) + strategy.initialize() + return strategy + + +def _prepare_label_inputs(strategy: SequenceParallelStrategy, input_ids: torch.Tensor, + position_ids: torch.Tensor) -> torch.Tensor: + labels = _make_labels(input_ids) + processed = strategy.preprocess_inputs({ + 'input_ids': input_ids, + 'position_ids': position_ids, + 'labels': labels, + }) + return processed['labels'] + + +def _compute_training_path_loss( + logits: torch.Tensor, + labels: torch.Tensor, + strategy: SequenceParallelStrategy | None = None, +) -> tuple[torch.Tensor, int]: + masked_labels = labels.masked_fill(labels == -100, 0) + loss_inputs = {'labels': labels} + loss_outputs = {'logps': selective_log_softmax(logits, masked_labels)} + if strategy is not None: + loss_inputs, loss_outputs = strategy.gather_loss_tensors(loss_inputs, loss_outputs) + result = CrossEntropyLoss(reduction='sum')(loss_inputs, loss_outputs) + num_tokens = result['num_tokens'] + if torch.is_tensor(num_tokens): + num_tokens = int(num_tokens.item()) + else: + num_tokens = int(num_tokens) + return result['loss'], num_tokens + + +def _normalize_grad_dict(grads: dict[str, torch.Tensor], num_tokens: int) -> dict[str, torch.Tensor]: + denom = float(max(num_tokens, 1)) + return {name: grad / denom for name, grad in grads.items()} + + +def _average_model_grads_over_group(model: LlamaForCausalLM, group: dist.ProcessGroup | None) -> None: + if group is None: + return + group_world_size = dist.get_world_size(group) + if group_world_size <= 1: + return + for param in model.parameters(): + if param.grad is None: + continue + dist.all_reduce(param.grad, group=group) + param.grad.div_(group_world_size) + + +def _collect_attention_param_grads(model: LlamaForCausalLM) -> dict[str, torch.Tensor]: + grads = {} + for name, param in model.named_parameters(): + if '.self_attn.' not in name: + continue + if param.grad is None: + continue + grads[name] = param.grad.detach().float().cpu() + if not grads: + raise AssertionError('No attention gradients were collected from the model.') + return grads + + +def _assert_grad_dict_close(case_name: str, rank: int, baseline_grads: dict[str, torch.Tensor], + sp_grads: dict[str, torch.Tensor]): + baseline_keys = sorted(baseline_grads.keys()) + sp_keys = sorted(sp_grads.keys()) + if baseline_keys != sp_keys: + raise AssertionError( + f'{case_name} attention grad keys mismatch on rank {rank}: baseline={baseline_keys}, sp={sp_keys}') + for key in baseline_keys: + baseline = baseline_grads[key] + current = sp_grads[key] + if not torch.allclose(current, baseline, rtol=GRAD_RTOL, atol=GRAD_ATOL): + max_diff = (current - baseline).abs().max().item() + raise AssertionError(f'{case_name} attention grad mismatch on rank {rank} for {key}: max_diff={max_diff}') + + +def _init_dist(rank: int, world_size: int, port: int, backend_config: dict): + os.environ['RANK'] = str(rank) + os.environ['WORLD_SIZE'] = str(world_size) + os.environ['LOCAL_RANK'] = str(rank) + os.environ['LOCAL_WORLD_SIZE'] = str(world_size) + os.environ['MASTER_ADDR'] = '127.0.0.1' + os.environ['MASTER_PORT'] = str(port) + + device_type = backend_config['device_type'] + dist_backend = backend_config['dist_backend'] + if dist_backend == 'hccl': + ensure_hccl_socket_env(port) + device = torch.device(Platform.get_local_device(rank, platform=device_type)) + _get_device_module(device_type).set_device(rank) + dist.init_process_group( + backend=dist_backend, + rank=rank, + world_size=world_size, + init_method=f'tcp://127.0.0.1:{port}', + device_id=device, + timeout=timedelta(minutes=15), + ) + return device + + +def _measure_peak_memory( + model: LlamaForCausalLM, + strategy: SequenceParallelStrategy, + *, + batch_size: int, + seq_len: int, + device: torch.device, + device_type: str, +) -> int: + vocab_size = int(model.config.vocab_size) + input_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len), device=device) + position_ids = torch.arange(seq_len, device=device).unsqueeze(0).repeat(batch_size, 1) + local_labels = _prepare_label_inputs(strategy, input_ids, position_ids) + + model.zero_grad(set_to_none=True) + device_module = _get_device_module(device_type) + device_module.empty_cache() + device_module.reset_peak_memory_stats(device) + outputs = model( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=None, + use_cache=False, + ) + local_logits = outputs.logits + loss_sum, _ = _compute_training_path_loss(local_logits, local_labels, strategy) + loss_sum.backward() + device_module.synchronize(device) + + peak = torch.tensor([int(device_module.max_memory_allocated(device))], device=device) + dist.all_reduce(peak, op=dist.ReduceOp.MAX) + return int(peak.item()) + + +def _format_memory_table(case_name: str, peaks: list[dict]) -> str: + header = f'[{case_name}] peak memory' + columns = ( + 'batch_size', + 'seq_len', + 'baseline_bytes', + 'baseline_mib', + 'peak_bytes', + 'peak_mib', + 'delta_bytes', + 'saving_ratio_pct', + ) + rows = [] + for row in peaks: + rows.append(( + str(row['batch_size']), + str(row['seq_len']), + str(row['baseline_bytes']), + f"{row['baseline_mib']:.2f}", + str(row['peak_bytes']), + f"{row['peak_mib']:.2f}", + str(row['delta_bytes']), + f"{row['saving_ratio_pct']:.2f}", + )) + + widths = [len(col) for col in columns] + for row in rows: + for i, value in enumerate(row): + widths[i] = max(widths[i], len(value)) + + def _fmt(values): + return ' | '.join(value.ljust(widths[i]) for i, value in enumerate(values)) + + lines = [ + header, + _fmt(columns), + '-+-'.join('-' * width for width in widths), + ] + lines.extend(_fmt(row) for row in rows) + return '\n'.join(lines) + + +def _run_precision_worker(rank: int, + world_size: int, + port: int, + case_name: str, + backend_config: dict, + attn_implementation: str = 'flash_attention_2'): + device = _init_dist(rank, world_size, port, backend_config) + try: + _seed_backend(1234, backend_config['device_type']) + case = _make_case(case_name) + + base_model = _build_tiny_llama(case, device, backend_config['device_type'], attn_implementation) + sp_model = _build_tiny_llama(case, device, backend_config['device_type'], attn_implementation) + sp_model.load_state_dict(base_model.state_dict()) + + seq_len = int(case['seq_len']) + input_ids = torch.randint(low=0, high=int(base_model.config.vocab_size), size=(1, seq_len), device=device) + position_ids = torch.arange(seq_len, device=device).unsqueeze(0) + labels = _make_labels(input_ids) + + base_model.zero_grad(set_to_none=True) + base_outputs = base_model( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=None, + use_cache=False, + ) + base_logits = base_outputs.logits.detach().float() + base_loss_sum, base_num_tokens = _compute_training_path_loss(base_outputs.logits, labels) + base_loss_sum.backward() + base_loss = base_loss_sum / max(base_num_tokens, 1) + base_attention_grads = _normalize_grad_dict(_collect_attention_param_grads(base_model), base_num_tokens) + + device_mesh = DeviceMesh.from_sizes( + fsdp_size=world_size, + dp_size=1, + ulysses_size=int(case['ulysses_size']), + device_type=backend_config['device_type'], + ) + strategy = _make_strategy(sp_model, device_mesh, int(case['ulysses_size'])) + processed_inputs = strategy.preprocess_inputs({ + 'input_ids': input_ids, + 'position_ids': position_ids, + 'labels': labels, + }) + local_labels = processed_inputs['labels'] + + sp_model.zero_grad(set_to_none=True) + sp_outputs = sp_model( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=None, + use_cache=False, + ) + local_logits = sp_outputs.logits + gathered_outputs = strategy.postprocess_outputs(sp_outputs) + sp_logits = gathered_outputs.logits.detach().float() + + sp_loss_sum, sp_num_tokens = _compute_training_path_loss(local_logits, local_labels, strategy) + sp_loss_sum.backward() + global_loss = sp_loss_sum / max(sp_num_tokens, 1) + _average_model_grads_over_group(sp_model, sequence_parallel._data_rank_group) + sp_attention_grads = _normalize_grad_dict(_collect_attention_param_grads(sp_model), sp_num_tokens) + + if sp_num_tokens != base_num_tokens: + raise AssertionError( + f'{case_name} num_tokens mismatch on rank {rank}: sp={sp_num_tokens} base={base_num_tokens}') + + if not torch.allclose(sp_logits[:, :seq_len], base_logits, rtol=LOGITS_RTOL, atol=LOGITS_ATOL): + diff = (sp_logits[:, :seq_len] - base_logits).abs().max().item() + raise AssertionError(f'{case_name} logits mismatch on rank {rank}: max_diff={diff}') + if abs(global_loss.item() - base_loss.item()) > LOSS_ATOL: + raise AssertionError( + f'{case_name} loss mismatch on rank {rank}: sp={global_loss.item()} base={base_loss.item()}') + _assert_grad_dict_close(case_name, rank, base_attention_grads, sp_attention_grads) + dist.barrier() + finally: + dist.destroy_process_group() + + +def _run_memory_worker(rank: int, world_size: int, port: int, case_name: str, backend_config: dict): + device = _init_dist(rank, world_size, port, backend_config) + try: + _seed_backend(1234, backend_config['device_type']) + case = _make_case(case_name) + baseline_device_mesh = DeviceMesh.from_sizes( + fsdp_size=world_size, + dp_size=1, + ulysses_size=1, + device_type=backend_config['device_type'], + ) + baseline_model = _build_tiny_llama(case, device, backend_config['device_type']) + baseline_strategy = _make_strategy(baseline_model, baseline_device_mesh, 1) + + baseline_peaks = {} + for batch_size in case['batch_sizes']: + for seq_len in case['seq_lens']: + baseline_peak = _measure_peak_memory( + baseline_model, + baseline_strategy, + batch_size=batch_size, + seq_len=seq_len, + device=device, + device_type=backend_config['device_type'], + ) + baseline_peaks[(int(batch_size), int(seq_len))] = int(baseline_peak) + + del baseline_model + del baseline_strategy + _get_device_module(backend_config['device_type']).empty_cache() + + device_mesh = DeviceMesh.from_sizes( + fsdp_size=world_size, + dp_size=1, + ulysses_size=int(case['ulysses_size']), + device_type=backend_config['device_type'], + ) + model = _build_tiny_llama(case, device, backend_config['device_type']) + strategy = _make_strategy(model, device_mesh, int(case['ulysses_size'])) + + peaks = [] + for batch_size in case['batch_sizes']: + for seq_len in case['seq_lens']: + peak = _measure_peak_memory( + model, + strategy, + batch_size=batch_size, + seq_len=seq_len, + device=device, + device_type=backend_config['device_type'], + ) + if rank == 0: + baseline_peak = baseline_peaks[(int(batch_size), int(seq_len))] + delta_bytes = int(peak) - int(baseline_peak) + saving_ratio_pct = 0.0 + if baseline_peak > 0: + saving_ratio_pct = (float(baseline_peak) - float(peak)) / float(baseline_peak) * 100.0 + peaks.append({ + 'batch_size': int(batch_size), + 'seq_len': int(seq_len), + 'baseline_bytes': int(baseline_peak), + 'baseline_mib': float(baseline_peak) / (1024**2), + 'peak_bytes': int(peak), + 'peak_mib': float(peak) / (1024**2), + 'delta_bytes': delta_bytes, + 'saving_ratio_pct': saving_ratio_pct, + }) + + if rank == 0: + for key in ('peak_bytes', 'baseline_bytes'): + by_batch = {} + for row in peaks: + by_batch.setdefault(row['batch_size'], []).append(row) + for rows in by_batch.values(): + rows.sort(key=lambda item: item['seq_len']) + for prev, cur in zip(rows, rows[1:]): + if cur[key] < prev[key]: + raise AssertionError( + f'{case_name} {key} should be non-decreasing with seq_len, got {prev} then {cur}') + + by_seq = {} + for row in peaks: + by_seq.setdefault(row['seq_len'], []).append(row) + for rows in by_seq.values(): + rows.sort(key=lambda item: item['batch_size']) + for prev, cur in zip(rows, rows[1:]): + if cur[key] < prev[key]: + raise AssertionError( + f'{case_name} {key} should be non-decreasing with batch_size, got {prev} then {cur}') + + print(_format_memory_table(case_name, peaks)) + dist.barrier() + finally: + dist.destroy_process_group() + + +class TestDerivedRingPrecision(unittest.TestCase): + + def _get_backend_or_skip(self, world_size: int = 4) -> dict: + backend = _get_runtime_backend() + if backend is None: + self.skipTest('CUDA or NPU is required for derived ring runtime tests.') + if backend['device_count'] < world_size: + self.skipTest(f'Requires at least {world_size} {backend["label"]} devices.') + return backend + + def _require_attn_impl_or_skip(self, backend: dict, attn_implementation: str) -> None: + if attn_implementation == 'flash_attention_2' and not is_flash_attn_available(): + if backend['device_type'] == 'npu': + self.skipTest( + 'Derived ring runtime tests currently require flash_attention_2, which is unavailable on NPU in ' + 'this environment.') + self.skipTest('flash_attention_2 is required for derived ring runtime tests.') + + def test_cp_only_precision_alignment(self): + case = _make_case('cp_only') + world_size = int(case['world_size']) + backend = self._get_backend_or_skip(world_size) + self._require_attn_impl_or_skip(backend, 'flash_attention_2') + port = _find_free_port() + mp.spawn( + _run_precision_worker, + args=(world_size, port, 'cp_only', backend, 'flash_attention_2'), + nprocs=world_size, + join=True) + + def test_cp_sp_precision_alignment(self): + case = _make_case('cp_sp') + world_size = int(case['world_size']) + backend = self._get_backend_or_skip(world_size) + self._require_attn_impl_or_skip(backend, 'flash_attention_2') + port = _find_free_port() + mp.spawn( + _run_precision_worker, + args=(world_size, port, 'cp_sp', backend, 'flash_attention_2'), + nprocs=world_size, + join=True) + + +class TestDerivedRingMemoryProfile(unittest.TestCase): + + def _get_backend_or_skip(self, world_size: int = 4) -> dict: + if os.environ.get('TWINKLE_RUN_MEMORY_TESTS', '0') != '1': + self.skipTest('Set TWINKLE_RUN_MEMORY_TESTS=1 to run derived ring memory profile tests.') + backend = _get_runtime_backend() + if backend is None: + self.skipTest('CUDA or NPU is required for derived ring memory tests.') + if backend['device_count'] < world_size: + self.skipTest(f'Requires at least {world_size} {backend["label"]} devices.') + if not is_flash_attn_available(): + if backend['device_type'] == 'npu': + self.skipTest( + 'Derived ring memory tests currently require flash_attention_2, which is unavailable on NPU in ' + 'this environment.') + self.skipTest('flash_attention_2 is required for derived ring memory tests.') + if not _supports_peak_memory_stats(backend['device_type']): + self.skipTest(f'{backend["label"]} peak-memory stats are unavailable in this environment.') + return backend + + def test_sp_only_memory_profile_grid(self): + case = _make_case('sp_only_memory') + world_size = int(case['world_size']) + backend = self._get_backend_or_skip(world_size) + port = _find_free_port() + mp.spawn(_run_memory_worker, args=(world_size, port, 'sp_only_memory', backend), nprocs=world_size, join=True) + + def test_cp_only_memory_profile_grid(self): + case = _make_case('cp_only_memory') + world_size = int(case['world_size']) + backend = self._get_backend_or_skip(world_size) + port = _find_free_port() + mp.spawn(_run_memory_worker, args=(world_size, port, 'cp_only_memory', backend), nprocs=world_size, join=True) + + def test_cp_sp_memory_profile_grid(self): + case = _make_case('cp_sp_memory') + world_size = int(case['world_size']) + backend = self._get_backend_or_skip(world_size) + port = _find_free_port() + mp.spawn(_run_memory_worker, args=(world_size, port, 'cp_sp_memory', backend), nprocs=world_size, join=True)