Skip to content

Commit 912744d

Browse files
[Fix] optimize visual token mask with caching and multi-token support (#28374)
Signed-off-by: Ferrebo <itachi971009@gmail.com> Signed-off-by: kebo01 <kebo01@baidu.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 15be507 commit 912744d

File tree

1 file changed

+29
-5
lines changed

1 file changed

+29
-5
lines changed

vllm/model_executor/models/ernie45_vl.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,6 +1367,23 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
13671367
self.make_empty_intermediate_tensors = (
13681368
self.language_model.make_empty_intermediate_tensors
13691369
)
1370+
if getattr(self.config, "im_patch_id", None):
1371+
visual_token_ids = [
1372+
token_id
1373+
for token_id in [
1374+
self.config.im_patch_id,
1375+
getattr(self.config, "image_start_token_id", None),
1376+
getattr(self.config, "image_end_token_id", None),
1377+
getattr(self.config, "video_start_token_id", None),
1378+
getattr(self.config, "video_end_token_id", None),
1379+
]
1380+
if token_id is not None
1381+
]
1382+
self._visual_token_ids_tensor_cache = torch.tensor(
1383+
visual_token_ids, dtype=torch.long
1384+
)
1385+
else:
1386+
self._visual_token_ids_tensor_cache = None
13701387

13711388
def compute_logits(
13721389
self,
@@ -1398,12 +1415,19 @@ def _vision_forward(
13981415
return image_features
13991416

14001417
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
1401-
if getattr(self.config, "im_patch_id", None) is not None:
1402-
self.visual_token_mask = (input_ids == self.config.im_patch_id).reshape(
1403-
-1, 1
1404-
)
1405-
else:
1418+
"""Set mask for visual tokens (image/video patches and delimiters)."""
1419+
if self._visual_token_ids_tensor_cache is None:
14061420
self.visual_token_mask = None
1421+
return
1422+
# Create tensor on the correct device
1423+
visual_token_ids_tensor = self._visual_token_ids_tensor_cache.to(
1424+
device=input_ids.device,
1425+
dtype=input_ids.dtype,
1426+
)
1427+
1428+
self.visual_token_mask = torch.isin(input_ids, visual_token_ids_tensor).reshape(
1429+
-1, 1
1430+
)
14071431

14081432
def get_mrope_input_positions(
14091433
self,

0 commit comments

Comments
 (0)