[2/3][Feat]: Offline DFlash training#1343
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
📝 WalkthroughWalkthroughAdds offline DFlash speculative-decoding training: pre-computed base-model hidden-state dumps, shared CLI/tokenization utilities for aux-layer selection and answer-only loss, DFlashConfig derivation for offline mode, model conversion/forward changes to accept offline inputs and free base-model memory, and tests plus changelog/docs updates. Changes
Sequence Diagram(s)sequenceDiagram
participant CLI as User/Script
participant Config as DFlashConfig
participant Dump as Hidden-State Dump
participant Dataset as OfflineSupervisedDataset
participant HFDFlash as HFDFlashModel
participant Loss as Loss Computation
Note over CLI,HFDFlash: Online training (existing)
CLI->>HFDFlash: forward(input_ids)
activate HFDFlash
HFDFlash->>HFDFlash: run base model -> hidden_states (no_grad)
HFDFlash->>HFDFlash: build target_hidden from hidden_states
HFDFlash->>Loss: compute loss (KD / reconstruction)
deactivate HFDFlash
Loss-->>CLI: loss
Note over CLI,Loss: Offline training (new)
CLI->>Config: model_validate(context={data_args, tokenizer})
activate Config
Config->>Config: derive dflash_offline from data_args.offline_data_path
Config-->>CLI: validated config (dflash_offline=True)
deactivate Config
CLI->>Dump: run compute_hidden_states_* (produce .pt with aux_hidden_states, loss_mask)
Dump-->>Dataset: store .pt files
CLI->>Dataset: load batch (aux_hidden_states, loss_mask)
CLI->>HFDFlash: forward(base_model_outputs=dict)
activate HFDFlash
HFDFlash->>HFDFlash: DFlashBaseModelOutput.from_offline_dict()
HFDFlash->>HFDFlash: extract target_hidden
HFDFlash->>Loss: compute loss using provided loss_mask
deactivate HFDFlash
Loss-->>CLI: loss
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 5 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 11
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py (1)
128-136:⚠️ Potential issue | 🟠 MajorSkip rows without IDs instead of asserting.
CI is already failing here with
AssertionError: conversation_id is required. One malformed sample currently aborts the entire dump before you reach the later invalid-row handling, so this needs to degrade to a skip/count path instead of a hard assert.Suggested change
def keep_conversation(entry): conversation_id = entry.get("conversation_id", entry.get("uuid", None)) - assert conversation_id is not None, "conversation_id is required" + if conversation_id is None: + return False output_file = args.output_dir / f"{conversation_id}.pt" return not output_file.exists()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py` around lines 128 - 136, Replace the hard assert in keep_conversation with logic that skips rows missing an ID: check conversation_id = entry.get("conversation_id", entry.get("uuid")); if it's None, increment a skip counter (e.g., initialize skipped_missing_id = 0 above and use nonlocal or a mutable container to update it) and optionally log a warning, then return False; otherwise compute output_file = args.output_dir / f"{conversation_id}.pt" and return not output_file.exists(); leave the dataset = dataset.filter(keep_conversation) call as-is so malformed samples are skipped instead of aborting.
🧹 Nitpick comments (6)
tests/unit/torch/export/test_hf_spec_rope_export.py (1)
37-46: Prefer a real namespace or a spec’d mock here.
MagicMockis the reason Lines 40-42 need a manual reset, and it can still mask future typos in_export_configby auto-creating missing attributes. ASimpleNamespaceorMock(spec_set=...)will make these rope-resolution tests fail loudly when the exporter asks for the wrong field.♻️ Tighten the test double
-from unittest.mock import MagicMock +from types import SimpleNamespace ... - model = MagicMock() - model.eagle_config.eagle_decoder_type = "llama" - model.eagle_config.rope_scaling = {"rope_type": rope_type, "rope_theta": rope_theta} - # rope_theta lives inside rope_scaling in transformers 5.x; clear the top-level attr - # so the fallback path is exercised instead of MagicMock's auto-attr. - model.eagle_config.rope_theta = None - model.eagle_export_rope_scaling = eagle_export_rope_scaling - model._draft_model_config = None - model.config.rope_scaling = None - model.config.rope_theta = None + model = SimpleNamespace( + eagle_config=SimpleNamespace( + eagle_decoder_type="llama", + rope_scaling={"rope_type": rope_type, "rope_theta": rope_theta}, + rope_theta=None, + ), + eagle_export_rope_scaling=eagle_export_rope_scaling, + _draft_model_config=None, + config=SimpleNamespace(rope_scaling=None, rope_theta=None), + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/torch/export/test_hf_spec_rope_export.py` around lines 37 - 46, Replace the loose MagicMock used for model with a stricter test double so missing/typo'd attributes fail loudly: instead of model = MagicMock(), create either a SimpleNamespace populated with the exact attributes used (e.g., eagle_config, eagle_export_rope_scaling, _draft_model_config, config) or use unittest.mock.Mock(spec_set=YourModelConfigSpec) configured with only the needed fields; update the test setup around model.eagle_config.*, model.eagle_export_rope_scaling, model._draft_model_config, and model.config.* to remove the manual resets (model.eagle_config.rope_theta = None etc.) since the stricter double will accurately reflect missing attributes and catch incorrect attribute access in the exporter.modelopt/torch/speculative/plugins/hf_eagle.py (5)
624-630: Consider usingValueErrorinstead ofassertfor clearer error messages.Line 626 uses
assertfor validating required kwargs, which raisesAssertionErrorwith minimal context. In production with assertions disabled (-Oflag), this check would be skipped entirely.💡 More explicit error handling
if self.eagle_offline: # Parse base model outputs forwarded from teacher - assert "base_model_outputs" in kwargs + if "base_model_outputs" not in kwargs: + raise ValueError("eagle_offline=True requires 'base_model_outputs' in forward kwargs") base_outputs = EagleBaseModelOutput.from_offline_dict(kwargs["base_model_outputs"])🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_eagle.py` around lines 624 - 630, Replace the runtime assertion with explicit error handling: in the eagle_offline branch where the code currently does `assert "base_model_outputs" in kwargs`, raise a ValueError (e.g., `raise ValueError("missing required kwarg 'base_model_outputs' for Eagle offline mode")`) so the check is not skipped under -O and provides a clear message; keep the subsequent logic that calls EagleBaseModelOutput.from_offline_dict(kwargs["base_model_outputs"]) and the fallback to computing logits via self._base_model_lm_head(base_outputs.out_hiddens) unchanged.
241-251: Docstring describes KL divergence but implementation computes cross-entropy.The implementation computes
-sum(softmax(ref) * log_softmax(lora)), which is cross-entropy H(p, q), not KL(p || q) = H(p, q) - H(p). For optimization purposes they're equivalent (H(p) is constant w.r.t. LoRA params), but the docstring could be more precise.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_eagle.py` around lines 241 - 251, The docstring for _preservation_loss currently claims it computes KL(ref || lora) but the implementation computes cross-entropy H(ref, lora) = -sum(softmax(ref) * log_softmax(lora)); update the docstring to accurately describe the computed quantity (cross-entropy weighted by eagle_base_lora_preservation_loss_weight) or, if you truly want KL, modify the implementation to subtract the entropy term H(ref) (i.e., compute H(ref, lora) - H(ref)) using ref_logits.detach() so only the LoRA params are optimized; refer to the _preservation_loss function, ref_logits, lora_logits, and eagle_base_lora_preservation_loss_weight when making the change.
49-54: Edge case: Index 1 is always included regardless ofnum_layers.For models with
num_layers < 2, this returns indices that don't exist (e.g.,num_layers=1yields{0, 1}). While unlikely in practice for EAGLE-3 use cases, consider adding a guard.💡 Optional defensive check
def default_eagle_aux_layer_ids(num_layers: int) -> list[int]: """Default EAGLE-3 auxiliary hidden-state layer IDs (0-based). Three layers near the start, middle, and end of the stack. """ + if num_layers < 2: + return list(range(num_layers)) return sorted({1, max(0, num_layers // 2 - 1), max(0, num_layers - 4)})🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_eagle.py` around lines 49 - 54, The function default_eagle_aux_layer_ids always includes index 1 which can be out of range for small models; update default_eagle_aux_layer_ids to build the candidate set (1, max(0, num_layers//2 - 1), max(0, num_layers - 4)), then filter out any indices not in the valid range 0 <= idx < num_layers before returning the sorted list so you never return non-existent layer IDs for small num_layers.
844-855: Prefertensor.clone()overcopy.deepcopy()for tensors.
copy.deepcopyworks but is heavier than necessary for tensors.input_ids.clone()is the idiomatic and more efficient approach.💡 Simpler tensor copy
def get_ground_truth(self, input_ids, osl): """This function returns ground truth output tokens from the base model.""" - input_ids = copy.deepcopy(input_ids).to(torch.cuda.current_device()) + input_ids = input_ids.clone().to(torch.cuda.current_device()) for _ in range(osl):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_eagle.py` around lines 844 - 855, The get_ground_truth implementation uses copy.deepcopy(input_ids) which is inefficient for tensors; replace that with input_ids.clone() and then move it to the CUDA device (e.g., input_ids.clone().to(torch.cuda.current_device())) so HFARValidation.get_ground_truth works identically but uses the idiomatic, lighter-weight tensor copy for the input_ids variable.
117-131: Globaltorch._dynamo.config.suppress_errorsmodification affects all compiled code.Setting
suppress_errors = Trueglobally will mask compilation errors in alltorch.compilecalls in the process, not just this model. This could hide issues in other code paths.Consider making this configurable via a model attribute or documenting the side effect:
def _activate_torch_compile(self): import torch._dynamo - torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode + # NOTE: This globally affects all torch.compile calls in the process + if getattr(self, "eagle_dynamo_suppress_errors", True): + torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_eagle.py` around lines 117 - 131, The code currently sets the global torch._dynamo.config.suppress_errors in _activate_torch_compile which affects all torch.compile calls; change this to be non-global by making the behavior configurable on the instance (e.g., self.suppress_dynamo_errors) and by temporarily setting suppress_errors only around your compile loop (save original = torch._dynamo.config.suppress_errors, set it to self.suppress_dynamo_errors, run the compile_targets loop invoking torch.compile on getattr(self, name), then restore the original value in a finally block); alternatively document the side effect and gate the global change behind the new attribute so callers can opt into the global suppression.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@CHANGELOG.rst`:
- Line 38: Update the module path in the changelog entry so it points to the new
post-reorg location: replace the old path string
"modelopt.torch.quantization.src.conv" with
"modelopt.torch.kernels.quantization.conv" in the Conv3D kernel description;
ensure the sentence still references nn.Conv3d and ModelOpt PTQ and that grouped
convolution and inference/training notes remain unchanged.
In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py`:
- Line 199: The call building "eagle3_layers_to_capture" passes
set(resolve_aux_layers(args, num_hidden_layers)) but num_hidden_layers may be
missing from the model config; validate that AutoConfig exposes a concrete
integer num_hidden_layers before calling resolve_aux_layers. In practice, obtain
the config (AutoConfig), check getattr(config, "num_hidden_layers", None) is an
int > 0 (or raise a clear ValueError/RuntimeError), and only then call
resolve_aux_layers(args, num_hidden_layers); include a helpful error message
referencing the model/config when failing so the failure matches the HF dumper
guard.
- Around line 26-27: The TRTLLM dumper must accept and persist the loss_mask so
generated dumps are compatible with
OfflineSupervisedDataset(answer_only_loss=True); update
compute_hidden_states_trtllm.py to parse/propagate the loss_mask from the
dataset or aux-layer helpers (use add_aux_layers_args and resolve_aux_layers
where aux fields are collected) and include a loss_mask field in the saved dump
payload produced by the TRTLLM dumper (the code path around the dump/save call
near the dumper or dump function at the end of the script). Ensure the emitted
dump keys match what OfflineSupervisedDataset expects (loss_mask present when
answer_only_loss=True) to avoid hard failures.
In `@modelopt/torch/speculative/plugins/__init__.py`:
- Around line 32-35: Wrap each HF plugin import in its own guarded
import_plugin("transformers") context (or try/except) so a failure importing
hf_dflash does not stop hf_eagle and hf_medusa from being attempted;
specifically, change the block that currently does with
import_plugin("transformers"): from .hf_dflash import * from .hf_eagle import *
from .hf_medusa import * to three independent guarded imports for the modules
hf_dflash, hf_eagle, and hf_medusa (using import_plugin("transformers")
separately or try/except around each import) and ensure any import exception is
caught and logged rather than letting it abort the whole block.
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 447-455: When running offline dflash with
dflash_self_logit_distillation=True the code looks for
kwargs["base_model_outputs"]["base_model_hidden_states"] but the dumper stores
the last layer under "hidden_states"; update the offline-path in the block
guarded by self.dflash_offline (the callsite using
DFlashBaseModelOutput.from_offline_dict and the variable base_outputs) to
prefer/normalize the dumper schema: read "hidden_states" if
"base_model_hidden_states" is missing (or adjust
DFlashBaseModelOutput.from_offline_dict to expose base_model_hidden_states from
hidden_states), then pass those hiddens into self._base_model_lm_head to compute
logits so logits are correctly recomputed offline.
- Around line 163-167: When computing num_target_layers in the dflash
conversion, seed and persist the original layer count before offline mode
removes layers: if self.dflash_offline is true and base_config lacks
num_orig_hidden_layers, capture base_config.num_hidden_layers into
base_config.num_orig_hidden_layers (or into a persistent attribute on self) on
first entry, then use base_config.num_orig_hidden_layers for num_target_layers
thereafter; update the logic around num_target_layers, self.dflash_offline, and
mtsp.convert so the original count is available after layer deletion and
repeated offline calls succeed.
In `@modelopt/torch/speculative/plugins/hf_eagle.py`:
- Around line 344-350: The cache key in _get_ttt_attention_mask currently uses
only ttt_step but _compute_ttt_attention_mask depends on seq_length; change the
key in self._cached_attn_blk_masks to include seq_length (e.g., use (ttt_step,
seq_length)) when storing and retrieving the mask so you won't return an
incorrectly-sized mask for different sequence lengths; update both the
.update/store call and the return lookup to use that composite key (referencing
_get_ttt_attention_mask, _compute_ttt_attention_mask, and
_cached_attn_blk_masks).
In `@modelopt/torch/speculative/plugins/hf_medusa.py`:
- Around line 112-123: The model forward call in hf_medusa.py incorrectly passes
rcache_position as a kwarg to self.model; change that argument name to
cache_position (i.e., replace rcache_position=cache_position with
cache_position=cache_position) inside the forward/inference block where
self.model(...) is invoked so it matches Transformers 4.56.0 causal LM forward
signatures and avoids kwarg validation failures in the Medusa plugin.
- Around line 126-145: The base-model loss currently compares logits and labels
at the same positions; update the computation in the block guarded by
freeze_base_model so logits and labels are causally shifted to match HF LM
convention and the Medusa loss: compute shift_logits = logits[..., :-1, :] and
shift_labels = labels[..., 1:], and if logits_to_keep is used ensure the same
slice is applied to labels before shifting; then pass shift_logits.view(-1,
shift_logits.shape[-1]) and shift_labels.view(-1) into loss_fct
(CrossEntropyLoss) to compute base_model_loss consistently with medusa_heads and
medusa loss behavior.
In `@modelopt/torch/speculative/plugins/modeling_eagle.py`:
- Around line 206-215: from_offline_dict currently reads out_hiddens from
d.get("base_model_hidden_states") which misses HF dumps that use
"hidden_states"; update the classmethod from_offline_dict to populate
out_hiddens from d.get("hidden_states") with a fallback to
d.get("base_model_hidden_states") for backward compatibility, and ensure the
same pattern (input_embeds/logits/aux_hiddens) remains unchanged so consumers of
the container receive the expected final activations.
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 20-29: The deprecation warning only directs users to
modelopt.torch.speculative.plugins.hf_eagle while this shim still re-exports
HFMedusaModel from hf_medusa; update the warning text in transformers.py to tell
callers to update imports for both hf_eagle (for HF EAGLE symbols) and hf_medusa
(for HFMedusaModel) so Medusa users are not misdirected, and keep the existing
re-export lines (from .hf_eagle import * and from .hf_medusa import *)
unchanged.
---
Outside diff comments:
In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py`:
- Around line 128-136: Replace the hard assert in keep_conversation with logic
that skips rows missing an ID: check conversation_id =
entry.get("conversation_id", entry.get("uuid")); if it's None, increment a skip
counter (e.g., initialize skipped_missing_id = 0 above and use nonlocal or a
mutable container to update it) and optionally log a warning, then return False;
otherwise compute output_file = args.output_dir / f"{conversation_id}.pt" and
return not output_file.exists(); leave the dataset =
dataset.filter(keep_conversation) call as-is so malformed samples are skipped
instead of aborting.
---
Nitpick comments:
In `@modelopt/torch/speculative/plugins/hf_eagle.py`:
- Around line 624-630: Replace the runtime assertion with explicit error
handling: in the eagle_offline branch where the code currently does `assert
"base_model_outputs" in kwargs`, raise a ValueError (e.g., `raise
ValueError("missing required kwarg 'base_model_outputs' for Eagle offline
mode")`) so the check is not skipped under -O and provides a clear message; keep
the subsequent logic that calls
EagleBaseModelOutput.from_offline_dict(kwargs["base_model_outputs"]) and the
fallback to computing logits via
self._base_model_lm_head(base_outputs.out_hiddens) unchanged.
- Around line 241-251: The docstring for _preservation_loss currently claims it
computes KL(ref || lora) but the implementation computes cross-entropy H(ref,
lora) = -sum(softmax(ref) * log_softmax(lora)); update the docstring to
accurately describe the computed quantity (cross-entropy weighted by
eagle_base_lora_preservation_loss_weight) or, if you truly want KL, modify the
implementation to subtract the entropy term H(ref) (i.e., compute H(ref, lora) -
H(ref)) using ref_logits.detach() so only the LoRA params are optimized; refer
to the _preservation_loss function, ref_logits, lora_logits, and
eagle_base_lora_preservation_loss_weight when making the change.
- Around line 49-54: The function default_eagle_aux_layer_ids always includes
index 1 which can be out of range for small models; update
default_eagle_aux_layer_ids to build the candidate set (1, max(0, num_layers//2
- 1), max(0, num_layers - 4)), then filter out any indices not in the valid
range 0 <= idx < num_layers before returning the sorted list so you never return
non-existent layer IDs for small num_layers.
- Around line 844-855: The get_ground_truth implementation uses
copy.deepcopy(input_ids) which is inefficient for tensors; replace that with
input_ids.clone() and then move it to the CUDA device (e.g.,
input_ids.clone().to(torch.cuda.current_device())) so
HFARValidation.get_ground_truth works identically but uses the idiomatic,
lighter-weight tensor copy for the input_ids variable.
- Around line 117-131: The code currently sets the global
torch._dynamo.config.suppress_errors in _activate_torch_compile which affects
all torch.compile calls; change this to be non-global by making the behavior
configurable on the instance (e.g., self.suppress_dynamo_errors) and by
temporarily setting suppress_errors only around your compile loop (save original
= torch._dynamo.config.suppress_errors, set it to self.suppress_dynamo_errors,
run the compile_targets loop invoking torch.compile on getattr(self, name), then
restore the original value in a finally block); alternatively document the side
effect and gate the global change behind the new attribute so callers can opt
into the global suppression.
In `@tests/unit/torch/export/test_hf_spec_rope_export.py`:
- Around line 37-46: Replace the loose MagicMock used for model with a stricter
test double so missing/typo'd attributes fail loudly: instead of model =
MagicMock(), create either a SimpleNamespace populated with the exact attributes
used (e.g., eagle_config, eagle_export_rope_scaling, _draft_model_config,
config) or use unittest.mock.Mock(spec_set=YourModelConfigSpec) configured with
only the needed fields; update the test setup around model.eagle_config.*,
model.eagle_export_rope_scaling, model._draft_model_config, and model.config.*
to remove the manual resets (model.eagle_config.rope_theta = None etc.) since
the stricter double will accurately reflect missing attributes and catch
incorrect attribute access in the exporter.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 289e65a1-7ac4-49fa-ade8-1285057efe75
📒 Files selected for processing (25)
.pre-commit-config.yamlCHANGELOG.rstexamples/speculative_decoding/collect_hidden_states/common.pyexamples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.pyexamples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.pyexamples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/main.pyexamples/speculative_decoding/scripts/ar_validate.pymodelopt/torch/export/plugins/hf_spec_export.pymodelopt/torch/speculative/config.pymodelopt/torch/speculative/dflash/dflash_model.pymodelopt/torch/speculative/eagle/default_config.pymodelopt/torch/speculative/eagle/utils.pymodelopt/torch/speculative/plugins/__init__.pymodelopt/torch/speculative/plugins/hf_dflash.pymodelopt/torch/speculative/plugins/hf_eagle.pymodelopt/torch/speculative/plugins/hf_medusa.pymodelopt/torch/speculative/plugins/modeling_dflash.pymodelopt/torch/speculative/plugins/modeling_eagle.pymodelopt/torch/speculative/plugins/transformers.pymodelopt/torch/speculative/utils.pytests/gpu/torch/speculative/plugins/test_hf_dflash.pytests/regression/torch/speculative/test_dflash_offline.pytests/unit/torch/export/test_hf_spec_rope_export.pytests/unit/torch/speculative/plugins/test_hf_dflash_offline.py
| - [Early Testing] Add Claude Code PTQ skill (``.claude/skills/ptq/``) for agent-assisted post-training quantization. The skill guides the agent through environment detection, model support checking, format selection, and execution via the launcher or manual SLURM/Docker/bare GPU paths. Includes handling for unlisted models with custom module patching. This feature is in early testing — use with caution. | ||
| - Add performant layerwise calibration for large models that don't fit on GPU (e.g. DeepSeek-R1, Kimi-K2). See `modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml>`_ for usage. Layerwise calibration also supports PTQ with intermediate progress saving — useful when long PTQ runs get hit with Slurm timeouts. See `modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml <https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt_recipes/general/ptq/nvfp4_default-none_kv_gptq.yaml>`_ for usage. | ||
| - Add implicit GEMM CUDA kernel for Conv3D with fused NVFP4 fake quantization (``modelopt.torch.kernels.quantization.conv``). When NVFP4 quantization is applied to an ``nn.Conv3d`` layer via ModelOpt PTQ, the implicit GEMM path is used automatically instead of cuDNN. Uses BF16 WMMA tensor cores (SM80+) with FP32 accumulation and in-kernel FP4 (E2M1) activation quantization. Grouped convolution (``groups > 1``) falls back to the default cuDNN path. Inference only — training mode falls back to cuDNN with a warning. | ||
| - Add implicit GEMM CUDA kernel for Conv3D with fused NVFP4 fake quantization (``modelopt.torch.quantization.src.conv``). When NVFP4 quantization is applied to an ``nn.Conv3d`` layer via ModelOpt PTQ, the implicit GEMM path is used automatically instead of cuDNN. Uses BF16 WMMA tensor cores (SM80+) with FP32 accumulation and in-kernel FP4 (E2M1) activation quantization. Grouped convolution (``groups > 1``) falls back to the default cuDNN path. Inference only — training mode falls back to cuDNN with a warning. |
There was a problem hiding this comment.
Update this module path to the post-reorg location.
The same 0.45 section already says the Conv3D kernel moved under modelopt.torch.kernels.quantization.conv, so this entry now points readers at the removed path.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@CHANGELOG.rst` at line 38, Update the module path in the changelog entry so
it points to the new post-reorg location: replace the old path string
"modelopt.torch.quantization.src.conv" with
"modelopt.torch.kernels.quantization.conv" in the Conv3D kernel description;
ensure the sentence still references nn.Conv3d and ModelOpt PTQ and that grouped
convolution and inference/training notes remain unchanged.
| from common import add_aux_layers_args, resolve_aux_layers | ||
| from datasets import load_dataset |
There was a problem hiding this comment.
Wire the TRTLLM dumper into the new answer-only-loss flow too.
This script picked up the shared aux-layer helpers, but it still never accepts or persists a loss_mask. That leaves TRTLLM-generated dumps incompatible with the new offline answer_only_loss=True training path, because OfflineSupervisedDataset now hard-fails when that field is absent.
Also applies to: 126-126
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py`
around lines 26 - 27, The TRTLLM dumper must accept and persist the loss_mask so
generated dumps are compatible with
OfflineSupervisedDataset(answer_only_loss=True); update
compute_hidden_states_trtllm.py to parse/propagate the loss_mask from the
dataset or aux-layer helpers (use add_aux_layers_args and resolve_aux_layers
where aux fields are collected) and include a loss_mask field in the saved dump
payload produced by the TRTLLM dumper (the code path around the dump/save call
near the dumper or dump function at the end of the script). Ensure the emitted
dump keys match what OfflineSupervisedDataset expects (loss_mask present when
answer_only_loss=True) to avoid hard failures.
| "write_interval": 1, | ||
| "file_prefix": f"dp_{args.dp_rank}", | ||
| "eagle3_layers_to_capture": {1, num_hidden_layers // 2 - 1, num_hidden_layers - 4}, | ||
| "eagle3_layers_to_capture": set(resolve_aux_layers(args, num_hidden_layers)), |
There was a problem hiding this comment.
Validate num_hidden_layers before resolving aux layers.
resolve_aux_layers() assumes a concrete layer count. If AutoConfig does not expose num_hidden_layers, this path now fails later with a much less helpful error than the HF dumper, which already added an explicit guard.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In
`@examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py`
at line 199, The call building "eagle3_layers_to_capture" passes
set(resolve_aux_layers(args, num_hidden_layers)) but num_hidden_layers may be
missing from the model config; validate that AutoConfig exposes a concrete
integer num_hidden_layers before calling resolve_aux_layers. In practice, obtain
the config (AutoConfig), check getattr(config, "num_hidden_layers", None) is an
int > 0 (or raise a clear ValueError/RuntimeError), and only then call
resolve_aux_layers(args, num_hidden_layers); include a helpful error message
referencing the model/config when failing so the failure matches the HF dumper
guard.
- Add `dflash_offline` config flag for training from pre-computed hidden states; deletes base model layers to save memory. - Move `dflash_mask_token_id` auto-detection from `main.py` into `DFlashConfig` Pydantic validators; derive `dflash_offline` from `data_args.offline_data_path`. - Add `DFlashBaseModelOutput.from_offline_dict` classmethod for consuming pre-computed hidden states in the forward path. Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
855ab0a to
2cf1784
Compare
|
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (3)
CHANGELOG.rst (1)
38-39:⚠️ Potential issue | 🟠 MajorUpdate Conv3D implicit GEMM kernel path to match the reorg migration table.
This changelog entry references
modelopt.torch.quantization.src.conv, but the migration/backward-breaking guidance in the same file indicates the kernel subpackages moved undermodelopt.torch.kernels...(and specificallymodelopt.torch.kernels.quantization.convfor the oldquantization/src/convpath).Please update this entry’s module path to the post-reorg path to avoid misleading users.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@CHANGELOG.rst` around lines 38 - 39, Update the changelog entry text to reference the new post-reorg module path: replace the old import path `modelopt.torch.quantization.src.conv` with `modelopt.torch.kernels.quantization.conv` in the Conv3D implicit GEMM paragraph so the entry matches the migration/backward-breaking guidance and avoids misleading users about the package location.modelopt/torch/speculative/plugins/hf_eagle.py (1)
344-350:⚠️ Potential issue | 🟠 MajorCache flex attention masks using a shape-dependent key (include
seq_length).
_get_ttt_attention_mask()caches byttt_steponly, but_compute_ttt_attention_mask(batch_size, seq_length, ttt_step)depends onseq_length(mask dimensions change). Reusing a cached mask across different sequence lengths can produce incorrect masks or shape/runtime errors.This matches a previously reported concern; please verify and fix the keying now (e.g.,
(seq_length, ttt_step)or include any other shape params that affect mask construction).Please confirm how
_compute_ttt_attention_mask()is called across the codebase (especially whetherseq_lengthvaries whilettt_steprepeats). If it does, update_cached_attn_blk_masksaccordingly and add/adjust a unit test that calls the function with two seq_lens.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_eagle.py` around lines 344 - 350, The cache key for flex attention masks is currently only ttt_step causing masks computed by _compute_ttt_attention_mask(batch_size, seq_length, ttt_step) to be reused across different seq_length values; update _get_ttt_attention_mask to key _cached_attn_blk_masks by (seq_length, ttt_step) (or another tuple including all shape-dependent params such as batch_size if relevant), and store/retrieve masks using that composite key instead of just ttt_step; then add/adjust a unit test that calls _get_ttt_attention_mask (or directly _compute_ttt_attention_mask) with the same ttt_step but two different seq_length values to verify distinct masks are produced and cached separately.modelopt/torch/speculative/plugins/hf_dflash.py (1)
446-456:⚠️ Potential issue | 🔴 CriticalNormalize offline logits-recompute hidden-state key (avoid brittle
base_model_hidden_states).The offline logits recompute block uses:
out_hiddens = kwargs["base_model_outputs"]["base_model_hidden_states"]If the offline dump schema provides
hidden_states(and the collator doesn’t rename), this will KeyError and break offline training withdflash_self_logit_distillation=Truewhenbase_model_logitsis absent.This matches an earlier reported issue; please fix by accepting both keys (or by ensuring
DFlashBaseModelOutput.from_offline_dict()exposes the needed hidden states in a consistent attribute).🔧 Proposed fix (accept both keys)
- out_hiddens = kwargs["base_model_outputs"]["base_model_hidden_states"] + offline_outputs = kwargs["base_model_outputs"] + out_hiddens = offline_outputs.get("base_model_hidden_states") + if out_hiddens is None: + # Fall back to dump-script key + out_hiddens = offline_outputs.get("hidden_states") + if out_hiddens is None: + raise KeyError( + "Missing hidden states for offline logits recompute. Expected " + "'base_model_hidden_states' or 'hidden_states' in base_model_outputs." + ) base_outputs.logits = self._base_model_lm_head(out_hiddens)Please verify the exact structure of
kwargs["base_model_outputs"]produced by the offline dataloader/collator during DFlash offline training. If it always usesbase_model_hidden_states, then this can be downgraded; otherwise keep the normalization.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 446 - 456, The offline logits-recompute path currently assumes kwargs["base_model_outputs"] contains "base_model_hidden_states" which is brittle; update the logic in the dflash_offline branch (around DFlashBaseModelOutput.from_offline_dict, the dflash_self_logit_distillation check, and the call to self._base_model_lm_head) to tolerate either key name by first normalizing/locating hidden states: prefer an attribute from DFlashBaseModelOutput (extend from_offline_dict to populate a consistent attribute like .hidden_states or .base_model_hidden_states) or, if leaving the dict-based approach, lookup hidden = kwargs["base_model_outputs"].get("base_model_hidden_states") or kwargs["base_model_outputs"].get("hidden_states") and use that for self._base_model_lm_head; ensure base_outputs.logits is set when hidden states are found and preserve base_outputs.target_hidden.
🧹 Nitpick comments (6)
modelopt/torch/speculative/config.py (1)
67-157: Clarifydflash_offline“not user-configurable” wording.The description states
dflash_offlineis “not user-configurable”. However,_derive_dflash_offlineonly overrides the value wheninfo.contextcontainsdata_argsanddatais a dict. If context isn’t provided, user-supplieddflash_offlinecan still take effect.Recommend rewording description to something like “auto-derived when context includes
data_args”.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/config.py` around lines 67 - 157, The description for dflash_offline is misleading — it claims the field is "not user-configurable" but _derive_dflash_offline only overrides when info.context contains data_args and data is a dict, so a user value can persist when context is absent; update the DFlashConfig.dflash_offline ModeloptField description to state that the field is "auto-derived when context includes data_args (overrides provided value)" or similar phrasing; ensure you edit the string in the dflash_offline ModeloptField (and optionally update the docstring of _derive_dflash_offline) to match this behavior.examples/speculative_decoding/collect_hidden_states/common.py (2)
113-130: Make generation-tag validation stricter (exact Jinja tags).
verify_generation_tags()currently checks("generation" in chat_template and "endgeneration" in chat_template). This can produce false positives if those substrings appear in comments/other content, while still missing the exact Jinja tags required byapply_chat_template(..., return_assistant_tokens_mask=True).Consider validating the specific token patterns (e.g.
"{% generation %}"and"{% endgeneration %}"or variants with whitespace control).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/collect_hidden_states/common.py` around lines 113 - 130, The current verify_generation_tags function should check for the exact Jinja tag patterns rather than substrings: update verify_generation_tags to look for specific tokens like "{% generation %}" and "{% endgeneration %}" and also accept common whitespace-control variants (e.g. "{%- generation -%}", "{%generation%}", "{%-generation-%}") so it only passes when the actual template block markers required by apply_chat_template(..., return_assistant_tokens_mask=True) are present; change the conditional that now uses "generation" / "endgeneration" to a robust check for these exact tag strings (or a small normalized/regex match covering the whitespace-control variants) and keep the same ValueError message if validation fails.
132-171: Harden assistant mask shape handling intokenize_with_loss_mask.When
answer_only_loss=True, you do:
mask = out["assistant_masks"]- optional
torch.tensor(...)loss_mask = mask.squeeze(0).to(torch.long)- then check
loss_mask.shape[0] == seq_lenIf
assistant_masksis already 1D (or has unexpected rank),squeeze(0)can silently produce an incorrect shape or throw a confusing error.Recommend:
- assert
loss_mask.ndim == 1after squeeze/reshape (or useloss_mask = mask.reshape(-1)with explicit checks),- validate rank explicitly and raise a clear error message if shapes don’t match
(seq_len,).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/speculative_decoding/collect_hidden_states/common.py` around lines 132 - 171, In tokenize_with_loss_mask, harden handling of out["assistant_masks"] (used when answer_only_loss=True) by avoiding blind mask.squeeze(0): after extracting mask (and converting to a torch.Tensor if needed), explicitly normalize it to 1D (e.g., mask.reshape(-1) or mask.view(-1)) and then assert loss_mask.ndim == 1 and loss_mask.shape[0] == seq_len; if the rank/length is unexpected, raise a clear RuntimeError mentioning "assistant_masks" and both the observed shape and the expected seq_len so the error is unambiguous (update references around tokenize_with_loss_mask, mask, loss_mask, and assistant_masks).modelopt/torch/speculative/eagle/utils.py (1)
81-143: Clarify / standardizeloss_maskdtype for downstream consumers.When
answer_only_loss=True, you set:
loss_mask = offline_data["loss_mask"].to(offline_data["input_ids"].dtype).Many training pipelines (including DFlash) treat
loss_maskas numeric weights and do comparisons like> 0.5or multiplication with float tensors. Int masks work, but it’s easy to accidentally introduce dtype assumptions later.Suggestion:
- either keep
loss_maskastorch.float32(and ensure dump/script emits 0/1), or- add a short comment/docstring in this file stating the expected dtype contract (int vs float) for offline speculative decoding trainers.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/eagle/utils.py` around lines 81 - 143, The loss_mask dtype is inconsistent: in OfflineSupervisedDataset.__getitem__ you cast loss_mask to input_ids.dtype which can be integer; change this to a consistent floating dtype for downstream trainers (e.g., torch.float32) and ensure the else branch also returns float tensors; specifically, when loading offline_data["loss_mask"] in __getitem__, cast it with .to(torch.float32) (and make torch.ones_like(..., dtype=torch.float32) for the default) so consumers can safely compare/multiply loss_mask as numeric weights.tests/regression/torch/speculative/test_dflash_offline.py (1)
112-146: Consider making the loss assertion less brittle across runs.
assert final_loss < 3.0may be sensitive to:
- batch/seed differences,
- dataset slice ordering,
- minor model/kernel changes.
If there’s already precedent for flakiness in this suite, consider:
- using a relative improvement threshold primarily (
final_loss < first_loss) and relaxing the absolute ceiling, or- pinning seeds / deterministic flags in the training launcher for this regression.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/regression/torch/speculative/test_dflash_offline.py` around lines 112 - 146, The absolute loss ceiling in test_dflash_offline_training is brittle; change the assertions to rely primarily on relative improvement (keep the existing assert final_loss < first_loss) and replace the hard assert final_loss < 3.0 with a relaxed, tolerant check such as a relative threshold (e.g., assert final_loss < first_loss * 0.95) and/or a looser ceiling (e.g., final_loss < 3.5) so runs with minor variance pass; additionally consider ensuring deterministic seeds are passed to the training launcher (the run_example_command call / launch_train.sh overrides) if you want reproducibility.tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py (1)
41-95: Add a unit test for missingnum_orig_hidden_layersin offline conversion.Right now the offline conversion tests set
model.config.num_orig_hidden_layers = NUM_BASE_LAYERSexplicitly. If upstream ever forgets to seed it, the offline modify path may crash.Recommend adding a test:
- create tiny model without setting
num_orig_hidden_layers,- call
mtsp.convert(... offline=True ...),- assert either it auto-seeds (if you implement the fallback) or raises a clear ValueError with guidance.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py` around lines 41 - 95, Add a unit test in tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py that verifies offline conversion fails fast when model.config.num_orig_hidden_layers is not set: create a tiny model via get_tiny_llama without setting model.config.num_orig_hidden_layers, call mtsp.convert(model, [("dflash", _get_dflash_config(offline=True))]) and assert it raises a ValueError (or a clear exception) with a message guiding the user to set config.num_orig_hidden_layers; reference mtsp.convert, get_tiny_llama, model.config.num_orig_hidden_layers and _get_dflash_config to locate the relevant logic to test.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/hf_eagle.py`:
- Around line 49-55: default_eagle_aux_layer_ids may return out-of-range indices
for small num_layers (e.g., returns 1 when num_layers==1); update the function
(default_eagle_aux_layer_ids) to generate the intended candidate indices (start,
middle, end), clamp each candidate into the valid 0..num_layers-1 range, filter
duplicates, and return them sorted — e.g., compute candidates = {1,
num_layers//2 - 1, num_layers - 4}, clamp each with min/max using num_layers-1
and 0, then remove any invalid entries for num_layers <= 0 before sorting and
returning.
---
Duplicate comments:
In `@CHANGELOG.rst`:
- Around line 38-39: Update the changelog entry text to reference the new
post-reorg module path: replace the old import path
`modelopt.torch.quantization.src.conv` with
`modelopt.torch.kernels.quantization.conv` in the Conv3D implicit GEMM paragraph
so the entry matches the migration/backward-breaking guidance and avoids
misleading users about the package location.
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 446-456: The offline logits-recompute path currently assumes
kwargs["base_model_outputs"] contains "base_model_hidden_states" which is
brittle; update the logic in the dflash_offline branch (around
DFlashBaseModelOutput.from_offline_dict, the dflash_self_logit_distillation
check, and the call to self._base_model_lm_head) to tolerate either key name by
first normalizing/locating hidden states: prefer an attribute from
DFlashBaseModelOutput (extend from_offline_dict to populate a consistent
attribute like .hidden_states or .base_model_hidden_states) or, if leaving the
dict-based approach, lookup hidden =
kwargs["base_model_outputs"].get("base_model_hidden_states") or
kwargs["base_model_outputs"].get("hidden_states") and use that for
self._base_model_lm_head; ensure base_outputs.logits is set when hidden states
are found and preserve base_outputs.target_hidden.
In `@modelopt/torch/speculative/plugins/hf_eagle.py`:
- Around line 344-350: The cache key for flex attention masks is currently only
ttt_step causing masks computed by _compute_ttt_attention_mask(batch_size,
seq_length, ttt_step) to be reused across different seq_length values; update
_get_ttt_attention_mask to key _cached_attn_blk_masks by (seq_length, ttt_step)
(or another tuple including all shape-dependent params such as batch_size if
relevant), and store/retrieve masks using that composite key instead of just
ttt_step; then add/adjust a unit test that calls _get_ttt_attention_mask (or
directly _compute_ttt_attention_mask) with the same ttt_step but two different
seq_length values to verify distinct masks are produced and cached separately.
---
Nitpick comments:
In `@examples/speculative_decoding/collect_hidden_states/common.py`:
- Around line 113-130: The current verify_generation_tags function should check
for the exact Jinja tag patterns rather than substrings: update
verify_generation_tags to look for specific tokens like "{% generation %}" and
"{% endgeneration %}" and also accept common whitespace-control variants (e.g.
"{%- generation -%}", "{%generation%}", "{%-generation-%}") so it only passes
when the actual template block markers required by apply_chat_template(...,
return_assistant_tokens_mask=True) are present; change the conditional that now
uses "generation" / "endgeneration" to a robust check for these exact tag
strings (or a small normalized/regex match covering the whitespace-control
variants) and keep the same ValueError message if validation fails.
- Around line 132-171: In tokenize_with_loss_mask, harden handling of
out["assistant_masks"] (used when answer_only_loss=True) by avoiding blind
mask.squeeze(0): after extracting mask (and converting to a torch.Tensor if
needed), explicitly normalize it to 1D (e.g., mask.reshape(-1) or mask.view(-1))
and then assert loss_mask.ndim == 1 and loss_mask.shape[0] == seq_len; if the
rank/length is unexpected, raise a clear RuntimeError mentioning
"assistant_masks" and both the observed shape and the expected seq_len so the
error is unambiguous (update references around tokenize_with_loss_mask, mask,
loss_mask, and assistant_masks).
In `@modelopt/torch/speculative/config.py`:
- Around line 67-157: The description for dflash_offline is misleading — it
claims the field is "not user-configurable" but _derive_dflash_offline only
overrides when info.context contains data_args and data is a dict, so a user
value can persist when context is absent; update the DFlashConfig.dflash_offline
ModeloptField description to state that the field is "auto-derived when context
includes data_args (overrides provided value)" or similar phrasing; ensure you
edit the string in the dflash_offline ModeloptField (and optionally update the
docstring of _derive_dflash_offline) to match this behavior.
In `@modelopt/torch/speculative/eagle/utils.py`:
- Around line 81-143: The loss_mask dtype is inconsistent: in
OfflineSupervisedDataset.__getitem__ you cast loss_mask to input_ids.dtype which
can be integer; change this to a consistent floating dtype for downstream
trainers (e.g., torch.float32) and ensure the else branch also returns float
tensors; specifically, when loading offline_data["loss_mask"] in __getitem__,
cast it with .to(torch.float32) (and make torch.ones_like(...,
dtype=torch.float32) for the default) so consumers can safely compare/multiply
loss_mask as numeric weights.
In `@tests/regression/torch/speculative/test_dflash_offline.py`:
- Around line 112-146: The absolute loss ceiling in test_dflash_offline_training
is brittle; change the assertions to rely primarily on relative improvement
(keep the existing assert final_loss < first_loss) and replace the hard assert
final_loss < 3.0 with a relaxed, tolerant check such as a relative threshold
(e.g., assert final_loss < first_loss * 0.95) and/or a looser ceiling (e.g.,
final_loss < 3.5) so runs with minor variance pass; additionally consider
ensuring deterministic seeds are passed to the training launcher (the
run_example_command call / launch_train.sh overrides) if you want
reproducibility.
In `@tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py`:
- Around line 41-95: Add a unit test in
tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py that verifies
offline conversion fails fast when model.config.num_orig_hidden_layers is not
set: create a tiny model via get_tiny_llama without setting
model.config.num_orig_hidden_layers, call mtsp.convert(model, [("dflash",
_get_dflash_config(offline=True))]) and assert it raises a ValueError (or a
clear exception) with a message guiding the user to set
config.num_orig_hidden_layers; reference mtsp.convert, get_tiny_llama,
model.config.num_orig_hidden_layers and _get_dflash_config to locate the
relevant logic to test.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 86ab1669-fed2-4c25-b2d1-4f9217ad52ae
📒 Files selected for processing (15)
CHANGELOG.rstexamples/speculative_decoding/collect_hidden_states/common.pyexamples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.pyexamples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.pyexamples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/main.pymodelopt/torch/speculative/config.pymodelopt/torch/speculative/dflash/dflash_model.pymodelopt/torch/speculative/eagle/utils.pymodelopt/torch/speculative/plugins/hf_dflash.pymodelopt/torch/speculative/plugins/hf_eagle.pymodelopt/torch/speculative/plugins/modeling_dflash.pytests/gpu/torch/speculative/plugins/test_hf_dflash.pytests/regression/torch/speculative/test_dflash_offline.pytests/unit/torch/speculative/plugins/test_hf_dflash_offline.py
✅ Files skipped from review due to trivial changes (1)
- examples/speculative_decoding/main.py
🚧 Files skipped from review as they are similar to previous changes (6)
- examples/speculative_decoding/eagle_utils.py
- modelopt/torch/speculative/dflash/dflash_model.py
- examples/speculative_decoding/collect_hidden_states/compute_hidden_states_trtllm.py
- modelopt/torch/speculative/plugins/modeling_dflash.py
- tests/gpu/torch/speculative/plugins/test_hf_dflash.py
- examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py
| def default_eagle_aux_layer_ids(num_layers: int) -> list[int]: | ||
| """Default EAGLE-3 auxiliary hidden-state layer IDs (0-based). | ||
|
|
||
| Three layers near the start, middle, and end of the stack. | ||
| """ | ||
| return sorted({1, max(0, num_layers // 2 - 1), max(0, num_layers - 4)}) | ||
|
|
There was a problem hiding this comment.
Fix potential out-of-range auxiliary layer IDs for small num_layers.
default_eagle_aux_layer_ids() returns sorted({1, ..., max(0, num_layers - 4)}). For num_layers=1, this becomes {0, 1}, which includes 1 even though valid 0-based layer IDs are only [0]. That can break hook registration / hidden-state collection for tiny configs.
🔧 Proposed fix (clamp to valid range)
def default_eagle_aux_layer_ids(num_layers: int) -> list[int]:
@@
- return sorted({1, max(0, num_layers // 2 - 1), max(0, num_layers - 4)})
+ if num_layers <= 0:
+ return []
+ # Always clamp into [0, num_layers-1]
+ start = min(1, num_layers - 1)
+ mid = max(0, num_layers // 2 - 1)
+ end = max(0, num_layers - 4)
+ end = min(end, num_layers - 1)
+ return sorted({start, mid, end})🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/speculative/plugins/hf_eagle.py` around lines 49 - 55,
default_eagle_aux_layer_ids may return out-of-range indices for small num_layers
(e.g., returns 1 when num_layers==1); update the function
(default_eagle_aux_layer_ids) to generate the intended candidate indices (start,
middle, end), clamp each candidate into the valid 0..num_layers-1 range, filter
duplicates, and return them sorted — e.g., compute candidates = {1,
num_layers//2 - 1, num_layers - 4}, clamp each with min/max using num_layers-1
and 0, then remove any invalid entries for num_layers <= 0 before sorting and
returning.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1343 +/- ##
==========================================
- Coverage 76.27% 75.82% -0.45%
==========================================
Files 471 471
Lines 50323 50375 +52
==========================================
- Hits 38383 38199 -184
- Misses 11940 12176 +236
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/regression/torch/speculative/test_dflash_offline.py`:
- Around line 124-126: The current check only asserts that some .pt files exist
(pt_files), which allows partial dumps to pass; replace this with an explicit
check that the number of dumped files equals the expected count: compute
expected_num_files from the test configuration (e.g., the training steps /
checkpoints variables used elsewhere in the test such as num_steps, max_steps,
checkpoint_interval or define expected_num_files explicitly for this test) and
assert len(pt_files) == expected_num_files, raising a clear error that includes
dump_dir and the sorted pt_files list when it does not match.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: e9cc0271-d27d-495b-9817-d5f769e28e3b
📒 Files selected for processing (1)
tests/regression/torch/speculative/test_dflash_offline.py
| pt_files = list(dump_dir.rglob("*.pt")) | ||
| assert pt_files, f"No .pt files dumped under {dump_dir}" | ||
| return dump_dir |
There was a problem hiding this comment.
Assert full hidden-state dump completion, not just non-empty output.
assert pt_files can still pass after a partial dump, which can make this regression falsely green with a much smaller training set.
Suggested change
- pt_files = list(dump_dir.rglob("*.pt"))
- assert pt_files, f"No .pt files dumped under {dump_dir}"
+ pt_files = list(dump_dir.rglob("*.pt"))
+ assert len(pt_files) == _DUMP_NUM_CONVERSATIONS, (
+ f"Expected {_DUMP_NUM_CONVERSATIONS} .pt files under {dump_dir}, got {len(pt_files)}"
+ )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/regression/torch/speculative/test_dflash_offline.py` around lines 124 -
126, The current check only asserts that some .pt files exist (pt_files), which
allows partial dumps to pass; replace this with an explicit check that the
number of dumped files equals the expected count: compute expected_num_files
from the test configuration (e.g., the training steps / checkpoints variables
used elsewhere in the test such as num_steps, max_steps, checkpoint_interval or
define expected_num_files explicitly for this test) and assert len(pt_files) ==
expected_num_files, raising a clear error that includes dump_dir and the sorted
pt_files list when it does not match.
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
What does this PR do?
Type of change: new feature
Part 2 of a 3-PR series splitting #1271:
ParallelDraftHFSpecDecMixinChanges:
dflash_offlineflag toDFlashConfigfor training from pre-computed hidden states; deletes base model layers to save memory.DFlashConfig:_derive_dflash_offline— auto-derivedflash_offlinefromdata_args.offline_data_pathin validation context. Not user-configurable: any user-supplied value is overridden by the derived value._resolve_mask_token_id— auto-detectdflash_mask_token_idfromtokenizer.mask_token_id._check_mask_token_id— fail fast if unset after resolution.HFDFlashModel.modify(): selectnum_orig_hidden_layerswhen offline; pick_base_model_lm_headdevice when no base layers present; drop base-modellayersmodule.HFDFlashModel.forward(): add offline branch — consumes precomputedbase_model_outputsviaDFlashBaseModelOutput.from_offline_dict, and whendflash_self_logit_distillationis enabled withbase_model_logitsabsent, recomputes logits frombase_model_hidden_statesvia_base_model_lm_head. Raises a clear error from the non-training /pseudo_speculative_generatepaths whendflash_offline=True, since base-model layers have been deleted.DFlashBaseModelOutputdataclass inmodeling_dflash.py(withfrom_offline_dictclassmethod) to unify online/offline output shapes.aux_hidden_statesis required infrom_offline_dictso missing keys fail fast at the entry point rather than deeper in the forward.examples/speculative_decoding/main.py: replace inlinemask_token_idauto-detect withDFlashConfig.model_validate(dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args}).Silent bug fix —
add_generation_template→add_generation_promptThe pre-refactor
compute_hidden_states_hf.pypassedadd_generation_template=Falsetotokenizer.apply_chat_template. This kwarg does not exist on HFapply_chat_templateand was being silently ignored, so the intended "don't append a generation prompt" behavior was never actually applied. The newtokenize_with_loss_maskhelper inexamples/speculative_decoding/collect_hidden_states/common.pyuses the correctadd_generation_prompt=False. This is a real behavior change for anyone re-dumping hidden states: trailing generation prompts that were previously appended to the tokenized sequences will no longer be included.Testing
New tests:
tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py— CPU unit tests for convert path (online keeps base layers, offline deletes them;num_orig_hidden_layersdrivestarget_layer_idsin offline mode) andDFlashConfig._derive_dflash_offlinevalidator.TestDFlashOfflineForwardGPUintests/gpu/torch/speculative/plugins/test_hf_dflash.py— GPU forward smoke with precomputedbase_model_outputs, plus thedflash_self_logit_distillationlogit-recompute path.training test:

Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).dflash_offlineflag defaulting toFalse; validators fall through when context not provided.CONTRIBUTING.md: N/ATODO (follow-up)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_*.pyto support DFlash offline data. Current scripts are Eagle-specific — they hardcode the[2, N/2, N-3]aux-layer selection and emit{input_ids, hidden_states, aux_hidden_states}. DFlash offline needs:build_target_layer_ids(num_orig_hidden_layers, num_draft_layers)(or a configurable list), not the Eagle triplet.base_model_hidden_stateskey (last-layer hidden) soDFlashBaseModelOutput.from_offline_dict+ thedflash_self_logit_distillationrecompute path can consume it.base_model_logitsdump so offline training can skip the self-distillation logit recomputation when logits are available.Additional Information
Base branch is #1296 (file reorg). Retarget to
mainonce #1296 merges.Summary by CodeRabbit
New Features
Documentation
Tests