@@ -153,6 +153,42 @@ def test_find_packed_sequence_indices(self):
153153 EXPECTED_SEQUENCE_INDICES = torch .tensor ([[0 , 0 , 0 , 0 , 1 , 1 , 2 , 2 , 2 , 2 ], [0 , 0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 ]])
154154 self .assertTrue ((find_packed_sequence_indices (position_ids ) == EXPECTED_SEQUENCE_INDICES ).all ())
155155
156+ def test_nonpacked_sequence_mask_skip (self ):
157+ config = LlamaConfig ()
158+ config ._attn_implementation = "sdpa"
159+
160+ batch_size = 2
161+ sequence_length = 10
162+ cache_position = torch .arange (sequence_length )
163+
164+ # Non-packed sequences
165+ position_ids = torch .arange (sequence_length )[None , :]
166+
167+ causal_mask = create_causal_mask (
168+ config = config ,
169+ # we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
170+ input_embeds = torch .empty ((batch_size , sequence_length ), dtype = torch .float16 ),
171+ attention_mask = None ,
172+ cache_position = cache_position ,
173+ past_key_values = None ,
174+ position_ids = position_ids ,
175+ )
176+ # packed sequence should be skipped
177+ self .assertTrue (causal_mask is None )
178+
179+ create_causal_mask_compiled = torch .compile (create_causal_mask , mode = "reduce-overhead" )
180+ causal_mask = create_causal_mask_compiled (
181+ config = config ,
182+ # we only need batch size, seq_length and dtype here - we don't care about the values of the embeddings
183+ input_embeds = torch .empty ((batch_size , sequence_length ), dtype = torch .float16 ),
184+ attention_mask = None ,
185+ cache_position = cache_position ,
186+ past_key_values = None ,
187+ position_ids = position_ids ,
188+ )
189+ # cannot be skipped under compile, should result into a triu mask
190+ self .assertTrue (torch .equal (~ torch .ones (* causal_mask .shape ).triu (diagonal = 1 ).bool (), causal_mask ))
191+
156192 def test_chunked_mask_with_left_padding_and_large_prefill (self ):
157193 # Make sure we have an attention_chunk_size in the config
158194 config = LlamaConfig (attention_chunk_size = 3 , attn_implementation = "sdpa" )
0 commit comments