diff --git a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py index bb8419cb..8388c8e4 100644 --- a/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py +++ b/src/twinkle/model/transformers/strategy/sequence_parallel/__init__.py @@ -140,6 +140,14 @@ def sdpa_mask(batch_size, q_length=None, kv_length=None, *args, **kwargs): else: kv_length = global_length + # padding_mask may be an AttentionMask wrapper instead of Tensor + _pm = kwargs.get('padding_mask') + if _pm is not None and not isinstance(_pm, torch.Tensor): + if hasattr(_pm, 'to_tensor'): + kwargs['padding_mask'] = _pm.to_tensor() + elif hasattr(_pm, 'mask'): + kwargs['padding_mask'] = _pm.mask + if origin_uses_cache_position: if cache_position is None: cache_position = q_length if torch.is_tensor(q_length) else torch.arange(