Skip to content

Refactor: add model-specific patch registry#966

Open
h-guo18 wants to merge 3 commits intomainfrom
haoguo/model_patches
Open

Refactor: add model-specific patch registry#966
h-guo18 wants to merge 3 commits intomainfrom
haoguo/model_patches

Conversation

@h-guo18
Copy link
Contributor

@h-guo18 h-guo18 commented Mar 4, 2026

What does this PR do?

Type of change: Refactor

  • Move model-specific patches to a separate file for clarity and extensibility.
  • Fix Kimi K2 to work OOTB

Usage

# Add a code snippet demonstrating how to use this

Testing

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, using torch.load(..., weights_only=True), avoiding pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other source, did you follow IP policy in CONTRIBUTING.md?: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

Summary by CodeRabbit

  • Bug Fixes

    • More reliable GPU detection and distributed training behavior for single- and multi-node runs.
    • Improved Kimi-K2 decoder compatibility to prevent shape and argument mismatches.
  • New Features

    • Pluggable, model-specific patching mechanism to apply targeted runtime adjustments.
  • Compatibility

    • Conditional behavior for Transformers 5.0+ to preserve correct parallelism and patching.

h-guo18 added 2 commits March 3, 2026 23:54
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 4, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 4, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 87358ea2-6205-4d53-abf0-e20f45322678

📥 Commits

Reviewing files that changed from the base of the PR and between aaa847c and 8e0c137.

📒 Files selected for processing (1)
  • modelopt/torch/speculative/plugins/hf_model_patches.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/speculative/plugins/hf_model_patches.py

📝 Walkthrough

Walkthrough

Multi-node GPU counting and FSDP argument logic in the speculative decoding training script were updated; a registry-based HF model patch system was added and integrated, moving Kimi-K2 special-casing into model-specific patches and adding a decoder-forward compatibility wrapper.

Changes

Cohort / File(s) Summary
Training scripts
examples/speculative_decoding/launch_train.sh, examples/speculative_decoding/main.py
Multi-node GPU counting added: multi-node uses per-node nvidia-smi count and computes TOTAL_GPU = NUM_NODES * GPU_PER_NODE; single-node uses Python/PyTorch to detect GPUs. FSDP is enabled when TOTAL_GPU > 1 with --fsdp 'full_shard'; for transformers >= 5.0 an --fsdp_config file is added. main.py now guards training_args.parallelism_config assignment behind a transformers version check (>= 5.0).
Model patch registry
modelopt/torch/speculative/plugins/hf_model_patches.py
New module implementing a registry (_MODEL_PATCH_REGISTRY) with register_model_patch(model_type) decorator and apply_model_patch(module) dispatcher. Registers a kimi_k2 patch that enforces transformers < 5.0, rejects flex_attention in eagle_config, appends an ignore pattern to CompressedTensorsConfig.quantization_config.ignore, and replaces module._compute_ttt_attention_mask with a patched version that repeats the mask along the batch dimension. Exports apply_model_patch.
Patch integration / transformers plugin
modelopt/torch/speculative/plugins/transformers.py
Now imports and calls apply_model_patch(self) during model modification. Removed the previous inline quantization ignore gating and the Kimi-K2 TTT mask repeat workaround; those behaviors are now driven by the model-patch mechanism.
Decoder compatibility
modelopt/torch/speculative/utils.py
Adds a compatibility wrapper for DeepseekV3DecoderLayer.forward used in the Kimi-K2 setup: maps past_key_values to past_key_value in kwargs to preserve older transformer expectation, delegating to the original forward. Minor import/formatting tweaks.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main refactoring objective: extracting model-specific patches into a dedicated registry-based system to improve code organization and extensibility.
Security Anti-Patterns ✅ Passed No security anti-patterns found: no torch.load with weights_only=False, numpy.load with allow_pickle=True, hardcoded trust_remote_code=True, eval(), exec(), or nosec comments.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch haoguo/model_patches

Comment @coderabbitai help to get the list of available commands and usage tips.

@h-guo18 h-guo18 changed the title Haoguo/model patches Refactor: add model-specific patch registry Mar 4, 2026
@h-guo18 h-guo18 marked this pull request as ready for review March 4, 2026 01:33
@h-guo18 h-guo18 requested a review from a team as a code owner March 4, 2026 01:33
@h-guo18 h-guo18 requested a review from ChenhanYu March 4, 2026 01:33
@h-guo18 h-guo18 self-assigned this Mar 4, 2026
@h-guo18 h-guo18 marked this pull request as draft March 4, 2026 01:33
@h-guo18 h-guo18 marked this pull request as ready for review March 4, 2026 01:36
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
examples/speculative_decoding/main.py (2)

