@@ -323,19 +323,21 @@ def batch_type(self) -> Union[FlashBatch, PaddedBatch]:
323323 def embed (self , batch : Union [FlashBatch , PaddedBatch ]) -> List [Embedding ]:
324324 if isinstance (batch , PaddedBatch ):
325325 input_lens = batch .attention_mask .cumsum (- 1 )[:, - 1 ].to (torch .int32 )
326- max_input_lens = input_lens . max (). item ()
326+ max_input_lens = 0 # This value will not be used
327327 cu_seqlens = torch .cat (
328328 (input_lens .new_tensor ([0 ]), input_lens .cumsum (- 1 ).int ())
329329 )
330330 mask = batch .attention_mask .bool ()
331- batch_size = input_lens .size (0 )
331+ bsz , tgt_len = mask .size ()
332+ min_val = torch .finfo (self .dtype ).min
332333 attn_mask = torch .full (
333- [batch_size , 1 , 1 , mask . shape [ - 1 ] ],
334- fill_value = torch . finfo ( self . dtype ). min ,
334+ [bsz , 1 , tgt_len , tgt_len ],
335+ fill_value = min_val ,
335336 device = self .device ,
336337 dtype = self .dtype ,
337338 )
338- attn_mask .masked_fill_ (mask [:, None , None , :], 0 )
339+ expanded_mask = mask [:, None , None , :].expand (bsz , 1 , tgt_len , tgt_len )
340+ attn_mask = attn_mask .masked_fill (expanded_mask , 0.0 )
339341 elif isinstance (batch , FlashBatch ):
340342 cu_seqlens = batch .cu_seqlens
341343 mask = None
0 commit comments