[3/3][Refactor]: Extract HFSpecDecMixin for HF spec-decoding plugins#1297
Draft
[3/3][Refactor]: Extract HFSpecDecMixin for HF spec-decoding plugins#1297
Conversation
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
- Add `dflash_offline` config flag for training from pre-computed hidden states; deletes base model layers to save memory. - Move `dflash_mask_token_id` auto-detection from `main.py` into `DFlashConfig` Pydantic validators; derive `dflash_offline` from `data_args.offline_data_path`. - Add `DFlashBaseModelOutput.from_offline_dict` classmethod for consuming pre-computed hidden states in the forward path. Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Extract duplicated base-model discovery, forward pass, NVTX profiling, and torch.compile logic from HFEagleModel / HFDFlashModel into a shared mixin (hf_spec_mixin.py). HFEagleModel and HFDFlashModel now inherit from (HFSpecDecMixin, EagleModel/DFlashModel). Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Contributor
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Plus Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
This was referenced Apr 19, 2026
f7a27dc to
e063406
Compare
Base automatically changed from
haoguo/dflash-offline
to
haoguo/spec-file-reorg
April 23, 2026 22:05
aedd188 to
536ea48
Compare
1 task
h-guo18
added a commit
that referenced
this pull request
Apr 24, 2026
### What does this PR do? Type of change: refactoring Part 1 of a 3-PR series splitting #1271: - **[1/3] this PR**: File reorg + deprecate `ParallelDraft` - **[2/3] #1295**: Offline DFlash training - **[3/3] #1297**: Extract `HFSpecDecMixin` Changes: - **File reorg**: `transformers.py` → `hf_eagle.py`; extract `HFMedusaModel` → `hf_medusa.py`; extract `EagleModule` / `EagleBaseModelOutput` → `modeling_eagle.py`; extract `DFlashModule` / `DFlashAttention` / `DFlashDecoderLayer` / `build_target_layer_ids` / `apply_rotary_pos_emb` → `modeling_dflash.py`. - **Deprecate `ParallelDraft`**: remove `parallel_draft_step`, `parallel_draft_heads_num_layers`, and the `ParallelDraft` module from HF Eagle; remove the `EagleMedusaExporter` branch from `HFEagleModel.get_exporter()` (the `EagleMedusaExporter` class itself still lives in `hf_spec_export.py` for Megatron parity). - **Rename**: `_draft_model_config` → `eagle_config` in export plugin. - Update imports in `examples/speculative_decoding/` and `modelopt/torch/speculative/utils.py` to follow the module rename. ### Testing Validated with existing Eagle and DFlash training scripts (re-run after `9ae5302729 revert behavior change`). ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ❌ — renames `modelopt.torch.speculative.plugins.transformers` → `.hf_eagle`; removes `parallel_draft_step` / `parallel_draft_heads_num_layers` from Eagle config; renames `_draft_model_config` → `eagle_config` in export plugin. - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: N/A — pure refactor; existing tests updated for the rename. `test_hf_spec_rope_export.py` assertions were also corrected to reflect the actual production path (the old assertions were masked by `MagicMock` not invoking the `_draft_model_config` `@property`). - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ❌ ### Additional Information Breaking changes: - `modelopt.torch.speculative.plugins.transformers` → `.hf_eagle` - `parallel_draft_step` / `parallel_draft_heads_num_layers` removed from Eagle config - `_draft_model_config` → `eagle_config` in export plugin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactoring** * Reorganized speculative-decoding plugins into focused modules, converting the legacy "transformers" entry into a deprecated shim that re-exports the new plugin surface. * Consolidated DFlash implementation into a shared modeling component and introduced a dedicated EAGLE decoder module. * **New Features** * Added a Medusa speculative-decoding plugin with configurable heads and combined-loss training behavior. * **Chores** * Updated pre-commit license-hook exclusion and feature-flag wiring. * **Tests** * Updated export tests to expect rope-scaling fallback semantics. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 task
ChenhanYu
pushed a commit
that referenced
this pull request
Apr 26, 2026
### What does this PR do? Type of change: new feature Part 2 of a 3-PR series splitting #1271: - **[1/3] #1296**: File reorg + deprecate `ParallelDraft` - **[2/3] this PR**: Offline DFlash training (depends on #1296) - **[3/3] #1297**: Extract `HFSpecDecMixin` Changes: - Add `dflash_offline` flag to `DFlashConfig` for training from pre-computed hidden states; deletes base model layers to save memory. - Add Pydantic validators on `DFlashConfig`: - `_derive_dflash_offline` — auto-derive `dflash_offline` from `data_args.offline_data_path` in validation context. Not user-configurable: any user-supplied value is overridden by the derived value. - `_resolve_mask_token_id` — auto-detect `dflash_mask_token_id` from `tokenizer.mask_token_id`. - `_check_mask_token_id` — fail fast if unset after resolution. - `HFDFlashModel.modify()`: select `num_orig_hidden_layers` when offline; pick `_base_model_lm_head` device when no base layers present; drop base-model `layers` module. - `HFDFlashModel.forward()`: add offline branch — consumes precomputed `base_model_outputs` via `DFlashBaseModelOutput.from_offline_dict`, and when `dflash_self_logit_distillation` is enabled with `base_model_logits` absent, recomputes logits from `base_model_hidden_states` via `_base_model_lm_head`. Raises a clear error from the non-training / `pseudo_speculative_generate` paths when `dflash_offline=True`, since base-model layers have been deleted. - `DFlashBaseModelOutput` dataclass in `modeling_dflash.py` (with `from_offline_dict` classmethod) to unify online/offline output shapes. `aux_hidden_states` is required in `from_offline_dict` so missing keys fail fast at the entry point rather than deeper in the forward. - `examples/speculative_decoding/main.py`: replace inline `mask_token_id` auto-detect with `DFlashConfig.model_validate(dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args})`. ### Silent bug fix — `add_generation_template` → `add_generation_prompt` The pre-refactor `compute_hidden_states_hf.py` passed `add_generation_template=False` to `tokenizer.apply_chat_template`. This kwarg does not exist on HF `apply_chat_template` and was being silently ignored, so the intended "don't append a generation prompt" behavior was never actually applied. The new `tokenize_with_loss_mask` helper in `examples/speculative_decoding/collect_hidden_states/common.py` uses the correct `add_generation_prompt=False`. **This is a real behavior change** for anyone re-dumping hidden states: trailing generation prompts that were previously appended to the tokenized sequences will no longer be included. ### Testing - New tests: - `tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py` — CPU unit tests for convert path (online keeps base layers, offline deletes them; `num_orig_hidden_layers` drives `target_layer_ids` in offline mode) and `DFlashConfig._derive_dflash_offline` validator. - `TestDFlashOfflineForwardGPU` in `tests/gpu/torch/speculative/plugins/test_hf_dflash.py` — GPU forward smoke with precomputed `base_model_outputs`, plus the `dflash_self_logit_distillation` logit-recompute path. - training test: <img width="454" height="317" alt="image" src="https://github.com/user-attachments/assets/79b92790-4d15-4313-bb9b-f35665b012e6" /> <img width="456" height="310" alt="image" src="https://github.com/user-attachments/assets/4558559f-9c35-49ed-b36e-82fbc99eab23" /> ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ✅ — additive `dflash_offline` flag defaulting to `False`; validators fall through when context not provided. - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: ✅ — see Testing section above. - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ✅ ### TODO (follow-up) - [x] Update `examples/speculative_decoding/collect_hidden_states/compute_hidden_states_*.py` to support DFlash offline data. Current scripts are Eagle-specific — they hardcode the `[2, N/2, N-3]` aux-layer selection and emit `{input_ids, hidden_states, aux_hidden_states}`. DFlash offline needs: - Aux layer indices driven by `build_target_layer_ids(num_orig_hidden_layers, num_draft_layers)` (or a configurable list), not the Eagle triplet. - `base_model_hidden_states` key (last-layer hidden) so `DFlashBaseModelOutput.from_offline_dict` + the `dflash_self_logit_distillation` recompute path can consume it. - Optional `base_model_logits` dump so offline training can skip the self-distillation logit recomputation when logits are available. ### Additional Information Base branch is #1296 (file reorg). Retarget to `main` once #1296 merges. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Offline DFlash speculative-decoding training from precomputed base-model hidden states * Answer-only-loss training with persisted loss masks and optional chat-template support * Flexible auxiliary-layer selection via CLI and an exposed default aux-layer helper * Auto-derived offline flag in config and automatic memory optimization during offline conversion * **Documentation** * Updated guides for offline pipeline, aux-layer selection, and loss-masking options * **Tests** * New unit, GPU, and regression tests covering offline conversion, training, and config derivation <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Type of change: refactoring
Part 3 of a 3-PR series splitting #1271:
Changes:
Testing
No behavioral change expected. Verified MRO includes `HFSpecDecMixin` and existing Eagle/DFlash training scripts run unchanged.
Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (`git commit -s -S`).
Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.).
Additional Information
Base branch is #1295. Retarget to `main` once #1296 and #1295 merge.