From eff68218bb497256ba76edf4d70daa3f19998092 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 10 Jun 2026 16:53:53 -0700 Subject: [PATCH 1/2] up --- examples/models/eagle3/export.py | 317 ++++++++++++++--- examples/models/eagle3/quant_mlx.py | 59 ++++ examples/models/eagle3/run_mlx.py | 323 ++++++++++++++++++ examples/models/eagle3/target.py | 13 +- .../models/eagle3/test_run_mlx_sampling.py | 59 ++++ .../gemma4_31b/mlx_source_transformations.py | 48 +++ 6 files changed, 769 insertions(+), 50 deletions(-) create mode 100644 examples/models/eagle3/quant_mlx.py create mode 100644 examples/models/eagle3/run_mlx.py create mode 100644 examples/models/eagle3/test_run_mlx_sampling.py diff --git a/examples/models/eagle3/export.py b/examples/models/eagle3/export.py index c3a4decf8f4..57730bb6541 100644 --- a/examples/models/eagle3/export.py +++ b/examples/models/eagle3/export.py @@ -27,11 +27,20 @@ (``test_shifted_speculative_decode_is_lossless`` drives the full loop through only these three methods). -Export runs with the model on the host (CPU); AOTInductor streams weights to the -GPU per kernel during compilation, so peak GPU memory stays low even for the INT4 -31B target. The target is loaded from a prequantized (INT4) directory and the -draft from a vLLM-speculator checkpoint; only the CUDA (AOTI) backend is -supported. +Export runs with the model on the host (CPU). For ``--backend cuda`` AOTInductor +streams weights to the GPU per kernel during compilation, so peak GPU memory +stays low even for the INT4 31B target. The target is loaded from a prequantized +(INT4) directory and the draft from a vLLM-speculator checkpoint. + +Backends: + - ``cuda`` (AOTI): three methods (prefill, target_verify, draft_decode) sharing + KV caches by FQN; bf16 draft. + - ``mlx`` (Apple silicon): MLX has no cross-method KV-cache sharing, so prefill + and verify are merged into one dynamic-seq ``target_forward`` (sharing the + target cache within a single handle) plus ``draft_decode``; the draft is bf16 + by default (``--quantize-draft`` for int4). Both methods return logits (target + soft-capped + draft) so the host applies temperature / sampling — greedy is + host-side argmax. Scope (this is a fixed-shape ExecuTorch artifact, not a generic EAGLE runtime): chain length, the chain_len+1 verify window, the prefill/draft dynamic ranges, @@ -95,6 +104,35 @@ def forward(self, tokens, feature, input_pos): return self.spec.draft_decode(tokens, feature, input_pos) +# Logit-returning variants for the MLX sampling path: the host applies +# temperature + modified rejection sampling, so the methods return distributions +# (soft-capped target logits / draft logits) instead of the greedy argmax. Greedy +# (--temperature 0) just argmaxes these host-side. + + +class _TargetForwardLogits(nn.Module): + def __init__(self, spec: Eagle3Speculator): + super().__init__() + self.spec = spec + + def forward(self, tokens, input_pos): + logits, taps = self.spec.target.forward_logits_taps( + tokens, input_pos, last_logits_only=False + ) + return logits, self.spec.draft.fuse(taps) + + +class _DraftDecodeLogits(nn.Module): + def __init__(self, spec: Eagle3Speculator): + super().__init__() + self.spec = spec + + def forward(self, tokens, feature, input_pos): + emb = self.spec.draft.embed(tokens) + draft_logits, g = self.spec.draft.forward_cached(emb, feature, input_pos) + return draft_logits, g + + def _export_cuda( spec: Eagle3Speculator, output_dir: str, @@ -253,39 +291,140 @@ def _partitioner(name: str): print("Done.") -def main() -> None: - p = argparse.ArgumentParser(description="Export an EAGLE-3 speculator to .pte.") - p.add_argument( - "--target-model", - default="gemma4_31b", - choices=list(TARGETS), - help="Registered target model (see eagle3/target.py).", +def _export_mlx( + spec: Eagle3Speculator, + output_dir: str, + max_prefill: int, + chain_len: int, + share_base_embedding: bool = False, +) -> None: + import executorch.backends.mlx.custom_kernel_ops.gguf.patterns # noqa: F401 + import executorch.extension.llm.export.gguf # noqa: F401 + import executorch.extension.llm.export.int4 # noqa: F401 + from executorch.backends.mlx import MLXPartitioner + from executorch.backends.mlx.passes import get_default_passes + from executorch.examples.models.gemma4_31b.mlx_source_transformations import ( + install_mlx_tap_forward, + mlx_source_transformations, ) - p.add_argument( - "--target", required=True, help="Prequantized (INT4) target directory." + from executorch.examples.models.gemma4_31b.model import materialize_runtime_buffers + from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, ) - p.add_argument("--draft", required=True, help="EAGLE-3 draft head directory.") - p.add_argument("--output-dir", default="./eagle3_exports") - p.add_argument("--max-seq-len", type=int, default=4096) - p.add_argument( - "--max-prefill", - type=int, - default=512, - help="Max prefill length: AOTI compiles prefill kernels for up to this T " - "and the whole prompt must fit in one prefill (the runner does not chunk). " - "Smaller compiles faster.", + from executorch.exir.passes import MemoryPlanningPass + from torch.export import Dim, export + + target_config = spec.target.config + hidden = spec.draft.config.hidden_size + draft_vocab_size = spec.draft.config.draft_vocab_size + + # MLX rewrites the target to mask-free layers + MLX KV caches; install a + # matching mask-free tap forward so the speculator's target methods trace. + mlx_source_transformations(spec.target, dtype=torch.bfloat16) + install_mlx_tap_forward(spec.target) + materialize_runtime_buffers(spec.target, dtype=torch.bfloat16) + + if share_base_embedding: + # Point the draft at the target's packed embedding so both methods emit + # identical bytes; the NamedDataStore then content-dedups them to one + # copy. Safe because the draft embed is a frozen copy of the target's and + # the draft's input_layernorm (RMSNorm) is invariant to the embed scale. + spec.draft.embed_tokens = spec.target.embed_tokens + + # MLX has no cross-method KV-cache sharing, so prefill and verify are one + # dynamic-seq method that shares the target cache within a single handle. The + # method returns per-position logits; the host samples (or argmaxes). + print(f"Exporting target_forward (T in [1, {max_prefill}])...") + target_dim = Dim("target_len", min=1, max=max_prefill) + with torch.no_grad(): + target_ep = export( + _TargetForwardLogits(spec), + ( + torch.zeros((1, max_prefill), dtype=torch.long), + torch.arange(max_prefill, dtype=torch.long), + ), + dynamic_shapes=({1: target_dim}, {0: target_dim}), + strict=True, + ) + + draft_max = max(max_prefill, chain_len + 1) + print(f"Exporting draft_decode (T in [1, {draft_max}])...") + draft_dim = Dim("draft_len", min=1, max=draft_max) + with torch.no_grad(): + draft_ep = export( + _DraftDecodeLogits(spec), + ( + torch.zeros((1, draft_max), dtype=torch.long), + torch.zeros((1, draft_max, hidden), dtype=torch.bfloat16), + torch.arange(draft_max, dtype=torch.long), + ), + dynamic_shapes=({1: draft_dim}, {1: draft_dim}, {0: draft_dim}), + strict=True, + ) + + del spec + gc.collect() + + print("Lowering to ExecuTorch with MLX backend...") + et_prog = to_edge_transform_and_lower( + {"target_forward": target_ep, "draft_decode": draft_ep}, + transform_passes=get_default_passes(), + partitioner={ + "target_forward": [MLXPartitioner()], + "draft_decode": [MLXPartitioner()], + }, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods={ + "get_max_seq_len": target_config.max_seq_len, + "get_vocab_size": target_config.vocab_size, + "get_n_layers": target_config.num_hidden_layers, + "get_max_prefill_chunk": max_prefill, + "get_min_prefill_chunk": 1, + "get_chain_len": chain_len, + "get_draft_vocab_size": draft_vocab_size, + "use_kv_cache": True, + "enable_dynamic_shape": True, + }, ) - p.add_argument( - "--chain", type=int, default=4, help="Draft chain length K (verify K+1)." + del target_ep, draft_ep + gc.collect() + + et_program = et_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), ) - args = p.parse_args() + del et_prog + gc.collect() - spec_t = TARGETS[args.target_model] - if not torch.cuda.is_available(): - p.error("CUDA is required to compile the EAGLE-3 export.") + os.makedirs(output_dir, exist_ok=True) + pte_path = os.path.join(output_dir, "model.pte") + print(f"Saving to {pte_path}...") + with open(pte_path, "wb") as f: + et_program.write_to_file(f) + print(f" {os.path.getsize(pte_path) / 1024**2:.1f} MB") + if et_program._tensor_data: + et_program.write_tensor_data_to_file(output_dir) + print(f" Saved tensor data (.ptd) to {output_dir}/") + print("Done.") + + +def _validate_backend_flags(p, args) -> None: + if args.share_draft_embedding and args.backend != "mlx": + p.error("--share-draft-embedding is only supported with --backend mlx.") + if args.quantize_draft and args.backend != "mlx": + p.error("--quantize-draft is only supported with --backend mlx.") - print(f"Loading {args.target_model} target from {args.target}...") - target = spec_t.load(args.target, args.max_seq_len) + +def _load_target_and_draft(p, args, spec_t): + print(f"Loading {args.target_model} target ({args.backend}) from {args.target}...") + target = spec_t.load(args.target, args.max_seq_len, args.backend) print(f"Loading draft head from {args.draft}...") draft, _ = Eagle3Draft.from_checkpoint( @@ -308,12 +447,38 @@ def main() -> None: f"[{int(target_ids.min())}, {int(target_ids.max())}]; the draft and " f"target are likely not a matched pair" ) + return target, draft + +def _run_mlx(p, args, target, draft, max_prefill, verify_len) -> None: + if args.quantize_draft: + from executorch.examples.models.eagle3.quant_mlx import ( + quantize_pack_draft_for_mlx, + ) + + print("Quantizing + packing draft for MLX (int4)...") + draft = quantize_pack_draft_for_mlx(draft) + else: + print("Keeping draft in bf16 (pass --quantize-draft for int4)...") + # MLX builds attention masks internally, so a single forward accepts T>=1. + if max_prefill < verify_len: + p.error( + f"computed max_prefill={max_prefill} < verify window {verify_len}; " + f"raise --max-prefill (got {args.max_prefill}) or --max-seq-len " + f"(got {args.max_seq_len})" + ) spec = Eagle3Speculator(target, draft).eval() + _export_mlx( + spec, + args.output_dir, + max_prefill=max_prefill, + chain_len=args.chain, + share_base_embedding=args.share_draft_embedding, + ) - # A single target forward accepts min_forward_len .. max_forward_len tokens. - max_forward = spec_t.max_forward_len(target.config) - max_prefill = min(args.max_prefill, args.max_seq_len - 1, max_forward) + +def _run_cuda(p, args, spec_t, target, draft, max_prefill, verify_len) -> None: + spec = Eagle3Speculator(target, draft).eval() # prefill's dynamic min (see _export_cuda target_min): the target's own # specialization (min_forward_len) and the INT4 dispatch (> MATVEC_MAX_M). prefill_min = max(spec_t.min_forward_len, _MATVEC_MAX_M + 1) @@ -324,9 +489,7 @@ def main() -> None: f"{args.max_seq_len})" ) # target_verify is a single static forward of chain+1 tokens: it must fit the - # small-M GEMM (chain+1 <= _MATVEC_MAX_M) and the target's per-forward bounds - # [min_forward_len, max_forward]. - verify_len = args.chain + 1 + # small-M GEMM (chain+1 <= _MATVEC_MAX_M) and the target's minimum forward. if verify_len > _MATVEC_MAX_M: p.error( f"--chain {args.chain} (verify window {verify_len}) exceeds the " @@ -337,14 +500,8 @@ def main() -> None: f"--chain {args.chain} (verify window {verify_len}) is below the " f"target's minimum forward length {spec_t.min_forward_len}" ) - if verify_len > min(args.max_seq_len - 1, max_forward): - p.error( - f"--chain {args.chain} (verify window {verify_len}) exceeds the " - f"target's per-forward limit {min(args.max_seq_len - 1, max_forward)}" - ) # Route the static chain_len+1 verify forward to the small-M INT4 GEMM by - # raising the dispatch threshold for this export only; restore it so the - # process-global default (4) is unchanged for any later use. + # raising the dispatch threshold for this export only; restore it after. import executorch.backends.cuda.int4_dispatch as int4_dispatch saved_threshold = int4_dispatch.MATVEC_MAX_M @@ -361,5 +518,75 @@ def main() -> None: int4_dispatch.MATVEC_MAX_M = saved_threshold +def main() -> None: + p = argparse.ArgumentParser(description="Export an EAGLE-3 speculator to .pte.") + p.add_argument( + "--target-model", + default="gemma4_31b", + choices=list(TARGETS), + help="Registered target model (see eagle3/target.py).", + ) + p.add_argument( + "--target", required=True, help="Prequantized (INT4) target directory." + ) + p.add_argument("--draft", required=True, help="EAGLE-3 draft head directory.") + p.add_argument("--output-dir", default="./eagle3_exports") + p.add_argument("--max-seq-len", type=int, default=4096) + p.add_argument( + "--max-prefill", + type=int, + default=512, + help="Max prefill length: AOTI compiles prefill kernels for up to this T " + "and the whole prompt must fit in one prefill (the runner does not chunk). " + "Smaller compiles faster.", + ) + p.add_argument( + "--chain", type=int, default=4, help="Draft chain length K (verify K+1)." + ) + p.add_argument( + "--backend", + default="cuda", + choices=["cuda", "mlx"], + help="Target backend: cuda (AOTI, INT4 target, bf16 draft) or mlx " + "(Apple silicon, INT4 target; bf16 draft, --quantize-draft for int4).", + ) + p.add_argument( + "--share-draft-embedding", + action="store_true", + help="MLX only: reuse the target's packed embedding for the draft so the " + "NamedDataStore dedups it to one copy (drops the draft's own embedding).", + ) + p.add_argument( + "--quantize-draft", + action="store_true", + help="MLX only: int4-pack the draft head (default keeps it bf16). bf16 " + "gives higher acceptance but a larger draft; pair with " + "--share-draft-embedding to avoid a separate draft embedding copy.", + ) + args = p.parse_args() + _validate_backend_flags(p, args) + + spec_t = TARGETS[args.target_model] + if args.backend == "cuda" and not torch.cuda.is_available(): + p.error("CUDA is required to compile the CUDA EAGLE-3 export.") + + target, draft = _load_target_and_draft(p, args, spec_t) + + # A single target forward accepts up to max_forward_len tokens. + max_forward = spec_t.max_forward_len(target.config) + max_prefill = min(args.max_prefill, args.max_seq_len - 1, max_forward) + verify_len = args.chain + 1 + if verify_len > min(args.max_seq_len - 1, max_forward): + p.error( + f"--chain {args.chain} (verify window {verify_len}) exceeds the " + f"target's per-forward limit {min(args.max_seq_len - 1, max_forward)}" + ) + + if args.backend == "mlx": + _run_mlx(p, args, target, draft, max_prefill, verify_len) + else: + _run_cuda(p, args, spec_t, target, draft, max_prefill, verify_len) + + if __name__ == "__main__": main() diff --git a/examples/models/eagle3/quant_mlx.py b/examples/models/eagle3/quant_mlx.py new file mode 100644 index 00000000000..ba55e3535ee --- /dev/null +++ b/examples/models/eagle3/quant_mlx.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Int4 quantization + MLX packing for the EAGLE-3 draft head. + +Mirrors the target's MLX recipe: int4 group-32 linears, int8 per-axis embedding, +norms left in bf16. ``quantize_model`` only captures persistent state, so the +draft's non-persistent ``d2t``/``t2d`` buffers and KV cache are restored here. +""" + +import torch + +from executorch.examples.models.eagle3.draft import Eagle3Draft +from executorch.examples.models.gemma4_31b.quant import ( + DEFAULT_MLX_PACKERS, + pack_model, + QuantConfig, + QuantRecipe, + QuantRule, + quantize_model, +) + + +def _draft_recipe(hidden_size: int, group_size: int) -> QuantRecipe: + int4 = QuantConfig( + bits=4, group_size=group_size, symmetric=False, method="min_max" + ) + int8_per_axis = QuantConfig( + bits=8, group_size=hidden_size, symmetric=True, method="min_max" + ) + return QuantRecipe( + rules=[ + QuantRule(r"embed_tokens\.weight", int8_per_axis), + QuantRule(r".*norm\.weight", None), + QuantRule(r".*\.weight", int4), + ] + ) + + +def quantize_pack_draft_for_mlx( + draft: Eagle3Draft, + dtype: torch.dtype = torch.bfloat16, + group_size: int = 32, +) -> Eagle3Draft: + """Return a new MLX-packed int4 draft (linears int4, embedding int8, norms bf16).""" + config = draft.config + state = quantize_model(draft, _draft_recipe(config.hidden_size, group_size), dtype) + + packed = Eagle3Draft(config) + pack_model(packed, state, packers=DEFAULT_MLX_PACKERS) + # quantize_model skips non-persistent buffers; carry the vocab maps over and + # allocate the KV cache in the compute dtype. + packed.register_buffer("d2t", draft.d2t.clone(), persistent=False) + packed.register_buffer("t2d", draft.t2d.clone(), persistent=False) + packed.allocate_kv_cache(dtype, device="cpu") + return packed.eval() diff --git a/examples/models/eagle3/run_mlx.py b/examples/models/eagle3/run_mlx.py new file mode 100644 index 00000000000..0b59157c0f5 --- /dev/null +++ b/examples/models/eagle3/run_mlx.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Run an exported MLX EAGLE-3 ``.pte`` (gemma4-31B target + draft) on device. + +The MLX export exposes two methods that each own a persistent KV cache within +their handle (MLX has no cross-method sharing) and return *distributions* so the +host applies temperature + (rejection) sampling: + - ``target_forward(tokens, input_pos) -> (logits[1,T,V], fused_feature[1,T,H])`` + - ``draft_decode(tokens, feature, input_pos) -> (draft_logits[1,T,Vd], g[1,T,H])`` + +The shifted (vLLM-EAGLE) loop from ``eagle3/test_speculator.py`` runs over these +two methods, mapping the test's ``prefill``/``target_verify`` onto the merged +``target_forward``. + +``--temperature 0`` is greedy: proposals/verify use argmax and the output equals +the target's greedy decode (lossless, token-exact). ``--temperature > 0`` uses +modified rejection sampling (Leviathan/Chen), which reproduces the target's +sampling distribution. The draft's reduced vocab is bridged to the target vocab +via ``d2t`` (target_id = draft_id + d2t[draft_id]); the rejection residual is +taken over the full target vocab so tokens the draft cannot propose are still +reachable. + +Each run loads a fresh program so the KV caches start zeroed. + +Usage: + python -m executorch.examples.models.eagle3.run_mlx \\ + --pte ./eagle3_mlx/model.pte \\ + --tokenizer-path ./gemma-4-31B-it-HQQ-INT4/tokenizer.json \\ + --draft "$SPEC" \\ + --prompt "Write a short joke about saving RAM." \\ + --chain 3 --num-gen 128 --temperature 0.8 --seed 0 +""" + +import argparse +import os +import time + +import torch + +from executorch.examples.models.gemma4_31b.inference import apply_chat_template + +EOS_TOKEN_IDS = {1, 50, 106} +BOS_TOKEN_ID = 2 + + +def _toks(ids: list[int]) -> torch.Tensor: + return torch.tensor([ids], dtype=torch.long) + + +def _target_forward(method, token_ids, positions): + logits, feat = method.execute( + [_toks(token_ids), torch.tensor(positions, dtype=torch.long)] + ) + return logits, feat + + +def _pick(row_logits, temp, gen): + """Greedy argmax (temp==0) or a temperature sample. Returns (index, probs).""" + if temp == 0.0: + return int(row_logits.argmax()), None + probs = torch.softmax(row_logits.float() / temp, dim=-1) + return int(torch.multinomial(probs, 1, generator=gen)), probs + + +def _propose_chain( + draft_method, seed_tokens, seed_feat, seed_pos, chain_len, temp, d2t, gen +): + """Seed the draft (last slot predicts proposal 0), then chain_len-1 steps. + + Returns (proposals, draft_ids, q_list) where ``proposals`` are target-vocab + ids; ``q_list`` holds the draft distribution per proposal (None when greedy). + """ + dl, g = draft_method.execute([seed_tokens, seed_feat, seed_pos]) + proposals, draft_ids, qs = [], [], [] + d, q = _pick(dl[0, -1], temp, gen) + proposals.append(d + int(d2t[d])) + draft_ids.append(d) + qs.append(q) + last = int(seed_pos[-1]) + tok, feat = _toks([proposals[-1]]), g[:, -1:] + for k in range(1, chain_len): + dl, g = draft_method.execute( + [tok, feat, torch.tensor([last + k], dtype=torch.long)] + ) + d, q = _pick(dl[0, 0], temp, gen) + proposals.append(d + int(d2t[d])) + draft_ids.append(d) + qs.append(q) + tok, feat = _toks([proposals[-1]]), g + return proposals, draft_ids, qs + + +def _accept(p_logits, proposal_x, draft_d, q, temp, target_ids_all, gen): + """Accept one proposal (greedy or modified rejection sampling). + + Returns (accepted, fallback_token): on rejection ``fallback_token`` is the + corrected token (greedy argmax, or a residual sample over the target vocab). + """ + if temp == 0.0: + greedy = int(p_logits.argmax()) + return (proposal_x == greedy), greedy + p = torch.softmax(p_logits.float() / temp, dim=-1) + qx = q[draft_d] + ratio = (p[proposal_x] / qx).clamp(max=1.0) + if torch.rand((), generator=gen) <= ratio: + return True, None + # Resample from the residual (p - q)_+ over the full target vocab; the draft + # contributes q only on its reachable target ids. + q_target = torch.zeros_like(p) + q_target[target_ids_all] = q + resid = (p - q_target).clamp(min=0) + resid = resid / resid.sum() + return False, int(torch.multinomial(resid, 1, generator=gen)) + + +def speculative_decode( + target_method, draft_method, prompt, chain_len, num_gen, temp, d2t, gen +): + """Shifted one-target-forward-per-round speculative decode (greedy or sampling). + + Returns (generated, num_steps, num_accepted, prefill_s) where prefill_s is the + time of the initial prompt forward, so the caller can report decode-only tok/s. + """ + target_ids_all = torch.arange(d2t.numel(), dtype=torch.long) + d2t + L = len(prompt) + _t = time.perf_counter() + logits, feat_prompt = _target_forward(target_method, prompt, list(range(L))) + prefill_s = time.perf_counter() - _t + anchor, _ = _pick(logits[0, -1], temp, gen) + anchor_pos = L + emitted = [anchor] + num_steps = num_accepted = 0 + if anchor in EOS_TOKEN_IDS: + return emitted, num_steps, num_accepted, prefill_s + + proposals, draft_ids, qs = _propose_chain( + draft_method, + _toks(prompt[1:] + [anchor]), + feat_prompt, + torch.arange(L, dtype=torch.long), + chain_len, + temp, + d2t, + gen, + ) + + while len(emitted) < num_gen: + num_steps += 1 + vlogits, vfeat = _target_forward( + target_method, + [anchor] + proposals, + list(range(anchor_pos, anchor_pos + chain_len + 1)), + ) + a = 0 + corrected = None + for j in range(chain_len): + accepted, fallback = _accept( + vlogits[0, j], proposals[j], draft_ids[j], qs[j], temp, target_ids_all, gen + ) + if accepted: + a += 1 + else: + corrected = fallback + break + if corrected is None: # whole chain accepted -> bonus from the next dist + corrected, _ = _pick(vlogits[0, chain_len], temp, gen) + num_accepted += a + + new = proposals[:a] + [corrected] + eos_pos = next((i for i, t in enumerate(new) if t in EOS_TOKEN_IDS), None) + if eos_pos is not None: + new = new[: eos_pos + 1] + new = new[: num_gen - len(emitted)] + emitted += new + if eos_pos is not None or len(emitted) >= num_gen: + break + + proposals, draft_ids, qs = _propose_chain( + draft_method, + _toks(proposals[:a] + [corrected]), + vfeat[:, : a + 1], + torch.arange(anchor_pos, anchor_pos + a + 1, dtype=torch.long), + chain_len, + temp, + d2t, + gen, + ) + anchor, anchor_pos = corrected, anchor_pos + 1 + a + + return emitted[:num_gen], num_steps, num_accepted, prefill_s + + +def greedy_decode(target_method, prompt, num_gen): + """Greedy baseline using only ``target_forward`` (the lossless reference). + + Returns (generated, prefill_s). + """ + L = len(prompt) + _t = time.perf_counter() + logits, _ = _target_forward(target_method, prompt, list(range(L))) + prefill_s = time.perf_counter() - _t + tok, pos, out = int(logits[0, -1].argmax()), L, [] + out.append(tok) + while len(out) < num_gen and tok not in EOS_TOKEN_IDS: + logits, _ = _target_forward(target_method, [tok], [pos]) + tok, pos = int(logits[0, -1].argmax()), pos + 1 + out.append(tok) + return out[:num_gen], prefill_s + + +def _load_d2t(draft_dir: str) -> torch.Tensor: + from safetensors.torch import load_file + + return load_file(os.path.join(draft_dir, "model.safetensors"))["d2t"].to(torch.long) + + +def _encode(tokenizer, prompt, raw_prompt): + text = prompt if raw_prompt else apply_chat_template(prompt) + ids = tokenizer.encode(text).ids + if not ids or ids[0] != BOS_TOKEN_ID: + ids = [BOS_TOKEN_ID] + ids + return ids + + +def _program_const(program, name, default): + """Read a scalar constant_method from the program; fall back to default.""" + try: + return int(program.load_method(name).execute([])[0]) + except Exception: + return default + + +def main() -> None: + p = argparse.ArgumentParser(description="Run an MLX EAGLE-3 .pte (gemma4-31B).") + p.add_argument("--pte", required=True, help="Path to the exported model.pte.") + p.add_argument("--tokenizer-path", required=True) + p.add_argument( + "--draft", + default=None, + help="EAGLE-3 draft dir (for d2t); required for --mode speculative.", + ) + p.add_argument("--prompt", default="Write a short joke about saving RAM.") + p.add_argument("--raw-prompt", action="store_true") + p.add_argument("--num-gen", type=int, default=128) + p.add_argument( + "--chain", + type=int, + default=None, + help="Draft chain length K (default: the export-time get_chain_len).", + ) + p.add_argument( + "--temperature", + type=float, + default=0.0, + help="0 = greedy (lossless); >0 = modified rejection sampling.", + ) + p.add_argument("--seed", type=int, default=0, help="RNG seed for sampling.") + p.add_argument( + "--mode", + default="speculative", + choices=["speculative", "greedy"], + help="speculative = draft+verify; greedy = target-only baseline.", + ) + args = p.parse_args() + if args.mode == "speculative" and not args.draft: + p.error("--mode speculative requires --draft (for the d2t vocab map).") + + from executorch.runtime import Runtime, Verification + from tokenizers import Tokenizer + + tokenizer = Tokenizer.from_file(args.tokenizer_path) + prompt_ids = _encode(tokenizer, args.prompt, args.raw_prompt) + + program = Runtime.get().load_program(args.pte, verification=Verification.Minimal) + target_method = program.load_method("target_forward") + chain = args.chain if args.chain is not None else _program_const(program, "get_chain_len", 3) + + print(f"\nPrompt: {args.prompt}\n" + "-" * 40) + t0 = time.perf_counter() + if args.mode == "greedy": + generated, prefill_s = greedy_decode(target_method, prompt_ids, args.num_gen) + else: + gen = torch.Generator().manual_seed(args.seed) + d2t = _load_d2t(args.draft) + draft_method = program.load_method("draft_decode") + generated, num_steps, num_accepted, prefill_s = speculative_decode( + target_method, + draft_method, + prompt_ids, + chain, + args.num_gen, + args.temperature, + d2t, + gen, + ) + if num_steps: + accept_len = (num_accepted + num_steps) / num_steps + print( + f" speculative (chain={chain}, temp={args.temperature}): " + f"{num_steps} target steps, {num_accepted} draft tokens accepted, " + f"mean accept length {accept_len:.2f}" + ) + elapsed = time.perf_counter() - t0 + # Decode tok/s excludes the prompt prefill (matches the CUDA runner). + decode_s = max(elapsed - prefill_s, 1e-9) + + print(tokenizer.decode(generated)) + print("-" * 40) + print(f"token ids: {generated}") + print( + f"prefill: {len(prompt_ids)} tokens in {prefill_s:.2f}s | " + f"decode: {len(generated) / decode_s:.2f} tok/s " + f"({len(generated)} tokens in {decode_s:.2f}s)" + ) + print(f"Generated in {elapsed:.2f}s") + + +if __name__ == "__main__": + main() diff --git a/examples/models/eagle3/target.py b/examples/models/eagle3/target.py index 30fb68126fb..788fe81fad2 100644 --- a/examples/models/eagle3/target.py +++ b/examples/models/eagle3/target.py @@ -61,9 +61,10 @@ def forward_logits_taps( class TargetSpec: """How to load a target for export and its export-shape constraints.""" - # (target_dir, max_seq_len) -> a CPU TapTarget with runtime buffers - # materialized (export keeps the model on the host). - load: Callable[[str, int], TapTarget] + # (target_dir, max_seq_len, backend) -> a CPU TapTarget with runtime buffers + # materialized (export keeps the model on the host). ``backend`` selects the + # weight packing ("cuda" or "mlx"). + load: Callable[..., TapTarget] # config -> max tokens accepted in one target forward (e.g. a sliding ring # buffer caps it at 2*window; a flat-cache model uses max_seq_len-1). max_forward_len: Callable[[Any], int] @@ -73,12 +74,14 @@ class TargetSpec: min_forward_len: int -def _load_gemma4_31b(target_dir: str, max_seq_len: int) -> TapTarget: +def _load_gemma4_31b( + target_dir: str, max_seq_len: int, backend: str = "cuda" +) -> TapTarget: from executorch.examples.models.gemma4_31b.export import load_prequantized_model from executorch.examples.models.gemma4_31b.model import materialize_runtime_buffers target, _ = load_prequantized_model( - target_dir, max_seq_len=max_seq_len, backend="cuda" + target_dir, max_seq_len=max_seq_len, backend=backend ) materialize_runtime_buffers(target, dtype=torch.bfloat16, device="cpu") return target.eval() diff --git a/examples/models/eagle3/test_run_mlx_sampling.py b/examples/models/eagle3/test_run_mlx_sampling.py new file mode 100644 index 00000000000..1ff7a49d2d3 --- /dev/null +++ b/examples/models/eagle3/test_run_mlx_sampling.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Distributional test for the MLX runner's modified rejection sampling. + +One speculative-sampling step (propose from the draft ``q``, then accept/reject +against the target ``p`` with a residual resample over the full target vocab) +must reproduce ``p`` exactly, including when the draft's reduced vocab covers +only some target ids. Monte-Carlo checks the implemented ``_accept`` path. +""" + +import torch + +from executorch.examples.models.eagle3.run_mlx import _accept + + +def test_rejection_sampling_reproduces_target_distribution(): + gen = torch.Generator().manual_seed(0) + + V, Vd = 8, 5 + # Draft reaches target ids [0, 3, 4, 6, 7]; {1, 2, 5} are draft-unreachable. + d2t = torch.tensor([0, 2, 2, 3, 3], dtype=torch.long) + target_ids_all = torch.arange(Vd, dtype=torch.long) + d2t + assert target_ids_all.tolist() == [0, 3, 4, 6, 7] + + p = torch.softmax(torch.randn(V, generator=gen), dim=-1) + q = torch.softmax(torch.randn(Vd, generator=gen), dim=-1) + p_logits = p.log() # softmax(log p) == p at temp 1 + + n = 60000 + counts = torch.zeros(V) + for _ in range(n): + d = int(torch.multinomial(q, 1, generator=gen)) + x0 = int(target_ids_all[d]) + accepted, fallback = _accept(p_logits, x0, d, q, 1.0, target_ids_all, gen) + counts[x0 if accepted else fallback] += 1 + + empirical = counts / n + assert torch.max(torch.abs(empirical - p)) < 0.02, ( + f"empirical={empirical.tolist()} target={p.tolist()}" + ) + + +def test_greedy_accept_is_exact_match(): + target_ids_all = torch.arange(4, dtype=torch.long) + logits = torch.tensor([0.1, 5.0, 0.2, 0.3]) # argmax == 1 + accepted, corrected = _accept(logits, 1, 0, None, 0.0, target_ids_all, None) + assert accepted and corrected == 1 + accepted, corrected = _accept(logits, 2, 0, None, 0.0, target_ids_all, None) + assert (not accepted) and corrected == 1 + + +if __name__ == "__main__": + import pytest + + raise SystemExit(pytest.main([__file__, "-q"])) diff --git a/examples/models/gemma4_31b/mlx_source_transformations.py b/examples/models/gemma4_31b/mlx_source_transformations.py index 0bbd4f7b250..28210948bbc 100644 --- a/examples/models/gemma4_31b/mlx_source_transformations.py +++ b/examples/models/gemma4_31b/mlx_source_transformations.py @@ -217,3 +217,51 @@ def mlx_source_transformations( _replace_layer_forward(layer) _replace_model_forward(model) + + +def install_mlx_tap_forward(model: nn.Module) -> None: + """Install a mask-free EAGLE-3 tap forward for the MLX path. + + ``Gemma4_31B.forward_logits_taps`` / ``_decode`` call layers with the 4-arg + mask-based signature, which is incompatible with the mask-free ``(x, + input_pos)`` layers produced by :func:`mlx_source_transformations`. This + installs an MLX-compatible ``forward_logits_taps`` that mirrors ``_decode``'s + tap-index convention (index 0 = embedding output; index k = output after + decoder layer k-1) but uses the mask-free layer signature, so the EAGLE-3 + speculator's target methods trace through the MLX layers. + + Call after :func:`mlx_source_transformations` (which rewrites the layers). + """ + import types + + from executorch.examples.models.gemma4_31b.model import validate_eagle_tap_layers + + def _mlx_forward_logits_taps(self, tokens, input_pos, last_logits_only=True): + x = self.embed_tokens(tokens) * self.embed_normalizer + + tap_layers = self.config.eagle_tap_layers + validate_eagle_tap_layers(tap_layers, len(self.layers)) + taps = [] + if 0 in tap_layers: + taps.append(x) # index 0 == embedding output + for i, layer in enumerate(self.layers): + x = layer(x, input_pos) # mask-free MLX layer + if (i + 1) in tap_layers: + taps.append(x) # output of layer i == hidden-state index i+1 + + if len(taps) != len(tap_layers): + raise ValueError( + f"collected {len(taps)} taps but eagle_tap_layers requests " + f"{len(tap_layers)} ({tap_layers}); check the index convention" + ) + + x = self.norm(x) + if last_logits_only: + x = x[:, -1:, :] + logits = self.lm_head(x).float() + cap = self.logit_softcap.float() + logits = torch.tanh(logits / cap) * cap + taps_out = torch.cat(taps, dim=-1) if taps else None + return logits, taps_out + + model.forward_logits_taps = types.MethodType(_mlx_forward_logits_taps, model) From 263bc33f09110cfe76bead79ef2f6bd9e563efca Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 10 Jun 2026 23:33:21 -0700 Subject: [PATCH 2/2] up --- Makefile | 22 +- examples/models/eagle3/CMakeLists.txt | 13 +- examples/models/eagle3/CMakePresets.json | 39 ++ examples/models/eagle3/export.py | 7 +- examples/models/eagle3/main.cpp | 536 +++++++++++++++--- examples/models/eagle3/quant_mlx.py | 6 +- examples/models/eagle3/run_mlx.py | 14 +- .../models/eagle3/test_run_mlx_sampling.py | 6 +- 8 files changed, 558 insertions(+), 85 deletions(-) diff --git a/Makefile b/Makefile index c54c56906ae..2b0e11cdb77 100644 --- a/Makefile +++ b/Makefile @@ -91,7 +91,7 @@ # # ============================================================================== -.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu lfm_2_5-mlx llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda gemma4_31b-mlx qwen3_5_moe-cuda qwen3_5_moe-metal clean help +.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu lfm_2_5-mlx llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda gemma4_31b-mlx eagle3-cuda eagle3-mlx qwen3_5_moe-cuda qwen3_5_moe-metal clean help help: @echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make \`. Available targets:" @@ -129,6 +129,8 @@ help: @echo " gemma3-cpu - Build Gemma3 runner with CPU backend" @echo " gemma4_31b-cuda - Build Gemma 4 31B runner with CUDA backend" @echo " gemma4_31b-mlx - Build Gemma 4 31B runner with MLX backend" + @echo " eagle3-cuda - Build EAGLE-3 speculator runner with CUDA backend" + @echo " eagle3-mlx - Build EAGLE-3 speculator runner with MLX backend" @echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend" @echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend" @echo " clean - Clean build artifacts" @@ -457,6 +459,24 @@ gemma4_31b-mlx: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner" +eagle3-cuda: + @echo "==> Building and installing ExecuTorch with CUDA..." + cmake --workflow --preset llm-release-cuda + @echo "==> Building EAGLE-3 speculator runner with CUDA..." + cd examples/models/eagle3 && cmake --workflow --preset eagle3-cuda + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/eagle3/eagle3_speculator_runner" + +eagle3-mlx: + @echo "==> Building and installing ExecuTorch with MLX..." + cmake --workflow --preset mlx-release + @echo "==> Building EAGLE-3 speculator runner with MLX..." + cd examples/models/eagle3 && cmake --workflow --preset eagle3-mlx + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/eagle3/eagle3_speculator_runner" + qwen3_5_moe-metal: @echo "==> Building and installing ExecuTorch with Metal..." cmake --workflow --preset llm-release-metal diff --git a/examples/models/eagle3/CMakeLists.txt b/examples/models/eagle3/CMakeLists.txt index f7fe225f4f3..4e2b9afc99e 100644 --- a/examples/models/eagle3/CMakeLists.txt +++ b/examples/models/eagle3/CMakeLists.txt @@ -42,14 +42,19 @@ list( extension_flat_tensor ) -# Backend: CUDA (AOTI). The EAGLE-3 speculator export is CUDA-only. +# Backend: CUDA (AOTI) or MLX (exactly one required). CUDA returns greedy ids; +# MLX returns logits and the runner argmaxes + maps draft ids via d2t on the +# host. if(EXECUTORCH_BUILD_CUDA) find_package(CUDAToolkit REQUIRED) list(APPEND link_libraries aoti_cuda_backend) executorch_target_link_options_shared_lib(aoti_cuda_backend) add_compile_definitions(EXECUTORCH_BUILD_CUDA) +elseif(TARGET mlxdelegate) + list(APPEND link_libraries mlxdelegate mlx) + executorch_target_link_options_shared_lib(mlxdelegate) else() - message(FATAL_ERROR "EAGLE-3 speculator runner requires EXECUTORCH_BUILD_CUDA=ON") + message(FATAL_ERROR "Set EXECUTORCH_BUILD_CUDA=ON or EXECUTORCH_BUILD_MLX=ON") endif() # Tokenizer (HuggingFace tokenizer.json) @@ -67,3 +72,7 @@ if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") target_link_options(eagle3_speculator_runner PRIVATE "LINKER:-s") endif() endif() + +if(TARGET mlxdelegate) + executorch_target_copy_mlx_metallib(eagle3_speculator_runner) +endif() diff --git a/examples/models/eagle3/CMakePresets.json b/examples/models/eagle3/CMakePresets.json index 9d9b75b262d..ffed1b41f96 100644 --- a/examples/models/eagle3/CMakePresets.json +++ b/examples/models/eagle3/CMakePresets.json @@ -16,6 +16,21 @@ "string": "${hostSystemName}", "list": ["Linux", "Windows"] } + }, + { + "name": "eagle3-mlx", + "displayName": "EAGLE-3 speculator runner (MLX)", + "binaryDir": "${sourceDir}/../../../cmake-out/examples/models/eagle3", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out", + "CMAKE_PREFIX_PATH": "${sourceDir}/../../../cmake-out" + }, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Darwin" + } } ], "buildPresets": [ @@ -24,6 +39,30 @@ "displayName": "Build EAGLE-3 speculator runner (CUDA)", "configurePreset": "eagle3-cuda", "targets": ["eagle3_speculator_runner"] + }, + { + "name": "eagle3-mlx", + "displayName": "Build EAGLE-3 speculator runner (MLX)", + "configurePreset": "eagle3-mlx", + "targets": ["eagle3_speculator_runner"] + } + ], + "workflowPresets": [ + { + "name": "eagle3-cuda", + "displayName": "Configure and build EAGLE-3 speculator runner (CUDA)", + "steps": [ + {"type": "configure", "name": "eagle3-cuda"}, + {"type": "build", "name": "eagle3-cuda"} + ] + }, + { + "name": "eagle3-mlx", + "displayName": "Configure and build EAGLE-3 speculator runner (MLX)", + "steps": [ + {"type": "configure", "name": "eagle3-mlx"}, + {"type": "build", "name": "eagle3-mlx"} + ] } ] } diff --git a/examples/models/eagle3/export.py b/examples/models/eagle3/export.py index 57730bb6541..c9da1e10084 100644 --- a/examples/models/eagle3/export.py +++ b/examples/models/eagle3/export.py @@ -364,9 +364,10 @@ def _export_mlx( strict=True, ) + # Capture d2t before freeing the speculator; baked in as get_d2t below. + d2t_const = spec.draft.d2t.to(torch.long).cpu().contiguous() del spec gc.collect() - print("Lowering to ExecuTorch with MLX backend...") et_prog = to_edge_transform_and_lower( {"target_forward": target_ep, "draft_decode": draft_ep}, @@ -387,6 +388,10 @@ def _export_mlx( "get_min_prefill_chunk": 1, "get_chain_len": chain_len, "get_draft_vocab_size": draft_vocab_size, + # draft->target vocab map (target_id = draft_id + d2t[draft_id]); the + # MLX draft_decode returns draft-vocab logits, so a logits-consuming + # runner reads this to map proposals back to target ids. + "get_d2t": d2t_const, "use_kv_cache": True, "enable_dynamic_shape": True, }, diff --git a/examples/models/eagle3/main.cpp b/examples/models/eagle3/main.cpp index 6a68e89eaaa..7b81f6ef6d1 100644 --- a/examples/models/eagle3/main.cpp +++ b/examples/models/eagle3/main.cpp @@ -6,15 +6,22 @@ * LICENSE file in the root directory of this source tree. */ -// EAGLE-3 speculative-decoding runner for ExecuTorch (CUDA/AOTI backend). +// EAGLE-3 speculative-decoding runner for ExecuTorch (CUDA/AOTI or MLX +// backend). // -// Loads the speculator .pte (examples/models/eagle3/export.py) exposing three -// methods that share the target / draft KV caches: +// Loads the speculator .pte (examples/models/eagle3/export.py). The CUDA export +// exposes three methods that share the target / draft KV caches: // prefill(tokens[1,T], pos[T]) -> (next_token[1,1], feat[1,T,H]) // target_verify(tokens[1,C], pos[C]) -> (greedy_ids[1,C], feat[1,C,H]) -// draft_decode(tokens[1,T], feat[1,T,H], pos[T]) -> (target_ids[1,T], g[1,T,H]) -// where feat is the fused (hidden-size) draft feature and H is the draft hidden -// size. Verification is greedy (argmax), so emitted tokens equal greedy target +// draft_decode(tokens[1,T], feat[1,T,H], pos[T]) -> (target_ids[1,T], +// g[1,T,H]) +// MLX has no cross-method KV sharing, so prefill+verify are one dynamic-seq +// method and the methods return logits instead of ids: +// target_forward(tokens[1,T], pos[T]) -> (logits[1,T,V], feat[1,T,H]) +// draft_decode(...) -> (draft_logits[1,T,Vd], g[1,T,H]) +// This runner argmaxes those host-side and maps draft ids via d2t (get_d2t). +// feat is the fused (hidden-size) draft feature and H is the draft hidden size. +// Verification is greedy (argmax), so emitted tokens equal greedy target // decoding (lossless) by construction. // // Scheme: the shifted EAGLE convention (vllm/v1/spec_decode/eagle.py, @@ -30,23 +37,25 @@ // negligible next to the INT4 31B target forward, and it keeps device-tensor // lifetimes simple. // -// Run (after exporting model.pte + aoti_cuda_blob.ptd via export.py, sourcing the -// CUDA env, and building the eagle3-cuda preset): +// Run (CUDA: export model.pte + aoti_cuda_blob.ptd and source the CUDA env; +// MLX: a single model.pte, no --data_path): // eagle3_speculator_runner --model_path /model.pte \ -// --data_path /aoti_cuda_blob.ptd --tokenizer_path \ +// [--data_path /aoti_cuda_blob.ptd] --tokenizer_path +// \ // --prompt "..." --max_new_tokens 128 // The chat template and stop tokens default to Gemma 4 IT; override -// --chat_prefix/--chat_suffix/--stop_ids/--stop_token (and --bos_id -1) for other -// target/tokenizer pairs. Per-run timing counters (tau, verify/draft ms) print at -// the end. +// --chat_prefix/--chat_suffix/--stop_ids/--stop_token (and --bos_id -1) for +// other target/tokenizer pairs. Per-run timing counters (tau, verify/draft ms) +// print at the end. // -// Scope: a single-sequence, greedy, fixed-shape demo runner -- not a generic -// EAGLE serving path. No batching, sampler stack (top-k/p/temperature), grammar/ -// tool constraints, streaming API, or integration with the standard ExecuTorch -// LLM runner. The host feature round-trip above is a first-implementation choice -// (the target forward dominates here); a device-resident handoff is future work. -// The target, draft, and tokenizer must be a matched, co-trained set -- a -// mismatch can pass export and silently degrade acceptance/output. +// Scope: a single-sequence, fixed-shape demo runner -- not a generic EAGLE +// serving path. CUDA is greedy; MLX adds --temperature (modified rejection +// sampling) but no top-k/p. No batching, grammar/tool constraints, streaming +// API, or integration with the standard ExecuTorch LLM runner. The host feature +// round-trip above is a first-implementation choice (the target forward +// dominates here); a device-resident handoff is future work. The target, draft, +// and tokenizer must be a matched, co-trained set -- a mismatch can pass export +// and silently degrade acceptance/output. #include @@ -62,9 +71,11 @@ #include #include +#include #include #include #include +#include #include #include @@ -96,9 +107,15 @@ DEFINE_bool(raw_prompt, false, "Skip the Gemma 4 IT chat template."); DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); DEFINE_int32(bos_id, 2, "BOS token id (-1 to skip; Gemma convention: 2)."); DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1)."); -DEFINE_bool(cuda_graph, true, "Capture target_verify as a CUDA graph (CUDA only)."); +DEFINE_bool( + cuda_graph, + true, + "Capture target_verify as a CUDA graph (CUDA only)."); // Chat template + stop tokens default to Gemma 4 IT; override for other models. -DEFINE_string(chat_prefix, "<|turn>user\n", "Chat-template text before the prompt."); +DEFINE_string( + chat_prefix, + "<|turn>user\n", + "Chat-template text before the prompt."); DEFINE_string( chat_suffix, "\n<|turn>model\n<|channel>thought\n", @@ -111,6 +128,17 @@ DEFINE_string( stop_token, "", "A stop-delimiter string to encode and add to EOS (empty to skip)."); +DEFINE_double( + temperature, + 0.0, + "Sampling temperature (0 = greedy/argmax; >0 = modified rejection " + "sampling; MLX only)."); +DEFINE_int64(seed, 0, "RNG seed for --temperature sampling."); +DEFINE_int64( + chain, + 0, + "Override draft chain length K (MLX only; 0 = use the exported " + "get_chain_len). Ignored on CUDA, whose verify shape is static."); using executorch::extension::from_blob; using executorch::extension::Module; @@ -130,7 +158,8 @@ std::vector to_host_bytes(const executorch::aten::Tensor& t) { cudaPointerAttributes attrs{}; if (cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess && attrs.type == cudaMemoryTypeDevice) { - cudaError_t err = cudaMemcpy(out.data(), ptr, out.size(), cudaMemcpyDeviceToHost); + cudaError_t err = + cudaMemcpy(out.data(), ptr, out.size(), cudaMemcpyDeviceToHost); if (err != cudaSuccess) { ET_LOG(Error, "D2H copy failed: %s", cudaGetErrorString(err)); exit(1); @@ -169,6 +198,99 @@ HostFeature read_feature(const executorch::aten::Tensor& t) { return f; } +#ifndef EXECUTORCH_BUILD_CUDA +// MLX returns logits; read one element of a (1, T, V) Float or BFloat16 row. +inline float +logit_at(const uint8_t* row, int64_t v, executorch::aten::ScalarType st) { + if (st == executorch::aten::ScalarType::BFloat16) { + uint32_t u = + static_cast(reinterpret_cast(row)[v]) << 16; + float f; + memcpy(&f, &u, sizeof(f)); + return f; + } + return reinterpret_cast(row)[v]; +} + +// Per-position argmax over a (1, T, V) logits tensor -> T ids. +std::vector argmax_rows(const executorch::aten::Tensor& t) { + const int64_t T = t.size(1), V = t.size(2); + const auto st = t.scalar_type(); + const int64_t esz = t.numel() > 0 ? (int64_t)(t.nbytes() / t.numel()) : 0; + auto bytes = to_host_bytes(t); + std::vector ids(T); + for (int64_t r = 0; r < T; r++) { + const uint8_t* row = bytes.data() + (size_t)r * V * esz; + int64_t best = 0; + float bv = logit_at(row, 0, st); + for (int64_t v = 1; v < V; v++) { + float x = logit_at(row, v, st); + if (x > bv) { + bv = x; + best = v; + } + } + ids[r] = best; + } + return ids; +} + +// Numerically-stable softmax in place (same max-subtract approach as the static +// softmax in extension/llm/sampler/sampler.cpp). Apply temperature by scaling +// logits first (see temperature_softmax). +void softmax_inplace(std::vector& x) { + if (x.empty()) { + return; + } + float max_val = *std::max_element(x.begin(), x.end()); + float sum = 0.0f; + for (float& v : x) { + v = std::exp(v - max_val); + sum += v; + } + for (float& v : x) { + v /= sum; + } +} + +// Temperature-scale logits then softmax (mirrors Sampler::sample: logits *= +// inv_temperature; softmax(...)). +void temperature_softmax(std::vector& x, float inv_temp) { + for (float& v : x) { + v *= inv_temp; + } + softmax_inplace(x); +} + +// Multinomial sample from a probability vector (mirrors Sampler::sample_mult: +// walk the CDF and return the first index past the coin). +int64_t sample_mult(const std::vector& p, std::mt19937_64& rng) { + float coin = std::uniform_real_distribution(0.0f, 1.0f)(rng); + float cdf = 0.0f; + for (size_t i = 0; i < p.size(); i++) { + cdf += p[i]; + if (coin < cdf) { + return static_cast(i); + } + } + return static_cast(p.size()) - 1; +} + +// Copy row r of a (1, T, V) Float/BFloat16 logits tensor to float[V]. +std::vector row_to_floats(const executorch::aten::Tensor& t, int64_t r) { + const int64_t V = t.size(2); + const auto st = t.scalar_type(); + const int64_t esz = t.numel() > 0 ? (int64_t)(t.nbytes() / t.numel()) : 0; + auto bytes = to_host_bytes(t); + const uint8_t* row = bytes.data() + (size_t)r * V * esz; + std::vector out(V); + for (int64_t v = 0; v < V; v++) { + out[v] = logit_at(row, v, st); + } + return out; +} +#endif + } // namespace int main(int argc, char** argv) { @@ -183,7 +305,10 @@ int main(int argc, char** argv) { auto tokenizer = std::make_unique(); if (tokenizer->load(FLAGS_tokenizer_path) != tokenizers::Error::Ok) { - ET_LOG(Error, "Failed to load tokenizer from %s", FLAGS_tokenizer_path.c_str()); + ET_LOG( + Error, + "Failed to load tokenizer from %s", + FLAGS_tokenizer_path.c_str()); return 1; } @@ -218,7 +343,11 @@ int main(int argc, char** argv) { } #endif +#ifdef EXECUTORCH_BUILD_CUDA for (const char* m : {"prefill", "target_verify", "draft_decode"}) { +#else + for (const char* m : {"target_forward", "draft_decode"}) { +#endif if (module->load_method(m) != Error::Ok) { ET_LOG(Error, "Failed to load method %s", m); return 1; @@ -244,7 +373,42 @@ int main(int argc, char** argv) { const int64_t max_prefill = meta("get_max_prefill_chunk"); const int64_t min_prefill = meta("get_min_prefill_chunk"); const int64_t max_seq_len = meta("get_max_seq_len"); - const int64_t K = chain_len; + int64_t K = chain_len; +#ifndef EXECUTORCH_BUILD_CUDA + // MLX methods are dynamic-seq, so the chain can be set at runtime as long as + // the verify window (K+1) fits the exported prefill range. + if (FLAGS_chain > 0) { + K = FLAGS_chain; + } + if (K + 1 > max_prefill) { + ET_LOG( + Error, + "--chain %" PRId64 " (verify window %" PRId64 + ") exceeds exported max_prefill %" PRId64, + K, + K + 1, + max_prefill); + return 1; + } +#endif + +#ifndef EXECUTORCH_BUILD_CUDA + // MLX draft_decode returns draft-vocab logits; load the draft->target map + // (target_id = draft_id + d2t[draft_id]) baked in by export.py. + std::vector d2t; + { + auto r = module->get("get_d2t"); + if (!r.ok()) { + ET_LOG(Error, "missing get_d2t metadata in .pte"); + return 1; + } + d2t = read_ids(r->toTensor()); + } + const float temp = + FLAGS_temperature > 0.0 ? static_cast(FLAGS_temperature) : 0.0f; + const float inv_temp = temp > 0.0f ? 1.0f / temp : 0.0f; + std::mt19937_64 rng(static_cast(FLAGS_seed)); +#endif // EOS: tokenizer/metadata ids, the configured eos, any --stop_ids, and the // encoded --stop_token delimiter (all default to the Gemma 4 IT conventions). @@ -282,24 +446,38 @@ int main(int argc, char** argv) { } const int64_t L = static_cast(prompt.size()); // The runner does not chunk: the whole prompt must fit one prefill, and its - // length must be within the exported prefill range [min_prefill, max_prefill]. + // length must be within the exported prefill range [min_prefill, + // max_prefill]. if (L > max_prefill) { - ET_LOG(Error, "Prompt (%" PRId64 " tokens) exceeds max_prefill %" PRId64 - "; this runner does not chunk prefill.", L, max_prefill); + ET_LOG( + Error, + "Prompt (%" PRId64 " tokens) exceeds max_prefill %" PRId64 + "; this runner does not chunk prefill.", + L, + max_prefill); return 1; } if (L < min_prefill) { - ET_LOG(Error, "Prompt (%" PRId64 " tokens) is below the exported prefill " - "minimum %" PRId64 "; use a longer prompt.", L, min_prefill); + ET_LOG( + Error, + "Prompt (%" PRId64 + " tokens) is below the exported prefill " + "minimum %" PRId64 "; use a longer prompt.", + L, + min_prefill); return 1; } // The prefill bonus token is always emittable (no KV write past the prompt). - // Each speculative round, however, writes a K-token verify window, so it needs - // anchor_pos + K <= max_seq_len - 1 (enforced in the loop below). Cap the total - // at the positions available; max_new >= 1 since L <= max_prefill < max_seq_len. + // Each speculative round, however, writes a K-token verify window, so it + // needs anchor_pos + K <= max_seq_len - 1 (enforced in the loop below). Cap + // the total at the positions available; max_new >= 1 since L <= max_prefill < + // max_seq_len. int64_t max_new = std::min(FLAGS_max_new_tokens, max_seq_len - L); - printf("Prompt tokens: %" PRId64 ", chain K=%" PRId64 ", max_new=%" PRId64 - "\n", L, K, max_new); + printf( + "Prompt tokens: %" PRId64 ", chain K=%" PRId64 ", max_new=%" PRId64 "\n", + L, + K, + max_new); auto S = [](int64_t v) { return static_cast(v); }; @@ -309,11 +487,15 @@ int main(int argc, char** argv) { auto long_tensor = [&](std::vector& buf) { return from_blob( - buf.data(), {1, S((int64_t)buf.size())}, executorch::aten::ScalarType::Long); + buf.data(), + {1, S((int64_t)buf.size())}, + executorch::aten::ScalarType::Long); }; auto pos_tensor = [&](std::vector& buf) { return from_blob( - buf.data(), {S((int64_t)buf.size())}, executorch::aten::ScalarType::Long); + buf.data(), + {S((int64_t)buf.size())}, + executorch::aten::ScalarType::Long); }; // draft_decode over (tokens, feat rows, positions); returns proposals + the @@ -333,7 +515,9 @@ int main(int argc, char** argv) { feat_buf.assign(feat_rows, feat_rows + feat_T * H); auto t_tok = long_tensor(tok_buf); auto t_feat = from_blob( - feat_buf.data(), {1, S(feat_T), S(H)}, executorch::aten::ScalarType::BFloat16); + feat_buf.data(), + {1, S(feat_T), S(H)}, + executorch::aten::ScalarType::BFloat16); auto t_pos = pos_tensor(pos_buf); auto r = module->execute( "draft_decode", {EValue(t_tok), EValue(t_feat), EValue(t_pos)}); @@ -341,12 +525,19 @@ int main(int argc, char** argv) { ET_LOG(Error, "draft_decode failed"); exit(1); } +#ifdef EXECUTORCH_BUILD_CUDA out_ids = read_ids(r->at(0).toTensor()); +#else + // draft-vocab argmax -> target ids via d2t. + out_ids = argmax_rows(r->at(0).toTensor()); + for (auto& id : out_ids) { + id += d2t[id]; + } +#endif HostFeature g = read_feature(r->at(1).toTensor()); out_last_g.T = 1; out_last_g.H = g.H; - out_last_g.data.assign( - g.data.end() - g.H, g.data.end()); // last row of g + out_last_g.data.assign(g.data.end() - g.H, g.data.end()); // last row of g }; // Run a draft chain seeded by (seed_tokens, seed_feat) at seed positions; the @@ -358,38 +549,139 @@ int main(int argc, char** argv) { std::vector ids; HostFeature last_g; draft_decode( - seed_tokens, seed_feat.data.data(), seed_feat.T, seed_feat.H, - seed_start_pos, ids, last_g); + seed_tokens, + seed_feat.data.data(), + seed_feat.T, + seed_feat.H, + seed_start_pos, + ids, + last_g); proposals.push_back(ids.back()); int64_t last_pos = seed_start_pos + seed_feat.T - 1; for (int64_t k = 1; k < K; k++) { std::vector step_ids; HostFeature step_g; draft_decode( - {proposals.back()}, last_g.data.data(), 1, last_g.H, - last_pos + k, step_ids, step_g); + {proposals.back()}, + last_g.data.data(), + 1, + last_g.H, + last_pos + k, + step_ids, + step_g); proposals.push_back(step_ids[0]); last_g = step_g; } return proposals; }; +#ifndef EXECUTORCH_BUILD_CUDA + // Run draft_decode and return the draft-vocab probabilities for the last + // (predicting) row plus the recurrent feature; the sampling counterpart of + // draft_decode that keeps the distribution for rejection sampling. + auto draft_logits_step = [&](const std::vector& tokens, + const uint16_t* feat_rows, + int64_t feat_T, + int64_t Hh, + int64_t start_pos, + HostFeature& out_last_g) -> std::vector { + tok_buf.assign(tokens.begin(), tokens.end()); + pos_buf.resize(tokens.size()); + for (size_t i = 0; i < tokens.size(); i++) { + pos_buf[i] = start_pos + static_cast(i); + } + feat_buf.assign(feat_rows, feat_rows + feat_T * Hh); + auto t_tok = long_tensor(tok_buf); + auto t_feat = from_blob( + feat_buf.data(), + {1, S(feat_T), S(Hh)}, + executorch::aten::ScalarType::BFloat16); + auto t_pos = pos_tensor(pos_buf); + auto r = module->execute( + "draft_decode", {EValue(t_tok), EValue(t_feat), EValue(t_pos)}); + if (r.error() != Error::Ok) { + ET_LOG(Error, "draft_decode failed"); + exit(1); + } + auto dl = r->at(0).toTensor(); + std::vector q = row_to_floats(dl, dl.size(1) - 1); + temperature_softmax(q, inv_temp); + HostFeature g = read_feature(r->at(1).toTensor()); + out_last_g.T = 1; + out_last_g.H = g.H; + out_last_g.data.assign(g.data.end() - g.H, g.data.end()); + return q; + }; + + // Sampling counterpart of `chain`: sample each proposal from the draft and + // record (draft_id, q) for the rejection test. + auto chain_sample = [&](const std::vector& seed_tokens, + const HostFeature& seed_feat, + int64_t seed_start_pos, + std::vector& draft_ids, + std::vector>& q_rows) { + std::vector proposals; + HostFeature last_g; + std::vector q = draft_logits_step( + seed_tokens, + seed_feat.data.data(), + seed_feat.T, + seed_feat.H, + seed_start_pos, + last_g); + int64_t d = sample_mult(q, rng); + draft_ids.push_back(d); + proposals.push_back(d + d2t[d]); + q_rows.push_back(std::move(q)); + int64_t last_pos = seed_start_pos + seed_feat.T - 1; + for (int64_t k = 1; k < K; k++) { + HostFeature step_g; + std::vector qk = draft_logits_step( + {proposals.back()}, + last_g.data.data(), + 1, + last_g.H, + last_pos + k, + step_g); + int64_t dk = sample_mult(qk, rng); + draft_ids.push_back(dk); + proposals.push_back(dk + d2t[dk]); + q_rows.push_back(std::move(qk)); + last_g = step_g; + } + return proposals; + }; +#endif + stats.model_load_end_ms = llm::time_in_ms(); stats.inference_start_ms = stats.model_load_end_ms; - // --- Prefill: target over the prompt -> bonus token + per-position feature. --- + // --- Prefill: target over the prompt -> bonus token + per-position feature. + // --- tok_buf = prompt; pos_buf.resize(L); for (int64_t i = 0; i < L; i++) { pos_buf[i] = i; } +#ifdef EXECUTORCH_BUILD_CUDA auto pf = module->execute( "prefill", {EValue(long_tensor(tok_buf)), EValue(pos_tensor(pos_buf))}); +#else + auto pf = module->execute( + "target_forward", + {EValue(long_tensor(tok_buf)), EValue(pos_tensor(pos_buf))}); +#endif if (pf.error() != Error::Ok) { ET_LOG(Error, "prefill failed"); return 1; } - int64_t anchor = read_ids(pf->at(0).toTensor())[0]; // bonus token at position L +#ifdef EXECUTORCH_BUILD_CUDA + int64_t anchor = + read_ids(pf->at(0).toTensor())[0]; // bonus token at position L +#else + // target_forward returns per-position logits; the bonus is the last argmax. + int64_t anchor = argmax_rows(pf->at(0).toTensor()).back(); +#endif HostFeature feat_prompt = read_feature(pf->at(1).toTensor()); const int64_t H = feat_prompt.H; int64_t anchor_pos = L; @@ -401,23 +693,37 @@ int main(int argc, char** argv) { uint64_t prev = static_cast(prompt.back()); { auto s = tokenizer->decode(prev, static_cast(anchor)); - if (s.ok()) { printf("%s", s->c_str()); fflush(stdout); } + if (s.ok()) { + printf("%s", s->c_str()); + fflush(stdout); + } prev = static_cast(anchor); } // We only run the speculative loop if more than the (already emitted) prefill - // bonus is wanted, the bonus wasn't EOS, and there is room for a K-token verify - // window. Otherwise we are done -- no draft seeding needed. + // bonus is wanted, the bonus wasn't EOS, and there is room for a K-token + // verify window. Otherwise we are done -- no draft seeding needed. bool hit_eos = eos_ids.count(static_cast(anchor)) > 0; bool speculate = max_new > 1 && !hit_eos && anchor_pos + K <= max_seq_len - 1; std::vector proposals; +#ifndef EXECUTORCH_BUILD_CUDA + // Per-round draft distributions kept for --temperature rejection sampling. + std::vector draft_ids; + std::vector> q_rows; +#endif if (speculate) { // Seed the first chain (shifted): draft slot p pairs feat_prompt[p] with // token_{p+1}; the last slot pairs feat_prompt[L-1] with the bonus and // predicts position L+1. std::vector seed_tokens(prompt.begin() + 1, prompt.end()); seed_tokens.push_back(anchor); +#ifdef EXECUTORCH_BUILD_CUDA proposals = chain(seed_tokens, feat_prompt, 0); +#else + proposals = (temp == 0.0f) + ? chain(seed_tokens, feat_prompt, 0) + : chain_sample(seed_tokens, feat_prompt, 0, draft_ids, q_rows); +#endif } // Stable buffers for target_verify (fixed length K+1) so the CUDA graph @@ -443,18 +749,27 @@ int main(int argc, char** argv) { vpos_buf[i] = anchor_pos + i; } int64_t t_v = llm::time_in_ms(); - auto vr = module->execute("target_verify", {EValue(vtok_t), EValue(vpos_t)}); +#ifdef EXECUTORCH_BUILD_CUDA + auto vr = + module->execute("target_verify", {EValue(vtok_t), EValue(vpos_t)}); +#else + auto vr = + module->execute("target_forward", {EValue(vtok_t), EValue(vpos_t)}); +#endif if (vr.error() != Error::Ok) { - ET_LOG(Error, "target_verify failed"); + ET_LOG(Error, "target forward failed"); return 1; } - std::vector verify_ids = read_ids(vr->at(0).toTensor()); HostFeature verify_feat = read_feature(vr->at(1).toTensor()); verify_ms += llm::time_in_ms() - t_v; - // Greedy acceptance: verify_ids[j] is the greedy token after token j, so it - // checks proposal j (which sits at position anchor_pos+1+j). + // Acceptance: count the leading proposals that pass, then a corrected + // token. verify slot j is the target distribution after token j (proposal j + // sits at position anchor_pos+1+j). int64_t a = 0; + int64_t corrected = 0; +#ifdef EXECUTORCH_BUILD_CUDA + std::vector verify_ids = read_ids(vr->at(0).toTensor()); for (int64_t j = 0; j < K; j++) { if (proposals[j] == verify_ids[j]) { a++; @@ -462,15 +777,75 @@ int main(int argc, char** argv) { break; } } - int64_t corrected = verify_ids[a]; + corrected = verify_ids[a]; +#else + auto verify_logits = vr->at(0).toTensor(); + if (temp == 0.0f) { + // Greedy acceptance against the per-position argmax. + std::vector verify_ids = argmax_rows(verify_logits); + for (int64_t j = 0; j < K; j++) { + if (proposals[j] == verify_ids[j]) { + a++; + } else { + break; + } + } + corrected = verify_ids[a]; + } else { + // Modified rejection sampling (lossless w.r.t. target sampling): accept + // proposal j with prob min(1, p_j[x]/q_j[d]); on reject resample from the + // residual (p - q)_+ over the target vocab; the all-accepted bonus is a + // sample from p_K. + std::uniform_real_distribution u(0.0, 1.0); + bool rejected = false; + for (int64_t j = 0; j < K; j++) { + std::vector p = row_to_floats(verify_logits, j); + temperature_softmax(p, inv_temp); + const int64_t x = proposals[j], d = draft_ids[j]; + const float qx = q_rows[j][d]; + const float ratio = qx > 0.0f ? std::min(1.0f, p[x] / qx) : 0.0f; + if (u(rng) <= ratio) { + a++; + continue; + } + // residual: subtract q mapped to the target vocab, clamp, renormalize. + const std::vector& q = q_rows[j]; + for (int64_t dd = 0; dd < (int64_t)q.size(); dd++) { + p[dd + d2t[dd]] -= q[dd]; + } + double sum = 0.0; + for (float& pv : p) { + if (pv < 0.0f) { + pv = 0.0f; + } + sum += pv; + } + for (float& pv : p) { + pv = static_cast(pv / sum); + } + corrected = sample_mult(p, rng); + rejected = true; + break; + } + if (!rejected) { + std::vector p = row_to_floats(verify_logits, K); + temperature_softmax(p, inv_temp); + corrected = sample_mult(p, rng); + } + } +#endif std::vector newly(proposals.begin(), proposals.begin() + a); newly.push_back(corrected); for (int64_t t : newly) { - if ((int64_t)emitted.size() >= max_new) break; + if ((int64_t)emitted.size() >= max_new) + break; emitted.push_back(t); auto s = tokenizer->decode(prev, static_cast(t)); - if (s.ok()) { printf("%s", s->c_str()); fflush(stdout); } + if (s.ok()) { + printf("%s", s->c_str()); + fflush(stdout); + } prev = static_cast(t); if (eos_ids.count(static_cast(t)) > 0) { // Stop at the first accepted EOS; do not emit the rest of this batch. @@ -481,13 +856,15 @@ int main(int argc, char** argv) { break; } } - if (hit_eos || (int64_t)emitted.size() >= max_new) break; + if (hit_eos || (int64_t)emitted.size() >= max_new) + break; // Reseed the draft (shifted): slot anchor_pos+i holds (verify_feat[i], - // token_{anchor_pos+i+1}) where token = p_i (i