Skip to content

[3/3][Refactor]: Extract HFSpecDecMixin for HF spec-decoding plugins#1297

Draft
h-guo18 wants to merge 3 commits intomainfrom
haoguo/spec-mixin-new
Draft

[3/3][Refactor]: Extract HFSpecDecMixin for HF spec-decoding plugins#1297
h-guo18 wants to merge 3 commits intomainfrom
haoguo/spec-mixin-new

Conversation

@h-guo18
Copy link
Copy Markdown
Contributor

@h-guo18 h-guo18 commented Apr 19, 2026

What does this PR do?

Type of change: refactoring

Part 3 of a 3-PR series splitting #1271:

Changes:

  • New `modelopt/torch/speculative/plugins/hf_spec_mixin.py` containing `HFSpecDecMixin` with:
    • Properties: `_base_model`, `_base_model_embeddings`, `_base_model_lm_head`, `_base_llm_config` (VLM-aware).
    • `_find_base_model_parts()` — probe `modeling_fakebase` paths.
    • `_base_model_forward()` — generic base forward with optional freeze + CE loss.
    • `_nvtx_range()` and `_activate_torch_compile()` driven by subclass `_compile_targets`.
  • `HFEagleModel` now `(HFSpecDecMixin, EagleModel)`; drops the duplicated helpers; sets `_compile_targets` and `self._enable_nvtx` in `modify()`.
  • `HFDFlashModel` now `(HFSpecDecMixin, DFlashModel)`; drops the duplicated helpers; `_dflash_base_model_forward` reuses the mixin's generic forward.

Testing

No behavioral change expected. Verified MRO includes `HFSpecDecMixin` and existing Eagle/DFlash training scripts run unchanged.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (`git commit -s -S`).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.).

  • Is this change backward compatible?: ✅ — internal refactor; no public API change.
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A
  • Did you write any new necessary tests?: N/A — pure refactor.
  • Did you update Changelog?: ❌

Additional Information

Base branch is #1295. Retarget to `main` once #1296 and #1295 merge.

h-guo18 added 3 commits April 19, 2026 21:46
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
- Add `dflash_offline` config flag for training from pre-computed hidden states;
  deletes base model layers to save memory.
- Move `dflash_mask_token_id` auto-detection from `main.py` into `DFlashConfig`
  Pydantic validators; derive `dflash_offline` from `data_args.offline_data_path`.
- Add `DFlashBaseModelOutput.from_offline_dict` classmethod for consuming
  pre-computed hidden states in the forward path.

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Extract duplicated base-model discovery, forward pass, NVTX profiling, and
torch.compile logic from HFEagleModel / HFDFlashModel into a shared mixin
(hf_spec_mixin.py). HFEagleModel and HFDFlashModel now inherit from
(HFSpecDecMixin, EagleModel/DFlashModel).

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 19, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 19, 2026

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 81e36b72-3144-46e9-b217-9598fb20a176

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch haoguo/spec-mixin-new

Comment @coderabbitai help to get the list of available commands and usage tips.

@h-guo18 h-guo18 changed the title [Refactor]: HFSpecDecMixin shared across HF spec-decoding plugins [3/3][Refactor]: Extract HFSpecDecMixin for HF spec-decoding plugins Apr 19, 2026
@h-guo18 h-guo18 force-pushed the haoguo/dflash-offline branch 2 times, most recently from f7a27dc to e063406 Compare April 23, 2026 21:45
Base automatically changed from haoguo/dflash-offline to haoguo/spec-file-reorg April 23, 2026 22:05
@h-guo18 h-guo18 force-pushed the haoguo/spec-file-reorg branch from aedd188 to 536ea48 Compare April 23, 2026 22:08
h-guo18 added a commit that referenced this pull request Apr 24, 2026
### What does this PR do?

Type of change: refactoring

Part 1 of a 3-PR series splitting #1271:
- **[1/3] this PR**: File reorg + deprecate `ParallelDraft`
- **[2/3] #1295**: Offline DFlash training
- **[3/3] #1297**: Extract `HFSpecDecMixin`