173-176: ⚠️ Potential issue | 🔴 Critical

CRITICAL: trust_remote_code=True is hardcoded.

Per coding guidelines, trust_remote_code=True should not be hardcoded when loading transformers models. This flag tells Transformers to execute arbitrary Python shipped with a checkpoint, which is an RCE vector if the model source is untrusted. The flag should be exposed as a caller-configurable parameter defaulting to False.

This same issue appears at multiple locations in this file (lines 174, 176, 185, 188, 192, 195).

Suggested approach

Add a command-line argument to control this behavior:

 `@dataclass`
 class ModelArguments:
     model_name_or_path: str | None = field(default="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
+    trust_remote_code: bool = field(
+        default=False,
+        metadata={"help": "Whether to trust remote code when loading models/tokenizers."}
+    )

Then use model_args.trust_remote_code instead of hardcoding True.

As per coding guidelines: "trust_remote_code=True hardcoded for transformers model or tokenizer loading" is flagged as CRITICAL security issue.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 173 - 176, The code
currently hardcodes trust_remote_code=True when calling
load_vlm_or_llm_with_kwargs and transformers.AutoTokenizer.from_pretrained; add
a command-line or function argument (e.g., model_args.trust_remote_code
defaulting to False) and use that variable instead of True in all places in this
file (references: load_vlm_or_llm_with_kwargs,
transformers.AutoTokenizer.from_pretrained); update the CLI parser or function
signature to expose trust_remote_code to callers and replace every hardcoded
True occurrence noted in the review with model_args.trust_remote_code.

148-155: ⚠️ Potential issue | 🟠 Major

Potential AttributeError when cp_size > 1 on older transformers versions.

If TRANSFORMERS_VERSION < 5.0 and cp_size > 1, training_args.parallelism_config is never assigned, but line 155 attempts to access training_args.parallelism_config.sp_backend, which will raise an AttributeError.

Suggested fix
     if Version("5.0") <= TRANSFORMERS_VERSION:
         training_args.parallelism_config = ParallelismConfig(
             cp_size=training_args.cp_size, dp_shard_size=training_args.dp_shard_size
         )
-    if training_args.cp_size > 1:
-        patch_ring_attention_for_ttt()
-        # Specific patch to accelerate 1.12.0. Removable after move to 1.13.0
-        training_args.parallelism_config.sp_backend = None
+        if training_args.cp_size > 1:
+            patch_ring_attention_for_ttt()
+            # Specific patch to accelerate 1.12.0. Removable after move to 1.13.0
+            training_args.parallelism_config.sp_backend = None

Or handle the case where parallelism_config is not available for older versions.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/main.py` around lines 148 - 155, If
TRANSFORMERS_VERSION < 5.0 and training_args.cp_size > 1,
training_args.parallelism_config may be unset and accessing
training_args.parallelism_config.sp_backend will raise AttributeError; fix by
ensuring training_args.parallelism_config is created before use (e.g.,
instantiate ParallelismConfig and assign to training_args.parallelism_config
when cp_size>1 or guard access), then call patch_ring_attention_for_ttt() and
only modify training_args.parallelism_config.sp_backend if the attribute exists;
update the block around TRANSFORMERS_VERSION, ParallelismConfig,
training_args.cp_size, patch_ring_attention_for_ttt, and sp_backend accordingly.
🧹 Nitpick comments (1)
modelopt/torch/speculative/utils.py (1)

444-453: Consider removing past_key_values from kwargs after mapping to avoid potential conflicts.

The patch adds past_key_value but doesn't remove past_key_values from kwargs. If the original forward method has logic that checks for or uses past_key_values, this could cause unexpected behavior or duplicate key issues.

Suggested fix
     def patched_decoder_layer_fwd(self, *args, **kwargs):
-        kwargs["past_key_value"] = kwargs.get("past_key_values")
+        kwargs["past_key_value"] = kwargs.pop("past_key_values", None)
         return original_decoder_layer_forward(self, *args, **kwargs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/utils.py` around lines 444 - 453, The
patched_forward currently copies kwargs["past_key_values"] to
kwargs["past_key_value"] but leaves the original key present, which can cause
duplicate/conflicting behavior; in patched_decoder_layer_fwd (which replaces
kimi_k2_module.DeepseekV3DecoderLayer.forward and calls
original_decoder_layer_forward), after mapping set kwargs["past_key_value"] =
kwargs.get("past_key_values") and then remove the original by popping
kwargs.pop("past_key_values", None) before calling
original_decoder_layer_forward(self, *args, **kwargs) so only the expected key
remains.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/speculative/plugins/hf_model_patches.py`:
- Line 28: Module defines a module-level export list as all =
["apply_model_patch"] which is a typo; rename it to __all__ =
["apply_model_patch"] so Python's import mechanism recognizes the intended
exported symbol (apply_model_patch). Update the variable name in the module
where apply_model_patch is defined to use double-underscore __all__ to control
from module import * exports.
- Around line 74-76: The patched _patched_compute_ttt_attention_mask calls
.repeat() on original_func(...) but original_func (the underlying
_compute_ttt_attention_mask) can return a BlockMask (from flex_attention) which
has no repeat method; update _patched_compute_ttt_attention_mask to detect the
return type (e.g., isinstance(tensor_mask, BlockMask) or hasattr(tensor_mask,
"repeat")) and handle both cases: if it's a torch.Tensor call
.repeat(batch_size,1,1,1) as now, and if it's a BlockMask return it unchanged
(or apply the BlockMask-appropriate transformation) so you avoid calling .repeat
on BlockMask; reference the symbols _patched_compute_ttt_attention_mask,
original_func, and BlockMask when making the change.
- Around line 64-67: The code incorrectly accesses a nested attribute on a
CompressedTensorsConfig instance; locate where quant_config is retrieved from
module.config (variable name quant_config) and replace the incorrect access
quant_config.quantization_config.ignore with the direct attribute
quant_config.ignore so it appends "re:.*eagle_module.*" to the
CompressedTensorsConfig.ignore list (ensure you only modify the line that
mutates the ignore list and keep the isinstance check for
CompressedTensorsConfig).

---

Outside diff comments:
In `@examples/speculative_decoding/main.py`:
- Around line 173-176: The code currently hardcodes trust_remote_code=True when
calling load_vlm_or_llm_with_kwargs and
transformers.AutoTokenizer.from_pretrained; add a command-line or function
argument (e.g., model_args.trust_remote_code defaulting to False) and use that
variable instead of True in all places in this file (references:
load_vlm_or_llm_with_kwargs, transformers.AutoTokenizer.from_pretrained); update
the CLI parser or function signature to expose trust_remote_code to callers and
replace every hardcoded True occurrence noted in the review with
model_args.trust_remote_code.
- Around line 148-155: If TRANSFORMERS_VERSION < 5.0 and training_args.cp_size >
1, training_args.parallelism_config may be unset and accessing
training_args.parallelism_config.sp_backend will raise AttributeError; fix by
ensuring training_args.parallelism_config is created before use (e.g.,
instantiate ParallelismConfig and assign to training_args.parallelism_config
when cp_size>1 or guard access), then call patch_ring_attention_for_ttt() and
only modify training_args.parallelism_config.sp_backend if the attribute exists;
update the block around TRANSFORMERS_VERSION, ParallelismConfig,
training_args.cp_size, patch_ring_attention_for_ttt, and sp_backend accordingly.

---

Nitpick comments:
In `@modelopt/torch/speculative/utils.py`:
- Around line 444-453: The patched_forward currently copies
kwargs["past_key_values"] to kwargs["past_key_value"] but leaves the original
key present, which can cause duplicate/conflicting behavior; in
patched_decoder_layer_fwd (which replaces
kimi_k2_module.DeepseekV3DecoderLayer.forward and calls
original_decoder_layer_forward), after mapping set kwargs["past_key_value"] =
kwargs.get("past_key_values") and then remove the original by popping
kwargs.pop("past_key_values", None) before calling
original_decoder_layer_forward(self, *args, **kwargs) so only the expected key
remains.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a4fde49 and aaa847c.

📒 Files selected for processing (5)
  • examples/speculative_decoding/launch_train.sh
  • examples/speculative_decoding/main.py
  • modelopt/torch/speculative/plugins/hf_model_patches.py
  • modelopt/torch/speculative/plugins/transformers.py
  • modelopt/torch/speculative/utils.py

from packaging.version import Version
from transformers.utils.quantization_config import CompressedTensorsConfig

all = ["apply_model_patch"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Typo: all should be __all__.

The module-level all variable has no special meaning in Python. It should be __all__ (with double underscores) to properly control what gets exported when using from module import *.

Fix
-all = ["apply_model_patch"]
+__all__ = ["apply_model_patch"]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_model_patches.py` at line 28, Module
defines a module-level export list as all = ["apply_model_patch"] which is a
typo; rename it to __all__ = ["apply_model_patch"] so Python's import mechanism
recognizes the intended exported symbol (apply_model_patch). Update the variable
name in the module where apply_model_patch is defined to use double-underscore
__all__ to control from module import * exports.

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant