diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 56380b01b55..9ae73512f52 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -11,6 +11,7 @@ Changelog **New Features** +- Add the **D-PACE** loss objective for DFlash speculative-decoding training (`arXiv:2605.18810 `_) and make it the default (``dflash_loss_objective: dpace``). It replaces the static exponential position decay with dynamic, confidence-derived per-position weights that adapt to whichever block positions currently limit acceptance. Smoothing is controlled by ``dflash_dpace_alpha`` (default 0.5); set ``dflash_loss_objective: decay`` to restore the previous static schedule. Training-only and detached from the gradient (no architecture or inference change). - Add the ``day0-release`` agent skill (``.agents/skills/day0-release/``), a deterministic end-to-end driver that chains the PTQ → evaluation → comparison skills (the evaluation stage deploys the checkpoint itself) with an enforced gate after each stage and returns a publish decision (ACCEPT / REGRESSION / ANOMALOUS / INFEASIBLE). Ships three GPU-free, unit-tested gate scripts (``gate_ptq.py``, ``gate_run.py``, ``gate_compare.py``) that validate checkpoint coverage, evaluation-run completeness, and baseline-vs-candidate accuracy threshold. v1 reports and stops on regression; the recipe-search loop is deferred. - Add **streaming** speculative-decoding training (EAGLE3 / DFlash): the draft trains on base-model hidden states produced on the fly by a co-located ``vllm serve`` (no disk dump), moved trainer-side over NIXL RDMA, scaling to multi-node (dedicated serve replicas + DDP trainers). New launcher examples for NVFP4 Kimi-K2.5 / K2.6 on GB200/aarch64 under ``tools/launcher/examples/moonshotai/``. - Add a fused Triton fast path for ``local_hessian`` NVFP4 weight-scale search (the Hessian-weighted FP8-E4M3 scale sweep). For each NVFP4 block it minimizes ``dwᵀ H dw`` over the 126 candidate scales using the per-cin-block local Hessian on tensor cores, replacing the per-weight Python reference sweep — roughly **34x** faster on a single 8192x4096 weight and bit-exact with the reference for fp32/fp16 weights. Used automatically during ``local_hessian`` calibration for both dense and fused-MoE expert weights; falls back to the reference sweep on CPU, when Triton is unavailable, or via ``MODELOPT_NVFP4_TRITON_SWEEP=0``. diff --git a/examples/speculative_decoding/doc/dflash.md b/examples/speculative_decoding/doc/dflash.md index 0150e0884e1..c468c7652be 100644 --- a/examples/speculative_decoding/doc/dflash.md +++ b/examples/speculative_decoding/doc/dflash.md @@ -162,6 +162,8 @@ See [`modelopt_recipes/general/speculative_decoding/dflash.yaml`](../../../model | `dflash.dflash_block_size` | 8 | Block size for parallel prediction | | `dflash.dflash_num_anchors` | 512 | Random anchor positions per sample (see below) | | `dflash.dflash_loss_decay_factor` | 4.0 | Exponential decay gamma (0 disables, see below) | +| `dflash.dflash_loss_objective` | `dpace` | Position weighting: `decay` (static) or `dpace` (dynamic, see below) | +| `dflash.dflash_dpace_alpha` | 0.5 | D-PACE smoothing factor in (0, 1]; only used when objective is `dpace` | | `dflash.dflash_self_logit_distillation` | true | Use target model logits as soft labels (vs hard CE) | | `dflash.dflash_mask_token_id` | auto | Token ID for masked positions (see note below) | | `dflash.dflash_architecture_config.num_hidden_layers` | 5 | Draft decoder layers | @@ -244,6 +246,35 @@ Note: this is different from EAGLE3's `eagle_loss_decay_factor` which multiplies `alpha^step` across TTT steps. DFlash decay operates within a single block, weighting early positions higher because they gate acceptance of all later positions. +### D-PACE (Dynamic Position-Aware Cross-Entropy) + +**D-PACE** ([arXiv:2605.18810](https://arxiv.org/abs/2605.18810)) is the default position-weighting +objective (`dflash_loss_objective: dpace`). It derives per-position weights from a differentiable +surrogate of expected accepted block length. Where the static decay above uses a fixed schedule, +D-PACE adapts to the draft's own per-position confidence and shifts training signal toward +whichever positions currently limit acceptance as the drafter improves. Set +`dflash_loss_objective: decay` to fall back to the static schedule. + +For each block, let `q_i = exp(-CE_i)` be the draft confidence on the target token at +predicted position `i`. D-PACE smooths it (Eq.7) and weights each position by the suffix-sum +of prefix products (Eq.8): + +```text +q~_i = (1 - alpha) * q_i + alpha +w_j = sum_{m >= j} prod_{i <= m} q~_i # detached; multiplies the per-token CE +``` + +The weight factors into the prefix-acceptance probability (`prod_{i<=j} q~_i`) times the +remaining accepted-length value, so it directly targets expected accepted length. The +weights are detached from the gradient — D-PACE only reshapes credit assignment and adds +~2.3% training overhead with no change to the draft architecture or inference. + +- `dflash_dpace_alpha` is the asymmetric smoothing floor (`q~_i >= alpha`) that keeps later + weights from vanishing. Stable in `[0.3, 0.7]`; `alpha=0` is rejected (cumulative product + collapses), and `alpha → 1` flattens toward uniform weighting. Default `0.5`. +- D-PACE is mutually exclusive with `dflash_loss_decay_factor`; when objective is `dpace`, + the decay factor is ignored. + ### Checkpoint Resume DFlash supports checkpoint resume transparently. Rotary embeddings are lazily diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 8da7b6ec93e..0cf08fa6209 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -16,6 +16,7 @@ """Configurations for speculative decoding modes.""" from copy import deepcopy +from typing import Literal from pydantic import model_validator @@ -103,7 +104,23 @@ class DFlashConfig(ModeloptBaseConfig): dflash_loss_decay_factor: float = ModeloptField( default=0.0, description="Gamma for exponential loss decay weighting (paper Eq.4). " - "Suggested: 7 for block_size=16, 5 for 10, 4 for 8. 0 disables.", + "Suggested: 7 for block_size=16, 5 for 10, 4 for 8. 0 disables. " + "Only used when dflash_loss_objective='decay'.", + ) + + dflash_loss_objective: Literal["decay", "dpace"] = ModeloptField( + default="dpace", + description="Block-position loss weighting objective. 'decay' uses the static " + "exponential decay of dflash_loss_decay_factor (DFlash, arXiv:2602.06036 Eq.4). " + "'dpace' uses dynamic, confidence-derived per-position weights " + "(D-PACE, arXiv:2605.18810 Eq.8).", + ) + + dflash_dpace_alpha: float = ModeloptField( + default=0.5, + description="D-PACE asymmetric smoothing factor alpha in (0, 1] (paper Eq.7). Used only " + "when dflash_loss_objective='dpace'. Stable in [0.3, 0.7]; alpha=0 is degenerate " + "(cumulative product vanishes) and alpha->1 removes the adaptive signal.", ) dflash_num_anchors: int = ModeloptField( @@ -146,6 +163,14 @@ class DFlashConfig(ModeloptBaseConfig): ), ) + @model_validator(mode="after") + def _check_dpace_alpha(self) -> "DFlashConfig": + # Validate at construction regardless of the active objective, so a bad alpha + # is rejected even if it only becomes active after a later objective override. + if not 0.0 < self.dflash_dpace_alpha <= 1.0: + raise ValueError(f"dflash_dpace_alpha must be in (0, 1], got {self.dflash_dpace_alpha}") + return self + class MedusaConfig(ModeloptBaseConfig): """Medusa config.""" diff --git a/modelopt/torch/speculative/dflash/dflash_model.py b/modelopt/torch/speculative/dflash/dflash_model.py index 702d9812482..ffffc4739d7 100644 --- a/modelopt/torch/speculative/dflash/dflash_model.py +++ b/modelopt/torch/speculative/dflash/dflash_model.py @@ -15,8 +15,12 @@ """DFlash model to support block-wise parallel speculative decoding.""" +import logging + from modelopt.torch.opt.dynamic import DynamicModule +logger = logging.getLogger(__name__) + class DFlashModel(DynamicModule): """Base DFlash Model.""" @@ -31,6 +35,15 @@ def modify(self, config): self.dflash_block_size = config.dflash_block_size self.dflash_freeze_base_model = config.dflash_freeze_base_model self.dflash_loss_decay_factor = config.dflash_loss_decay_factor + self.dflash_loss_objective = config.dflash_loss_objective + self.dflash_dpace_alpha = config.dflash_dpace_alpha + # dflash_dpace_alpha range is validated on DFlashConfig at construction time. + if self.dflash_loss_objective == "dpace" and self.dflash_loss_decay_factor > 0: + logger.warning( + "dflash_loss_decay_factor=%s is ignored when dflash_loss_objective='dpace'; " + "D-PACE derives per-position weights dynamically from draft confidence.", + self.dflash_loss_decay_factor, + ) self.dflash_self_logit_distillation = config.dflash_self_logit_distillation self.dflash_num_anchors = config.dflash_num_anchors self.dflash_report_acc = config.dflash_report_acc diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 9e54ee531ce..cc3b7c94061 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -74,6 +74,55 @@ __all__ = ["HFDFlashModel"] +def _dpace_position_weights( + confidences: torch.Tensor, alpha: float, valid_mask: torch.Tensor | None = None +) -> torch.Tensor: + """Compute detached D-PACE per-position weights from draft confidences. + + Derived from D-PACE (arXiv:2605.18810). The paper factorizes the per-position + weight (Fig. 2 / Eq. 8) into a *cumulative confidence* times a *continuation + value*, which is equivalently the suffix sum of the cumulative confidences:: + + C_j = prod_{i<=j} q~_i # cumulative confidence (Eq. 8) + w_j = sum_{m>=j} C_m # = C_j * continuation value f~_j + + Each confidence is asymmetrically smoothed toward 1 (Eq. 7):: + + q~_i = (1 - alpha) * q_i + alpha, alpha in (0, 1], + + so the floor ``q~_i >= alpha`` keeps every cumulative product (hence every + weight) strictly positive. We evaluate the suffix sum from its definition as + ``total - exclusive_prefix_sum`` of ``C`` rather than reversing the tensor. + Positions with ``valid_mask == 0`` are multiplicative no-ops in ``C`` and + contribute nothing to the sum, matching the per-token loss mask. Weights are + detached (Eq. 9): they reweight the cross-entropy without adding gradient. + + Args: + confidences: ``[..., L]`` draft confidence ``q_i = exp(-CE_i)`` per position. + alpha: smoothing factor in (0, 1]; raises if outside that range. + valid_mask: optional ``[..., L]`` 0/1 mask; ``None`` treats all positions valid. + + Returns: + Detached weights with the same shape and dtype as ``confidences``. + """ + if not 0.0 < alpha <= 1.0: + raise ValueError(f"dflash_dpace_alpha must be in (0, 1], got {alpha}") + + with torch.no_grad(): + smoothed = alpha + (1.0 - alpha) * confidences.float() # Eq. 7 + if valid_mask is not None: + keep = valid_mask.to(torch.bool) + smoothed = torch.where(keep, smoothed, torch.ones_like(smoothed)) + cum_conf = torch.cumprod(smoothed, dim=-1) # Eq. 8 cumulative confidence C_j + if valid_mask is not None: + cum_conf = cum_conf * keep.to(cum_conf.dtype) + # Suffix sum w_j = sum_{m>=j} C_m, written as total minus the exclusive + # prefix sum so no axis reversal is needed (Eq. 8). + inclusive = torch.cumsum(cum_conf, dim=-1) + weights = inclusive[..., -1:] - inclusive + cum_conf + return weights.to(dtype=confidences.dtype) + + @DFlashDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) class HFDFlashModel(DFlashModel): """DFlash Model for HuggingFace transformers.""" @@ -368,14 +417,40 @@ def _compute_loss( binary_eval_mask = weight_mask.view(-1) - # Optional loss decay - if self.dflash_loss_decay_factor > 0: + flat_logits = logits.view(-1, logits.size(-1)) + flat_targets = target_ids.view(-1) + + # Non-KD loss is per-token cross-entropy; compute it once (grad enabled) so the + # D-PACE confidences below can reuse it instead of a second CE pass. The KD path + # (base_logits is not None) optimizes KL, so its confidences need a dedicated + # no_grad CE pass. + loss_per_token = None + if base_logits is None: + loss_per_token = F.cross_entropy(flat_logits, flat_targets, reduction="none") + + # Block-position loss weighting: dynamic D-PACE weights or static exponential decay. + if self.dflash_loss_objective == "dpace" and block_size > 1: + # Draft confidence q_i = exp(-CE) on the target-selected token, over the + # predicted positions (slot 0 is the given anchor, already masked above). + # Weights are detached (paper Eq.9), so this adds the documented ~2.3% + # training overhead without altering the cross-entropy gradient. + with torch.no_grad(): + conf_ce = ( + loss_per_token.detach() + if loss_per_token is not None + else F.cross_entropy(flat_logits, flat_targets, reduction="none") + ).view(bsz, n_blocks, block_size) + confidences = torch.exp(-conf_ce[..., 1:].float()) + dpace = torch.ones_like(weight_mask) + dpace[..., 1:] = _dpace_position_weights( + confidences, self.dflash_dpace_alpha, valid_mask=weight_mask[..., 1:] + ) + weight_mask = weight_mask * dpace + elif self.dflash_loss_decay_factor > 0: k = torch.arange(block_size, device=device).view(1, 1, -1) decay = torch.exp(-(k - 1).clamp(min=0).float() / self.dflash_loss_decay_factor) weight_mask = weight_mask * decay - flat_logits = logits.view(-1, logits.size(-1)) - flat_targets = target_ids.view(-1) flat_weights = weight_mask.view(-1) valid_count = flat_weights.sum() + 1e-6 @@ -394,7 +469,6 @@ def _compute_loss( kd_loss = -(target_soft * draft_logsoft).sum(dim=-1) loss = (kd_loss * flat_weights).sum() / valid_count else: - loss_per_token = F.cross_entropy(flat_logits, flat_targets, reduction="none") loss = (loss_per_token * flat_weights).sum() / valid_count with torch.no_grad(): diff --git a/tests/unit/torch/speculative/plugins/test_hf_dflash.py b/tests/unit/torch/speculative/plugins/test_hf_dflash.py index e35ac698e76..cdaba1766e3 100644 --- a/tests/unit/torch/speculative/plugins/test_hf_dflash.py +++ b/tests/unit/torch/speculative/plugins/test_hf_dflash.py @@ -19,11 +19,13 @@ """ import json +import logging import os from copy import deepcopy from types import SimpleNamespace from unittest.mock import MagicMock +import pytest import torch from _test_utils.torch.transformers_models import ( get_tiny_llama, @@ -38,6 +40,7 @@ DFlashAttention, DFlashModule, HFDFlashModel, + _dpace_position_weights, build_target_layer_ids, ) from modelopt.torch.speculative.utils import AcceptanceRateValidation @@ -116,6 +119,148 @@ def test_convert_sets_mask_token_id(self): assert model.mask_token_id == 0 +class TestDPaceWeights: + """Test the D-PACE position-weighting objective (arXiv:2605.18810).""" + + @staticmethod + def _reference_weights(conf, alpha): + """Paper closed form computed by explicit summation (Eq.7-8). + + q~_i = (1-a)q_i + a; C_m = prod_{i<=m} q~_i; w_j = sum_{m>=j} C_m. + Deliberately a plain double loop so it is an independent oracle for the + vectorized implementation under test. + """ + smoothed = alpha + (1.0 - alpha) * conf + length = smoothed.shape[-1] + cum = torch.ones_like(smoothed) + running = torch.ones(smoothed.shape[:-1]) + for m in range(length): + running = running * smoothed[..., m] + cum[..., m] = running + expected = torch.zeros_like(smoothed) + for j in range(length): + expected[..., j] = cum[..., j:].sum(dim=-1) + return expected + + def test_weights_match_paper_formula(self): + """Eq.7-8, pinned both to a hand-worked value and an independent loop oracle. + + conf=[0.8, 0.5], alpha=0.5 -> q~=[0.9, 0.75] -> prefix=[0.9, 0.675] + -> w=[0.9+0.675, 0.675]=[1.575, 0.675]. + """ + hand = _dpace_position_weights(torch.tensor([[0.8, 0.5]]), alpha=0.5) + assert torch.allclose(hand, torch.tensor([[1.575, 0.675]]), atol=1e-6) + conf = torch.tensor([[0.9, 0.6, 0.3, 0.8]]) + assert torch.allclose( + _dpace_position_weights(conf, 0.5), self._reference_weights(conf, 0.5), atol=1e-6 + ) + + def test_mask_makes_invalid_positions_noops(self): + """Invalid positions neither shrink the prefix product nor add to the sum.""" + alpha = 0.5 + conf = torch.tensor([[0.9, 0.2, 0.3, 0.8]]) + mask = torch.tensor([[1.0, 0.0, 1.0, 1.0]]) + masked = _dpace_position_weights(conf, alpha, valid_mask=mask) + # Dropping the invalid slot entirely must give the same weights at the kept slots. + kept = _dpace_position_weights(conf[:, [0, 2, 3]], alpha) + assert torch.allclose(masked[:, [0, 2, 3]], kept, atol=1e-6) + + def test_weights_are_detached(self): + """Weights must carry no gradient (paper Eq.9 detaches them).""" + conf = torch.rand(2, 3, 5, requires_grad=True) + weights = _dpace_position_weights(conf, 0.5) + assert not weights.requires_grad + + def test_invalid_alpha_raises(self): + with pytest.raises(ValueError, match="dflash_dpace_alpha"): + _dpace_position_weights(torch.rand(1, 4), alpha=1.5) + + def test_default_objective_is_dpace(self): + """D-PACE is the default (alpha=0.5); an explicit alpha override is wired through.""" + model = get_tiny_llama(num_hidden_layers=4) + mtsp.convert(model, [("dflash", _get_dflash_config())]) + assert model.dflash_loss_objective == "dpace" + assert model.dflash_dpace_alpha == 0.5 + + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + config["dflash_dpace_alpha"] = 0.3 + mtsp.convert(model, [("dflash", config)]) + assert model.dflash_dpace_alpha == 0.3 + + def test_convert_rejects_bad_objective(self): + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + config["dflash_loss_objective"] = "nope" + with pytest.raises(ValueError, match="dflash_loss_objective"): + mtsp.convert(model, [("dflash", config)]) + + def test_convert_rejects_degenerate_alpha(self): + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + config["dflash_loss_objective"] = "dpace" + config["dflash_dpace_alpha"] = 0.0 + with pytest.raises(ValueError, match="dflash_dpace_alpha"): + mtsp.convert(model, [("dflash", config)]) + + def test_convert_dpace_with_decay_factor_warns(self, caplog): + """dpace + a non-zero decay factor converts but warns that decay is ignored.""" + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + config["dflash_loss_objective"] = "dpace" + config["dflash_loss_decay_factor"] = 4.0 + with caplog.at_level(logging.WARNING): + mtsp.convert(model, [("dflash", config)]) + assert any("dflash_loss_decay_factor" in r.message for r in caplog.records) + + +class TestDPaceLossIntegration: + """Exercise the _compute_loss block-weighting branches on CPU.""" + + @staticmethod + def _make_inputs(vocab=32, seq_len=SEQ_LEN, n_blocks=2): + """Synthetic CPU inputs for _compute_loss (no model forward needed).""" + bsz = 1 + logits = torch.randn(bsz, n_blocks * BLOCK_SIZE, vocab) + input_ids = torch.randint(0, vocab, (bsz, seq_len)) + anchor_positions = torch.tensor([[0, BLOCK_SIZE]])[:, :n_blocks] + block_keep_mask = torch.ones(bsz, n_blocks) + loss_mask = torch.ones(bsz, seq_len) + return logits, input_ids, anchor_positions, block_keep_mask, loss_mask + + def _converted_model(self, objective, **overrides): + model = get_tiny_llama(num_hidden_layers=4) + config = _get_dflash_config() + config["dflash_loss_objective"] = objective + config.update(overrides) + mtsp.convert(model, [("dflash", config)]) + return model + + def test_compute_loss_dpace_branch(self): + """Default dpace objective produces a finite loss and valid accuracy.""" + model = self._converted_model("dpace") + loss, acc = model._compute_loss(*self._make_inputs()) + assert torch.isfinite(loss).item() and loss.item() > 0 + assert 0.0 <= acc <= 1.0 + + def test_compute_loss_decay_branch(self): + """The static-decay objective path also produces a finite loss.""" + model = self._converted_model("decay", dflash_loss_decay_factor=4.0) + loss, acc = model._compute_loss(*self._make_inputs()) + assert torch.isfinite(loss).item() and loss.item() > 0 + assert 0.0 <= acc <= 1.0 + + def test_compute_loss_dpace_kd_branch(self): + """dpace + KD (base_logits given): confidences use a dedicated no_grad CE pass.""" + vocab = 32 + model = self._converted_model("dpace") + inputs = self._make_inputs(vocab=vocab) + base_logits = torch.randn(1, SEQ_LEN, vocab) + loss, acc = model._compute_loss(*inputs, base_logits=base_logits) + assert torch.isfinite(loss).item() + assert 0.0 <= acc <= 1.0 + + class TestDFlashSaveRestore: """Test DFlash model save and restore."""