Changes:
- **File reorg**: `transformers.py` → `hf_eagle.py`; extract
`HFMedusaModel` → `hf_medusa.py`; extract `EagleModule` /
`EagleBaseModelOutput` → `modeling_eagle.py`; extract `DFlashModule` /
`DFlashAttention` / `DFlashDecoderLayer` / `build_target_layer_ids` /
`apply_rotary_pos_emb` → `modeling_dflash.py`.
- **Deprecate `ParallelDraft`**: remove `parallel_draft_step`,
`parallel_draft_heads_num_layers`, and the `ParallelDraft` module from
HF Eagle; remove the `EagleMedusaExporter` branch from
`HFEagleModel.get_exporter()` (the `EagleMedusaExporter` class itself
still lives in `hf_spec_export.py` for Megatron parity).
- **Rename**: `_draft_model_config` → `eagle_config` in export plugin.
- Update imports in `examples/speculative_decoding/` and
`modelopt/torch/speculative/utils.py` to follow the module rename.

### Testing

Validated with existing Eagle and DFlash training scripts (re-run after
`9ae5302729 revert behavior change`).

### Before your PR is "*Ready for review*"

Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)
and your commits are signed (`git commit -s -S`).

Make sure you read and follow the [Security Best
Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors)
(e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(...,
weights_only=False)`, `pickle`, etc.).

- Is this change backward compatible?: ❌ — renames
`modelopt.torch.speculative.plugins.transformers` → `.hf_eagle`; removes
`parallel_draft_step` / `parallel_draft_heads_num_layers` from Eagle
config; renames `_draft_model_config` → `eagle_config` in export plugin.
- If you copied code from any other sources or added a new PIP
dependency, did you follow guidance in `CONTRIBUTING.md`: N/A
- Did you write any new necessary tests?: N/A — pure refactor; existing
tests updated for the rename. `test_hf_spec_rope_export.py` assertions
were also corrected to reflect the actual production path (the old
assertions were masked by `MagicMock` not invoking the
`_draft_model_config` `@property`).
- Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?:
❌

### Additional Information

Breaking changes:
- `modelopt.torch.speculative.plugins.transformers` → `.hf_eagle`
- `parallel_draft_step` / `parallel_draft_heads_num_layers` removed from
Eagle config
- `_draft_model_config` → `eagle_config` in export plugin

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Refactoring**
* Reorganized speculative-decoding plugins into focused modules,
converting the legacy "transformers" entry into a deprecated shim that
re-exports the new plugin surface.
* Consolidated DFlash implementation into a shared modeling component
and introduced a dedicated EAGLE decoder module.

* **New Features**
* Added a Medusa speculative-decoding plugin with configurable heads and
combined-loss training behavior.

* **Chores**
  * Updated pre-commit license-hook exclusion and feature-flag wiring.

* **Tests**
  * Updated export tests to expect rope-scaling fallback semantics.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Base automatically changed from haoguo/spec-file-reorg to main April 24, 2026 21:46
ChenhanYu pushed a commit that referenced this pull request Apr 26, 2026
### What does this PR do?

Type of change: new feature

Part 2 of a 3-PR series splitting #1271:
- **[1/3] #1296**: File reorg + deprecate `ParallelDraft`
- **[2/3] this PR**: Offline DFlash training (depends on #1296)
- **[3/3] #1297**: Extract `HFSpecDecMixin`

Changes:
- Add `dflash_offline` flag to `DFlashConfig` for training from
pre-computed hidden states; deletes base model layers to save memory.
- Add Pydantic validators on `DFlashConfig`:
- `_derive_dflash_offline` — auto-derive `dflash_offline` from
`data_args.offline_data_path` in validation context. Not
user-configurable: any user-supplied value is overridden by the derived
value.
- `_resolve_mask_token_id` — auto-detect `dflash_mask_token_id` from
`tokenizer.mask_token_id`.
  - `_check_mask_token_id` — fail fast if unset after resolution.
- `HFDFlashModel.modify()`: select `num_orig_hidden_layers` when
offline; pick `_base_model_lm_head` device when no base layers present;
drop base-model `layers` module.
- `HFDFlashModel.forward()`: add offline branch — consumes precomputed
`base_model_outputs` via `DFlashBaseModelOutput.from_offline_dict`, and
when `dflash_self_logit_distillation` is enabled with
`base_model_logits` absent, recomputes logits from
`base_model_hidden_states` via `_base_model_lm_head`. Raises a clear
error from the non-training / `pseudo_speculative_generate` paths when
`dflash_offline=True`, since base-model layers have been deleted.
- `DFlashBaseModelOutput` dataclass in `modeling_dflash.py` (with
`from_offline_dict` classmethod) to unify online/offline output shapes.
`aux_hidden_states` is required in `from_offline_dict` so missing keys
fail fast at the entry point rather than deeper in the forward.
- `examples/speculative_decoding/main.py`: replace inline
`mask_token_id` auto-detect with
`DFlashConfig.model_validate(dflash_cfg, context={"tokenizer":
tokenizer, "data_args": data_args})`.

### Silent bug fix — `add_generation_template` → `add_generation_prompt`

The pre-refactor `compute_hidden_states_hf.py` passed
`add_generation_template=False` to `tokenizer.apply_chat_template`. This
kwarg does not exist on HF `apply_chat_template` and was being silently
ignored, so the intended "don't append a generation prompt" behavior was
never actually applied. The new `tokenize_with_loss_mask` helper in
`examples/speculative_decoding/collect_hidden_states/common.py` uses the
correct `add_generation_prompt=False`. **This is a real behavior
change** for anyone re-dumping hidden states: trailing generation
prompts that were previously appended to the tokenized sequences will no
longer be included.


### Testing
- New tests:
- `tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py` — CPU
unit tests for convert path (online keeps base layers, offline deletes
them; `num_orig_hidden_layers` drives `target_layer_ids` in offline
mode) and `DFlashConfig._derive_dflash_offline` validator.
- `TestDFlashOfflineForwardGPU` in
`tests/gpu/torch/speculative/plugins/test_hf_dflash.py` — GPU forward
smoke with precomputed `base_model_outputs`, plus the
`dflash_self_logit_distillation` logit-recompute path.

- training test:
<img width="454" height="317" alt="image"
src="https://github.com/user-attachments/assets/79b92790-4d15-4313-bb9b-f35665b012e6"
/> <img width="456" height="310" alt="image"
src="https://github.com/user-attachments/assets/4558559f-9c35-49ed-b36e-82fbc99eab23"
/>


### Before your PR is "*Ready for review*"

Make sure you read and follow [Contributor
guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)
and your commits are signed (`git commit -s -S`).

Make sure you read and follow the [Security Best
Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors)
(e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(...,
weights_only=False)`, `pickle`, etc.).

