From ca0be5aa5992b0790831feac02c2bcf9a777b062 Mon Sep 17 00:00:00 2001 From: Mr-Neutr0n <64578610+Mr-Neutr0n@users.noreply.github.com> Date: Tue, 10 Feb 2026 03:29:33 +0530 Subject: [PATCH 1/2] Fix mRoPE position ID crash when Qwen2-VL prompts are truncated When training Qwen2.5-VL with agent-lightning + verl, prompt truncation changes the token count but image_grid_thw is computed from the original (untruncated) image_urls. This causes get_rope_index to fail with a shape mismatch because it finds fewer image tokens in the truncated input_ids than entries in image_grid_thw. After prompt truncation, count remaining image regions in the truncated token sequence and slice image_urls to match before computing image_grid_thw, ensuring consistency between the token content and the mRoPE spatial metadata. Fixes #441 --- agentlightning/verl/daemon.py | 49 ++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index 98c58f330..ca0fa3599 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -310,6 +310,45 @@ def _resolve_image_path(self, path: str) -> str: raise ValueError(f"Relative path '{path}' requires 'image_base_dir' to be set.") return os.path.join(self.image_base_dir, path) + def _count_images_in_tokens(self, token_ids: List[int]) -> int: + """Count the number of complete image regions in a token ID sequence. + + Image regions are identified by finding ``vision_start_token_id`` + followed by ``image_token_id``, matching the detection logic used by + ``get_rope_index`` in the Qwen2-VL / Qwen2.5-VL model implementation. + This is needed to reconcile ``image_grid_thw`` with truncated prompts + so that mRoPE position IDs are computed correctly. + + Args: + token_ids: List of token IDs (possibly truncated). + + Returns: + Number of image regions found in the token sequence, or ``-1`` if + the required special-token IDs could not be resolved (in which case + the caller should fall back to the original image count). + """ + # Resolve image_token_id from the processor (set during __init__) + image_token_id = getattr(self.processor, "image_token_id", None) + if image_token_id is None and hasattr(self.tokenizer, "convert_tokens_to_ids"): + image_token_id = self.tokenizer.convert_tokens_to_ids("<|image_pad|>") + + # Resolve vision_start_token_id -- not stored on the processor, so we + # try the tokenizer first and fall back to the well-known default. + vision_start_token_id = None + if hasattr(self.tokenizer, "convert_tokens_to_ids"): + vision_start_token_id = self.tokenizer.convert_tokens_to_ids("<|vision_start|>") + if vision_start_token_id is None: + vision_start_token_id = 151652 # Qwen2-VL / Qwen2.5-VL default + + if image_token_id is None: + return -1 + + count = 0 + for i in range(len(token_ids) - 1): + if token_ids[i] == vision_start_token_id and token_ids[i + 1] == image_token_id: + count += 1 + return count + def _get_image_grid_thw(self, image_urls: List[str]) -> Optional[torch.Tensor]: """Compute image_grid_thw from image URLs for M-RoPE computation. @@ -907,9 +946,17 @@ def get_train_data_batch( rollout_id_list.append(rollout_id) turn_index_list.append(turn_index) - # Compute image_grid_thw for this triplet using image_urls from prompt + # Compute image_grid_thw for this triplet using image_urls from prompt. + # After prompt truncation, some image tokens may have been removed, + # so we must reconcile image_urls with the actual images remaining + # in the (possibly truncated) prompt to avoid shape mismatches in + # get_rope_index when computing mRoPE position IDs. if self._use_mrope: image_urls = trace.get("image_urls", []) + if image_urls: + n_images_in_tokens = self._count_images_in_tokens(prompt_ids) + if n_images_in_tokens >= 0 and n_images_in_tokens < len(image_urls): + image_urls = image_urls[:n_images_in_tokens] image_grid_thw_list.append(self._get_image_grid_thw(image_urls)) elif self.trace_aggregator.get("level", "transition") == "trajectory": From bf8be4ff626cc2943dcdb888dbb9c510644620a1 Mon Sep 17 00:00:00 2001 From: Mr-Neutr0n <64578610+Mr-Neutr0n@users.noreply.github.com> Date: Tue, 2 Jun 2026 03:33:19 +0530 Subject: [PATCH 2/2] Skip get_rope_index for dropped samples in mRoPE position_ids MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit @totoluo (2026-03-03) noted that the current count-based fix still fails for a corner case where an image is truncated in the middle of the prompt — the count will increment by 1 and get_rope_index crashes at the same place. This change uses the existing is_drop_list to skip _compute_mrope_position_ids for samples that will be dropped by is_drop_mask downstream, substituting a zero placeholder. Those samples are removed by the trainer, so the placeholder pos_ids are never consumed. This is a strict superset of the previous fix. --- agentlightning/verl/daemon.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/agentlightning/verl/daemon.py b/agentlightning/verl/daemon.py index ca0fa3599..1e7866419 100644 --- a/agentlightning/verl/daemon.py +++ b/agentlightning/verl/daemon.py @@ -1088,12 +1088,22 @@ def get_train_data_batch( # For Qwen2-VL: compute 4D position_ids (batch_size, 4, seq_length) position_ids_list: list[torch.Tensor] = [] for i in range(n_transition): - pos_ids = self._compute_mrope_position_ids( - input_ids=batch_seq[i], - attention_mask=attention_mask[i], - image_grid_thw=image_grid_thw_list[i] if image_grid_thw_list else None, - ) # (4, seq_length) - position_ids_list.append(pos_ids) + if is_drop_list[i]: + # Skip get_rope_index for dropped samples (e.g. truncated + # image in the middle of the prompt) — it would crash. + # is_drop_mask removes this sample in the trainer, so the + # placeholder pos_ids are never used. + seq_len = batch_seq[i].size(0) + position_ids_list.append( + torch.zeros(4, seq_len, dtype=torch.long, device=device) + ) + else: + pos_ids = self._compute_mrope_position_ids( + input_ids=batch_seq[i], + attention_mask=attention_mask[i], + image_grid_thw=image_grid_thw_list[i] if image_grid_thw_list else None, + ) # (4, seq_length) + position_ids_list.append(pos_ids) # Stack to (batch_size, 4, seq_length) position_ids = torch.stack(position_ids_list, dim=0) else: