[None][feat] Add encoder_max_batch_size & encoder_max_num_tokens to TorchLlmArgs#13503
[None][feat] Add encoder_max_batch_size & encoder_max_num_tokens to TorchLlmArgs#13503yechank-nvidia wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
…onMetadata Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
📝 WalkthroughWalkthroughIntroduces a new Changes
Sequence DiagramsequenceDiagram
participant Engine as Model Engine
participant Config as Config (llm_args)
participant Model as Loaded Model
participant Mixin as MmEncoderMixin
participant Metadata as AttentionMetadata
Engine->>Config: get_encoder_runtime_sizes()
Config-->>Engine: (batch_size, num_tokens)
Engine->>Model: Scan submodules
Model-->>Engine: Find MmEncoderMixin instances
loop For each MmEncoderMixin
Engine->>Mixin: setup_attn_metadata(batch_size, num_tokens)
Mixin->>Metadata: Create with runtime params
Metadata-->>Mixin: AttentionMetadata instance
Mixin->>Mixin: Store as attn_metadata
end
Engine-->>Engine: Initialization complete
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/models/modeling_multimodal_utils.py`:
- Around line 37-56: Initialize attn_metadata during construction so forwards
that run outside the engine don't hit None: after each subclass sets
metadata_cls in its __init__ (e.g., PixtralVisionModel.__init__ and
CLIPVisionModel.__init__), call setup_attn_metadata with conservative defaults
(for example setup_attn_metadata(max_num_requests=1, max_num_tokens=1)) to
create a default AttentionMetadata instance; alternatively add an
MmEncoderMixin.__init__ that expects metadata_cls to already be set and calls
setup_attn_metadata(1,1). This ensures self.attn_metadata is non-None before
_prepare_attn_metadata mutates it.
In `@tensorrt_llm/_torch/models/modeling_qwen2vl.py`:
- Around line 691-702: The new full_attn_metadata and window_attn_metadata can
remain None after construction and forward() still passes them into
prepare_attn_metadata(), which will dereference them and cause a NoneType crash;
add a fail-fast guard in either forward() (before calling
prepare_attn_metadata()) or at the start of prepare_attn_metadata() that checks
self.full_attn_metadata and self.window_attn_metadata and raises a clear
RuntimeError instructing the caller to call
setup_attn_metadata(max_num_requests, max_num_tokens) (referencing the
attributes full_attn_metadata, window_attn_metadata and the method
setup_attn_metadata) so missing initialization is detected with an explicit
error instead of a later NoneType exception.
In `@tensorrt_llm/llmapi/llm_args.py`:
- Around line 2880-2892: The encoder knobs encoder_max_batch_size and
encoder_max_num_tokens currently accept negative integers and are treated as
"unset" when falsy (e.g., 0) due to using Optional[int] and `or` fallback logic;
change their Pydantic types to Optional[NonNegativeInt] (or PositiveInt if zero
should be disallowed) so negatives fail validation, and update
get_encoder_runtime_sizes() to use explicit `is not None` checks (instead of
`or`) when deciding whether to use encoder_max_batch_size/encoder_max_num_tokens
or fall back to max_batch_size/max_num_tokens; apply the same type and fallback
fix for the duplicate fields mentioned around lines 3166-3175.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: d7fe393d-501b-42ac-90aa-6183fcb3668d
📒 Files selected for processing (9)
tensorrt_llm/_torch/models/modeling_clip.pytensorrt_llm/_torch/models/modeling_multimodal_utils.pytensorrt_llm/_torch/models/modeling_pixtral.pytensorrt_llm/_torch/models/modeling_qwen2vl.pytensorrt_llm/_torch/models/modeling_qwen3vl.pytensorrt_llm/_torch/models/modeling_radio.pytensorrt_llm/_torch/models/modeling_siglip.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/llmapi/llm_args.py
| self.full_attn_metadata: Optional[AttentionMetadata] = None | ||
| self.window_attn_metadata: Optional[AttentionMetadata] = None | ||
|
|
||
| def setup_attn_metadata(self, max_num_requests: int, | ||
| max_num_tokens: int) -> None: | ||
| # Override: Qwen2/2.5-VL uses two metadata objects (full + window | ||
| # attention) instead of the mixin's single `attn_metadata`. | ||
| kwargs = dict(max_num_requests=max_num_requests, | ||
| max_num_tokens=max_num_tokens, | ||
| kv_cache_manager=None) | ||
| self.full_attn_metadata = self.metadata_cls(**kwargs) | ||
| self.window_attn_metadata = self.metadata_cls(**kwargs) |
There was a problem hiding this comment.
Fail fast if encoder metadata was not initialized.
These fields now stay None after construction, but forward() still passes them into prepare_attn_metadata() at Lines 804-807, and that method dereferences attn_metadata on Line 785. Any caller that misses setup_attn_metadata() will now crash with a NoneType error during inference.
Suggested guard
`@torch.inference_mode`()
def forward(self, pixel_values: torch.Tensor, grid_thw: torch.Tensor,
**kwargs) -> torch.Tensor:
+ if self.full_attn_metadata is None or self.window_attn_metadata is None:
+ raise RuntimeError(
+ "Qwen2_5_VisionModel.setup_attn_metadata() must be called before forward()."
+ )
+
window_index, window_seq_lens = self.get_window_index(grid_thw)
seq_lens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2],
grid_thw[:, 0]).tolist()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/models/modeling_qwen2vl.py` around lines 691 - 702, The
new full_attn_metadata and window_attn_metadata can remain None after
construction and forward() still passes them into prepare_attn_metadata(), which
will dereference them and cause a NoneType crash; add a fail-fast guard in
either forward() (before calling prepare_attn_metadata()) or at the start of
prepare_attn_metadata() that checks self.full_attn_metadata and
self.window_attn_metadata and raises a clear RuntimeError instructing the caller
to call setup_attn_metadata(max_num_requests, max_num_tokens) (referencing the
attributes full_attn_metadata, window_attn_metadata and the method
setup_attn_metadata) so missing initialization is detected with an explicit
error instead of a later NoneType exception.
| encoder_max_batch_size: Optional[int] = Field( | ||
| default=None, | ||
| description=( | ||
| "Maximum batch size for the multimodal encoder's AttentionMetadata. " | ||
| "Falls back to `max_batch_size` when unset."), | ||
| status="prototype") | ||
|
|
||
| encoder_max_num_tokens: Optional[int] = Field( | ||
| default=None, | ||
| description=( | ||
| "Maximum number of tokens for the multimodal encoder's " | ||
| "AttentionMetadata. Falls back to `max_num_tokens` when unset."), | ||
| status="prototype") |
There was a problem hiding this comment.
Treat only None as "unset" for the new encoder runtime knobs.
get_encoder_runtime_sizes() currently uses or, so an explicit 0 is silently treated as "fallback to LLM defaults" even though the field was set. At the same time, the new knobs are plain Optional[int], so negative values are accepted and can flow into encoder AttentionMetadata. These are user-facing Pydantic limits, so they should fail fast and use explicit is not None fallback logic.
Suggested fix
- encoder_max_batch_size: Optional[int] = Field(
+ encoder_max_batch_size: Optional[PositiveInt] = Field(
default=None,
description=(
"Maximum batch size for the multimodal encoder's AttentionMetadata. "
"Falls back to `max_batch_size` when unset."),
status="prototype")
- encoder_max_num_tokens: Optional[int] = Field(
+ encoder_max_num_tokens: Optional[PositiveInt] = Field(
default=None,
description=(
"Maximum number of tokens for the multimodal encoder's "
"AttentionMetadata. Falls back to `max_num_tokens` when unset."),
status="prototype")
@@
def get_encoder_runtime_sizes(self) -> Tuple[int, int]:
"""Return encoder runtime batch and token limits.
@@
return (
- self.encoder_max_batch_size or self.max_batch_size,
- self.encoder_max_num_tokens or self.max_num_tokens,
+ self.encoder_max_batch_size
+ if self.encoder_max_batch_size is not None else self.max_batch_size,
+ self.encoder_max_num_tokens
+ if self.encoder_max_num_tokens is not None else self.max_num_tokens,
)As per coding guidelines, "Prefer PositiveInt, NonNegativeInt, NonNegativeFloat, PositiveFloat, Field(gt=0), etc. for numeric constraints in Python Pydantic fields instead of custom validators".
Also applies to: 3166-3175
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/llmapi/llm_args.py` around lines 2880 - 2892, The encoder knobs
encoder_max_batch_size and encoder_max_num_tokens currently accept negative
integers and are treated as "unset" when falsy (e.g., 0) due to using
Optional[int] and `or` fallback logic; change their Pydantic types to
Optional[NonNegativeInt] (or PositiveInt if zero should be disallowed) so
negatives fail validation, and update get_encoder_runtime_sizes() to use
explicit `is not None` checks (instead of `or`) when deciding whether to use
encoder_max_batch_size/encoder_max_num_tokens or fall back to
max_batch_size/max_num_tokens; apply the same type and fallback fix for the
duplicate fields mentioned around lines 3166-3175.
Summary
Adds two user-facing knobs to
TorchLlmArgsfor sizing the multimodalencoder's
AttentionMetadataindependently from the LLM-side runtimesizes:
encoder_max_batch_size— Maximum batch size for the multimodalencoder's
AttentionMetadata. Falls back tomax_batch_sizewhen unset.encoder_max_num_tokens— Maximum number of tokens for the multimodalencoder's
AttentionMetadata. Falls back tomax_num_tokenswhen unset.Both fields are
Optional[int]withstatus="prototype".Summary by CodeRabbit
New Features
Refactor