From fd2e368ad47b80373d2e4444a7de09a33c05f104 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Fri, 19 Jun 2026 21:50:48 +0000 Subject: [PATCH 1/5] squash: d-pace Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- CHANGELOG.rst | 1 + examples/speculative_decoding/doc/dflash.md | 30 +++++++ modelopt/torch/speculative/config.py | 18 +++- .../torch/speculative/dflash/dflash_model.py | 22 +++++ .../torch/speculative/plugins/hf_dflash.py | 84 +++++++++++++++-- .../speculative/plugins/test_hf_dflash.py | 90 +++++++++++++++++++ 6 files changed, 239 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 56380b01b55..4d08b27286d 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -11,6 +11,7 @@ Changelog **New Features** +- Add the **D-PACE** loss objective for DFlash speculative-decoding training (`arXiv:2605.18810 `_). Set ``dflash_loss_objective: dpace`` to replace the static exponential position decay with dynamic, confidence-derived per-position weights that adapt to whichever block positions currently limit acceptance. Smoothing is controlled by ``dflash_dpace_alpha`` (default 0.5). Training-only and detached from the gradient (no architecture or inference change). - Add the ``day0-release`` agent skill (``.agents/skills/day0-release/``), a deterministic end-to-end driver that chains the PTQ → evaluation → comparison skills (the evaluation stage deploys the checkpoint itself) with an enforced gate after each stage and returns a publish decision (ACCEPT / REGRESSION / ANOMALOUS / INFEASIBLE). Ships three GPU-free, unit-tested gate scripts (``gate_ptq.py``, ``gate_run.py``, ``gate_compare.py``) that validate checkpoint coverage, evaluation-run completeness, and baseline-vs-candidate accuracy threshold. v1 reports and stops on regression; the recipe-search loop is deferred. - Add **streaming** speculative-decoding training (EAGLE3 / DFlash): the draft trains on base-model hidden states produced on the fly by a co-located ``vllm serve`` (no disk dump), moved trainer-side over NIXL RDMA, scaling to multi-node (dedicated serve replicas + DDP trainers). New launcher examples for NVFP4 Kimi-K2.5 / K2.6 on GB200/aarch64 under ``tools/launcher/examples/moonshotai/``. - Add a fused Triton fast path for ``local_hessian`` NVFP4 weight-scale search (the Hessian-weighted FP8-E4M3 scale sweep). For each NVFP4 block it minimizes ``dwᵀ H dw`` over the 126 candidate scales using the per-cin-block local Hessian on tensor cores, replacing the per-weight Python reference sweep — roughly **34x** faster on a single 8192x4096 weight and bit-exact with the reference for fp32/fp16 weights. Used automatically during ``local_hessian`` calibration for both dense and fused-MoE expert weights; falls back to the reference sweep on CPU, when Triton is unavailable, or via ``MODELOPT_NVFP4_TRITON_SWEEP=0``. diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index 0150e0884e1..b6a66700cff 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -162,6 +162,8 @@ See [`modelopt_recipes/general/speculative_decoding/dflash.yaml`](../../../model | `dflash.dflash_block_size` | 8 | Block size for parallel prediction | | `dflash.dflash_num_anchors` | 512 | Random anchor positions per sample (see below) | | `dflash.dflash_loss_decay_factor` | 4.0 | Exponential decay gamma (0 disables, see below) | +| `dflash.dflash_loss_objective` | `decay` | Position weighting: `decay` (static) or `dpace` (dynamic, see below) | +| `dflash.dflash_dpace_alpha` | 0.5 | D-PACE smoothing factor in (0, 1]; only used when objective is `dpace` | | `dflash.dflash_self_logit_distillation` | true | Use target model logits as soft labels (vs hard CE) | | `dflash.dflash_mask_token_id` | auto | Token ID for masked positions (see note below) | | `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers | @@ -244,6 +246,34 @@ Note: this is different from EAGLE3's `eagle_loss_decay_factor` which multiplies `alpha^step` across TTT steps. DFlash decay operates within a single block, weighting early positions higher because they gate acceptance of all later positions. +### D-PACE (Dynamic Position-Aware Cross-Entropy) + +Set `dflash.dflash_loss_objective: dpace` to replace the static decay with **D-PACE** +([arXiv:2605.18810](https://arxiv.org/abs/2605.18810)), which derives per-position weights +from a differentiable surrogate of expected accepted block length. Where static decay uses +a fixed schedule, D-PACE adapts to the draft's own per-position confidence and shifts +training signal toward whichever positions currently limit acceptance as the drafter improves. + +For each block, let `q_i = exp(-CE_i)` be the draft confidence on the target token at +predicted position `i`. D-PACE smooths it (Eq.7) and weights each position by the suffix-sum +of prefix products (Eq.8): + +```text +q~_i = (1 - alpha) * q_i + alpha +w_j = sum_{m >= j} prod_{i <= m} q~_i # detached; multiplies the per-token CE +``` + +The weight factors into the prefix-acceptance probability (`prod_{i<=j} q~_i`) times the +remaining accepted-length value, so it directly targets expected accepted length. The +weights are detached from the gradient — D-PACE only reshapes credit assignment and adds +~2.3% training overhead with no change to the draft architecture or inference. + +- `dflash_dpace_alpha` is the asymmetric smoothing floor (`q~_i >= alpha`) that keeps later + weights from vanishing. Stable in `[0.3, 0.7]`; `alpha=0` is rejected (cumulative product + collapses), and `alpha → 1` flattens toward uniform weighting. Default `0.5`. +- D-PACE is mutually exclusive with `dflash_loss_decay_factor`; when objective is `dpace`, + the decay factor is ignored. + ### Checkpoint Resume DFlash supports checkpoint resume transparently. Rotary embeddings are lazily diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 8da7b6ec93e..db8d1b3fe8b 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -103,7 +103,23 @@ class DFlashConfig(ModeloptBaseConfig): dflash_loss_decay_factor: float = ModeloptField( default=0.0, description="Gamma for exponential loss decay weighting (paper Eq.4). " - "Suggested: 7 for block_size=16, 5 for 10, 4 for 8. 0 disables.", + "Suggested: 7 for block_size=16, 5 for 10, 4 for 8. 0 disables. " + "Only used when dflash_loss_objective='decay'.", + ) + + dflash_loss_objective: str = ModeloptField( + default="decay", + description="Block-position loss weighting objective. 'decay' uses the static " + "exponential decay of dflash_loss_decay_factor (DFlash, arXiv:2602.06036 Eq.4). " + "'dpace' uses dynamic, confidence-derived per-position weights " + "(D-PACE, arXiv:2605.18810 Eq.8).", + ) + + dflash_dpace_alpha: float = ModeloptField( + default=0.5, + description="D-PACE asymmetric smoothing factor alpha in (0, 1] (paper Eq.7). Used only " + "when dflash_loss_objective='dpace'. Stable in [0.3, 0.7]; alpha=0 is degenerate " + "(cumulative product vanishes) and alpha->1 removes the adaptive signal.", ) dflash_num_anchors: int = ModeloptField( diff --git a/modelopt/torch/speculative/dflash/dflash_model.py b/modelopt/torch/speculative/dflash/dflash_model.py index 702d9812482..3d499713c50 100644 --- a/modelopt/torch/speculative/dflash/dflash_model.py +++ b/modelopt/torch/speculative/dflash/dflash_model.py @@ -15,8 +15,12 @@ """DFlash model to support block-wise parallel speculative decoding.""" +import logging + from modelopt.torch.opt.dynamic import DynamicModule +logger = logging.getLogger(__name__) + class DFlashModel(DynamicModule): """Base DFlash Model.""" @@ -31,6 +35,24 @@ def modify(self, config): self.dflash_block_size = config.dflash_block_size self.dflash_freeze_base_model = config.dflash_freeze_base_model self.dflash_loss_decay_factor = config.dflash_loss_decay_factor + self.dflash_loss_objective = config.dflash_loss_objective + self.dflash_dpace_alpha = config.dflash_dpace_alpha + if self.dflash_loss_objective not in ("decay", "dpace"): + raise ValueError( + f"dflash_loss_objective must be 'decay' or 'dpace', got " + f"{self.dflash_loss_objective!r}" + ) + if self.dflash_loss_objective == "dpace" and not 0.0 < self.dflash_dpace_alpha <= 1.0: + raise ValueError( + f"dflash_dpace_alpha must be in (0, 1] for the D-PACE objective, got " + f"{self.dflash_dpace_alpha}" + ) + if self.dflash_loss_objective == "dpace" and self.dflash_loss_decay_factor > 0: + logger.warning( + "dflash_loss_decay_factor=%s is ignored when dflash_loss_objective='dpace'; " + "D-PACE derives per-position weights dynamically from draft confidence.", + self.dflash_loss_decay_factor, + ) self.dflash_self_logit_distillation = config.dflash_self_logit_distillation self.dflash_num_anchors = config.dflash_num_anchors self.dflash_report_acc = config.dflash_report_acc diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 9e54ee531ce..cc3b7c94061 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -74,6 +74,55 @@ __all__ = ["HFDFlashModel"] +def _dpace_position_weights( + confidences: torch.Tensor, alpha: float, valid_mask: torch.Tensor | None = None +) -> torch.Tensor: + """Compute detached D-PACE per-position weights from draft confidences. + + Derived from D-PACE (arXiv:2605.18810). The paper factorizes the per-position + weight (Fig. 2 / Eq. 8) into a *cumulative confidence* times a *continuation + value*, which is equivalently the suffix sum of the cumulative confidences:: + + C_j = prod_{i<=j} q~_i # cumulative confidence (Eq. 8) + w_j = sum_{m>=j} C_m # = C_j * continuation value f~_j + + Each confidence is asymmetrically smoothed toward 1 (Eq. 7):: + + q~_i = (1 - alpha) * q_i + alpha, alpha in (0, 1], + + so the floor ``q~_i >= alpha`` keeps every cumulative product (hence every + weight) strictly positive. We evaluate the suffix sum from its definition as + ``total - exclusive_prefix_sum`` of ``C`` rather than reversing the tensor. + Positions with ``valid_mask == 0`` are multiplicative no-ops in ``C`` and + contribute nothing to the sum, matching the per-token loss mask. Weights are + detached (Eq. 9): they reweight the cross-entropy without adding gradient. + + Args: + confidences: ``[..., L]`` draft confidence ``q_i = exp(-CE_i)`` per position. + alpha: smoothing factor in (0, 1]; raises if outside that range. + valid_mask: optional ``[..., L]`` 0/1 mask; ``None`` treats all positions valid. + + Returns: + Detached weights with the same shape and dtype as ``confidences``. + """ + if not 0.0 < alpha <= 1.0: + raise ValueError(f"dflash_dpace_alpha must be in (0, 1], got {alpha}") + + with torch.no_grad(): + smoothed = alpha + (1.0 - alpha) * confidences.float() # Eq. 7 + if valid_mask is not None: + keep = valid_mask.to(torch.bool) + smoothed = torch.where(keep, smoothed, torch.ones_like(smoothed)) + cum_conf = torch.cumprod(smoothed, dim=-1) # Eq. 8 cumulative confidence C_j + if valid_mask is not None: + cum_conf = cum_conf * keep.to(cum_conf.dtype) + # Suffix sum w_j = sum_{m>=j} C_m, written as total minus the exclusive + # prefix sum so no axis reversal is needed (Eq. 8). + inclusive = torch.cumsum(cum_conf, dim=-1) + weights = inclusive[..., -1:] - inclusive + cum_conf + return weights.to(dtype=confidences.dtype) + + @DFlashDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) class HFDFlashModel(DFlashModel): """DFlash Model for HuggingFace transformers.""" @@ -368,14 +417,40 @@ def _compute_loss( binary_eval_mask = weight_mask.view(-1) - # Optional loss decay - if self.dflash_loss_decay_factor > 0: + flat_logits = logits.view(-1, logits.size(-1)) + flat_targets = target_ids.view(-1) + + # Non-KD loss is per-token cross-entropy; compute it once (grad enabled) so the + # D-PACE confidences below can reuse it instead of a second CE pass. The KD path + # (base_logits is not None) optimizes KL, so its confidences need a dedicated + # no_grad CE pass. + loss_per_token = None + if base_logits is None: + loss_per_token = F.cross_entropy(flat_logits, flat_targets, reduction="none") + + # Block-position loss weighting: dynamic D-PACE weights or static exponential decay. + if self.dflash_loss_objective == "dpace" and block_size > 1: + # Draft confidence q_i = exp(-CE) on the target-selected token, over the + # predicted positions (slot 0 is the given anchor, already masked above). + # Weights are detached (paper Eq.9), so this adds the documented ~2.3% + # training overhead without altering the cross-entropy gradient. + with torch.no_grad(): + conf_ce = ( + loss_per_token.detach() + if loss_per_token is not None + else F.cross_entropy(flat_logits, flat_targets, reduction="none") + ).view(bsz, n_blocks, block_size) + confidences = torch.exp(-conf_ce[..., 1:].float()) + dpace = torch.ones_like(weight_mask) + dpace[..., 1:] = _dpace_position_weights( + confidences, self.dflash_dpace_alpha, valid_mask=weight_mask[..., 1:] + ) + weight_mask = weight_mask * dpace + elif self.dflash_loss_decay_factor > 0: k = torch.arange(block_size, device=device).view(1, 1, -1) decay = torch.exp(-(k - 1).clamp(min=0).float() / self.dflash_loss_decay_factor) weight_mask = weight_mask * decay - flat_logits = logits.view(-1, logits.size(-1)) - flat_targets = target_ids.view(-1) flat_weights = weight_mask.view(-1) valid_count = flat_weights.sum() + 1e-6 @@ -394,7 +469,6 @@ def _compute_loss( kd_loss = -(target_soft * draft_logsoft).sum(dim=-1) loss = (kd_loss * flat_weights).sum() / valid_count else: - loss_per_token = F.cross_entropy(flat_logits, flat_targets, reduction="none") loss = (loss_per_token * flat_weights).sum() / valid_count with torch.no_grad(): diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py index e35ac698e76..c27646dd1b1 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_dflash.py +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -24,6 +24,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock +import pytest import torch from _test_utils.torch.transformers_models import ( get_tiny_llama, @@ -38,6 +39,7 @@ DFlashAttention, DFlashModule, HFDFlashModel, + _dpace_position_weights, build_target_layer_ids, ) from modelopt.torch.speculative.utils import AcceptanceRateValidation @@ -116,6 +118,94 @@ def test_convert_sets_mask_token_id(self): assert model.mask_token_id == 0 +class TestDPaceWeights: + """Test the D-PACE position-weighting objective (arXiv:2605.18810).""" + + @staticmethod + def _reference_weights(conf, alpha): + """Paper closed form computed by explicit summation (Eq.7-8). + + q~_i = (1-a)q_i + a; C_m = prod_{i<=m} q~_i; w_j = sum_{m>=j} C_m. + Deliberately a plain double loop so it is an independent oracle for the + vectorized implementation under test. + """ + smoothed = alpha + (1.0 - alpha) * conf + length = smoothed.shape[-1] + cum = torch.ones_like(smoothed) + running = torch.ones(smoothed.shape[:-1]) + for m in range(length): + running = running * smoothed[..., m] + cum[..., m] = running + expected = torch.zeros_like(smoothed) + for j in range(length): + expected[..., j] = cum[..., j:].sum(dim=-1) + return expected + + def test_weights_match_paper_formula(self): + """w_j = sum_{m>=j} prod_{i<=m} q~_i with q~_i = (1-a)q_i + a (Eq.7-8).""" + alpha = 0.5 + conf = torch.tensor([[0.9, 0.6, 0.3, 0.8]]) + weights = _dpace_position_weights(conf, alpha) + assert torch.allclose(weights, self._reference_weights(conf, alpha), atol=1e-6) + + def test_mask_makes_invalid_positions_noops(self): + """Invalid positions neither shrink the prefix product nor add to the sum.""" + alpha = 0.5 + conf = torch.tensor([[0.9, 0.2, 0.3, 0.8]]) + mask = torch.tensor([[1.0, 0.0, 1.0, 1.0]]) + masked = _dpace_position_weights(conf, alpha, valid_mask=mask) + # Dropping the invalid slot entirely must give the same weights at the kept slots. + kept = _dpace_position_weights(conf[:, [0, 2, 3]], alpha) + assert torch.allclose(masked[:, [0, 2, 3]], kept, atol=1e-6) + + def test_weights_are_detached(self): + """Weights must carry no gradient (paper Eq.9 detaches them).""" + conf = torch.rand(2, 3, 5, requires_grad=True) + weights = _dpace_position_weights(conf, 0.5) + assert not weights.requires_grad + + def test_weights_monotonic_nonincreasing(self): + """Suffix-sum of positive prefix products is non-increasing along the block.""" + conf = torch.rand(4, 8).clamp(0.05, 0.99) + weights = _dpace_position_weights(conf, 0.5) + assert torch.all(weights[:, :-1] >= weights[:, 1:] - 1e-6) + + def test_smoothing_keeps_later_weights_nonzero(self): + """With alpha>0, q~_i >= alpha so cumulative products cannot vanish.""" + conf = torch.zeros(1, 6) # worst case: zero confidence everywhere + weights = _dpace_position_weights(conf, alpha=0.5) + assert torch.all(weights > 0) + + def test_invalid_alpha_raises(self): + with pytest.raises(ValueError, match="dflash_dpace_alpha"): + _dpace_position_weights(torch.rand(1, 4), alpha=1.5) + + def test_convert_with_dpace_objective(self): + """Convert with the dpace objective wires the attributes onto the model.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + config["dflash_loss_objective"] = "dpace" + config["dflash_dpace_alpha"] = 0.3 + mtsp.convert(model, [("dflash", config)]) + assert model.dflash_loss_objective == "dpace" + assert model.dflash_dpace_alpha == 0.3 + + def test_convert_rejects_bad_objective(self): + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + config["dflash_loss_objective"] = "nope" + with pytest.raises(ValueError, match="dflash_loss_objective"): + mtsp.convert(model, [("dflash", config)]) + + def test_convert_rejects_degenerate_alpha(self): + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + config["dflash_loss_objective"] = "dpace" + config["dflash_dpace_alpha"] = 0.0 + with pytest.raises(ValueError, match="dflash_dpace_alpha"): + mtsp.convert(model, [("dflash", config)]) + + class TestDFlashSaveRestore: """Test DFlash model save and restore.""" From 79f749e1c2b7693a15e71685e507afc9f19df9d9 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Fri, 19 Jun 2026 22:16:05 +0000 Subject: [PATCH 2/5] set dpace to default Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- CHANGELOG.rst | 2 +- examples/speculative_decoding/doc/dflash.md | 13 +++++++------ modelopt/torch/speculative/config.py | 5 +++-- modelopt/torch/speculative/dflash/dflash_model.py | 5 ----- .../torch/speculative/plugins/test_hf_dflash.py | 7 +++++++ 5 files changed, 18 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 4d08b27286d..9ae73512f52 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -11,7 +11,7 @@ Changelog **New Features** -- Add the **D-PACE** loss objective for DFlash speculative-decoding training (`arXiv:2605.18810 `_). Set ``dflash_loss_objective: dpace`` to replace the static exponential position decay with dynamic, confidence-derived per-position weights that adapt to whichever block positions currently limit acceptance. Smoothing is controlled by ``dflash_dpace_alpha`` (default 0.5). Training-only and detached from the gradient (no architecture or inference change). +- Add the **D-PACE** loss objective for DFlash speculative-decoding training (`arXiv:2605.18810 `_) and make it the default (``dflash_loss_objective: dpace``). It replaces the static exponential position decay with dynamic, confidence-derived per-position weights that adapt to whichever block positions currently limit acceptance. Smoothing is controlled by ``dflash_dpace_alpha`` (default 0.5); set ``dflash_loss_objective: decay`` to restore the previous static schedule. Training-only and detached from the gradient (no architecture or inference change). - Add the ``day0-release`` agent skill (``.agents/skills/day0-release/``), a deterministic end-to-end driver that chains the PTQ → evaluation → comparison skills (the evaluation stage deploys the checkpoint itself) with an enforced gate after each stage and returns a publish decision (ACCEPT / REGRESSION / ANOMALOUS / INFEASIBLE). Ships three GPU-free, unit-tested gate scripts (``gate_ptq.py``, ``gate_run.py``, ``gate_compare.py``) that validate checkpoint coverage, evaluation-run completeness, and baseline-vs-candidate accuracy threshold. v1 reports and stops on regression; the recipe-search loop is deferred. - Add **streaming** speculative-decoding training (EAGLE3 / DFlash): the draft trains on base-model hidden states produced on the fly by a co-located ``vllm serve`` (no disk dump), moved trainer-side over NIXL RDMA, scaling to multi-node (dedicated serve replicas + DDP trainers). New launcher examples for NVFP4 Kimi-K2.5 / K2.6 on GB200/aarch64 under ``tools/launcher/examples/moonshotai/``. - Add a fused Triton fast path for ``local_hessian`` NVFP4 weight-scale search (the Hessian-weighted FP8-E4M3 scale sweep). For each NVFP4 block it minimizes ``dwᵀ H dw`` over the 126 candidate scales using the per-cin-block local Hessian on tensor cores, replacing the per-weight Python reference sweep — roughly **34x** faster on a single 8192x4096 weight and bit-exact with the reference for fp32/fp16 weights. Used automatically during ``local_hessian`` calibration for both dense and fused-MoE expert weights; falls back to the reference sweep on CPU, when Triton is unavailable, or via ``MODELOPT_NVFP4_TRITON_SWEEP=0``. diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index b6a66700cff..c468c7652be 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -162,7 +162,7 @@ See [`modelopt_recipes/general/speculative_decoding/dflash.yaml`](../../../model | `dflash.dflash_block_size` | 8 | Block size for parallel prediction | | `dflash.dflash_num_anchors` | 512 | Random anchor positions per sample (see below) | | `dflash.dflash_loss_decay_factor` | 4.0 | Exponential decay gamma (0 disables, see below) | -| `dflash.dflash_loss_objective` | `decay` | Position weighting: `decay` (static) or `dpace` (dynamic, see below) | +| `dflash.dflash_loss_objective` | `dpace` | Position weighting: `decay` (static) or `dpace` (dynamic, see below) | | `dflash.dflash_dpace_alpha` | 0.5 | D-PACE smoothing factor in (0, 1]; only used when objective is `dpace` | | `dflash.dflash_self_logit_distillation` | true | Use target model logits as soft labels (vs hard CE) | | `dflash.dflash_mask_token_id` | auto | Token ID for masked positions (see note below) | @@ -248,11 +248,12 @@ early positions higher because they gate acceptance of all later positions. ### D-PACE (Dynamic Position-Aware Cross-Entropy) -Set `dflash.dflash_loss_objective: dpace` to replace the static decay with **D-PACE** -([arXiv:2605.18810](https://arxiv.org/abs/2605.18810)), which derives per-position weights -from a differentiable surrogate of expected accepted block length. Where static decay uses -a fixed schedule, D-PACE adapts to the draft's own per-position confidence and shifts -training signal toward whichever positions currently limit acceptance as the drafter improves. +**D-PACE** ([arXiv:2605.18810](https://arxiv.org/abs/2605.18810)) is the default position-weighting +objective (`dflash_loss_objective: dpace`). It derives per-position weights from a differentiable +surrogate of expected accepted block length. Where the static decay above uses a fixed schedule, +D-PACE adapts to the draft's own per-position confidence and shifts training signal toward +whichever positions currently limit acceptance as the drafter improves. Set +`dflash_loss_objective: decay` to fall back to the static schedule. For each block, let `q_i = exp(-CE_i)` be the draft confidence on the target token at predicted position `i`. D-PACE smooths it (Eq.7) and weights each position by the suffix-sum diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index db8d1b3fe8b..3bd6dfb2fb9 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -16,6 +16,7 @@ """Configurations for speculative decoding modes.""" from copy import deepcopy +from typing import Literal from pydantic import model_validator @@ -107,8 +108,8 @@ class DFlashConfig(ModeloptBaseConfig): "Only used when dflash_loss_objective='decay'.", ) - dflash_loss_objective: str = ModeloptField( - default="decay", + dflash_loss_objective: Literal["decay", "dpace"] = ModeloptField( + default="dpace", description="Block-position loss weighting objective. 'decay' uses the static " "exponential decay of dflash_loss_decay_factor (DFlash, arXiv:2602.06036 Eq.4). " "'dpace' uses dynamic, confidence-derived per-position weights " diff --git a/modelopt/torch/speculative/dflash/dflash_model.py b/modelopt/torch/speculative/dflash/dflash_model.py index 3d499713c50..3513d304de1 100644 --- a/modelopt/torch/speculative/dflash/dflash_model.py +++ b/modelopt/torch/speculative/dflash/dflash_model.py @@ -37,11 +37,6 @@ def modify(self, config): self.dflash_loss_decay_factor = config.dflash_loss_decay_factor self.dflash_loss_objective = config.dflash_loss_objective self.dflash_dpace_alpha = config.dflash_dpace_alpha - if self.dflash_loss_objective not in ("decay", "dpace"): - raise ValueError( - f"dflash_loss_objective must be 'decay' or 'dpace', got " - f"{self.dflash_loss_objective!r}" - ) if self.dflash_loss_objective == "dpace" and not 0.0 < self.dflash_dpace_alpha <= 1.0: raise ValueError( f"dflash_dpace_alpha must be in (0, 1] for the D-PACE objective, got " diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py index c27646dd1b1..1dd438f38b8 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_dflash.py +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -180,6 +180,13 @@ def test_invalid_alpha_raises(self): with pytest.raises(ValueError, match="dflash_dpace_alpha"): _dpace_position_weights(torch.rand(1, 4), alpha=1.5) + def test_default_objective_is_dpace(self): + """D-PACE is the default objective (alpha=0.5) when none is specified.""" + model = get_tiny_llama(num_hidden_layers=4) + mtsp.convert(model, [("dflash", _get_dflash_config())]) + assert model.dflash_loss_objective == "dpace" + assert model.dflash_dpace_alpha == 0.5 + def test_convert_with_dpace_objective(self): """Convert with the dpace objective wires the attributes onto the model.""" model = get_tiny_llama(num_hidden_layers=4) From b9d965c2f45e3d960a2182a0d0384de4ecb230c0 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Fri, 19 Jun 2026 22:29:45 +0000 Subject: [PATCH 3/5] address comments Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- modelopt/torch/speculative/config.py | 8 ++++++++ modelopt/torch/speculative/dflash/dflash_model.py | 6 +----- .../general/speculative_decoding/dflash.yaml | 6 +++++- tests/unit/torch/speculative/plugins/test_hf_dflash.py | 9 +++++++++ 4 files changed, 23 insertions(+), 6 deletions(-) diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 3bd6dfb2fb9..0cf08fa6209 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -163,6 +163,14 @@ class DFlashConfig(ModeloptBaseConfig): ), ) + @model_validator(mode="after") + def _check_dpace_alpha(self) -> "DFlashConfig": + # Validate at construction regardless of the active objective, so a bad alpha + # is rejected even if it only becomes active after a later objective override. + if not 0.0 < self.dflash_dpace_alpha <= 1.0: + raise ValueError(f"dflash_dpace_alpha must be in (0, 1], got {self.dflash_dpace_alpha}") + return self + class MedusaConfig(ModeloptBaseConfig): """Medusa config.""" diff --git a/modelopt/torch/speculative/dflash/dflash_model.py b/modelopt/torch/speculative/dflash/dflash_model.py index 3513d304de1..ffffc4739d7 100644 --- a/modelopt/torch/speculative/dflash/dflash_model.py +++ b/modelopt/torch/speculative/dflash/dflash_model.py @@ -37,11 +37,7 @@ def modify(self, config): self.dflash_loss_decay_factor = config.dflash_loss_decay_factor self.dflash_loss_objective = config.dflash_loss_objective self.dflash_dpace_alpha = config.dflash_dpace_alpha - if self.dflash_loss_objective == "dpace" and not 0.0 < self.dflash_dpace_alpha <= 1.0: - raise ValueError( - f"dflash_dpace_alpha must be in (0, 1] for the D-PACE objective, got " - f"{self.dflash_dpace_alpha}" - ) + # dflash_dpace_alpha range is validated on DFlashConfig at construction time. if self.dflash_loss_objective == "dpace" and self.dflash_loss_decay_factor > 0: logger.warning( "dflash_loss_decay_factor=%s is ignored when dflash_loss_objective='dpace'; " diff --git a/modelopt_recipes/general/speculative_decoding/dflash.yaml b/modelopt_recipes/general/speculative_decoding/dflash.yaml index 021cccd475d..c4577ce0e02 100644 --- a/modelopt_recipes/general/speculative_decoding/dflash.yaml +++ b/modelopt_recipes/general/speculative_decoding/dflash.yaml @@ -56,7 +56,11 @@ dflash: dflash_num_anchors: 512 dflash_use_torch_compile: false dflash_self_logit_distillation: true - dflash_loss_decay_factor: 4.0 + # Position-loss weighting defaults to D-PACE (dynamic, confidence-derived). + dflash_dpace_alpha: 0.5 + # To use the static exponential decay instead, uncomment: + # dflash_loss_objective: decay + # dflash_loss_decay_factor: 4.0 # gamma: 7 for block_size=16, 5 for 10, 4 for 8 dflash_architecture_config: num_hidden_layers: 5 # mask_token_id: auto-detected from model vocab (override for specific models) diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py index 1dd438f38b8..7bfc357c1fe 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_dflash.py +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -141,6 +141,15 @@ def _reference_weights(conf, alpha): expected[..., j] = cum[..., j:].sum(dim=-1) return expected + def test_weights_match_hand_computed(self): + """Pin Eq.7-8 to a hand-worked example, independent of the implementation. + + conf=[0.8, 0.5], alpha=0.5 -> q~=[0.9, 0.75] -> prefix=[0.9, 0.675] + -> w=[0.9+0.675, 0.675]=[1.575, 0.675]. + """ + weights = _dpace_position_weights(torch.tensor([[0.8, 0.5]]), alpha=0.5) + assert torch.allclose(weights, torch.tensor([[1.575, 0.675]]), atol=1e-6) + def test_weights_match_paper_formula(self): """w_j = sum_{m>=j} prod_{i<=m} q~_i with q~_i = (1-a)q_i + a (Eq.7-8).""" alpha = 0.5 From 5e7744e135ae81b133ca999a5e1259eff46b94f7 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Fri, 19 Jun 2026 23:56:22 +0000 Subject: [PATCH 4/5] revert recipe change Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- modelopt_recipes/general/speculative_decoding/dflash.yaml | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/modelopt_recipes/general/speculative_decoding/dflash.yaml b/modelopt_recipes/general/speculative_decoding/dflash.yaml index c4577ce0e02..021cccd475d 100644 --- a/modelopt_recipes/general/speculative_decoding/dflash.yaml +++ b/modelopt_recipes/general/speculative_decoding/dflash.yaml @@ -56,11 +56,7 @@ dflash: dflash_num_anchors: 512 dflash_use_torch_compile: false dflash_self_logit_distillation: true - # Position-loss weighting defaults to D-PACE (dynamic, confidence-derived). - dflash_dpace_alpha: 0.5 - # To use the static exponential decay instead, uncomment: - # dflash_loss_objective: decay - # dflash_loss_decay_factor: 4.0 # gamma: 7 for block_size=16, 5 for 10, 4 for 8 + dflash_loss_decay_factor: 4.0 dflash_architecture_config: num_hidden_layers: 5 # mask_token_id: auto-detected from model vocab (override for specific models) From e8a0bcdd7331e3eac20b39e99fd4d59580e18d49 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sat, 20 Jun 2026 01:06:48 +0000 Subject: [PATCH 5/5] update tests Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../speculative/plugins/test_hf_dflash.py | 93 +++++++++++++------ 1 file changed, 66 insertions(+), 27 deletions(-) diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py index 7bfc357c1fe..cdaba1766e3 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_dflash.py +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -19,6 +19,7 @@ """ import json +import logging import os from copy import deepcopy from types import SimpleNamespace @@ -141,21 +142,18 @@ def _reference_weights(conf, alpha): expected[..., j] = cum[..., j:].sum(dim=-1) return expected - def test_weights_match_hand_computed(self): - """Pin Eq.7-8 to a hand-worked example, independent of the implementation. + def test_weights_match_paper_formula(self): + """Eq.7-8, pinned both to a hand-worked value and an independent loop oracle. conf=[0.8, 0.5], alpha=0.5 -> q~=[0.9, 0.75] -> prefix=[0.9, 0.675] -> w=[0.9+0.675, 0.675]=[1.575, 0.675]. """ - weights = _dpace_position_weights(torch.tensor([[0.8, 0.5]]), alpha=0.5) - assert torch.allclose(weights, torch.tensor([[1.575, 0.675]]), atol=1e-6) - - def test_weights_match_paper_formula(self): - """w_j = sum_{m>=j} prod_{i<=m} q~_i with q~_i = (1-a)q_i + a (Eq.7-8).""" - alpha = 0.5 + hand = _dpace_position_weights(torch.tensor([[0.8, 0.5]]), alpha=0.5) + assert torch.allclose(hand, torch.tensor([[1.575, 0.675]]), atol=1e-6) conf = torch.tensor([[0.9, 0.6, 0.3, 0.8]]) - weights = _dpace_position_weights(conf, alpha) - assert torch.allclose(weights, self._reference_weights(conf, alpha), atol=1e-6) + assert torch.allclose( + _dpace_position_weights(conf, 0.5), self._reference_weights(conf, 0.5), atol=1e-6 + ) def test_mask_makes_invalid_positions_noops(self): """Invalid positions neither shrink the prefix product nor add to the sum.""" @@ -173,37 +171,21 @@ def test_weights_are_detached(self): weights = _dpace_position_weights(conf, 0.5) assert not weights.requires_grad - def test_weights_monotonic_nonincreasing(self): - """Suffix-sum of positive prefix products is non-increasing along the block.""" - conf = torch.rand(4, 8).clamp(0.05, 0.99) - weights = _dpace_position_weights(conf, 0.5) - assert torch.all(weights[:, :-1] >= weights[:, 1:] - 1e-6) - - def test_smoothing_keeps_later_weights_nonzero(self): - """With alpha>0, q~_i >= alpha so cumulative products cannot vanish.""" - conf = torch.zeros(1, 6) # worst case: zero confidence everywhere - weights = _dpace_position_weights(conf, alpha=0.5) - assert torch.all(weights > 0) - def test_invalid_alpha_raises(self): with pytest.raises(ValueError, match="dflash_dpace_alpha"): _dpace_position_weights(torch.rand(1, 4), alpha=1.5) def test_default_objective_is_dpace(self): - """D-PACE is the default objective (alpha=0.5) when none is specified.""" + """D-PACE is the default (alpha=0.5); an explicit alpha override is wired through.""" model = get_tiny_llama(num_hidden_layers=4) mtsp.convert(model, [("dflash", _get_dflash_config())]) assert model.dflash_loss_objective == "dpace" assert model.dflash_dpace_alpha == 0.5 - def test_convert_with_dpace_objective(self): - """Convert with the dpace objective wires the attributes onto the model.""" model = get_tiny_llama(num_hidden_layers=4) config = _get_dflash_config() - config["dflash_loss_objective"] = "dpace" config["dflash_dpace_alpha"] = 0.3 mtsp.convert(model, [("dflash", config)]) - assert model.dflash_loss_objective == "dpace" assert model.dflash_dpace_alpha == 0.3 def test_convert_rejects_bad_objective(self): @@ -221,6 +203,63 @@ def test_convert_rejects_degenerate_alpha(self): with pytest.raises(ValueError, match="dflash_dpace_alpha"): mtsp.convert(model, [("dflash", config)]) + def test_convert_dpace_with_decay_factor_warns(self, caplog): + """dpace + a non-zero decay factor converts but warns that decay is ignored.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + config["dflash_loss_objective"] = "dpace" + config["dflash_loss_decay_factor"] = 4.0 + with caplog.at_level(logging.WARNING): + mtsp.convert(model, [("dflash", config)]) + assert any("dflash_loss_decay_factor" in r.message for r in caplog.records) + + +class TestDPaceLossIntegration: + """Exercise the _compute_loss block-weighting branches on CPU.""" + + @staticmethod + def _make_inputs(vocab=32, seq_len=SEQ_LEN, n_blocks=2): + """Synthetic CPU inputs for _compute_loss (no model forward needed).""" + bsz = 1 + logits = torch.randn(bsz, n_blocks * BLOCK_SIZE, vocab) + input_ids = torch.randint(0, vocab, (bsz, seq_len)) + anchor_positions = torch.tensor([[0, BLOCK_SIZE]])[:, :n_blocks] + block_keep_mask = torch.ones(bsz, n_blocks) + loss_mask = torch.ones(bsz, seq_len) + return logits, input_ids, anchor_positions, block_keep_mask, loss_mask + + def _converted_model(self, objective, **overrides): + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + config["dflash_loss_objective"] = objective + config.update(overrides) + mtsp.convert(model, [("dflash", config)]) + return model + + def test_compute_loss_dpace_branch(self): + """Default dpace objective produces a finite loss and valid accuracy.""" + model = self._converted_model("dpace") + loss, acc = model._compute_loss(*self._make_inputs()) + assert torch.isfinite(loss).item() and loss.item() > 0 + assert 0.0 <= acc <= 1.0 + + def test_compute_loss_decay_branch(self): + """The static-decay objective path also produces a finite loss.""" + model = self._converted_model("decay", dflash_loss_decay_factor=4.0) + loss, acc = model._compute_loss(*self._make_inputs()) + assert torch.isfinite(loss).item() and loss.item() > 0 + assert 0.0 <= acc <= 1.0 + + def test_compute_loss_dpace_kd_branch(self): + """dpace + KD (base_logits given): confidences use a dedicated no_grad CE pass.""" + vocab = 32 + model = self._converted_model("dpace") + inputs = self._make_inputs(vocab=vocab) + base_logits = torch.randn(1, SEQ_LEN, vocab) + loss, acc = model._compute_loss(*inputs, base_logits=base_logits) + assert torch.isfinite(loss).item() + assert 0.0 <= acc <= 1.0 + class TestDFlashSaveRestore: """Test DFlash model save and restore."""