Skip to content

[2/3][Feat]: Offline DFlash training#1343

Merged
ChenhanYu merged 11 commits intomainfrom
haoguo/dflash-offline
Apr 26, 2026
Merged

[2/3][Feat]: Offline DFlash training#1343
ChenhanYu merged 11 commits intomainfrom
haoguo/dflash-offline

Conversation

@h-guo18
Copy link
Copy Markdown
Contributor

@h-guo18 h-guo18 commented Apr 24, 2026

What does this PR do?

Type of change: new feature

Part 2 of a 3-PR series splitting #1271:

Changes:

  • Add dflash_offline flag to DFlashConfig for training from pre-computed hidden states; deletes base model layers to save memory.
  • Add Pydantic validators on DFlashConfig:
    • _derive_dflash_offline — auto-derive dflash_offline from data_args.offline_data_path in validation context. Not user-configurable: any user-supplied value is overridden by the derived value.
    • _resolve_mask_token_id — auto-detect dflash_mask_token_id from tokenizer.mask_token_id.
    • _check_mask_token_id — fail fast if unset after resolution.
  • HFDFlashModel.modify(): select num_orig_hidden_layers when offline; pick _base_model_lm_head device when no base layers present; drop base-model layers module.
  • HFDFlashModel.forward(): add offline branch — consumes precomputed base_model_outputs via DFlashBaseModelOutput.from_offline_dict, and when dflash_self_logit_distillation is enabled with base_model_logits absent, recomputes logits from base_model_hidden_states via _base_model_lm_head. Raises a clear error from the non-training / pseudo_speculative_generate paths when dflash_offline=True, since base-model layers have been deleted.
  • DFlashBaseModelOutput dataclass in modeling_dflash.py (with from_offline_dict classmethod) to unify online/offline output shapes. aux_hidden_states is required in from_offline_dict so missing keys fail fast at the entry point rather than deeper in the forward.
  • examples/speculative_decoding/main.py: replace inline mask_token_id auto-detect with DFlashConfig.model_validate(dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args}).

Silent bug fix — add_generation_templateadd_generation_prompt

The pre-refactor compute_hidden_states_hf.py passed add_generation_template=False to tokenizer.apply_chat_template. This kwarg does not exist on HF apply_chat_template and was being silently ignored, so the intended "don't append a generation prompt" behavior was never actually applied. The new tokenize_with_loss_mask helper in examples/speculative_decoding/collect_hidden_states/common.py uses the correct add_generation_prompt=False. This is a real behavior change for anyone re-dumping hidden states: trailing generation prompts that were previously appended to the tokenized sequences will no longer be included.

Testing

  • New tests:

    • tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py — CPU unit tests for convert path (online keeps base layers, offline deletes them; num_orig_hidden_layers drives target_layer_ids in offline mode) and DFlashConfig._derive_dflash_offline validator.
    • TestDFlashOfflineForwardGPU in tests/gpu/torch/speculative/plugins/test_hf_dflash.py — GPU forward smoke with precomputed base_model_outputs, plus the dflash_self_logit_distillation logit-recompute path.
  • training test:
    image image

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ — additive dflash_offline flag defaulting to False; validators fall through when context not provided.
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: ✅ — see Testing section above.
  • Did you update Changelog?: ✅

TODO (follow-up)

  • Update examples/speculative_decoding/collect_hidden_states/compute_hidden_states_*.py to support DFlash offline data. Current scripts are Eagle-specific — they hardcode the [2, N/2, N-3] aux-layer selection and emit {input_ids, hidden_states, aux_hidden_states}. DFlash offline needs:
    • Aux layer indices driven by build_target_layer_ids(num_orig_hidden_layers, num_draft_layers) (or a configurable list), not the Eagle triplet.
    • base_model_hidden_states key (last-layer hidden) so DFlashBaseModelOutput.from_offline_dict + the dflash_self_logit_distillation recompute path can consume it.
    • Optional base_model_logits dump so offline training can skip the self-distillation logit recomputation when logits are available.

Additional Information

Base branch is #1296 (file reorg). Retarget to main once #1296 merges.

Summary by CodeRabbit

  • New Features

    • Offline DFlash speculative-decoding training from precomputed base-model hidden states
    • Answer-only-loss training with persisted loss masks and optional chat-template support
    • Flexible auxiliary-layer selection via CLI and an exposed default aux-layer helper
    • Auto-derived offline flag in config and automatic memory optimization during offline conversion
  • Documentation

    • Updated guides for offline pipeline, aux-layer selection, and loss-masking options
  • Tests

    • New unit, GPU, and regression tests covering offline conversion, training, and config derivation

@h-guo18 h-guo18 requested review from a team as code owners April 24, 2026 21:49
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 24, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 4b6bf289-8d97-4f81-8fc9-2ef6491d3172

📥 Commits

Reviewing files that changed from the base of the PR and between 063eeb5 and 6c11e95.

📒 Files selected for processing (1)
  • tests/regression/torch/speculative/test_dflash_offline.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/regression/torch/speculative/test_dflash_offline.py

📝 Walkthrough

Walkthrough

Adds offline DFlash speculative-decoding training: pre-computed base-model hidden-state dumps, shared CLI/tokenization utilities for aux-layer selection and answer-only loss, DFlashConfig derivation for offline mode, model conversion/forward changes to accept offline inputs and free base-model memory, and tests plus changelog/docs updates.

Changes

Cohort / File(s) Summary
Changelog
CHANGELOG.rst
Documented offline DFlash training flow, --aux-layers usage, answer-only-loss option, and corrected Conv3D kernel module path.
Shared CLI / Tokenization
examples/speculative_decoding/collect_hidden_states/common.py
New helpers for CLI and tokenization: add_aux_layers_args, resolve_aux_layers, add_answer_only_loss_args, chat-template loader/validator, and tokenize_with_loss_mask producing input_ids + aligned loss_mask.
Hidden-state dump scripts
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py, examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
Wired common helpers for aux-layer selection and optional answer-only-loss; unified tokenization producing loss_mask; removed hardcoded aux-layer selection.
Data pipeline / Dataset wiring
examples/speculative_decoding/eagle_utils.py, modelopt/torch/speculative/eagle/utils.py
Pass answer_only_loss into OfflineSupervisedDataset; dataset now requires and returns stored loss_mask when enabled, otherwise produces all-ones mask.
Main CLI / Conversion
examples/speculative_decoding/main.py
Use DFlashConfig.model_validate(...context...) to derive config (including mask token and offline flag) and pass validated config to conversion.
Config & Validation
modelopt/torch/speculative/config.py
Added dflash_offline: bool; pre-validation derives it from data_args.offline_data_path and resolves dflash_mask_token_id from tokenizer; post-validation enforces mask-token presence.
DFlash model surface
modelopt/torch/speculative/dflash/dflash_model.py
DFlashModel.modify now stores config.dflash_offline on the model instance as self.dflash_offline.
HFDFlash plugin
modelopt/torch/speculative/plugins/hf_dflash.py
Offline-mode conversion removes base-model layers to save memory; forward accepts base_model_outputs via DFlashBaseModelOutput.from_offline_dict, extracts target_hidden, optionally computes logits for self-logit distillation, accepts external loss_mask, and prevents eval/generation in offline mode.
Model interop dataclass
modelopt/torch/speculative/plugins/modeling_dflash.py
Added DFlashBaseModelOutput dataclass with from_offline_dict factory to represent offline base-model outputs (aux_hidden_states, optional logits).
EAGLE plugin helper
modelopt/torch/speculative/plugins/hf_eagle.py
Exported default_eagle_aux_layer_ids(num_layers) and refactored default-layer selection to use it.
Tests — unit
tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py
New unit tests validating offline conversion: dflash_offline derivation, base-layer removal, and target-layer-id derivation from original layer count.
Tests — GPU
tests/gpu/torch/speculative/plugins/test_hf_dflash.py
Added GPU tests for offline forward: differentiable finite loss and self-logit-distillation behavior when base logits are absent.
Tests — regression / e2e
tests/regression/torch/speculative/test_dflash_offline.py
New regression test exercising hidden-state dump with compute_hidden_states_hf.py and end-to-end offline training convergence check.

Sequence Diagram(s)

sequenceDiagram
    participant CLI as User/Script
    participant Config as DFlashConfig
    participant Dump as Hidden-State Dump
    participant Dataset as OfflineSupervisedDataset
    participant HFDFlash as HFDFlashModel
    participant Loss as Loss Computation

    Note over CLI,HFDFlash: Online training (existing)
    CLI->>HFDFlash: forward(input_ids)
    activate HFDFlash
    HFDFlash->>HFDFlash: run base model -> hidden_states (no_grad)
    HFDFlash->>HFDFlash: build target_hidden from hidden_states
    HFDFlash->>Loss: compute loss (KD / reconstruction)
    deactivate HFDFlash
    Loss-->>CLI: loss

    Note over CLI,Loss: Offline training (new)
    CLI->>Config: model_validate(context={data_args, tokenizer})
    activate Config
    Config->>Config: derive dflash_offline from data_args.offline_data_path
    Config-->>CLI: validated config (dflash_offline=True)
    deactivate Config
    CLI->>Dump: run compute_hidden_states_* (produce .pt with aux_hidden_states, loss_mask)
    Dump-->>Dataset: store .pt files
    CLI->>Dataset: load batch (aux_hidden_states, loss_mask)
    CLI->>HFDFlash: forward(base_model_outputs=dict)
    activate HFDFlash
    HFDFlash->>HFDFlash: DFlashBaseModelOutput.from_offline_dict()
    HFDFlash->>HFDFlash: extract target_hidden
    HFDFlash->>Loss: compute loss using provided loss_mask
    deactivate HFDFlash
    Loss-->>CLI: loss
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 77.78% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title "[2/3][Feat]: Offline DFlash training" accurately and specifically describes the main feature being added - offline DFlash training support as the second part of a multi-PR series.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed PR's security implementation is sound with torch.load() using weights_only=True and trust_remote_code parameter properly exposed with safe defaults.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch haoguo/dflash-offline

Comment @coderabbitai help to get the list of available commands and usage tips.

@h-guo18 h-guo18 self-assigned this Apr 24, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 11

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (1)

128-136: ⚠️ Potential issue | 🟠 Major

Skip rows without IDs instead of asserting.

CI is already failing here with AssertionError: conversation_id is required. One malformed sample currently aborts the entire dump before you reach the later invalid-row handling, so this needs to degrade to a skip/count path instead of a hard assert.

Suggested change
     def keep_conversation(entry):
         conversation_id = entry.get("conversation_id", entry.get("uuid", None))
-        assert conversation_id is not None, "conversation_id is required"
+        if conversation_id is None:
+            return False
         output_file = args.output_dir / f"{conversation_id}.pt"
         return not output_file.exists()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py`
around lines 128 - 136, Replace the hard assert in keep_conversation with logic
that skips rows missing an ID: check conversation_id =
entry.get("conversation_id", entry.get("uuid")); if it's None, increment a skip
counter (e.g., initialize skipped_missing_id = 0 above and use nonlocal or a
mutable container to update it) and optionally log a warning, then return False;
otherwise compute output_file = args.output_dir / f"{conversation_id}.pt" and
return not output_file.exists(); leave the dataset =
dataset.filter(keep_conversation) call as-is so malformed samples are skipped
instead of aborting.
🧹 Nitpick comments (6)
tests/unit/torch/export/test_hf_spec_rope_export.py (1)

37-46: Prefer a real namespace or a spec’d mock here.

MagicMock is the reason Lines 40-42 need a manual reset, and it can still mask future typos in _export_config by auto-creating missing attributes. A SimpleNamespace or Mock(spec_set=...) will make these rope-resolution tests fail loudly when the exporter asks for the wrong field.

♻️ Tighten the test double
-from unittest.mock import MagicMock
+from types import SimpleNamespace
...
-    model = MagicMock()
-    model.eagle_config.eagle_decoder_type = "llama"
-    model.eagle_config.rope_scaling = {"rope_type": rope_type, "rope_theta": rope_theta}
-    # rope_theta lives inside rope_scaling in transformers 5.x; clear the top-level attr
-    # so the fallback path is exercised instead of MagicMock's auto-attr.
-    model.eagle_config.rope_theta = None
-    model.eagle_export_rope_scaling = eagle_export_rope_scaling
-    model._draft_model_config = None
-    model.config.rope_scaling = None
-    model.config.rope_theta = None
+    model = SimpleNamespace(
+        eagle_config=SimpleNamespace(
+            eagle_decoder_type="llama",
+            rope_scaling={"rope_type": rope_type, "rope_theta": rope_theta},
+            rope_theta=None,
+        ),
+        eagle_export_rope_scaling=eagle_export_rope_scaling,
+        _draft_model_config=None,
+        config=SimpleNamespace(rope_scaling=None, rope_theta=None),
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/export/test_hf_spec_rope_export.py` around lines 37 - 46,
Replace the loose MagicMock used for model with a stricter test double so
missing/typo'd attributes fail loudly: instead of model = MagicMock(), create
either a SimpleNamespace populated with the exact attributes used (e.g.,
eagle_config, eagle_export_rope_scaling, _draft_model_config, config) or use
unittest.mock.Mock(spec_set=YourModelConfigSpec) configured with only the needed
fields; update the test setup around model.eagle_config.*,
model.eagle_export_rope_scaling, model._draft_model_config, and model.config.*
to remove the manual resets (model.eagle_config.rope_theta = None etc.) since
the stricter double will accurately reflect missing attributes and catch
incorrect attribute access in the exporter.
modelopt/torch/speculative/plugins/hf_eagle.py (5)

624-630: Consider using ValueError instead of assert for clearer error messages.

Line 626 uses assert for validating required kwargs, which raises AssertionError with minimal context. In production with assertions disabled (-O flag), this check would be skipped entirely.

💡 More explicit error handling
         if self.eagle_offline:
             # Parse base model outputs forwarded from teacher
