Skip to content

Commit 672dc07

Browse files
authored
[Attn Masks] Add skip option for non-packed sequences (#42367)
skip option
1 parent 59ed41e commit 672dc07

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

src/transformers/masking_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -644,7 +644,7 @@ class AttentionMaskInterface(GeneralInterface):
644644
ALL_MASK_ATTENTION_FUNCTIONS: AttentionMaskInterface = AttentionMaskInterface()
645645

646646

647-
def find_packed_sequence_indices(position_ids: torch.Tensor) -> torch.Tensor:
647+
def find_packed_sequence_indices(position_ids: torch.Tensor) -> Optional[torch.Tensor]:
648648
"""
649649
Find the indices of the sequence to which each new query token in the sequence belongs when using packed
650650
tensor format (i.e. several sequences packed in the same batch dimension).
@@ -656,6 +656,9 @@ def find_packed_sequence_indices(position_ids: torch.Tensor) -> torch.Tensor:
656656
Returns:
657657
A 2D tensor where each similar integer indicates that the tokens belong to the same sequence. For example, if we
658658
pack 3 sequences of 2, 3 and 1 tokens respectively along a single batch dim, this will return [[0, 0, 1, 1, 1, 2]].
659+
660+
If the there is only one sequence in each batch item (and we don't compile), then we return `None` indicating
661+
no packed sequences. This is the same as [[0, 0, 0, 0, 0, 0]] for the example above.
659662
"""
660663
# What separate different sequences is when 2 consecutive positions_ids are separated by more than 1. So
661664
# taking the diff (by prepending the first value - 1 to keep correct indexing) and applying cumsum to the result
@@ -666,8 +669,10 @@ def find_packed_sequence_indices(position_ids: torch.Tensor) -> torch.Tensor:
666669
position_diff = torch.diff(position_ids, prepend=first_dummy_value, dim=-1)
667670
packed_sequence_mask = (position_diff != 1).cumsum(-1)
668671

669-
# Here it would be nice to return None if we did not detect packed sequence format, i.e. if `packed_sequence_mask[:, -1] == 0`
670-
# but it causes issues with export
672+
# Sadly this is a dynamic control flow, so we cannot enable this check on anything compile related
673+
if not is_tracing(packed_sequence_mask) and (packed_sequence_mask[:, -1] == 0).all():
674+
return None
675+
671676
return packed_sequence_mask
672677

673678

tests/utils/test_masking_utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,42 @@ def test_find_packed_sequence_indices(self):
153153
EXPECTED_SEQUENCE_INDICES = torch.tensor([[0, 0, 0, 0, 1, 1, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 1, 1, 1, 1]])
154154
self.assertTrue((find_packed_sequence_indices(position_ids) == EXPECTED_SEQUENCE_INDICES).all())
155155

156+
def test_nonpacked_sequence_mask_skip(self):
157+
config = LlamaConfig()
158+
config._attn_implementation = "sdpa"
159+
160+
batch_size = 2
161+
sequence_length = 10
162+
cache_position = torch.arange(sequence_length)
163+
164+
# Non-packed sequences
165+
position_ids = torch.arange(sequence_length)[None, :]
166+
167+
causal_mask = create_causal_mask(
168+
config=config,
169+
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
170+
input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
171+
attention_mask=None,
172+
cache_position=cache_position,
173+
past_key_values=None,
174+
position_ids=position_ids,
175+
)
176+
# packed sequence should be skipped
177+
self.assertTrue(causal_mask is None)
178+
179+
create_causal_mask_compiled = torch.compile(create_causal_mask, mode="reduce-overhead")
180+
causal_mask = create_causal_mask_compiled(
181+
config=config,
182+
# we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
183+
input_embeds=torch.empty((batch_size, sequence_length), dtype=torch.float16),
184+
attention_mask=None,
185+
cache_position=cache_position,
186+
past_key_values=None,
187+
position_ids=position_ids,
188+
)
189+
# cannot be skipped under compile, should result into a triu mask
190+
self.assertTrue(torch.equal(~torch.ones(*causal_mask.shape).triu(diagonal=1).bool(), causal_mask))
191+
156192
def test_chunked_mask_with_left_padding_and_large_prefill(self):
157193
# Make sure we have an attention_chunk_size in the config
158194
config = LlamaConfig(attention_chunk_size=3, attn_implementation="sdpa")

0 commit comments

Comments
 (0)