@@ -1367,6 +1367,23 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
13671367 self .make_empty_intermediate_tensors = (
13681368 self .language_model .make_empty_intermediate_tensors
13691369 )
1370+ if getattr (self .config , "im_patch_id" , None ):
1371+ visual_token_ids = [
1372+ token_id
1373+ for token_id in [
1374+ self .config .im_patch_id ,
1375+ getattr (self .config , "image_start_token_id" , None ),
1376+ getattr (self .config , "image_end_token_id" , None ),
1377+ getattr (self .config , "video_start_token_id" , None ),
1378+ getattr (self .config , "video_end_token_id" , None ),
1379+ ]
1380+ if token_id is not None
1381+ ]
1382+ self ._visual_token_ids_tensor_cache = torch .tensor (
1383+ visual_token_ids , dtype = torch .long
1384+ )
1385+ else :
1386+ self ._visual_token_ids_tensor_cache = None
13701387
13711388 def compute_logits (
13721389 self ,
@@ -1398,12 +1415,19 @@ def _vision_forward(
13981415 return image_features
13991416
14001417 def _set_visual_token_mask (self , input_ids : torch .Tensor ) -> None :
1401- if getattr (self .config , "im_patch_id" , None ) is not None :
1402- self .visual_token_mask = (input_ids == self .config .im_patch_id ).reshape (
1403- - 1 , 1
1404- )
1405- else :
1418+ """Set mask for visual tokens (image/video patches and delimiters)."""
1419+ if self ._visual_token_ids_tensor_cache is None :
14061420 self .visual_token_mask = None
1421+ return
1422+ # Create tensor on the correct device
1423+ visual_token_ids_tensor = self ._visual_token_ids_tensor_cache .to (
1424+ device = input_ids .device ,
1425+ dtype = input_ids .dtype ,
1426+ )
1427+
1428+ self .visual_token_mask = torch .isin (input_ids , visual_token_ids_tensor ).reshape (
1429+ - 1 , 1
1430+ )
14071431
14081432 def get_mrope_input_positions (
14091433 self ,
0 commit comments