-            assert "base_model_outputs" in kwargs
+            if "base_model_outputs" not in kwargs:
+                raise ValueError("eagle_offline=True requires 'base_model_outputs' in forward kwargs")
             base_outputs = EagleBaseModelOutput.from_offline_dict(kwargs["base_model_outputs"])
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_eagle.py` around lines 624 - 630,
Replace the runtime assertion with explicit error handling: in the eagle_offline
branch where the code currently does `assert "base_model_outputs" in kwargs`,
raise a ValueError (e.g., `raise ValueError("missing required kwarg
'base_model_outputs' for Eagle offline mode")`) so the check is not skipped
under -O and provides a clear message; keep the subsequent logic that calls
EagleBaseModelOutput.from_offline_dict(kwargs["base_model_outputs"]) and the
fallback to computing logits via
self._base_model_lm_head(base_outputs.out_hiddens) unchanged.

241-251: Docstring describes KL divergence but implementation computes cross-entropy.

The implementation computes -sum(softmax(ref) * log_softmax(lora)), which is cross-entropy H(p, q), not KL(p || q) = H(p, q) - H(p). For optimization purposes they're equivalent (H(p) is constant w.r.t. LoRA params), but the docstring could be more precise.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_eagle.py` around lines 241 - 251, The
docstring for _preservation_loss currently claims it computes KL(ref || lora)
but the implementation computes cross-entropy H(ref, lora) = -sum(softmax(ref) *
log_softmax(lora)); update the docstring to accurately describe the computed
quantity (cross-entropy weighted by eagle_base_lora_preservation_loss_weight)
or, if you truly want KL, modify the implementation to subtract the entropy term
H(ref) (i.e., compute H(ref, lora) - H(ref)) using ref_logits.detach() so only
the LoRA params are optimized; refer to the _preservation_loss function,
ref_logits, lora_logits, and eagle_base_lora_preservation_loss_weight when
making the change.

49-54: Edge case: Index 1 is always included regardless of num_layers.

For models with num_layers < 2, this returns indices that don't exist (e.g., num_layers=1 yields {0, 1}). While unlikely in practice for EAGLE-3 use cases, consider adding a guard.

💡 Optional defensive check
 def default_eagle_aux_layer_ids(num_layers: int) -> list[int]:
     """Default EAGLE-3 auxiliary hidden-state layer IDs (0-based).

     Three layers near the start, middle, and end of the stack.
     """
+    if num_layers < 2:
+        return list(range(num_layers))
     return sorted({1, max(0, num_layers // 2 - 1), max(0, num_layers - 4)})
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_eagle.py` around lines 49 - 54, The
function default_eagle_aux_layer_ids always includes index 1 which can be out of
range for small models; update default_eagle_aux_layer_ids to build the
candidate set (1, max(0, num_layers//2 - 1), max(0, num_layers - 4)), then
filter out any indices not in the valid range 0 <= idx < num_layers before
returning the sorted list so you never return non-existent layer IDs for small
num_layers.

844-855: Prefer tensor.clone() over copy.deepcopy() for tensors.

copy.deepcopy works but is heavier than necessary for tensors. input_ids.clone() is the idiomatic and more efficient approach.

💡 Simpler tensor copy
     def get_ground_truth(self, input_ids, osl):
         """This function returns ground truth output tokens from the base model."""
-        input_ids = copy.deepcopy(input_ids).to(torch.cuda.current_device())
+        input_ids = input_ids.clone().to(torch.cuda.current_device())
         for _ in range(osl):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_eagle.py` around lines 844 - 855, The
get_ground_truth implementation uses copy.deepcopy(input_ids) which is
inefficient for tensors; replace that with input_ids.clone() and then move it to
the CUDA device (e.g., input_ids.clone().to(torch.cuda.current_device())) so
HFARValidation.get_ground_truth works identically but uses the idiomatic,
lighter-weight tensor copy for the input_ids variable.

117-131: Global torch._dynamo.config.suppress_errors modification affects all compiled code.

Setting suppress_errors = True globally will mask compilation errors in all torch.compile calls in the process, not just this model. This could hide issues in other code paths.

Consider making this configurable via a model attribute or documenting the side effect:

     def _activate_torch_compile(self):
         import torch._dynamo

-        torch._dynamo.config.suppress_errors = True  # Allow fallback to eager mode
+        # NOTE: This globally affects all torch.compile calls in the process
+        if getattr(self, "eagle_dynamo_suppress_errors", True):
+            torch._dynamo.config.suppress_errors = True  # Allow fallback to eager mode
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_eagle.py` around lines 117 - 131, The
code currently sets the global torch._dynamo.config.suppress_errors in
_activate_torch_compile which affects all torch.compile calls; change this to be
non-global by making the behavior configurable on the instance (e.g.,
self.suppress_dynamo_errors) and by temporarily setting suppress_errors only
around your compile loop (save original = torch._dynamo.config.suppress_errors,
set it to self.suppress_dynamo_errors, run the compile_targets loop invoking
torch.compile on getattr(self, name), then restore the original value in a
finally block); alternatively document the side effect and gate the global
change behind the new attribute so callers can opt into the global suppression.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@CHANGELOG.rst`:
- Line 38: Update the module path in the changelog entry so it points to the new
post-reorg location: replace the old path string
"modelopt.torch.quantization.src.conv" with
"modelopt.torch.kernels.quantization.conv" in the Conv3D kernel description;
ensure the sentence still references nn.Conv3d and ModelOpt PTQ and that grouped
convolution and inference/training notes remain unchanged.

In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py`:
- Line 199: The call building "eagle3_layers_to_capture" passes
set(resolve_aux_layers(args, num_hidden_layers)) but num_hidden_layers may be
missing from the model config; validate that AutoConfig exposes a concrete
integer num_hidden_layers before calling resolve_aux_layers. In practice, obtain
the config (AutoConfig), check getattr(config, "num_hidden_layers", None) is an
int > 0 (or raise a clear ValueError/RuntimeError), and only then call
resolve_aux_layers(args, num_hidden_layers); include a helpful error message
referencing the model/config when failing so the failure matches the HF dumper
guard.
- Around line 26-27: The TRTLLM dumper must accept and persist the loss_mask so
generated dumps are compatible with
OfflineSupervisedDataset(answer_only_loss=True); update
compute_hidden_states_trtllm.py to parse/propagate the loss_mask from the
dataset or aux-layer helpers (use add_aux_layers_args and resolve_aux_layers
where aux fields are collected) and include a loss_mask field in the saved dump
payload produced by the TRTLLM dumper (the code path around the dump/save call
near the dumper or dump function at the end of the script). Ensure the emitted
dump keys match what OfflineSupervisedDataset expects (loss_mask present when
answer_only_loss=True) to avoid hard failures.

In `@modelopt/torch/speculative/plugins/__init__.py`:
- Around line 32-35: Wrap each HF plugin import in its own guarded
import_plugin("transformers") context (or try/except) so a failure importing
hf_dflash does not stop hf_eagle and hf_medusa from being attempted;
specifically, change the block that currently does with
import_plugin("transformers"): from .hf_dflash import * from .hf_eagle import *
from .hf_medusa import * to three independent guarded imports for the modules
hf_dflash, hf_eagle, and hf_medusa (using import_plugin("transformers")
separately or try/except around each import) and ensure any import exception is
caught and logged rather than letting it abort the whole block.

In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 447-455: When running offline dflash with
dflash_self_logit_distillation=True the code looks for
kwargs["base_model_outputs"]["base_model_hidden_states"] but the dumper stores
the last layer under "hidden_states"; update the offline-path in the block
guarded by self.dflash_offline (the callsite using
DFlashBaseModelOutput.from_offline_dict and the variable base_outputs) to
prefer/normalize the dumper schema: read "hidden_states" if
"base_model_hidden_states" is missing (or adjust
DFlashBaseModelOutput.from_offline_dict to expose base_model_hidden_states from
hidden_states), then pass those hiddens into self._base_model_lm_head to compute
logits so logits are correctly recomputed offline.
- Around line 163-167: When computing num_target_layers in the dflash
conversion, seed and persist the original layer count before offline mode
removes layers: if self.dflash_offline is true and base_config lacks
num_orig_hidden_layers, capture base_config.num_hidden_layers into
base_config.num_orig_hidden_layers (or into a persistent attribute on self) on
first entry, then use base_config.num_orig_hidden_layers for num_target_layers
thereafter; update the logic around num_target_layers, self.dflash_offline, and
mtsp.convert so the original count is available after layer deletion and
repeated offline calls succeed.

In `@modelopt/torch/speculative/plugins/hf_eagle.py`:
- Around line 344-350: The cache key in _get_ttt_attention_mask currently uses
only ttt_step but _compute_ttt_attention_mask depends on seq_length; change the
key in self._cached_attn_blk_masks to include seq_length (e.g., use (ttt_step,
seq_length)) when storing and retrieving the mask so you won't return an
incorrectly-sized mask for different sequence lengths; update both the
.update/store call and the return lookup to use that composite key (referencing
_get_ttt_attention_mask, _compute_ttt_attention_mask, and
_cached_attn_blk_masks).

In `@modelopt/torch/speculative/plugins/hf_medusa.py`:
- Around line 112-123: The model forward call in hf_medusa.py incorrectly passes
rcache_position as a kwarg to self.model; change that argument name to
cache_position (i.e., replace rcache_position=cache_position with
cache_position=cache_position) inside the forward/inference block where
self.model(...) is invoked so it matches Transformers 4.56.0 causal LM forward
signatures and avoids kwarg validation failures in the Medusa plugin.
- Around line 126-145: The base-model loss currently compares logits and labels
at the same positions; update the computation in the block guarded by
freeze_base_model so logits and labels are causally shifted to match HF LM
convention and the Medusa loss: compute shift_logits = logits[..., :-1, :] and
shift_labels = labels[..., 1:], and if logits_to_keep is used ensure the same
slice is applied to labels before shifting; then pass shift_logits.view(-1,
shift_logits.shape[-1]) and shift_labels.view(-1) into loss_fct
(CrossEntropyLoss) to compute base_model_loss consistently with medusa_heads and
medusa loss behavior.

In `@modelopt/torch/speculative/plugins/modeling_eagle.py`:
- Around line 206-215: from_offline_dict currently reads out_hiddens from
d.get("base_model_hidden_states") which misses HF dumps that use
"hidden_states"; update the classmethod from_offline_dict to populate
out_hiddens from d.get("hidden_states") with a fallback to
d.get("base_model_hidden_states") for backward compatibility, and ensure the
same pattern (input_embeds/logits/aux_hiddens) remains unchanged so consumers of
the container receive the expected final activations.

In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 20-29: The deprecation warning only directs users to
modelopt.torch.speculative.plugins.hf_eagle while this shim still re-exports
HFMedusaModel from hf_medusa; update the warning text in transformers.py to tell
callers to update imports for both hf_eagle (for HF EAGLE symbols) and hf_medusa
(for HFMedusaModel) so Medusa users are not misdirected, and keep the existing
re-export lines (from .hf_eagle import * and from .hf_medusa import *)
unchanged.

---

Outside diff comments:
In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py`:
- Around line 128-136: Replace the hard assert in keep_conversation with logic
that skips rows missing an ID: check conversation_id =
entry.get("conversation_id", entry.get("uuid")); if it's None, increment a skip
counter (e.g., initialize skipped_missing_id = 0 above and use nonlocal or a
mutable container to update it) and optionally log a warning, then return False;
otherwise compute output_file = args.output_dir / f"{conversation_id}.pt" and
return not output_file.exists(); leave the dataset =
dataset.filter(keep_conversation) call as-is so malformed samples are skipped
instead of aborting.

---

Nitpick comments:
In `@modelopt/torch/speculative/plugins/hf_eagle.py`:
- Around line 624-630: Replace the runtime assertion with explicit error
handling: in the eagle_offline branch where the code currently does `assert
"base_model_outputs" in kwargs`, raise a ValueError (e.g., `raise
ValueError("missing required kwarg 'base_model_outputs' for Eagle offline
mode")`) so the check is not skipped under -O and provides a clear message; keep
the subsequent logic that calls
EagleBaseModelOutput.from_offline_dict(kwargs["base_model_outputs"]) and the
fallback to computing logits via
self._base_model_lm_head(base_outputs.out_hiddens) unchanged.
- Around line 241-251: The docstring for _preservation_loss currently claims it
computes KL(ref || lora) but the implementation computes cross-entropy H(ref,
lora) = -sum(softmax(ref) * log_softmax(lora)); update the docstring to
accurately describe the computed quantity (cross-entropy weighted by
eagle_base_lora_preservation_loss_weight) or, if you truly want KL, modify the
implementation to subtract the entropy term H(ref) (i.e., compute H(ref, lora) -
H(ref)) using ref_logits.detach() so only the LoRA params are optimized; refer
to the _preservation_loss function, ref_logits, lora_logits, and
eagle_base_lora_preservation_loss_weight when making the change.
- Around line 49-54: The function default_eagle_aux_layer_ids always includes
index 1 which can be out of range for small models; update
default_eagle_aux_layer_ids to build the candidate set (1, max(0, num_layers//2
- 1), max(0, num_layers - 4)), then filter out any indices not in the valid
range 0 <= idx < num_layers before returning the sorted list so you never return
non-existent layer IDs for small num_layers.
- Around line 844-855: The get_ground_truth implementation uses
copy.deepcopy(input_ids) which is inefficient for tensors; replace that with
input_ids.clone() and then move it to the CUDA device (e.g.,
input_ids.clone().to(torch.cuda.current_device())) so
HFARValidation.get_ground_truth works identically but uses the idiomatic,
lighter-weight tensor copy for the input_ids variable.
- Around line 117-131: The code currently sets the global
torch._dynamo.config.suppress_errors in _activate_torch_compile which affects
all torch.compile calls; change this to be non-global by making the behavior
configurable on the instance (e.g., self.suppress_dynamo_errors) and by
temporarily setting suppress_errors only around your compile loop (save original
= torch._dynamo.config.suppress_errors, set it to self.suppress_dynamo_errors,
run the compile_targets loop invoking torch.compile on getattr(self, name), then
restore the original value in a finally block); alternatively document the side
effect and gate the global change behind the new attribute so callers can opt
into the global suppression.

In `@tests/unit/torch/export/test_hf_spec_rope_export.py`:
- Around line 37-46: Replace the loose MagicMock used for model with a stricter
test double so missing/typo'd attributes fail loudly: instead of model =
MagicMock(), create either a SimpleNamespace populated with the exact attributes
used (e.g., eagle_config, eagle_export_rope_scaling, _draft_model_config,
config) or use unittest.mock.Mock(spec_set=YourModelConfigSpec) configured with
only the needed fields; update the test setup around model.eagle_config.*,
model.eagle_export_rope_scaling, model._draft_model_config, and model.config.*
to remove the manual resets (model.eagle_config.rope_theta = None etc.) since
the stricter double will accurately reflect missing attributes and catch
incorrect attribute access in the exporter.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 289e65a1-7ac4-49fa-ade8-1285057efe75

📥 Commits

Reviewing files that changed from the base of the PR and between 7c80d85 and 855ab0a.

📒 Files selected for processing (25)
  • .pre-commit-config.yaml
  • CHANGELOG.rst
  • examples/speculative_decoding/collect_hidden_states/common.py
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/main.py
  • examples/speculative_decoding/scripts/ar_validate.py
  • modelopt/torch/export/plugins/hf_spec_export.py
  • modelopt/torch/speculative/config.py
  • modelopt/torch/speculative/dflash/dflash_model.py
  • modelopt/torch/speculative/eagle/default_config.py
  • modelopt/torch/speculative/eagle/utils.py
  • modelopt/torch/speculative/plugins/__init__.py
  • modelopt/torch/speculative/plugins/hf_dflash.py
  • modelopt/torch/speculative/plugins/hf_eagle.py
  • modelopt/torch/speculative/plugins/hf_medusa.py
  • modelopt/torch/speculative/plugins/modeling_dflash.py
  • modelopt/torch/speculative/plugins/modeling_eagle.py
  • modelopt/torch/speculative/plugins/transformers.py
  • modelopt/torch/speculative/utils.py
  • tests/gpu/torch/speculative/plugins/test_hf_dflash.py
  • tests/regression/torch/speculative/test_dflash_offline.py
  • tests/unit/torch/export/test_hf_spec_rope_export.py
  • tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py

Comment thread CHANGELOG.rst
- [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution.
- Add performant layerwise calibration for large models that don't fit on GPU (e.g. DeepSeek-R1, Kimi-K2). See `modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml>`_ for usage. Layerwise calibration also supports PTQ with intermediate progress saving — useful when long PTQ runs get hit with Slurm timeouts. See `modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml>`_ for usage.
- Add implicit GEMM CUDA kernel for Conv3D with fused NVFP4 fake quantization (``modelopt.torch.kernels.quantization.conv``). When NVFP4 quantization is applied to an ``nn.Conv3d`` layer via ModelOpt PTQ, the implicit GEMM path is used automatically instead of cuDNN. Uses BF16 WMMA tensor cores (SM80+) with FP32 accumulation and in-kernel FP4 (E2M1) activation quantization. Grouped convolution (``groups > 1``) falls back to the default cuDNN path. Inference only — training mode falls back to cuDNN with a warning.
- Add implicit GEMM CUDA kernel for Conv3D with fused NVFP4 fake quantization (``modelopt.torch.quantization.src.conv``). When NVFP4 quantization is applied to an ``nn.Conv3d`` layer via ModelOpt PTQ, the implicit GEMM path is used automatically instead of cuDNN. Uses BF16 WMMA tensor cores (SM80+) with FP32 accumulation and in-kernel FP4 (E2M1) activation quantization. Grouped convolution (``groups > 1``) falls back to the default cuDNN path. Inference only — training mode falls back to cuDNN with a warning.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Update this module path to the post-reorg location.

The same 0.45 section already says the Conv3D kernel moved under modelopt.torch.kernels.quantization.conv, so this entry now points readers at the removed path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@CHANGELOG.rst` at line 38, Update the module path in the changelog entry so
it points to the new post-reorg location: replace the old path string
"modelopt.torch.quantization.src.conv" with
"modelopt.torch.kernels.quantization.conv" in the Conv3D kernel description;
ensure the sentence still references nn.Conv3d and ModelOpt PTQ and that grouped
convolution and inference/training notes remain unchanged.

Comment on lines +26 to 27
from common import add_aux_layers_args, resolve_aux_layers
from datasets import load_dataset
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Wire the TRTLLM dumper into the new answer-only-loss flow too.

This script picked up the shared aux-layer helpers, but it still never accepts or persists a loss_mask. That leaves TRTLLM-generated dumps incompatible with the new offline answer_only_loss=True training path, because OfflineSupervisedDataset now hard-fails when that field is absent.

Also applies to: 126-126

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py`
around lines 26 - 27, The TRTLLM dumper must accept and persist the loss_mask so
generated dumps are compatible with
OfflineSupervisedDataset(answer_only_loss=True); update
compute_hidden_states_trtllm.py to parse/propagate the loss_mask from the
dataset or aux-layer helpers (use add_aux_layers_args and resolve_aux_layers
where aux fields are collected) and include a loss_mask field in the saved dump
payload produced by the TRTLLM dumper (the code path around the dump/save call
near the dumper or dump function at the end of the script). Ensure the emitted
dump keys match what OfflineSupervisedDataset expects (loss_mask present when
answer_only_loss=True) to avoid hard failures.

"write_interval": 1,
"file_prefix": f"dp_{args.dp_rank}",
"eagle3_layers_to_capture": {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4},
"eagle3_layers_to_capture": set(resolve_aux_layers(args, num_hidden_layers)),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate num_hidden_layers before resolving aux layers.

resolve_aux_layers() assumes a concrete layer count. If AutoConfig does not expose num_hidden_layers, this path now fails later with a much less helpful error than the HF dumper, which already added an explicit guard.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py`
at line 199, The call building "eagle3_layers_to_capture" passes
set(resolve_aux_layers(args, num_hidden_layers)) but num_hidden_layers may be
missing from the model config; validate that AutoConfig exposes a concrete
integer num_hidden_layers before calling resolve_aux_layers. In practice, obtain
the config (AutoConfig), check getattr(config, "num_hidden_layers", None) is an
int > 0 (or raise a clear ValueError/RuntimeError), and only then call
resolve_aux_layers(args, num_hidden_layers); include a helpful error message
referencing the model/config when failing so the failure matches the HF dumper
guard.

Comment thread modelopt/torch/speculative/plugins/__init__.py
Comment thread modelopt/torch/speculative/plugins/hf_dflash.py
Comment thread modelopt/torch/speculative/plugins/hf_eagle.py
Comment thread modelopt/torch/speculative/plugins/hf_medusa.py
Comment thread modelopt/torch/speculative/plugins/hf_medusa.py
Comment thread modelopt/torch/speculative/plugins/modeling_eagle.py
Comment thread modelopt/torch/speculative/plugins/transformers.py
@h-guo18 h-guo18 requested a review from ChenhanYu April 24, 2026 23:44
h-guo18 added 9 commits April 24, 2026 23:46
- Add `dflash_offline` config flag for training from pre-computed hidden states;
  deletes base model layers to save memory.
- Move `dflash_mask_token_id` auto-detection from `main.py` into `DFlashConfig`
  Pydantic validators; derive `dflash_offline` from `data_args.offline_data_path`.
- Add `DFlashBaseModelOutput.from_offline_dict` classmethod for consuming
  pre-computed hidden states in the forward path.

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@h-guo18 h-guo18 force-pushed the haoguo/dflash-offline branch from 855ab0a to 2cf1784 Compare April 24, 2026 23:46
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 24, 2026

PR Preview Action v1.8.1
Preview removed because the pull request was closed.
2026-04-26 01:38 UTC

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
CHANGELOG.rst (1)

38-39: ⚠️ Potential issue | 🟠 Major

Update Conv3D implicit GEMM kernel path to match the reorg migration table.

This changelog entry references modelopt.torch.quantization.src.conv, but the migration/backward-breaking guidance in the same file indicates the kernel subpackages moved under modelopt.torch.kernels... (and specifically modelopt.torch.kernels.quantization.conv for the old quantization/src/conv path).

Please update this entry’s module path to the post-reorg path to avoid misleading users.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@CHANGELOG.rst` around lines 38 - 39, Update the changelog entry text to
reference the new post-reorg module path: replace the old import path
`modelopt.torch.quantization.src.conv` with
`modelopt.torch.kernels.quantization.conv` in the Conv3D implicit GEMM paragraph
so the entry matches the migration/backward-breaking guidance and avoids
misleading users about the package location.
modelopt/torch/speculative/plugins/hf_eagle.py (1)

344-350: ⚠️ Potential issue | 🟠 Major

Cache flex attention masks using a shape-dependent key (include seq_length).

_get_ttt_attention_mask() caches by ttt_step only, but _compute_ttt_attention_mask(batch_size, seq_length, ttt_step) depends on seq_length (mask dimensions change). Reusing a cached mask across different sequence lengths can produce incorrect masks or shape/runtime errors.

This matches a previously reported concern; please verify and fix the keying now (e.g., (seq_length, ttt_step) or include any other shape params that affect mask construction).

Please confirm how _compute_ttt_attention_mask() is called across the codebase (especially whether seq_length varies while ttt_step repeats). If it does, update _cached_attn_blk_masks accordingly and add/adjust a unit test that calls the function with two seq_lens.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_eagle.py` around lines 344 - 350, The
cache key for flex attention masks is currently only ttt_step causing masks
computed by _compute_ttt_attention_mask(batch_size, seq_length, ttt_step) to be
reused across different seq_length values; update _get_ttt_attention_mask to key
_cached_attn_blk_masks by (seq_length, ttt_step) (or another tuple including all
shape-dependent params such as batch_size if relevant), and store/retrieve masks
using that composite key instead of just ttt_step; then add/adjust a unit test
that calls _get_ttt_attention_mask (or directly _compute_ttt_attention_mask)
with the same ttt_step but two different seq_length values to verify distinct
masks are produced and cached separately.
modelopt/torch/speculative/plugins/hf_dflash.py (1)

446-456: ⚠️ Potential issue | 🔴 Critical

Normalize offline logits-recompute hidden-state key (avoid brittle base_model_hidden_states).

The offline logits recompute block uses:
out_hiddens = kwargs["base_model_outputs"]["base_model_hidden_states"]

If the offline dump schema provides hidden_states (and the collator doesn’t rename), this will KeyError and break offline training with dflash_self_logit_distillation=True when base_model_logits is absent.

This matches an earlier reported issue; please fix by accepting both keys (or by ensuring DFlashBaseModelOutput.from_offline_dict() exposes the needed hidden states in a consistent attribute).

🔧 Proposed fix (accept both keys)
-                out_hiddens = kwargs["base_model_outputs"]["base_model_hidden_states"]
+                offline_outputs = kwargs["base_model_outputs"]
+                out_hiddens = offline_outputs.get("base_model_hidden_states")
+                if out_hiddens is None:
+                    # Fall back to dump-script key
+                    out_hiddens = offline_outputs.get("hidden_states")
+                if out_hiddens is None:
+                    raise KeyError(
+                        "Missing hidden states for offline logits recompute. Expected "
+                        "'base_model_hidden_states' or 'hidden_states' in base_model_outputs."
+                    )
                 base_outputs.logits = self._base_model_lm_head(out_hiddens)

Please verify the exact structure of kwargs["base_model_outputs"] produced by the offline dataloader/collator during DFlash offline training. If it always uses base_model_hidden_states, then this can be downgraded; otherwise keep the normalization.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 446 - 456, The
offline logits-recompute path currently assumes kwargs["base_model_outputs"]
contains "base_model_hidden_states" which is brittle; update the logic in the
dflash_offline branch (around DFlashBaseModelOutput.from_offline_dict, the
dflash_self_logit_distillation check, and the call to self._base_model_lm_head)
to tolerate either key name by first normalizing/locating hidden states: prefer
an attribute from DFlashBaseModelOutput (extend from_offline_dict to populate a
consistent attribute like .hidden_states or .base_model_hidden_states) or, if
leaving the dict-based approach, lookup hidden =
kwargs["base_model_outputs"].get("base_model_hidden_states") or
kwargs["base_model_outputs"].get("hidden_states") and use that for
self._base_model_lm_head; ensure base_outputs.logits is set when hidden states
are found and preserve base_outputs.target_hidden.
🧹 Nitpick comments (6)
modelopt/torch/speculative/config.py (1)

67-157: Clarify dflash_offline “not user-configurable” wording.

The description states dflash_offline is “not user-configurable”. However, _derive_dflash_offline only overrides the value when info.context contains data_args and data is a dict. If context isn’t provided, user-supplied dflash_offline can still take effect.

Recommend rewording description to something like “auto-derived when context includes data_args”.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/config.py` around lines 67 - 157, The description
for dflash_offline is misleading — it claims the field is "not
user-configurable" but _derive_dflash_offline only overrides when info.context
contains data_args and data is a dict, so a user value can persist when context
is absent; update the DFlashConfig.dflash_offline ModeloptField description to
state that the field is "auto-derived when context includes data_args (overrides
provided value)" or similar phrasing; ensure you edit the string in the
dflash_offline ModeloptField (and optionally update the docstring of
_derive_dflash_offline) to match this behavior.
examples/speculative_decoding/collect_hidden_states/common.py (2)

113-130: Make generation-tag validation stricter (exact Jinja tags).

verify_generation_tags() currently checks ("generation" in chat_template and "endgeneration" in chat_template). This can produce false positives if those substrings appear in comments/other content, while still missing the exact Jinja tags required by apply_chat_template(..., return_assistant_tokens_mask=True).

Consider validating the specific token patterns (e.g. "{% generation %}" and "{% endgeneration %}" or variants with whitespace control).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/collect_hidden_states/common.py` around lines
113 - 130, The current verify_generation_tags function should check for the
exact Jinja tag patterns rather than substrings: update verify_generation_tags
to look for specific tokens like "{% generation %}" and "{% endgeneration %}"
and also accept common whitespace-control variants (e.g. "{%- generation -%}",
"{%generation%}", "{%-generation-%}") so it only passes when the actual template
block markers required by apply_chat_template(...,
return_assistant_tokens_mask=True) are present; change the conditional that now
uses "generation" / "endgeneration" to a robust check for these exact tag
strings (or a small normalized/regex match covering the whitespace-control
variants) and keep the same ValueError message if validation fails.

132-171: Harden assistant mask shape handling in tokenize_with_loss_mask.

When answer_only_loss=True, you do:

  • mask = out["assistant_masks"]
  • optional torch.tensor(...)
  • loss_mask = mask.squeeze(0).to(torch.long)
  • then check loss_mask.shape[0] == seq_len

If assistant_masks is already 1D (or has unexpected rank), squeeze(0) can silently produce an incorrect shape or throw a confusing error.

Recommend:

  • assert loss_mask.ndim == 1 after squeeze/reshape (or use loss_mask = mask.reshape(-1) with explicit checks),
  • validate rank explicitly and raise a clear error message if shapes don’t match (seq_len,).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/speculative_decoding/collect_hidden_states/common.py` around lines
132 - 171, In tokenize_with_loss_mask, harden handling of out["assistant_masks"]
(used when answer_only_loss=True) by avoiding blind mask.squeeze(0): after
extracting mask (and converting to a torch.Tensor if needed), explicitly
normalize it to 1D (e.g., mask.reshape(-1) or mask.view(-1)) and then assert
loss_mask.ndim == 1 and loss_mask.shape[0] == seq_len; if the rank/length is
unexpected, raise a clear RuntimeError mentioning "assistant_masks" and both the
observed shape and the expected seq_len so the error is unambiguous (update
references around tokenize_with_loss_mask, mask, loss_mask, and
assistant_masks).
modelopt/torch/speculative/eagle/utils.py (1)

81-143: Clarify / standardize loss_mask dtype for downstream consumers.

When answer_only_loss=True, you set:
loss_mask = offline_data["loss_mask"].to(offline_data["input_ids"].dtype).

Many training pipelines (including DFlash) treat loss_mask as numeric weights and do comparisons like > 0.5 or multiplication with float tensors. Int masks work, but it’s easy to accidentally introduce dtype assumptions later.

Suggestion:

  • either keep loss_mask as torch.float32 (and ensure dump/script emits 0/1), or
  • add a short comment/docstring in this file stating the expected dtype contract (int vs float) for offline speculative decoding trainers.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/eagle/utils.py` around lines 81 - 143, The
loss_mask dtype is inconsistent: in OfflineSupervisedDataset.__getitem__ you
cast loss_mask to input_ids.dtype which can be integer; change this to a
consistent floating dtype for downstream trainers (e.g., torch.float32) and
ensure the else branch also returns float tensors; specifically, when loading
offline_data["loss_mask"] in __getitem__, cast it with .to(torch.float32) (and
make torch.ones_like(..., dtype=torch.float32) for the default) so consumers can
safely compare/multiply loss_mask as numeric weights.
tests/regression/torch/speculative/test_dflash_offline.py (1)

112-146: Consider making the loss assertion less brittle across runs.

assert final_loss < 3.0 may be sensitive to:

  • batch/seed differences,
  • dataset slice ordering,
  • minor model/kernel changes.

If there’s already precedent for flakiness in this suite, consider:

  • using a relative improvement threshold primarily (final_loss < first_loss) and relaxing the absolute ceiling, or
  • pinning seeds / deterministic flags in the training launcher for this regression.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/regression/torch/speculative/test_dflash_offline.py` around lines 112 -
146, The absolute loss ceiling in test_dflash_offline_training is brittle;
change the assertions to rely primarily on relative improvement (keep the
existing assert final_loss < first_loss) and replace the hard assert final_loss
< 3.0 with a relaxed, tolerant check such as a relative threshold (e.g., assert
final_loss < first_loss * 0.95) and/or a looser ceiling (e.g., final_loss < 3.5)
so runs with minor variance pass; additionally consider ensuring deterministic
seeds are passed to the training launcher (the run_example_command call /
launch_train.sh overrides) if you want reproducibility.
tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py (1)

41-95: Add a unit test for missing num_orig_hidden_layers in offline conversion.

Right now the offline conversion tests set model.config.num_orig_hidden_layers = NUM_BASE_LAYERS explicitly. If upstream ever forgets to seed it, the offline modify path may crash.

Recommend adding a test:

  • create tiny model without setting num_orig_hidden_layers,
  • call mtsp.convert(... offline=True ...),
  • assert either it auto-seeds (if you implement the fallback) or raises a clear ValueError with guidance.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py` around lines
41 - 95, Add a unit test in
tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py that verifies
offline conversion fails fast when model.config.num_orig_hidden_layers is not
set: create a tiny model via get_tiny_llama without setting
model.config.num_orig_hidden_layers, call mtsp.convert(model, [("dflash",
_get_dflash_config(offline=True))]) and assert it raises a ValueError (or a
clear exception) with a message guiding the user to set
config.num_orig_hidden_layers; reference mtsp.convert, get_tiny_llama,
model.config.num_orig_hidden_layers and _get_dflash_config to locate the
relevant logic to test.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/speculative/plugins/hf_eagle.py`:
- Around line 49-55: default_eagle_aux_layer_ids may return out-of-range indices
for small num_layers (e.g., returns 1 when num_layers==1); update the function
(default_eagle_aux_layer_ids) to generate the intended candidate indices (start,
middle, end), clamp each candidate into the valid 0..num_layers-1 range, filter
duplicates, and return them sorted — e.g., compute candidates = {1,
num_layers//2 - 1, num_layers - 4}, clamp each with min/max using num_layers-1
and 0, then remove any invalid entries for num_layers <= 0 before sorting and
returning.

---

Duplicate comments:
In `@CHANGELOG.rst`:
- Around line 38-39: Update the changelog entry text to reference the new
post-reorg module path: replace the old import path
`modelopt.torch.quantization.src.conv` with
`modelopt.torch.kernels.quantization.conv` in the Conv3D implicit GEMM paragraph
so the entry matches the migration/backward-breaking guidance and avoids
misleading users about the package location.

In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 446-456: The offline logits-recompute path currently assumes
kwargs["base_model_outputs"] contains "base_model_hidden_states" which is
brittle; update the logic in the dflash_offline branch (around
DFlashBaseModelOutput.from_offline_dict, the dflash_self_logit_distillation
check, and the call to self._base_model_lm_head) to tolerate either key name by
first normalizing/locating hidden states: prefer an attribute from
DFlashBaseModelOutput (extend from_offline_dict to populate a consistent
attribute like .hidden_states or .base_model_hidden_states) or, if leaving the
dict-based approach, lookup hidden =
kwargs["base_model_outputs"].get("base_model_hidden_states") or
kwargs["base_model_outputs"].get("hidden_states") and use that for
self._base_model_lm_head; ensure base_outputs.logits is set when hidden states
are found and preserve base_outputs.target_hidden.

In `@modelopt/torch/speculative/plugins/hf_eagle.py`:
- Around line 344-350: The cache key for flex attention masks is currently only
ttt_step causing masks computed by _compute_ttt_attention_mask(batch_size,
seq_length, ttt_step) to be reused across different seq_length values; update
_get_ttt_attention_mask to key _cached_attn_blk_masks by (seq_length, ttt_step)
(or another tuple including all shape-dependent params such as batch_size if
relevant), and store/retrieve masks using that composite key instead of just
ttt_step; then add/adjust a unit test that calls _get_ttt_attention_mask (or
directly _compute_ttt_attention_mask) with the same ttt_step but two different
seq_length values to verify distinct masks are produced and cached separately.

---

Nitpick comments:
In `@examples/speculative_decoding/collect_hidden_states/common.py`:
- Around line 113-130: The current verify_generation_tags function should check
for the exact Jinja tag patterns rather than substrings: update
verify_generation_tags to look for specific tokens like "{% generation %}" and
"{% endgeneration %}" and also accept common whitespace-control variants (e.g.
"{%- generation -%}", "{%generation%}", "{%-generation-%}") so it only passes
when the actual template block markers required by apply_chat_template(...,
return_assistant_tokens_mask=True) are present; change the conditional that now
uses "generation" / "endgeneration" to a robust check for these exact tag
strings (or a small normalized/regex match covering the whitespace-control
variants) and keep the same ValueError message if validation fails.
- Around line 132-171: In tokenize_with_loss_mask, harden handling of
out["assistant_masks"] (used when answer_only_loss=True) by avoiding blind
mask.squeeze(0): after extracting mask (and converting to a torch.Tensor if
needed), explicitly normalize it to 1D (e.g., mask.reshape(-1) or mask.view(-1))
and then assert loss_mask.ndim == 1 and loss_mask.shape[0] == seq_len; if the
rank/length is unexpected, raise a clear RuntimeError mentioning
"assistant_masks" and both the observed shape and the expected seq_len so the
error is unambiguous (update references around tokenize_with_loss_mask, mask,
loss_mask, and assistant_masks).

In `@modelopt/torch/speculative/config.py`:
- Around line 67-157: The description for dflash_offline is misleading — it
claims the field is "not user-configurable" but _derive_dflash_offline only
overrides when info.context contains data_args and data is a dict, so a user
value can persist when context is absent; update the DFlashConfig.dflash_offline
ModeloptField description to state that the field is "auto-derived when context
includes data_args (overrides provided value)" or similar phrasing; ensure you
edit the string in the dflash_offline ModeloptField (and optionally update the
docstring of _derive_dflash_offline) to match this behavior.

In `@modelopt/torch/speculative/eagle/utils.py`:
- Around line 81-143: The loss_mask dtype is inconsistent: in
OfflineSupervisedDataset.__getitem__ you cast loss_mask to input_ids.dtype which
can be integer; change this to a consistent floating dtype for downstream
trainers (e.g., torch.float32) and ensure the else branch also returns float
tensors; specifically, when loading offline_data["loss_mask"] in __getitem__,
cast it with .to(torch.float32) (and make torch.ones_like(...,
dtype=torch.float32) for the default) so consumers can safely compare/multiply
loss_mask as numeric weights.

In `@tests/regression/torch/speculative/test_dflash_offline.py`:
- Around line 112-146: The absolute loss ceiling in test_dflash_offline_training
is brittle; change the assertions to rely primarily on relative improvement
(keep the existing assert final_loss < first_loss) and replace the hard assert
final_loss < 3.0 with a relaxed, tolerant check such as a relative threshold
(e.g., assert final_loss < first_loss * 0.95) and/or a looser ceiling (e.g.,
final_loss < 3.5) so runs with minor variance pass; additionally consider
ensuring deterministic seeds are passed to the training launcher (the
run_example_command call / launch_train.sh overrides) if you want
reproducibility.

In `@tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py`:
- Around line 41-95: Add a unit test in
tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py that verifies
offline conversion fails fast when model.config.num_orig_hidden_layers is not
set: create a tiny model via get_tiny_llama without setting
model.config.num_orig_hidden_layers, call mtsp.convert(model, [("dflash",
_get_dflash_config(offline=True))]) and assert it raises a ValueError (or a
clear exception) with a message guiding the user to set
config.num_orig_hidden_layers; reference mtsp.convert, get_tiny_llama,
model.config.num_orig_hidden_layers and _get_dflash_config to locate the
relevant logic to test.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 86ab1669-fed2-4c25-b2d1-4f9217ad52ae

📥 Commits

Reviewing files that changed from the base of the PR and between 855ab0a and 2cf1784.

📒 Files selected for processing (15)
  • CHANGELOG.rst
  • examples/speculative_decoding/collect_hidden_states/common.py
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
  • examples/speculative_decoding/eagle_utils.py
  • examples/speculative_decoding/main.py
  • modelopt/torch/speculative/config.py
  • modelopt/torch/speculative/dflash/dflash_model.py
  • modelopt/torch/speculative/eagle/utils.py
  • modelopt/torch/speculative/plugins/hf_dflash.py
  • modelopt/torch/speculative/plugins/hf_eagle.py
  • modelopt/torch/speculative/plugins/modeling_dflash.py
  • tests/gpu/torch/speculative/plugins/test_hf_dflash.py
  • tests/regression/torch/speculative/test_dflash_offline.py
  • tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py
✅ Files skipped from review due to trivial changes (1)
  • examples/speculative_decoding/main.py
🚧 Files skipped from review as they are similar to previous changes (6)
  • examples/speculative_decoding/eagle_utils.py
  • modelopt/torch/speculative/dflash/dflash_model.py
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
  • modelopt/torch/speculative/plugins/modeling_dflash.py
  • tests/gpu/torch/speculative/plugins/test_hf_dflash.py
  • examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py

Comment on lines +49 to +55
def default_eagle_aux_layer_ids(num_layers: int) -> list[int]:
"""Default EAGLE-3 auxiliary hidden-state layer IDs (0-based).

Three layers near the start, middle, and end of the stack.
"""
return sorted({1, max(0, num_layers // 2 - 1), max(0, num_layers - 4)})

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix potential out-of-range auxiliary layer IDs for small num_layers.

default_eagle_aux_layer_ids() returns sorted({1, ..., max(0, num_layers - 4)}). For num_layers=1, this becomes {0, 1}, which includes 1 even though valid 0-based layer IDs are only [0]. That can break hook registration / hidden-state collection for tiny configs.

🔧 Proposed fix (clamp to valid range)
 def default_eagle_aux_layer_ids(num_layers: int) -> list[int]:
@@
-    return sorted({1, max(0, num_layers // 2 - 1), max(0, num_layers - 4)})
+    if num_layers <= 0:
+        return []
+    # Always clamp into [0, num_layers-1]
+    start = min(1, num_layers - 1)
+    mid = max(0, num_layers // 2 - 1)
+    end = max(0, num_layers - 4)
+    end = min(end, num_layers - 1)
+    return sorted({start, mid, end})
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/speculative/plugins/hf_eagle.py` around lines 49 - 55,
default_eagle_aux_layer_ids may return out-of-range indices for small num_layers
(e.g., returns 1 when num_layers==1); update the function
(default_eagle_aux_layer_ids) to generate the intended candidate indices (start,
middle, end), clamp each candidate into the valid 0..num_layers-1 range, filter
duplicates, and return them sorted — e.g., compute candidates = {1,
num_layers//2 - 1, num_layers - 4}, clamp each with min/max using num_layers-1
and 0, then remove any invalid entries for num_layers <= 0 before sorting and
returning.

@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 24, 2026

Codecov Report

❌ Patch coverage is 98.50746% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 75.82%. Comparing base (7c80d85) to head (6c11e95).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/speculative/config.py 95.65% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1343      +/-   ##
==========================================
- Coverage   76.27%   75.82%   -0.45%     
==========================================
  Files         471      471              
  Lines       50323    50375      +52     
==========================================
- Hits        38383    38199     -184     
- Misses      11940    12176     +236     
Flag Coverage Δ
examples 41.60% <37.31%> (-0.16%) ⬇️
gpu 58.40% <77.61%> (-0.49%) ⬇️
regression 14.92% <86.56%> (+0.20%) ⬆️
unit 52.75% <70.14%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/regression/torch/speculative/test_dflash_offline.py`:
- Around line 124-126: The current check only asserts that some .pt files exist
(pt_files), which allows partial dumps to pass; replace this with an explicit
check that the number of dumped files equals the expected count: compute
expected_num_files from the test configuration (e.g., the training steps /
checkpoints variables used elsewhere in the test such as num_steps, max_steps,
checkpoint_interval or define expected_num_files explicitly for this test) and
assert len(pt_files) == expected_num_files, raising a clear error that includes
dump_dir and the sorted pt_files list when it does not match.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: e9cc0271-d27d-495b-9817-d5f769e28e3b

📥 Commits

Reviewing files that changed from the base of the PR and between 2cf1784 and 063eeb5.

📒 Files selected for processing (1)
  • tests/regression/torch/speculative/test_dflash_offline.py

Comment on lines +124 to +126
pt_files = list(dump_dir.rglob("*.pt"))
assert pt_files, f"No .pt files dumped under {dump_dir}"
return dump_dir
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Assert full hidden-state dump completion, not just non-empty output.

assert pt_files can still pass after a partial dump, which can make this regression falsely green with a much smaller training set.

Suggested change
-    pt_files = list(dump_dir.rglob("*.pt"))
-    assert pt_files, f"No .pt files dumped under {dump_dir}"
+    pt_files = list(dump_dir.rglob("*.pt"))
+    assert len(pt_files) == _DUMP_NUM_CONVERSATIONS, (
+        f"Expected {_DUMP_NUM_CONVERSATIONS} .pt files under {dump_dir}, got {len(pt_files)}"
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/regression/torch/speculative/test_dflash_offline.py` around lines 124 -
126, The current check only asserts that some .pt files exist (pt_files), which
allows partial dumps to pass; replace this with an explicit check that the
number of dumped files equals the expected count: compute expected_num_files
from the test configuration (e.g., the training steps / checkpoints variables
used elsewhere in the test such as num_steps, max_steps, checkpoint_interval or
define expected_num_files explicitly for this test) and assert len(pt_files) ==
expected_num_files, raising a clear error that includes dump_dir and the sorted
pt_files list when it does not match.

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@ChenhanYu ChenhanYu merged commit 1ec931c into main Apr 26, 2026
47 checks passed
@ChenhanYu ChenhanYu deleted the haoguo/dflash-offline branch April 26, 2026 01:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants