[2/3][Feat]: Offline DFlash training#1295
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
Important Review skippedAuto reviews are disabled on base/target branches other than the default branch. 🗂️ Base branches to auto review (3)
Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## haoguo/spec-file-reorg #1295 +/- ##
===========================================================
+ Coverage 60.34% 75.73% +15.39%
===========================================================
Files 470 471 +1
Lines 50255 50375 +120
===========================================================
+ Hits 30325 38154 +7829
+ Misses 19930 12221 -7709
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
9e4eeb0 to
f208109
Compare
f208109 to
178b191
Compare
|
Self-review follow-ups (not in this PR — will address separately so this one stays focused):
|
|
AI-assisted self-review (Claude) — findings posted for transparency. Fixed in this PR:
Deferred to follow-up (see earlier comment):
|
There was a problem hiding this comment.
Review
CI test failure root cause
The test_unified_export_megatron failure (TypeError: '>' not supported between instances of 'NoneType' and 'int') is at megatron_eagle.py:526:
if self.config.parallel_draft_step > 1:This is because PR #1296 (the [1/3] dependency) removed parallel_draft_step from eagle/default_config.py, but megatron_eagle.py still references it. The config attribute is now None, causing the comparison to fail. Fix: guard with getattr(self.config, "parallel_draft_step", 1) > 1 or restore the default in the config.
Missing tests/gpu_regression for DFlash offline
The PR adds GPU integration and CPU unit tests that verify the forward pass runs and returns a finite loss — good. But there's no regression test that verifies end-to-end training convergence.
Please add a tests/gpu_regression test for DFlash offline that:
- Dumps hidden states from the same synthetic dataset used by the existing online DFlash regression test
- Trains offline DFlash for a few steps
- Verifies loss decreases (or matches a golden threshold)
Without this, the offline training path could silently regress while all existing tests still pass.
9c1ed15 to
536ea48
Compare
- Add `dflash_offline` config flag for training from pre-computed hidden states; deletes base model layers to save memory. - Move `dflash_mask_token_id` auto-detection from `main.py` into `DFlashConfig` Pydantic validators; derive `dflash_offline` from `data_args.offline_data_path`. - Add `DFlashBaseModelOutput.from_offline_dict` classmethod for consuming pre-computed hidden states in the forward path. Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
f7a27dc to
e063406
Compare
|
Accidently closed. Reopened in #1337 |
This PR is reopened in #1337 Comment 1 addressed in previous PR. Comment 2 addressed in this PR.Thanks! |
### What does this PR do? Type of change: refactoring Part 1 of a 3-PR series splitting #1271: - **[1/3] this PR**: File reorg + deprecate `ParallelDraft` - **[2/3] #1295**: Offline DFlash training - **[3/3] #1297**: Extract `HFSpecDecMixin` Changes: - **File reorg**: `transformers.py` → `hf_eagle.py`; extract `HFMedusaModel` → `hf_medusa.py`; extract `EagleModule` / `EagleBaseModelOutput` → `modeling_eagle.py`; extract `DFlashModule` / `DFlashAttention` / `DFlashDecoderLayer` / `build_target_layer_ids` / `apply_rotary_pos_emb` → `modeling_dflash.py`. - **Deprecate `ParallelDraft`**: remove `parallel_draft_step`, `parallel_draft_heads_num_layers`, and the `ParallelDraft` module from HF Eagle; remove the `EagleMedusaExporter` branch from `HFEagleModel.get_exporter()` (the `EagleMedusaExporter` class itself still lives in `hf_spec_export.py` for Megatron parity). - **Rename**: `_draft_model_config` → `eagle_config` in export plugin. - Update imports in `examples/speculative_decoding/` and `modelopt/torch/speculative/utils.py` to follow the module rename. ### Testing Validated with existing Eagle and DFlash training scripts (re-run after `9ae5302729 revert behavior change`). ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ❌ — renames `modelopt.torch.speculative.plugins.transformers` → `.hf_eagle`; removes `parallel_draft_step` / `parallel_draft_heads_num_layers` from Eagle config; renames `_draft_model_config` → `eagle_config` in export plugin. - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: N/A — pure refactor; existing tests updated for the rename. `test_hf_spec_rope_export.py` assertions were also corrected to reflect the actual production path (the old assertions were masked by `MagicMock` not invoking the `_draft_model_config` `@property`). - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ❌ ### Additional Information Breaking changes: - `modelopt.torch.speculative.plugins.transformers` → `.hf_eagle` - `parallel_draft_step` / `parallel_draft_heads_num_layers` removed from Eagle config - `_draft_model_config` → `eagle_config` in export plugin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactoring** * Reorganized speculative-decoding plugins into focused modules, converting the legacy "transformers" entry into a deprecated shim that re-exports the new plugin surface. * Consolidated DFlash implementation into a shared modeling component and introduced a dedicated EAGLE decoder module. * **New Features** * Added a Medusa speculative-decoding plugin with configurable heads and combined-loss training behavior. * **Chores** * Updated pre-commit license-hook exclusion and feature-flag wiring. * **Tests** * Updated export tests to expect rope-scaling fallback semantics. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
What does this PR do?
Type of change: new feature
Part 2 of a 3-PR series splitting #1271:
ParallelDraftHFSpecDecMixinChanges:
dflash_offlineflag toDFlashConfigfor training from pre-computed hidden states; deletes base model layers to save memory.DFlashConfig:_derive_dflash_offline— auto-derivedflash_offlinefromdata_args.offline_data_pathin validation context. Not user-configurable: any user-supplied value is overridden by the derived value._resolve_mask_token_id— auto-detectdflash_mask_token_idfromtokenizer.mask_token_id._check_mask_token_id— fail fast if unset after resolution.HFDFlashModel.modify(): selectnum_orig_hidden_layerswhen offline; pick_base_model_lm_headdevice when no base layers present; drop base-modellayersmodule.HFDFlashModel.forward(): add offline branch — consumes precomputedbase_model_outputsviaDFlashBaseModelOutput.from_offline_dict, and whendflash_self_logit_distillationis enabled withbase_model_logitsabsent, recomputes logits frombase_model_hidden_statesvia_base_model_lm_head. Raises a clear error from the non-training /pseudo_speculative_generatepaths whendflash_offline=True, since base-model layers have been deleted.DFlashBaseModelOutputdataclass inmodeling_dflash.py(withfrom_offline_dictclassmethod) to unify online/offline output shapes.aux_hidden_statesis required infrom_offline_dictso missing keys fail fast at the entry point rather than deeper in the forward.examples/speculative_decoding/main.py: replace inlinemask_token_idauto-detect withDFlashConfig.model_validate(dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args}).Silent bug fix —
add_generation_template→add_generation_promptThe pre-refactor
compute_hidden_states_hf.pypassedadd_generation_template=Falsetotokenizer.apply_chat_template. This kwarg does not exist on HFapply_chat_templateand was being silently ignored, so the intended "don't append a generation prompt" behavior was never actually applied. The newtokenize_with_loss_maskhelper inexamples/speculative_decoding/collect_hidden_states/common.pyuses the correctadd_generation_prompt=False. This is a real behavior change for anyone re-dumping hidden states: trailing generation prompts that were previously appended to the tokenized sequences will no longer be included.Testing
New tests:
tests/unit/torch/speculative/plugins/test_hf_dflash_offline.py— CPU unit tests for convert path (online keeps base layers, offline deletes them;num_orig_hidden_layersdrivestarget_layer_idsin offline mode) andDFlashConfig._derive_dflash_offlinevalidator.TestDFlashOfflineForwardGPUintests/gpu/torch/speculative/plugins/test_hf_dflash.py— GPU forward smoke with precomputedbase_model_outputs, plus thedflash_self_logit_distillationlogit-recompute path.training test:

Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).dflash_offlineflag defaulting toFalse; validators fall through when context not provided.CONTRIBUTING.md: N/ATODO (follow-up)
examples/speculative_decoding/collect_hidden_states/compute_hidden_states_*.pyto support DFlash offline data. Current scripts are Eagle-specific — they hardcode the[2, N/2, N-3]aux-layer selection and emit{input_ids, hidden_states, aux_hidden_states}. DFlash offline needs:build_target_layer_ids(num_orig_hidden_layers, num_draft_layers)(or a configurable list), not the Eagle triplet.base_model_hidden_stateskey (last-layer hidden) soDFlashBaseModelOutput.from_offline_dict+ thedflash_self_logit_distillationrecompute path can consume it.base_model_logitsdump so offline training can skip the self-distillation logit recomputation when logits are available.Additional Information
Base branch is #1296 (file reorg). Retarget to
mainonce #1296 merges.