diff --git a/Makefile b/Makefile index c54c56906ae..2b0e11cdb77 100644 --- a/Makefile +++ b/Makefile @@ -91,7 +91,7 @@ # # ============================================================================== -.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu lfm_2_5-mlx llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda gemma4_31b-mlx qwen3_5_moe-cuda qwen3_5_moe-metal clean help +.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral-mlx voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal voxtral_realtime-mlx voxtral_tts-cpu voxtral_tts-cuda whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal parakeet-mlx parakeet-vulkan dinov2-cuda dinov2-cuda-debug sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu lfm_2_5-mlx llava-cpu gemma3-cuda gemma3-cpu gemma4_31b-cuda gemma4_31b-mlx eagle3-cuda eagle3-mlx qwen3_5_moe-cuda qwen3_5_moe-metal clean help help: @echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make \`. Available targets:" @@ -129,6 +129,8 @@ help: @echo " gemma3-cpu - Build Gemma3 runner with CPU backend" @echo " gemma4_31b-cuda - Build Gemma 4 31B runner with CUDA backend" @echo " gemma4_31b-mlx - Build Gemma 4 31B runner with MLX backend" + @echo " eagle3-cuda - Build EAGLE-3 speculator runner with CUDA backend" + @echo " eagle3-mlx - Build EAGLE-3 speculator runner with MLX backend" @echo " qwen3_5_moe-cuda - Build Qwen3.5 MoE runner with CUDA backend" @echo " qwen3_5_moe-metal - Build Qwen3.5 MoE runner with Metal backend" @echo " clean - Clean build artifacts" @@ -457,6 +459,24 @@ gemma4_31b-mlx: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/gemma4_31b/gemma4_31b_runner" +eagle3-cuda: + @echo "==> Building and installing ExecuTorch with CUDA..." + cmake --workflow --preset llm-release-cuda + @echo "==> Building EAGLE-3 speculator runner with CUDA..." + cd examples/models/eagle3 && cmake --workflow --preset eagle3-cuda + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/eagle3/eagle3_speculator_runner" + +eagle3-mlx: + @echo "==> Building and installing ExecuTorch with MLX..." + cmake --workflow --preset mlx-release + @echo "==> Building EAGLE-3 speculator runner with MLX..." + cd examples/models/eagle3 && cmake --workflow --preset eagle3-mlx + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/eagle3/eagle3_speculator_runner" + qwen3_5_moe-metal: @echo "==> Building and installing ExecuTorch with Metal..." cmake --workflow --preset llm-release-metal diff --git a/examples/models/eagle3/CMakeLists.txt b/examples/models/eagle3/CMakeLists.txt index f7fe225f4f3..4e2b9afc99e 100644 --- a/examples/models/eagle3/CMakeLists.txt +++ b/examples/models/eagle3/CMakeLists.txt @@ -42,14 +42,19 @@ list( extension_flat_tensor ) -# Backend: CUDA (AOTI). The EAGLE-3 speculator export is CUDA-only. +# Backend: CUDA (AOTI) or MLX (exactly one required). CUDA returns greedy ids; +# MLX returns logits and the runner argmaxes + maps draft ids via d2t on the +# host. if(EXECUTORCH_BUILD_CUDA) find_package(CUDAToolkit REQUIRED) list(APPEND link_libraries aoti_cuda_backend) executorch_target_link_options_shared_lib(aoti_cuda_backend) add_compile_definitions(EXECUTORCH_BUILD_CUDA) +elseif(TARGET mlxdelegate) + list(APPEND link_libraries mlxdelegate mlx) + executorch_target_link_options_shared_lib(mlxdelegate) else() - message(FATAL_ERROR "EAGLE-3 speculator runner requires EXECUTORCH_BUILD_CUDA=ON") + message(FATAL_ERROR "Set EXECUTORCH_BUILD_CUDA=ON or EXECUTORCH_BUILD_MLX=ON") endif() # Tokenizer (HuggingFace tokenizer.json) @@ -67,3 +72,7 @@ if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") target_link_options(eagle3_speculator_runner PRIVATE "LINKER:-s") endif() endif() + +if(TARGET mlxdelegate) + executorch_target_copy_mlx_metallib(eagle3_speculator_runner) +endif() diff --git a/examples/models/eagle3/CMakePresets.json b/examples/models/eagle3/CMakePresets.json index 9d9b75b262d..ffed1b41f96 100644 --- a/examples/models/eagle3/CMakePresets.json +++ b/examples/models/eagle3/CMakePresets.json @@ -16,6 +16,21 @@ "string": "${hostSystemName}", "list": ["Linux", "Windows"] } + }, + { + "name": "eagle3-mlx", + "displayName": "EAGLE-3 speculator runner (MLX)", + "binaryDir": "${sourceDir}/../../../cmake-out/examples/models/eagle3", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out", + "CMAKE_PREFIX_PATH": "${sourceDir}/../../../cmake-out" + }, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Darwin" + } } ], "buildPresets": [ @@ -24,6 +39,30 @@ "displayName": "Build EAGLE-3 speculator runner (CUDA)", "configurePreset": "eagle3-cuda", "targets": ["eagle3_speculator_runner"] + }, + { + "name": "eagle3-mlx", + "displayName": "Build EAGLE-3 speculator runner (MLX)", + "configurePreset": "eagle3-mlx", + "targets": ["eagle3_speculator_runner"] + } + ], + "workflowPresets": [ + { + "name": "eagle3-cuda", + "displayName": "Configure and build EAGLE-3 speculator runner (CUDA)", + "steps": [ + {"type": "configure", "name": "eagle3-cuda"}, + {"type": "build", "name": "eagle3-cuda"} + ] + }, + { + "name": "eagle3-mlx", + "displayName": "Configure and build EAGLE-3 speculator runner (MLX)", + "steps": [ + {"type": "configure", "name": "eagle3-mlx"}, + {"type": "build", "name": "eagle3-mlx"} + ] } ] } diff --git a/examples/models/eagle3/export.py b/examples/models/eagle3/export.py index c3a4decf8f4..c9da1e10084 100644 --- a/examples/models/eagle3/export.py +++ b/examples/models/eagle3/export.py @@ -27,11 +27,20 @@ (``test_shifted_speculative_decode_is_lossless`` drives the full loop through only these three methods). -Export runs with the model on the host (CPU); AOTInductor streams weights to the -GPU per kernel during compilation, so peak GPU memory stays low even for the INT4 -31B target. The target is loaded from a prequantized (INT4) directory and the -draft from a vLLM-speculator checkpoint; only the CUDA (AOTI) backend is -supported. +Export runs with the model on the host (CPU). For ``--backend cuda`` AOTInductor +streams weights to the GPU per kernel during compilation, so peak GPU memory +stays low even for the INT4 31B target. The target is loaded from a prequantized +(INT4) directory and the draft from a vLLM-speculator checkpoint. + +Backends: + - ``cuda`` (AOTI): three methods (prefill, target_verify, draft_decode) sharing + KV caches by FQN; bf16 draft. + - ``mlx`` (Apple silicon): MLX has no cross-method KV-cache sharing, so prefill + and verify are merged into one dynamic-seq ``target_forward`` (sharing the + target cache within a single handle) plus ``draft_decode``; the draft is bf16 + by default (``--quantize-draft`` for int4). Both methods return logits (target + soft-capped + draft) so the host applies temperature / sampling — greedy is + host-side argmax. Scope (this is a fixed-shape ExecuTorch artifact, not a generic EAGLE runtime): chain length, the chain_len+1 verify window, the prefill/draft dynamic ranges, @@ -95,6 +104,35 @@ def forward(self, tokens, feature, input_pos): return self.spec.draft_decode(tokens, feature, input_pos) +# Logit-returning variants for the MLX sampling path: the host applies +# temperature + modified rejection sampling, so the methods return distributions +# (soft-capped target logits / draft logits) instead of the greedy argmax. Greedy +# (--temperature 0) just argmaxes these host-side. + + +class _TargetForwardLogits(nn.Module): + def __init__(self, spec: Eagle3Speculator): + super().__init__() + self.spec = spec + + def forward(self, tokens, input_pos): + logits, taps = self.spec.target.forward_logits_taps( + tokens, input_pos, last_logits_only=False + ) + return logits, self.spec.draft.fuse(taps) + + +class _DraftDecodeLogits(nn.Module): + def __init__(self, spec: Eagle3Speculator): + super().__init__() + self.spec = spec + + def forward(self, tokens, feature, input_pos): + emb = self.spec.draft.embed(tokens) + draft_logits, g = self.spec.draft.forward_cached(emb, feature, input_pos) + return draft_logits, g + + def _export_cuda( spec: Eagle3Speculator, output_dir: str, @@ -253,39 +291,145 @@ def _partitioner(name: str): print("Done.") -def main() -> None: - p = argparse.ArgumentParser(description="Export an EAGLE-3 speculator to .pte.") - p.add_argument( - "--target-model", - default="gemma4_31b", - choices=list(TARGETS), - help="Registered target model (see eagle3/target.py).", +def _export_mlx( + spec: Eagle3Speculator, + output_dir: str, + max_prefill: int, + chain_len: int, + share_base_embedding: bool = False, +) -> None: + import executorch.backends.mlx.custom_kernel_ops.gguf.patterns # noqa: F401 + import executorch.extension.llm.export.gguf # noqa: F401 + import executorch.extension.llm.export.int4 # noqa: F401 + from executorch.backends.mlx import MLXPartitioner + from executorch.backends.mlx.passes import get_default_passes + from executorch.examples.models.gemma4_31b.mlx_source_transformations import ( + install_mlx_tap_forward, + mlx_source_transformations, ) - p.add_argument( - "--target", required=True, help="Prequantized (INT4) target directory." + from executorch.examples.models.gemma4_31b.model import materialize_runtime_buffers + from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, ) - p.add_argument("--draft", required=True, help="EAGLE-3 draft head directory.") - p.add_argument("--output-dir", default="./eagle3_exports") - p.add_argument("--max-seq-len", type=int, default=4096) - p.add_argument( - "--max-prefill", - type=int, - default=512, - help="Max prefill length: AOTI compiles prefill kernels for up to this T " - "and the whole prompt must fit in one prefill (the runner does not chunk). " - "Smaller compiles faster.", + from executorch.exir.passes import MemoryPlanningPass + from torch.export import Dim, export + + target_config = spec.target.config + hidden = spec.draft.config.hidden_size + draft_vocab_size = spec.draft.config.draft_vocab_size + + # MLX rewrites the target to mask-free layers + MLX KV caches; install a + # matching mask-free tap forward so the speculator's target methods trace. + mlx_source_transformations(spec.target, dtype=torch.bfloat16) + install_mlx_tap_forward(spec.target) + materialize_runtime_buffers(spec.target, dtype=torch.bfloat16) + + if share_base_embedding: + # Point the draft at the target's packed embedding so both methods emit + # identical bytes; the NamedDataStore then content-dedups them to one + # copy. Safe because the draft embed is a frozen copy of the target's and + # the draft's input_layernorm (RMSNorm) is invariant to the embed scale. + spec.draft.embed_tokens = spec.target.embed_tokens + + # MLX has no cross-method KV-cache sharing, so prefill and verify are one + # dynamic-seq method that shares the target cache within a single handle. The + # method returns per-position logits; the host samples (or argmaxes). + print(f"Exporting target_forward (T in [1, {max_prefill}])...") + target_dim = Dim("target_len", min=1, max=max_prefill) + with torch.no_grad(): + target_ep = export( + _TargetForwardLogits(spec), + ( + torch.zeros((1, max_prefill), dtype=torch.long), + torch.arange(max_prefill, dtype=torch.long), + ), + dynamic_shapes=({1: target_dim}, {0: target_dim}), + strict=True, + ) + + draft_max = max(max_prefill, chain_len + 1) + print(f"Exporting draft_decode (T in [1, {draft_max}])...") + draft_dim = Dim("draft_len", min=1, max=draft_max) + with torch.no_grad(): + draft_ep = export( + _DraftDecodeLogits(spec), + ( + torch.zeros((1, draft_max), dtype=torch.long), + torch.zeros((1, draft_max, hidden), dtype=torch.bfloat16), + torch.arange(draft_max, dtype=torch.long), + ), + dynamic_shapes=({1: draft_dim}, {1: draft_dim}, {0: draft_dim}), + strict=True, + ) + + # Capture d2t before freeing the speculator; baked in as get_d2t below. + d2t_const = spec.draft.d2t.to(torch.long).cpu().contiguous() + del spec + gc.collect() + print("Lowering to ExecuTorch with MLX backend...") + et_prog = to_edge_transform_and_lower( + {"target_forward": target_ep, "draft_decode": draft_ep}, + transform_passes=get_default_passes(), + partitioner={ + "target_forward": [MLXPartitioner()], + "draft_decode": [MLXPartitioner()], + }, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods={ + "get_max_seq_len": target_config.max_seq_len, + "get_vocab_size": target_config.vocab_size, + "get_n_layers": target_config.num_hidden_layers, + "get_max_prefill_chunk": max_prefill, + "get_min_prefill_chunk": 1, + "get_chain_len": chain_len, + "get_draft_vocab_size": draft_vocab_size, + # draft->target vocab map (target_id = draft_id + d2t[draft_id]); the + # MLX draft_decode returns draft-vocab logits, so a logits-consuming + # runner reads this to map proposals back to target ids. + "get_d2t": d2t_const, + "use_kv_cache": True, + "enable_dynamic_shape": True, + }, ) - p.add_argument( - "--chain", type=int, default=4, help="Draft chain length K (verify K+1)." + del target_ep, draft_ep + gc.collect() + + et_program = et_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ), ) - args = p.parse_args() + del et_prog + gc.collect() - spec_t = TARGETS[args.target_model] - if not torch.cuda.is_available(): - p.error("CUDA is required to compile the EAGLE-3 export.") + os.makedirs(output_dir, exist_ok=True) + pte_path = os.path.join(output_dir, "model.pte") + print(f"Saving to {pte_path}...") + with open(pte_path, "wb") as f: + et_program.write_to_file(f) + print(f" {os.path.getsize(pte_path) / 1024**2:.1f} MB") + if et_program._tensor_data: + et_program.write_tensor_data_to_file(output_dir) + print(f" Saved tensor data (.ptd) to {output_dir}/") + print("Done.") + + +def _validate_backend_flags(p, args) -> None: + if args.share_draft_embedding and args.backend != "mlx": + p.error("--share-draft-embedding is only supported with --backend mlx.") + if args.quantize_draft and args.backend != "mlx": + p.error("--quantize-draft is only supported with --backend mlx.") - print(f"Loading {args.target_model} target from {args.target}...") - target = spec_t.load(args.target, args.max_seq_len) + +def _load_target_and_draft(p, args, spec_t): + print(f"Loading {args.target_model} target ({args.backend}) from {args.target}...") + target = spec_t.load(args.target, args.max_seq_len, args.backend) print(f"Loading draft head from {args.draft}...") draft, _ = Eagle3Draft.from_checkpoint( @@ -308,12 +452,38 @@ def main() -> None: f"[{int(target_ids.min())}, {int(target_ids.max())}]; the draft and " f"target are likely not a matched pair" ) + return target, draft + +def _run_mlx(p, args, target, draft, max_prefill, verify_len) -> None: + if args.quantize_draft: + from executorch.examples.models.eagle3.quant_mlx import ( + quantize_pack_draft_for_mlx, + ) + + print("Quantizing + packing draft for MLX (int4)...") + draft = quantize_pack_draft_for_mlx(draft) + else: + print("Keeping draft in bf16 (pass --quantize-draft for int4)...") + # MLX builds attention masks internally, so a single forward accepts T>=1. + if max_prefill < verify_len: + p.error( + f"computed max_prefill={max_prefill} < verify window {verify_len}; " + f"raise --max-prefill (got {args.max_prefill}) or --max-seq-len " + f"(got {args.max_seq_len})" + ) spec = Eagle3Speculator(target, draft).eval() + _export_mlx( + spec, + args.output_dir, + max_prefill=max_prefill, + chain_len=args.chain, + share_base_embedding=args.share_draft_embedding, + ) - # A single target forward accepts min_forward_len .. max_forward_len tokens. - max_forward = spec_t.max_forward_len(target.config) - max_prefill = min(args.max_prefill, args.max_seq_len - 1, max_forward) + +def _run_cuda(p, args, spec_t, target, draft, max_prefill, verify_len) -> None: + spec = Eagle3Speculator(target, draft).eval() # prefill's dynamic min (see _export_cuda target_min): the target's own # specialization (min_forward_len) and the INT4 dispatch (> MATVEC_MAX_M). prefill_min = max(spec_t.min_forward_len, _MATVEC_MAX_M + 1) @@ -324,9 +494,7 @@ def main() -> None: f"{args.max_seq_len})" ) # target_verify is a single static forward of chain+1 tokens: it must fit the - # small-M GEMM (chain+1 <= _MATVEC_MAX_M) and the target's per-forward bounds - # [min_forward_len, max_forward]. - verify_len = args.chain + 1 + # small-M GEMM (chain+1 <= _MATVEC_MAX_M) and the target's minimum forward. if verify_len > _MATVEC_MAX_M: p.error( f"--chain {args.chain} (verify window {verify_len}) exceeds the " @@ -337,14 +505,8 @@ def main() -> None: f"--chain {args.chain} (verify window {verify_len}) is below the " f"target's minimum forward length {spec_t.min_forward_len}" ) - if verify_len > min(args.max_seq_len - 1, max_forward): - p.error( - f"--chain {args.chain} (verify window {verify_len}) exceeds the " - f"target's per-forward limit {min(args.max_seq_len - 1, max_forward)}" - ) # Route the static chain_len+1 verify forward to the small-M INT4 GEMM by - # raising the dispatch threshold for this export only; restore it so the - # process-global default (4) is unchanged for any later use. + # raising the dispatch threshold for this export only; restore it after. import executorch.backends.cuda.int4_dispatch as int4_dispatch saved_threshold = int4_dispatch.MATVEC_MAX_M @@ -361,5 +523,75 @@ def main() -> None: int4_dispatch.MATVEC_MAX_M = saved_threshold +def main() -> None: + p = argparse.ArgumentParser(description="Export an EAGLE-3 speculator to .pte.") + p.add_argument( + "--target-model", + default="gemma4_31b", + choices=list(TARGETS), + help="Registered target model (see eagle3/target.py).", + ) + p.add_argument( + "--target", required=True, help="Prequantized (INT4) target directory." + ) + p.add_argument("--draft", required=True, help="EAGLE-3 draft head directory.") + p.add_argument("--output-dir", default="./eagle3_exports") + p.add_argument("--max-seq-len", type=int, default=4096) + p.add_argument( + "--max-prefill", + type=int, + default=512, + help="Max prefill length: AOTI compiles prefill kernels for up to this T " + "and the whole prompt must fit in one prefill (the runner does not chunk). " + "Smaller compiles faster.", + ) + p.add_argument( + "--chain", type=int, default=4, help="Draft chain length K (verify K+1)." + ) + p.add_argument( + "--backend", + default="cuda", + choices=["cuda", "mlx"], + help="Target backend: cuda (AOTI, INT4 target, bf16 draft) or mlx " + "(Apple silicon, INT4 target; bf16 draft, --quantize-draft for int4).", + ) + p.add_argument( + "--share-draft-embedding", + action="store_true", + help="MLX only: reuse the target's packed embedding for the draft so the " + "NamedDataStore dedups it to one copy (drops the draft's own embedding).", + ) + p.add_argument( + "--quantize-draft", + action="store_true", + help="MLX only: int4-pack the draft head (default keeps it bf16). bf16 " + "gives higher acceptance but a larger draft; pair with " + "--share-draft-embedding to avoid a separate draft embedding copy.", + ) + args = p.parse_args() + _validate_backend_flags(p, args) + + spec_t = TARGETS[args.target_model] + if args.backend == "cuda" and not torch.cuda.is_available(): + p.error("CUDA is required to compile the CUDA EAGLE-3 export.") + + target, draft = _load_target_and_draft(p, args, spec_t) + + # A single target forward accepts up to max_forward_len tokens. + max_forward = spec_t.max_forward_len(target.config) + max_prefill = min(args.max_prefill, args.max_seq_len - 1, max_forward) + verify_len = args.chain + 1 + if verify_len > min(args.max_seq_len - 1, max_forward): + p.error( + f"--chain {args.chain} (verify window {verify_len}) exceeds the " + f"target's per-forward limit {min(args.max_seq_len - 1, max_forward)}" + ) + + if args.backend == "mlx": + _run_mlx(p, args, target, draft, max_prefill, verify_len) + else: + _run_cuda(p, args, spec_t, target, draft, max_prefill, verify_len) + + if __name__ == "__main__": main() diff --git a/examples/models/eagle3/main.cpp b/examples/models/eagle3/main.cpp index 6a68e89eaaa..7b81f6ef6d1 100644 --- a/examples/models/eagle3/main.cpp +++ b/examples/models/eagle3/main.cpp @@ -6,15 +6,22 @@ * LICENSE file in the root directory of this source tree. */ -// EAGLE-3 speculative-decoding runner for ExecuTorch (CUDA/AOTI backend). +// EAGLE-3 speculative-decoding runner for ExecuTorch (CUDA/AOTI or MLX +// backend). // -// Loads the speculator .pte (examples/models/eagle3/export.py) exposing three -// methods that share the target / draft KV caches: +// Loads the speculator .pte (examples/models/eagle3/export.py). The CUDA export +// exposes three methods that share the target / draft KV caches: // prefill(tokens[1,T], pos[T]) -> (next_token[1,1], feat[1,T,H]) // target_verify(tokens[1,C], pos[C]) -> (greedy_ids[1,C], feat[1,C,H]) -// draft_decode(tokens[1,T], feat[1,T,H], pos[T]) -> (target_ids[1,T], g[1,T,H]) -// where feat is the fused (hidden-size) draft feature and H is the draft hidden -// size. Verification is greedy (argmax), so emitted tokens equal greedy target +// draft_decode(tokens[1,T], feat[1,T,H], pos[T]) -> (target_ids[1,T], +// g[1,T,H]) +// MLX has no cross-method KV sharing, so prefill+verify are one dynamic-seq +// method and the methods return logits instead of ids: +// target_forward(tokens[1,T], pos[T]) -> (logits[1,T,V], feat[1,T,H]) +// draft_decode(...) -> (draft_logits[1,T,Vd], g[1,T,H]) +// This runner argmaxes those host-side and maps draft ids via d2t (get_d2t). +// feat is the fused (hidden-size) draft feature and H is the draft hidden size. +// Verification is greedy (argmax), so emitted tokens equal greedy target // decoding (lossless) by construction. // // Scheme: the shifted EAGLE convention (vllm/v1/spec_decode/eagle.py, @@ -30,23 +37,25 @@ // negligible next to the INT4 31B target forward, and it keeps device-tensor // lifetimes simple. // -// Run (after exporting model.pte + aoti_cuda_blob.ptd via export.py, sourcing the -// CUDA env, and building the eagle3-cuda preset): +// Run (CUDA: export model.pte + aoti_cuda_blob.ptd and source the CUDA env; +// MLX: a single model.pte, no --data_path): // eagle3_speculator_runner --model_path /model.pte \ -// --data_path /aoti_cuda_blob.ptd --tokenizer_path \ +// [--data_path /aoti_cuda_blob.ptd] --tokenizer_path +// \ // --prompt "..." --max_new_tokens 128 // The chat template and stop tokens default to Gemma 4 IT; override -// --chat_prefix/--chat_suffix/--stop_ids/--stop_token (and --bos_id -1) for other -// target/tokenizer pairs. Per-run timing counters (tau, verify/draft ms) print at -// the end. +// --chat_prefix/--chat_suffix/--stop_ids/--stop_token (and --bos_id -1) for +// other target/tokenizer pairs. Per-run timing counters (tau, verify/draft ms) +// print at the end. // -// Scope: a single-sequence, greedy, fixed-shape demo runner -- not a generic -// EAGLE serving path. No batching, sampler stack (top-k/p/temperature), grammar/ -// tool constraints, streaming API, or integration with the standard ExecuTorch -// LLM runner. The host feature round-trip above is a first-implementation choice -// (the target forward dominates here); a device-resident handoff is future work. -// The target, draft, and tokenizer must be a matched, co-trained set -- a -// mismatch can pass export and silently degrade acceptance/output. +// Scope: a single-sequence, fixed-shape demo runner -- not a generic EAGLE +// serving path. CUDA is greedy; MLX adds --temperature (modified rejection +// sampling) but no top-k/p. No batching, grammar/tool constraints, streaming +// API, or integration with the standard ExecuTorch LLM runner. The host feature +// round-trip above is a first-implementation choice (the target forward +// dominates here); a device-resident handoff is future work. The target, draft, +// and tokenizer must be a matched, co-trained set -- a mismatch can pass export +// and silently degrade acceptance/output. #include @@ -62,9 +71,11 @@ #include #include +#include #include #include #include +#include #include #include @@ -96,9 +107,15 @@ DEFINE_bool(raw_prompt, false, "Skip the Gemma 4 IT chat template."); DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate."); DEFINE_int32(bos_id, 2, "BOS token id (-1 to skip; Gemma convention: 2)."); DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1)."); -DEFINE_bool(cuda_graph, true, "Capture target_verify as a CUDA graph (CUDA only)."); +DEFINE_bool( + cuda_graph, + true, + "Capture target_verify as a CUDA graph (CUDA only)."); // Chat template + stop tokens default to Gemma 4 IT; override for other models. -DEFINE_string(chat_prefix, "<|turn>user\n", "Chat-template text before the prompt."); +DEFINE_string( + chat_prefix, + "<|turn>user\n", + "Chat-template text before the prompt."); DEFINE_string( chat_suffix, "\n<|turn>model\n<|channel>thought\n", @@ -111,6 +128,17 @@ DEFINE_string( stop_token, "", "A stop-delimiter string to encode and add to EOS (empty to skip)."); +DEFINE_double( + temperature, + 0.0, + "Sampling temperature (0 = greedy/argmax; >0 = modified rejection " + "sampling; MLX only)."); +DEFINE_int64(seed, 0, "RNG seed for --temperature sampling."); +DEFINE_int64( + chain, + 0, + "Override draft chain length K (MLX only; 0 = use the exported " + "get_chain_len). Ignored on CUDA, whose verify shape is static."); using executorch::extension::from_blob; using executorch::extension::Module; @@ -130,7 +158,8 @@ std::vector to_host_bytes(const executorch::aten::Tensor& t) { cudaPointerAttributes attrs{}; if (cudaPointerGetAttributes(&attrs, ptr) == cudaSuccess && attrs.type == cudaMemoryTypeDevice) { - cudaError_t err = cudaMemcpy(out.data(), ptr, out.size(), cudaMemcpyDeviceToHost); + cudaError_t err = + cudaMemcpy(out.data(), ptr, out.size(), cudaMemcpyDeviceToHost); if (err != cudaSuccess) { ET_LOG(Error, "D2H copy failed: %s", cudaGetErrorString(err)); exit(1); @@ -169,6 +198,99 @@ HostFeature read_feature(const executorch::aten::Tensor& t) { return f; } +#ifndef EXECUTORCH_BUILD_CUDA +// MLX returns logits; read one element of a (1, T, V) Float or BFloat16 row. +inline float +logit_at(const uint8_t* row, int64_t v, executorch::aten::ScalarType st) { + if (st == executorch::aten::ScalarType::BFloat16) { + uint32_t u = + static_cast(reinterpret_cast(row)[v]) << 16; + float f; + memcpy(&f, &u, sizeof(f)); + return f; + } + return reinterpret_cast(row)[v]; +} + +// Per-position argmax over a (1, T, V) logits tensor -> T ids. +std::vector argmax_rows(const executorch::aten::Tensor& t) { + const int64_t T = t.size(1), V = t.size(2); + const auto st = t.scalar_type(); + const int64_t esz = t.numel() > 0 ? (int64_t)(t.nbytes() / t.numel()) : 0; + auto bytes = to_host_bytes(t); + std::vector ids(T); + for (int64_t r = 0; r < T; r++) { + const uint8_t* row = bytes.data() + (size_t)r * V * esz; + int64_t best = 0; + float bv = logit_at(row, 0, st); + for (int64_t v = 1; v < V; v++) { + float x = logit_at(row, v, st); + if (x > bv) { + bv = x; + best = v; + } + } + ids[r] = best; + } + return ids; +} + +// Numerically-stable softmax in place (same max-subtract approach as the static +// softmax in extension/llm/sampler/sampler.cpp). Apply temperature by scaling +// logits first (see temperature_softmax). +void softmax_inplace(std::vector& x) { + if (x.empty()) { + return; + } + float max_val = *std::max_element(x.begin(), x.end()); + float sum = 0.0f; + for (float& v : x) { + v = std::exp(v - max_val); + sum += v; + } + for (float& v : x) { + v /= sum; + } +} + +// Temperature-scale logits then softmax (mirrors Sampler::sample: logits *= +// inv_temperature; softmax(...)). +void temperature_softmax(std::vector& x, float inv_temp) { + for (float& v : x) { + v *= inv_temp; + } + softmax_inplace(x); +} + +// Multinomial sample from a probability vector (mirrors Sampler::sample_mult: +// walk the CDF and return the first index past the coin). +int64_t sample_mult(const std::vector& p, std::mt19937_64& rng) { + float coin = std::uniform_real_distribution(0.0f, 1.0f)(rng); + float cdf = 0.0f; + for (size_t i = 0; i < p.size(); i++) { + cdf += p[i]; + if (coin < cdf) { + return static_cast(i); + } + } + return static_cast(p.size()) - 1; +} + +// Copy row r of a (1, T, V) Float/BFloat16 logits tensor to float[V]. +std::vector row_to_floats(const executorch::aten::Tensor& t, int64_t r) { + const int64_t V = t.size(2); + const auto st = t.scalar_type(); + const int64_t esz = t.numel() > 0 ? (int64_t)(t.nbytes() / t.numel()) : 0; + auto bytes = to_host_bytes(t); + const uint8_t* row = bytes.data() + (size_t)r * V * esz; + std::vector out(V); + for (int64_t v = 0; v < V; v++) { + out[v] = logit_at(row, v, st); + } + return out; +} +#endif + } // namespace int main(int argc, char** argv) { @@ -183,7 +305,10 @@ int main(int argc, char** argv) { auto tokenizer = std::make_unique(); if (tokenizer->load(FLAGS_tokenizer_path) != tokenizers::Error::Ok) { - ET_LOG(Error, "Failed to load tokenizer from %s", FLAGS_tokenizer_path.c_str()); + ET_LOG( + Error, + "Failed to load tokenizer from %s", + FLAGS_tokenizer_path.c_str()); return 1; } @@ -218,7 +343,11 @@ int main(int argc, char** argv) { } #endif +#ifdef EXECUTORCH_BUILD_CUDA for (const char* m : {"prefill", "target_verify", "draft_decode"}) { +#else + for (const char* m : {"target_forward", "draft_decode"}) { +#endif if (module->load_method(m) != Error::Ok) { ET_LOG(Error, "Failed to load method %s", m); return 1; @@ -244,7 +373,42 @@ int main(int argc, char** argv) { const int64_t max_prefill = meta("get_max_prefill_chunk"); const int64_t min_prefill = meta("get_min_prefill_chunk"); const int64_t max_seq_len = meta("get_max_seq_len"); - const int64_t K = chain_len; + int64_t K = chain_len; +#ifndef EXECUTORCH_BUILD_CUDA + // MLX methods are dynamic-seq, so the chain can be set at runtime as long as + // the verify window (K+1) fits the exported prefill range. + if (FLAGS_chain > 0) { + K = FLAGS_chain; + } + if (K + 1 > max_prefill) { + ET_LOG( + Error, + "--chain %" PRId64 " (verify window %" PRId64 + ") exceeds exported max_prefill %" PRId64, + K, + K + 1, + max_prefill); + return 1; + } +#endif + +#ifndef EXECUTORCH_BUILD_CUDA + // MLX draft_decode returns draft-vocab logits; load the draft->target map + // (target_id = draft_id + d2t[draft_id]) baked in by export.py. + std::vector d2t; + { + auto r = module->get("get_d2t"); + if (!r.ok()) { + ET_LOG(Error, "missing get_d2t metadata in .pte"); + return 1; + } + d2t = read_ids(r->toTensor()); + } + const float temp = + FLAGS_temperature > 0.0 ? static_cast(FLAGS_temperature) : 0.0f; + const float inv_temp = temp > 0.0f ? 1.0f / temp : 0.0f; + std::mt19937_64 rng(static_cast(FLAGS_seed)); +#endif // EOS: tokenizer/metadata ids, the configured eos, any --stop_ids, and the // encoded --stop_token delimiter (all default to the Gemma 4 IT conventions). @@ -282,24 +446,38 @@ int main(int argc, char** argv) { } const int64_t L = static_cast(prompt.size()); // The runner does not chunk: the whole prompt must fit one prefill, and its - // length must be within the exported prefill range [min_prefill, max_prefill]. + // length must be within the exported prefill range [min_prefill, + // max_prefill]. if (L > max_prefill) { - ET_LOG(Error, "Prompt (%" PRId64 " tokens) exceeds max_prefill %" PRId64 - "; this runner does not chunk prefill.", L, max_prefill); + ET_LOG( + Error, + "Prompt (%" PRId64 " tokens) exceeds max_prefill %" PRId64 + "; this runner does not chunk prefill.", + L, + max_prefill); return 1; } if (L < min_prefill) { - ET_LOG(Error, "Prompt (%" PRId64 " tokens) is below the exported prefill " - "minimum %" PRId64 "; use a longer prompt.", L, min_prefill); + ET_LOG( + Error, + "Prompt (%" PRId64 + " tokens) is below the exported prefill " + "minimum %" PRId64 "; use a longer prompt.", + L, + min_prefill); return 1; } // The prefill bonus token is always emittable (no KV write past the prompt). - // Each speculative round, however, writes a K-token verify window, so it needs - // anchor_pos + K <= max_seq_len - 1 (enforced in the loop below). Cap the total - // at the positions available; max_new >= 1 since L <= max_prefill < max_seq_len. + // Each speculative round, however, writes a K-token verify window, so it + // needs anchor_pos + K <= max_seq_len - 1 (enforced in the loop below). Cap + // the total at the positions available; max_new >= 1 since L <= max_prefill < + // max_seq_len. int64_t max_new = std::min(FLAGS_max_new_tokens, max_seq_len - L); - printf("Prompt tokens: %" PRId64 ", chain K=%" PRId64 ", max_new=%" PRId64 - "\n", L, K, max_new); + printf( + "Prompt tokens: %" PRId64 ", chain K=%" PRId64 ", max_new=%" PRId64 "\n", + L, + K, + max_new); auto S = [](int64_t v) { return static_cast(v); }; @@ -309,11 +487,15 @@ int main(int argc, char** argv) { auto long_tensor = [&](std::vector& buf) { return from_blob( - buf.data(), {1, S((int64_t)buf.size())}, executorch::aten::ScalarType::Long); + buf.data(), + {1, S((int64_t)buf.size())}, + executorch::aten::ScalarType::Long); }; auto pos_tensor = [&](std::vector& buf) { return from_blob( - buf.data(), {S((int64_t)buf.size())}, executorch::aten::ScalarType::Long); + buf.data(), + {S((int64_t)buf.size())}, + executorch::aten::ScalarType::Long); }; // draft_decode over (tokens, feat rows, positions); returns proposals + the @@ -333,7 +515,9 @@ int main(int argc, char** argv) { feat_buf.assign(feat_rows, feat_rows + feat_T * H); auto t_tok = long_tensor(tok_buf); auto t_feat = from_blob( - feat_buf.data(), {1, S(feat_T), S(H)}, executorch::aten::ScalarType::BFloat16); + feat_buf.data(), + {1, S(feat_T), S(H)}, + executorch::aten::ScalarType::BFloat16); auto t_pos = pos_tensor(pos_buf); auto r = module->execute( "draft_decode", {EValue(t_tok), EValue(t_feat), EValue(t_pos)}); @@ -341,12 +525,19 @@ int main(int argc, char** argv) { ET_LOG(Error, "draft_decode failed"); exit(1); } +#ifdef EXECUTORCH_BUILD_CUDA out_ids = read_ids(r->at(0).toTensor()); +#else + // draft-vocab argmax -> target ids via d2t. + out_ids = argmax_rows(r->at(0).toTensor()); + for (auto& id : out_ids) { + id += d2t[id]; + } +#endif HostFeature g = read_feature(r->at(1).toTensor()); out_last_g.T = 1; out_last_g.H = g.H; - out_last_g.data.assign( - g.data.end() - g.H, g.data.end()); // last row of g + out_last_g.data.assign(g.data.end() - g.H, g.data.end()); // last row of g }; // Run a draft chain seeded by (seed_tokens, seed_feat) at seed positions; the @@ -358,38 +549,139 @@ int main(int argc, char** argv) { std::vector ids; HostFeature last_g; draft_decode( - seed_tokens, seed_feat.data.data(), seed_feat.T, seed_feat.H, - seed_start_pos, ids, last_g); + seed_tokens, + seed_feat.data.data(), + seed_feat.T, + seed_feat.H, + seed_start_pos, + ids, + last_g); proposals.push_back(ids.back()); int64_t last_pos = seed_start_pos + seed_feat.T - 1; for (int64_t k = 1; k < K; k++) { std::vector step_ids; HostFeature step_g; draft_decode( - {proposals.back()}, last_g.data.data(), 1, last_g.H, - last_pos + k, step_ids, step_g); + {proposals.back()}, + last_g.data.data(), + 1, + last_g.H, + last_pos + k, + step_ids, + step_g); proposals.push_back(step_ids[0]); last_g = step_g; } return proposals; }; +#ifndef EXECUTORCH_BUILD_CUDA + // Run draft_decode and return the draft-vocab probabilities for the last + // (predicting) row plus the recurrent feature; the sampling counterpart of + // draft_decode that keeps the distribution for rejection sampling. + auto draft_logits_step = [&](const std::vector& tokens, + const uint16_t* feat_rows, + int64_t feat_T, + int64_t Hh, + int64_t start_pos, + HostFeature& out_last_g) -> std::vector { + tok_buf.assign(tokens.begin(), tokens.end()); + pos_buf.resize(tokens.size()); + for (size_t i = 0; i < tokens.size(); i++) { + pos_buf[i] = start_pos + static_cast(i); + } + feat_buf.assign(feat_rows, feat_rows + feat_T * Hh); + auto t_tok = long_tensor(tok_buf); + auto t_feat = from_blob( + feat_buf.data(), + {1, S(feat_T), S(Hh)}, + executorch::aten::ScalarType::BFloat16); + auto t_pos = pos_tensor(pos_buf); + auto r = module->execute( + "draft_decode", {EValue(t_tok), EValue(t_feat), EValue(t_pos)}); + if (r.error() != Error::Ok) { + ET_LOG(Error, "draft_decode failed"); + exit(1); + } + auto dl = r->at(0).toTensor(); + std::vector q = row_to_floats(dl, dl.size(1) - 1); + temperature_softmax(q, inv_temp); + HostFeature g = read_feature(r->at(1).toTensor()); + out_last_g.T = 1; + out_last_g.H = g.H; + out_last_g.data.assign(g.data.end() - g.H, g.data.end()); + return q; + }; + + // Sampling counterpart of `chain`: sample each proposal from the draft and + // record (draft_id, q) for the rejection test. + auto chain_sample = [&](const std::vector& seed_tokens, + const HostFeature& seed_feat, + int64_t seed_start_pos, + std::vector& draft_ids, + std::vector>& q_rows) { + std::vector proposals; + HostFeature last_g; + std::vector q = draft_logits_step( + seed_tokens, + seed_feat.data.data(), + seed_feat.T, + seed_feat.H, + seed_start_pos, + last_g); + int64_t d = sample_mult(q, rng); + draft_ids.push_back(d); + proposals.push_back(d + d2t[d]); + q_rows.push_back(std::move(q)); + int64_t last_pos = seed_start_pos + seed_feat.T - 1; + for (int64_t k = 1; k < K; k++) { + HostFeature step_g; + std::vector qk = draft_logits_step( + {proposals.back()}, + last_g.data.data(), + 1, + last_g.H, + last_pos + k, + step_g); + int64_t dk = sample_mult(qk, rng); + draft_ids.push_back(dk); + proposals.push_back(dk + d2t[dk]); + q_rows.push_back(std::move(qk)); + last_g = step_g; + } + return proposals; + }; +#endif + stats.model_load_end_ms = llm::time_in_ms(); stats.inference_start_ms = stats.model_load_end_ms; - // --- Prefill: target over the prompt -> bonus token + per-position feature. --- + // --- Prefill: target over the prompt -> bonus token + per-position feature. + // --- tok_buf = prompt; pos_buf.resize(L); for (int64_t i = 0; i < L; i++) { pos_buf[i] = i; } +#ifdef EXECUTORCH_BUILD_CUDA auto pf = module->execute( "prefill", {EValue(long_tensor(tok_buf)), EValue(pos_tensor(pos_buf))}); +#else + auto pf = module->execute( + "target_forward", + {EValue(long_tensor(tok_buf)), EValue(pos_tensor(pos_buf))}); +#endif if (pf.error() != Error::Ok) { ET_LOG(Error, "prefill failed"); return 1; } - int64_t anchor = read_ids(pf->at(0).toTensor())[0]; // bonus token at position L +#ifdef EXECUTORCH_BUILD_CUDA + int64_t anchor = + read_ids(pf->at(0).toTensor())[0]; // bonus token at position L +#else + // target_forward returns per-position logits; the bonus is the last argmax. + int64_t anchor = argmax_rows(pf->at(0).toTensor()).back(); +#endif HostFeature feat_prompt = read_feature(pf->at(1).toTensor()); const int64_t H = feat_prompt.H; int64_t anchor_pos = L; @@ -401,23 +693,37 @@ int main(int argc, char** argv) { uint64_t prev = static_cast(prompt.back()); { auto s = tokenizer->decode(prev, static_cast(anchor)); - if (s.ok()) { printf("%s", s->c_str()); fflush(stdout); } + if (s.ok()) { + printf("%s", s->c_str()); + fflush(stdout); + } prev = static_cast(anchor); } // We only run the speculative loop if more than the (already emitted) prefill - // bonus is wanted, the bonus wasn't EOS, and there is room for a K-token verify - // window. Otherwise we are done -- no draft seeding needed. + // bonus is wanted, the bonus wasn't EOS, and there is room for a K-token + // verify window. Otherwise we are done -- no draft seeding needed. bool hit_eos = eos_ids.count(static_cast(anchor)) > 0; bool speculate = max_new > 1 && !hit_eos && anchor_pos + K <= max_seq_len - 1; std::vector proposals; +#ifndef EXECUTORCH_BUILD_CUDA + // Per-round draft distributions kept for --temperature rejection sampling. + std::vector draft_ids; + std::vector> q_rows; +#endif if (speculate) { // Seed the first chain (shifted): draft slot p pairs feat_prompt[p] with // token_{p+1}; the last slot pairs feat_prompt[L-1] with the bonus and // predicts position L+1. std::vector seed_tokens(prompt.begin() + 1, prompt.end()); seed_tokens.push_back(anchor); +#ifdef EXECUTORCH_BUILD_CUDA proposals = chain(seed_tokens, feat_prompt, 0); +#else + proposals = (temp == 0.0f) + ? chain(seed_tokens, feat_prompt, 0) + : chain_sample(seed_tokens, feat_prompt, 0, draft_ids, q_rows); +#endif } // Stable buffers for target_verify (fixed length K+1) so the CUDA graph @@ -443,18 +749,27 @@ int main(int argc, char** argv) { vpos_buf[i] = anchor_pos + i; } int64_t t_v = llm::time_in_ms(); - auto vr = module->execute("target_verify", {EValue(vtok_t), EValue(vpos_t)}); +#ifdef EXECUTORCH_BUILD_CUDA + auto vr = + module->execute("target_verify", {EValue(vtok_t), EValue(vpos_t)}); +#else + auto vr = + module->execute("target_forward", {EValue(vtok_t), EValue(vpos_t)}); +#endif if (vr.error() != Error::Ok) { - ET_LOG(Error, "target_verify failed"); + ET_LOG(Error, "target forward failed"); return 1; } - std::vector verify_ids = read_ids(vr->at(0).toTensor()); HostFeature verify_feat = read_feature(vr->at(1).toTensor()); verify_ms += llm::time_in_ms() - t_v; - // Greedy acceptance: verify_ids[j] is the greedy token after token j, so it - // checks proposal j (which sits at position anchor_pos+1+j). + // Acceptance: count the leading proposals that pass, then a corrected + // token. verify slot j is the target distribution after token j (proposal j + // sits at position anchor_pos+1+j). int64_t a = 0; + int64_t corrected = 0; +#ifdef EXECUTORCH_BUILD_CUDA + std::vector verify_ids = read_ids(vr->at(0).toTensor()); for (int64_t j = 0; j < K; j++) { if (proposals[j] == verify_ids[j]) { a++; @@ -462,15 +777,75 @@ int main(int argc, char** argv) { break; } } - int64_t corrected = verify_ids[a]; + corrected = verify_ids[a]; +#else + auto verify_logits = vr->at(0).toTensor(); + if (temp == 0.0f) { + // Greedy acceptance against the per-position argmax. + std::vector verify_ids = argmax_rows(verify_logits); + for (int64_t j = 0; j < K; j++) { + if (proposals[j] == verify_ids[j]) { + a++; + } else { + break; + } + } + corrected = verify_ids[a]; + } else { + // Modified rejection sampling (lossless w.r.t. target sampling): accept + // proposal j with prob min(1, p_j[x]/q_j[d]); on reject resample from the + // residual (p - q)_+ over the target vocab; the all-accepted bonus is a + // sample from p_K. + std::uniform_real_distribution u(0.0, 1.0); + bool rejected = false; + for (int64_t j = 0; j < K; j++) { + std::vector p = row_to_floats(verify_logits, j); + temperature_softmax(p, inv_temp); + const int64_t x = proposals[j], d = draft_ids[j]; + const float qx = q_rows[j][d]; + const float ratio = qx > 0.0f ? std::min(1.0f, p[x] / qx) : 0.0f; + if (u(rng) <= ratio) { + a++; + continue; + } + // residual: subtract q mapped to the target vocab, clamp, renormalize. + const std::vector& q = q_rows[j]; + for (int64_t dd = 0; dd < (int64_t)q.size(); dd++) { + p[dd + d2t[dd]] -= q[dd]; + } + double sum = 0.0; + for (float& pv : p) { + if (pv < 0.0f) { + pv = 0.0f; + } + sum += pv; + } + for (float& pv : p) { + pv = static_cast(pv / sum); + } + corrected = sample_mult(p, rng); + rejected = true; + break; + } + if (!rejected) { + std::vector p = row_to_floats(verify_logits, K); + temperature_softmax(p, inv_temp); + corrected = sample_mult(p, rng); + } + } +#endif std::vector newly(proposals.begin(), proposals.begin() + a); newly.push_back(corrected); for (int64_t t : newly) { - if ((int64_t)emitted.size() >= max_new) break; + if ((int64_t)emitted.size() >= max_new) + break; emitted.push_back(t); auto s = tokenizer->decode(prev, static_cast(t)); - if (s.ok()) { printf("%s", s->c_str()); fflush(stdout); } + if (s.ok()) { + printf("%s", s->c_str()); + fflush(stdout); + } prev = static_cast(t); if (eos_ids.count(static_cast(t)) > 0) { // Stop at the first accepted EOS; do not emit the rest of this batch. @@ -481,13 +856,15 @@ int main(int argc, char** argv) { break; } } - if (hit_eos || (int64_t)emitted.size() >= max_new) break; + if (hit_eos || (int64_t)emitted.size() >= max_new) + break; // Reseed the draft (shifted): slot anchor_pos+i holds (verify_feat[i], - // token_{anchor_pos+i+1}) where token = p_i (i