- Is this change backward compatible?: ✅ — additive `dflash_offline`
flag defaulting to `False`; validators fall through when context not
provided.
- If you copied code from any other sources or added a new PIP
dependency, did you follow guidance in `CONTRIBUTING.md`: N/A
- Did you write any new necessary tests?: ✅ — see Testing section above.
- Did you update
[Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?:
✅

### TODO (follow-up)

- [x] Update
`examples/speculative_decoding/collect_hidden_states/compute_hidden_states_*.py`
to support DFlash offline data. Current scripts are Eagle-specific —
they hardcode the `[2, N/2, N-3]` aux-layer selection and emit
`{input_ids, hidden_states, aux_hidden_states}`. DFlash offline needs:
- Aux layer indices driven by
`build_target_layer_ids(num_orig_hidden_layers, num_draft_layers)` (or a
configurable list), not the Eagle triplet.
- `base_model_hidden_states` key (last-layer hidden) so
`DFlashBaseModelOutput.from_offline_dict` + the
`dflash_self_logit_distillation` recompute path can consume it.
- Optional `base_model_logits` dump so offline training can skip the
self-distillation logit recomputation when logits are available.

### Additional Information

Base branch is #1296 (file reorg). Retarget to `main` once #1296 merges.



<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **New Features**
* Offline DFlash speculative-decoding training from precomputed
base-model hidden states
* Answer-only-loss training with persisted loss masks and optional
chat-template support
* Flexible auxiliary-layer selection via CLI and an exposed default
aux-layer helper
* Auto-derived offline flag in config and automatic memory optimization
during offline conversion

* **Documentation**
* Updated guides for offline pipeline, aux-layer selection, and
loss-masking options

* **Tests**
* New unit, GPU, and regression tests covering offline conversion,
training, and config derivation
<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant