From 028752f4ee3f756775bfdd6ca37ab796f7ad76c2 Mon Sep 17 00:00:00 2001 From: Robin Nabel Date: Tue, 12 May 2026 17:42:40 +0200 Subject: [PATCH 01/13] feat: add LagunaXS2Renderer for poolside/Laguna-XS.2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Hard-coded renderer mirroring the laguna_glm_thinking_v5_1 chat template. The format uses block-style role markers (/, /, /, /) — of these only / are single (added) tokens. Tool calls wrap with single-token /, but inner / tags are plain text and parsed via regex on the decoded block. Other notable properties: - Prefix is <|EOS|> (BOS=EOS in this tokenizer) emitted unconditionally. - Default system prompt baked into the template; consumed from messages[0] if present, attributed to msg_idx=0 so build_training_sample sees it. - Reasoning is rendered for every assistant message (no last-user-index gating), so the renderer is listed in NO_OP_MODELS for the preserve_* thinking tests. - _visible_text accepts list-form content with TextPart entries; the new _thinking_text helper routes ThinkingPart entries to reasoning_content so a parse → reserialize → re-render round-trip preserves reasoning. Wires the renderer through __init__, MODEL_RENDERER_MAP, _populate_registry, and adds the model to the standard test conftest + roundtrip matrices. Adds tests/test_laguna_xs2.py with five focused regressions covering the ThinkingPart round-trip path and degenerate content shapes. --- renderers/__init__.py | 2 + renderers/base.py | 8 +- renderers/laguna_xs2.py | 380 ++++++++++++++++++++++++++++++++ renderers/parsing.py | 117 ++++++++++ tests/conftest.py | 1 + tests/test_laguna_xs2.py | 114 ++++++++++ tests/test_preserve_thinking.py | 1 + tests/test_roundtrip.py | 1 + 8 files changed, 622 insertions(+), 2 deletions(-) create mode 100644 renderers/laguna_xs2.py create mode 100644 tests/test_laguna_xs2.py diff --git a/renderers/__init__.py b/renderers/__init__.py index b28a485..b668df0 100644 --- a/renderers/__init__.py +++ b/renderers/__init__.py @@ -33,6 +33,7 @@ from renderers.gpt_oss import GptOssRenderer from renderers.kimi_k2 import KimiK2Renderer from renderers.kimi_k25 import KimiK25Renderer +from renderers.laguna_xs2 import LagunaXS2Renderer from renderers.minimax_m2 import MiniMaxM2Renderer from renderers.nemotron3 import Nemotron3Renderer from renderers.qwen3 import Qwen3Renderer @@ -51,6 +52,7 @@ "ImagePart", "KimiK2Renderer", "KimiK25Renderer", + "LagunaXS2Renderer", "MULTIMODAL_MODELS", "Message", "MiniMaxM2Renderer", diff --git a/renderers/base.py b/renderers/base.py index a8823fe..3c6c485 100644 --- a/renderers/base.py +++ b/renderers/base.py @@ -541,6 +541,8 @@ def bridge_to_next_turn(self, *args: Any, **kwargs: Any) -> "RenderedTokens | No # Nemotron 3. "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16": "nemotron-3", "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16": "nemotron-3", + # Poolside Laguna. + "poolside/Laguna-XS.2": "laguna-xs.2", # GPT-OSS. "openai/gpt-oss-20b": "gpt-oss", "openai/gpt-oss-120b": "gpt-oss", @@ -677,6 +679,7 @@ def _populate_registry(): from renderers.gpt_oss import GptOssRenderer from renderers.kimi_k2 import KimiK2Renderer from renderers.kimi_k25 import KimiK25Renderer + from renderers.laguna_xs2 import LagunaXS2Renderer from renderers.minimax_m2 import MiniMaxM2Renderer from renderers.nemotron3 import Nemotron3Renderer from renderers.qwen3 import Qwen3Renderer @@ -698,6 +701,7 @@ def _populate_registry(): "deepseek-v3": DeepSeekV3Renderer, "kimi-k2": KimiK2Renderer, "kimi-k2.5": KimiK25Renderer, + "laguna-xs.2": LagunaXS2Renderer, "nemotron-3": Nemotron3Renderer, "gpt-oss": GptOssRenderer, } @@ -762,8 +766,8 @@ def create_renderer( tokenizer: HuggingFace tokenizer instance. renderer: Renderer name ('qwen3', 'qwen3-vl', 'qwen3.5', 'qwen3.6', 'glm-5', 'glm-5.1', 'glm-4.5', 'minimax-m2', 'deepseek-v3', - 'kimi-k2', 'kimi-k2.5', 'nemotron-3', 'gpt-oss', 'default') - or 'auto' to detect from model name. + 'kimi-k2', 'kimi-k2.5', 'laguna-xs.2', 'nemotron-3', + 'gpt-oss', 'default') or 'auto' to detect from model name. tool_parser: Name of a tool parser registered in ``renderers.parsers``. Only consumed by DefaultRenderer. Model-specific renderers have their own parsing wired in. diff --git a/renderers/laguna_xs2.py b/renderers/laguna_xs2.py new file mode 100644 index 0000000..18eef96 --- /dev/null +++ b/renderers/laguna_xs2.py @@ -0,0 +1,380 @@ +"""Laguna-XS.2 Renderer. + +Main properties: +- Prefix is the single token ``〈|EOS|〉`` (also the EOS / stop token). +- Role markers are block-style: ``...``, ``...``, + ``...``, ``...``. Of + these, only ```` / ```` are single (added) tokens + in the tokenizer; everything else is plain text and BPEs into multiple + subwords. +- Assistant turn has an explicit close: ```` is the canonical + stop token (alongside ``〈|EOS|〉``). +- Tool calls: ```` / ```` ARE single tokens, but the + inner ```` / ```` / ```` / ```` + markers are plain text — parsed via regex on the decoded inner block. +- The template bakes in a default system prompt when ``messages[0]`` is not + a system message. The system block also contains the tools section (under + a ``### Tools`` header with an ```` listing and prose + format instructions that vary on ``enable_thinking``). +- Reasoning is rendered for every assistant message — no last-user-index + gating. ``preserve_all_thinking`` and + ``preserve_thinking_between_tool_calls`` are accepted for protocol + uniformity but are effectively no-ops since past reasoning is preserved + by default. +""" + +from __future__ import annotations + +import json +from typing import Any, assert_never + +from transformers.tokenization_utils import PreTrainedTokenizer + +from renderers.base import ( + Content, + Message, + ParsedResponse, + RenderedTokens, + ToolSpec, + reject_assistant_in_extension, +) +from renderers.parsing import parse_laguna_xs2 + +_DEFAULT_SYSTEM_MESSAGE = ( + "You are a helpful, conversationally-fluent assistant made by Poolside. " + "You are here to be helpful to users through natural language conversations." +) + +_TOOLS_HEADER = ( + "\n\n### Tools\n\n" + "You may call functions to assist with the user query.\n" + "All available function signatures are listed below:\n" + "\n" +) + +_TOOLS_FOOTER_THINKING = ( + "\n\n" + "Wrap your thinking in '', '' tags, followed by a function call. " + "For each function call, return an unescaped XML-like object with function name " + "and arguments within '' and '' tags, like here:\n" + " your thoughts here \n" + "function-name\n" + "argument-key\n" + "value-of-argument-key\n" + "" +) + +_TOOLS_FOOTER_NO_THINKING = ( + "\n\n" + "For each function call, return an unescaped XML-like object with function name " + "and arguments within '' and '' tags, like here:\n" + "function-name\n" + "argument-key\n" + "value-of-argument-key\n" + "" +) + + +class LagunaXS2Renderer: + """Deterministic message → token renderer for Poolside's Laguna-XS.2 model.""" + + def __init__( + self, + tokenizer: PreTrainedTokenizer, + *, + enable_thinking: bool = False, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, + ): + self._tokenizer = tokenizer + self._enable_thinking = enable_thinking + # Accepted for protocol uniformity. The chat template renders + # reasoning on every assistant message regardless, so flipping + # these flags has no effect on the byte-level output. + self._preserve_all_thinking = preserve_all_thinking + self._preserve_thinking_between_tool_calls = ( + preserve_thinking_between_tool_calls + ) + + self._eos = self._token_id("〈|EOS|〉") + self._think = self._token_id("") + self._think_end = self._token_id("") + self._assistant = self._token_id("") + self._assistant_end = self._token_id("") + self._tool_call = self._token_id("") + self._tool_call_end = self._token_id("") + + def _token_id(self, token: str) -> int: + tid = self._tokenizer.convert_tokens_to_ids(token) + assert isinstance(tid, int) and tid != self._tokenizer.unk_token_id, ( + f"Special token {token!r} not found in tokenizer vocabulary" + ) + return tid + + def _encode(self, text: str) -> list[int]: + if not text: + return [] + return self._tokenizer.encode(text, add_special_tokens=False) + + @staticmethod + def _visible_text(content: Content | None) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + parts.append(item.get("text", "")) + return "".join(parts) + return "" + + @staticmethod + def _thinking_text(content: Content | None) -> str: + """Concatenate ``ThinkingPart`` entries from list-form content. + + Used as a reasoning source in ``_render_assistant`` when neither + ``reasoning`` nor ``reasoning_content`` is present on the message. + Returns ``""`` for any non-list input. + """ + if not isinstance(content, list): + return "" + parts: list[str] = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "thinking": + parts.append(item.get("thinking", "")) + return "".join(parts) + + def render( + self, + messages: list[Message], + *, + tools: list[ToolSpec] | None = None, + add_generation_prompt: bool = False, + ) -> RenderedTokens: + if not messages: + raise ValueError("No messages provided.") + + tokens: list[int] = [] + indices: list[int] = [] + + def emit_special(token_id: int, msg_idx: int) -> None: + tokens.append(token_id) + indices.append(msg_idx) + + def emit_text(text: str, msg_idx: int) -> None: + ids = self._encode(text) + tokens.extend(ids) + indices.extend([msg_idx] * len(ids)) + + emit_special(self._eos, -1) + + # ── System header (absorbs messages[0] if it's a system message) ── + system_content = _DEFAULT_SYSTEM_MESSAGE + system_msg_idx = -1 + if messages and messages[0].get("role") == "system": + system_content = self._visible_text(messages[0].get("content")) + system_msg_idx = 0 + + has_sys_content = bool(system_content and system_content.strip()) + # Mirrors the template's ``(system_message and system_message.strip()) or tools`` + # gate: when the caller passes an empty system message and no tools, + # the whole ``...`` block is omitted. + if has_sys_content or tools: + # The template emits ``\n`` then conditionally a second + # ``\n``. Bundle those into one emit so BPE merges ``\n\n`` into + # its single-token form (rather than two ``\n`` atoms). + emit_text("\n\n" if has_sys_content else "\n", -1) + if has_sys_content: + emit_text(system_content.rstrip(), system_msg_idx) + if tools: + tool_text = _TOOLS_HEADER + for tool in tools: + tool_text += json.dumps(tool, ensure_ascii=False) + "\n" + tool_text += ( + _TOOLS_FOOTER_THINKING + if self._enable_thinking + else _TOOLS_FOOTER_NO_THINKING + ) + emit_text(tool_text, -1) + emit_text("\n\n", -1) + + # ── Per-message loop ────────────────────────────────────────── + for i, msg in enumerate(messages): + content = self._visible_text(msg.get("content")) + + match msg["role"]: + case "system": + # Already consumed in the header block. + if i == 0: + continue + emit_text("\n" + content + "\n\n", i) + case "user": + emit_text("\n" + content + "\n\n", i) + case "assistant": + self._render_assistant( + msg, i, content, emit_special=emit_special, emit_text=emit_text + ) + case "tool": + emit_text("\n" + content + "\n\n", i) + case unexpected_role: + assert_never(unexpected_role) + + # ── Generation prompt ───────────────────────────────────────── + if add_generation_prompt: + emit_special(self._assistant, -1) + emit_text("\n", -1) + if self._enable_thinking: + emit_special(self._think, -1) + else: + emit_special(self._think_end, -1) + + return RenderedTokens(token_ids=tokens, message_indices=indices) + + def render_ids( + self, + messages: list[Message], + *, + tools: list[ToolSpec] | None = None, + add_generation_prompt: bool = False, + ) -> list[int]: + return self.render( + messages, + tools=tools, + add_generation_prompt=add_generation_prompt, + ).token_ids + + def parse_response(self, token_ids: list[int]) -> ParsedResponse: + return parse_laguna_xs2( + self._tokenizer, + token_ids, + stop_ids={self._assistant_end, self._eos}, + think_id=self._think, + think_end_id=self._think_end, + tool_call_id=self._tool_call, + tool_call_end_id=self._tool_call_end, + ) + + def get_stop_token_ids(self) -> list[int]: + return [self._assistant_end, self._eos] + + def bridge_to_next_turn( + self, + previous_prompt_ids: list[int], + previous_completion_ids: list[int], + new_messages: list[Message], + *, + tools: list[ToolSpec] | None = None, + ) -> RenderedTokens | None: + if ( + not previous_prompt_ids + or not new_messages + or reject_assistant_in_extension(new_messages) + ): + return None + + # The canonical assistant-turn close is ````. ``〈|EOS|〉`` + # also stops generation; either being the final token means the turn + # ended cleanly. Truncation (no stop token at the tail) synthesises + # ``\n`` — the same scaffold the template emits. + previous_ids = list(previous_prompt_ids) + list(previous_completion_ids) + stop_ids = {self._assistant_end, self._eos} + if ( + not previous_ids[len(previous_prompt_ids) :] + or previous_ids[-1] not in stop_ids + ): + previous_ids.append(self._assistant_end) + previous_ids.extend(self._encode("\n")) + + ext: list[int] = [] + + def emit_special(token_id: int, _msg_idx: int = -1) -> None: + ext.append(token_id) + + def emit_text(text: str, _msg_idx: int = -1) -> None: + ext.extend(self._encode(text)) + + for msg in new_messages: + role = msg.get("role") + content = self._visible_text(msg.get("content")) + if role == "user": + emit_text("\n" + content + "\n\n") + elif role == "system": + emit_text("\n" + content + "\n\n") + elif role == "tool": + emit_text("\n" + content + "\n\n") + else: + return None + + emit_special(self._assistant) + emit_text("\n") + if self._enable_thinking: + emit_special(self._think) + else: + emit_special(self._think_end) + + return RenderedTokens(token_ids=previous_ids + ext) + + def _render_assistant( + self, + msg: Message, + msg_idx: int, + content: str, + *, + emit_special, + emit_text, + ) -> None: + reasoning_content = "" + if isinstance(msg.get("reasoning_content"), str): + reasoning_content = msg["reasoning_content"] + else: + # When the caller stores reasoning as a ``ThinkingPart`` inside + # a list-form ``content`` (e.g. after parse_response → + # reserialize), pull it out here so it survives the re-render. + part_thinking = self._thinking_text(msg.get("content")) + if part_thinking: + reasoning_content = part_thinking + + emit_special(self._assistant, msg_idx) + emit_text("\n", msg_idx) + + if reasoning_content: + emit_special(self._think, msg_idx) + emit_text("\n" + reasoning_content.strip() + "\n", msg_idx) + emit_special(self._think_end, msg_idx) + else: + emit_special(self._think_end, msg_idx) + + # Combined newline-after- with optional content. Bundling + # preserves BPE merges across the boundary. + post_think_text = "\n" + if content.strip(): + post_think_text += content.strip() + "\n" + emit_text(post_think_text, msg_idx) + + tool_calls = msg.get("tool_calls") or [] + for tc in tool_calls: + func = tc.get("function") or tc + name = func.get("name", "") + arguments = func.get("arguments", {}) + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + arguments = {} + + emit_special(self._tool_call, msg_idx) + inner = name + "\n" + if isinstance(arguments, dict): + for k, v in arguments.items(): + inner += "" + k + "\n" + if isinstance(v, str): + val_text = v + else: + val_text = json.dumps(v, ensure_ascii=False) + inner += "" + val_text + "\n" + emit_text(inner, msg_idx) + emit_special(self._tool_call_end, msg_idx) + emit_text("\n", msg_idx) + + emit_special(self._assistant_end, msg_idx) + emit_text("\n", msg_idx) diff --git a/renderers/parsing.py b/renderers/parsing.py index 6644103..9540353 100644 --- a/renderers/parsing.py +++ b/renderers/parsing.py @@ -310,6 +310,123 @@ def _parse_glm_tool_calls( return tool_calls +# ── Laguna-XS.2: name\nk\nv +# Same outer skeleton as parse_glm, but / are plain text +# (multi-token BPE), not single special tokens — so the inner block is decoded +# to text and the key/value pairs are pulled out by regex. + + +def parse_laguna_xs2( + tokenizer, + token_ids: list[int], + *, + stop_ids: set[int], + think_id: int, + think_end_id: int, + tool_call_id: int, + tool_call_end_id: int, +) -> ParsedResponse: + """Parse Laguna-XS.2 completion tokens. + + Thinking uses single-token ```` / ```` (ids found by + scan). Tool calls are delimited by single-token ```` / + ````, but ```` / ```` inside are + plain text — regex-extracted from the decoded inner block. + """ + ids = _strip_stop_tokens(token_ids, stop_ids) + + # The template wraps reasoning with ``\n`` on both sides + # (``\n{r}\n``) and brackets post-think content with ``\n`` + # too (``\n{c}\n``). Strip exactly those newlines from each + # decoded segment — never a bare ``.strip()``, which would also eat + # whitespace the model emitted intentionally. + reasoning = None + think_end = _find(ids, think_end_id) + if think_end != -1: + reasoning_ids = ids[:think_end] + reasoning_ids = [t for t in reasoning_ids if t != think_id] + reasoning = _decode(tokenizer, reasoning_ids).strip("\n") + ids = ids[think_end + 1 :] + elif (think_start := _find(ids, think_id)) != -1: + reasoning = _decode(tokenizer, ids[think_start + 1 :]).strip("\n") + return ParsedResponse( + content="", reasoning_content=reasoning or None, tool_calls=None + ) + + tc_start = _find(ids, tool_call_id) + if tc_start != -1: + content_text = _decode(tokenizer, ids[:tc_start]).strip("\n") + tool_calls = _parse_laguna_xs2_tool_calls( + tokenizer, ids[tc_start:], tool_call_id, tool_call_end_id + ) + else: + content_text = _decode(tokenizer, ids).strip("\n") + tool_calls = None + + return ParsedResponse( + content=content_text, + reasoning_content=reasoning or None, + tool_calls=tool_calls or None, + ) + + +def _parse_laguna_xs2_tool_calls( + tokenizer, + ids: list[int], + tc_id: int, + tc_end_id: int, +) -> list[dict]: + """Parse Laguna-XS.2 tool calls. + + Inside each ``...`` block, the format is:: + + {name}\\n + {k1}\\n{v1}\\n + ... + {kn}\\n{vn}\\n + + The function name is everything before the first ```` literal + in the decoded block. + """ + import re + + tool_calls: list[dict] = [] + i = 0 + while i < len(ids): + if ids[i] == tc_id: + tc_end = _find(ids, tc_end_id, i + 1) + if tc_end == -1: + break + block_text = _decode(tokenizer, ids[i + 1 : tc_end]) + + ak_pos = block_text.find("") + if ak_pos != -1: + name = block_text[:ak_pos].strip() + args_section = block_text[ak_pos:] + else: + name = block_text.strip() + args_section = "" + + arguments: dict = {} + for m in re.finditer( + r"(.*?)\s*(.*?)", + args_section, + re.DOTALL, + ): + k = m.group(1).strip() + v = m.group(2).strip() + try: + arguments[k] = json.loads(v) + except (json.JSONDecodeError, ValueError): + arguments[k] = v + + tool_calls.append({"function": {"name": name, "arguments": arguments}}) + i = tc_end + 1 + else: + i += 1 + return tool_calls + + # ── DeepSeek V3: <|tool▁calls▁begin|>...<|tool▁calls▁end|> + text tags ── diff --git a/tests/conftest.py b/tests/conftest.py index 5ef65a5..8eea97b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,6 +32,7 @@ ("moonshotai/Kimi-K2.6", "auto"), ("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", "auto"), ("nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16", "auto"), + ("poolside/Laguna-XS.2", "auto"), ("openai/gpt-oss-20b", "gpt-oss"), ("Qwen/Qwen2.5-0.5B-Instruct", "default"), ] diff --git a/tests/test_laguna_xs2.py b/tests/test_laguna_xs2.py new file mode 100644 index 0000000..aa90939 --- /dev/null +++ b/tests/test_laguna_xs2.py @@ -0,0 +1,114 @@ +"""Laguna-XS.2-specific renderer behavior. + +The Jinja template never iterates ``message.content`` as a list — it +coerces non-string content to ``""``. The renderer, however, is the seam +where a parse → reserialize → re-render round-trip can hand it +structured ``content=[ThinkingPart, TextPart]`` (e.g. trajectory storage +that preserves typed parts). Dropping those parts would silently lose +reasoning from previous turns on every re-render — the regression this +file guards against. +""" + +from __future__ import annotations + +from functools import lru_cache + +import pytest + + +_MODEL = "poolside/Laguna-XS.2" + + +@lru_cache(maxsize=1) +def _renderer(): + from renderers import create_renderer + from renderers.base import load_tokenizer + + return create_renderer(load_tokenizer(_MODEL)) + + +PROMPT = [{"role": "user", "content": "Hi"}] + + +def test_thinking_part_round_trip_matches_flat_form(): + """``content=[ThinkingPart, TextPart]`` must render identically to + ``reasoning_content=... , content=...`` — otherwise reasoning is + silently dropped on every re-render through the trajectory loop.""" + r = _renderer() + structured = { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "thinking through the problem"}, + {"type": "text", "text": "Hello!"}, + ], + } + flat = { + "role": "assistant", + "reasoning_content": "thinking through the problem", + "content": "Hello!", + } + assert r.render_ids(PROMPT + [structured]) == r.render_ids(PROMPT + [flat]) + + +def test_text_part_only_matches_string_content(): + """List with only ``TextPart`` entries collapses to the equivalent string content.""" + r = _renderer() + listed = {"role": "assistant", "content": [{"type": "text", "text": "Hello!"}]} + string = {"role": "assistant", "content": "Hello!"} + assert r.render_ids(PROMPT + [listed]) == r.render_ids(PROMPT + [string]) + + +def test_reasoning_field_beats_thinking_part(): + """Explicit ``reasoning_content`` wins over ``ThinkingPart`` in the + same message — the field is canonical, the part is the fallback.""" + r = _renderer() + both = { + "role": "assistant", + "reasoning_content": "from field", + "content": [ + {"type": "thinking", "thinking": "from part (should be ignored)"}, + {"type": "text", "text": "Hi"}, + ], + } + field_only = { + "role": "assistant", + "reasoning_content": "from field", + "content": "Hi", + } + assert r.render_ids(PROMPT + [both]) == r.render_ids(PROMPT + [field_only]) + + +def test_multiple_thinking_parts_concatenated(): + """Multiple ``ThinkingPart`` entries concatenate (insertion order).""" + r = _renderer() + multi = { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "first"}, + {"type": "thinking", "thinking": "second"}, + {"type": "text", "text": "out"}, + ], + } + flat = { + "role": "assistant", + "reasoning_content": "firstsecond", + "content": "out", + } + assert r.render_ids(PROMPT + [multi]) == r.render_ids(PROMPT + [flat]) + + +@pytest.mark.parametrize( + "shape", + [ + pytest.param(None, id="none"), + pytest.param([], id="empty-list"), + pytest.param([{"type": "image"}], id="non-text-non-thinking-only"), + ], +) +def test_degenerate_content_collapses_to_empty(shape): + """Content shapes that produce no visible text and no thinking must + render the same as ``content=""`` — no crashes, no extra tokens.""" + r = _renderer() + degen = {"role": "assistant", "content": shape} + empty = {"role": "assistant", "content": ""} + assert r.render_ids(PROMPT + [degen]) == r.render_ids(PROMPT + [empty]) diff --git a/tests/test_preserve_thinking.py b/tests/test_preserve_thinking.py index b0d6a9e..661d577 100644 --- a/tests/test_preserve_thinking.py +++ b/tests/test_preserve_thinking.py @@ -42,6 +42,7 @@ def _make(tokenizer, renderer_name, **flags): "Qwen/Qwen3-VL-4B-Instruct", "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen3-VL-30B-A3B-Instruct", + "poolside/Laguna-XS.2", } diff --git a/tests/test_roundtrip.py b/tests/test_roundtrip.py index 8c313ed..6df3bf7 100644 --- a/tests/test_roundtrip.py +++ b/tests/test_roundtrip.py @@ -43,6 +43,7 @@ ("moonshotai/Kimi-K2.6", "auto"), ("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", "auto"), ("nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16", "auto"), + ("poolside/Laguna-XS.2", "auto"), ("openai/gpt-oss-20b", "gpt-oss"), ("Qwen/Qwen2.5-0.5B-Instruct", "default"), ] From fbb4dbc0a2f395305cd0a15ba2e749338d4a7c19 Mon Sep 17 00:00:00 2001 From: Robin Nabel Date: Tue, 12 May 2026 17:48:17 +0200 Subject: [PATCH 02/13] style: drop redundant LagunaXS2Renderer class docstring --- renderers/laguna_xs2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/renderers/laguna_xs2.py b/renderers/laguna_xs2.py index 18eef96..c981265 100644 --- a/renderers/laguna_xs2.py +++ b/renderers/laguna_xs2.py @@ -76,8 +76,6 @@ class LagunaXS2Renderer: - """Deterministic message → token renderer for Poolside's Laguna-XS.2 model.""" - def __init__( self, tokenizer: PreTrainedTokenizer, From a98edd483599808227007e361d7d38e04a4ad7df Mon Sep 17 00:00:00 2001 From: Robin Nabel Date: Tue, 12 May 2026 17:50:50 +0200 Subject: [PATCH 03/13] test: remove Laguna-XS.2-specific test file The list-form content behaviors these tests exercised (TextPart extraction, ThinkingPart routing, reasoning_content precedence, degenerate-shape robustness) are generic Renderer-protocol invariants rather than Laguna-specific quirks. Better suited to the shared conftest matrix with opt-in subsets, see PR description for upstreaming suggestions. --- tests/test_laguna_xs2.py | 114 --------------------------------------- 1 file changed, 114 deletions(-) delete mode 100644 tests/test_laguna_xs2.py diff --git a/tests/test_laguna_xs2.py b/tests/test_laguna_xs2.py deleted file mode 100644 index aa90939..0000000 --- a/tests/test_laguna_xs2.py +++ /dev/null @@ -1,114 +0,0 @@ -"""Laguna-XS.2-specific renderer behavior. - -The Jinja template never iterates ``message.content`` as a list — it -coerces non-string content to ``""``. The renderer, however, is the seam -where a parse → reserialize → re-render round-trip can hand it -structured ``content=[ThinkingPart, TextPart]`` (e.g. trajectory storage -that preserves typed parts). Dropping those parts would silently lose -reasoning from previous turns on every re-render — the regression this -file guards against. -""" - -from __future__ import annotations - -from functools import lru_cache - -import pytest - - -_MODEL = "poolside/Laguna-XS.2" - - -@lru_cache(maxsize=1) -def _renderer(): - from renderers import create_renderer - from renderers.base import load_tokenizer - - return create_renderer(load_tokenizer(_MODEL)) - - -PROMPT = [{"role": "user", "content": "Hi"}] - - -def test_thinking_part_round_trip_matches_flat_form(): - """``content=[ThinkingPart, TextPart]`` must render identically to - ``reasoning_content=... , content=...`` — otherwise reasoning is - silently dropped on every re-render through the trajectory loop.""" - r = _renderer() - structured = { - "role": "assistant", - "content": [ - {"type": "thinking", "thinking": "thinking through the problem"}, - {"type": "text", "text": "Hello!"}, - ], - } - flat = { - "role": "assistant", - "reasoning_content": "thinking through the problem", - "content": "Hello!", - } - assert r.render_ids(PROMPT + [structured]) == r.render_ids(PROMPT + [flat]) - - -def test_text_part_only_matches_string_content(): - """List with only ``TextPart`` entries collapses to the equivalent string content.""" - r = _renderer() - listed = {"role": "assistant", "content": [{"type": "text", "text": "Hello!"}]} - string = {"role": "assistant", "content": "Hello!"} - assert r.render_ids(PROMPT + [listed]) == r.render_ids(PROMPT + [string]) - - -def test_reasoning_field_beats_thinking_part(): - """Explicit ``reasoning_content`` wins over ``ThinkingPart`` in the - same message — the field is canonical, the part is the fallback.""" - r = _renderer() - both = { - "role": "assistant", - "reasoning_content": "from field", - "content": [ - {"type": "thinking", "thinking": "from part (should be ignored)"}, - {"type": "text", "text": "Hi"}, - ], - } - field_only = { - "role": "assistant", - "reasoning_content": "from field", - "content": "Hi", - } - assert r.render_ids(PROMPT + [both]) == r.render_ids(PROMPT + [field_only]) - - -def test_multiple_thinking_parts_concatenated(): - """Multiple ``ThinkingPart`` entries concatenate (insertion order).""" - r = _renderer() - multi = { - "role": "assistant", - "content": [ - {"type": "thinking", "thinking": "first"}, - {"type": "thinking", "thinking": "second"}, - {"type": "text", "text": "out"}, - ], - } - flat = { - "role": "assistant", - "reasoning_content": "firstsecond", - "content": "out", - } - assert r.render_ids(PROMPT + [multi]) == r.render_ids(PROMPT + [flat]) - - -@pytest.mark.parametrize( - "shape", - [ - pytest.param(None, id="none"), - pytest.param([], id="empty-list"), - pytest.param([{"type": "image"}], id="non-text-non-thinking-only"), - ], -) -def test_degenerate_content_collapses_to_empty(shape): - """Content shapes that produce no visible text and no thinking must - render the same as ``content=""`` — no crashes, no extra tokens.""" - r = _renderer() - degen = {"role": "assistant", "content": shape} - empty = {"role": "assistant", "content": ""} - assert r.render_ids(PROMPT + [degen]) == r.render_ids(PROMPT + [empty]) From 6e7792a1b9669b5e4900c56bf5fa084293be3d7a Mon Sep 17 00:00:00 2001 From: Robin Nabel Date: Tue, 12 May 2026 17:42:40 +0200 Subject: [PATCH 04/13] feat: add LagunaXS2Renderer for poolside/Laguna-XS.2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Hard-coded renderer mirroring the laguna_glm_thinking_v5_1 chat template. The format uses block-style role markers (/, /, /, /) — of these only / are single (added) tokens. Tool calls wrap with single-token /, but inner / tags are plain text and parsed via regex on the decoded block. Other notable properties: - Prefix is <|EOS|> (BOS=EOS in this tokenizer) emitted unconditionally. - Default system prompt baked into the template; consumed from messages[0] if present, attributed to msg_idx=0 so build_training_sample sees it. - Reasoning is rendered for every assistant message (no last-user-index gating), so the renderer is listed in NO_OP_MODELS for the preserve_* thinking tests. - _visible_text accepts list-form content with TextPart entries; the new _thinking_text helper routes ThinkingPart entries to reasoning_content so a parse → reserialize → re-render round-trip preserves reasoning. Wires the renderer through __init__, MODEL_RENDERER_MAP, _populate_registry, and adds the model to the standard test conftest + roundtrip matrices. Adds tests/test_laguna_xs2.py with five focused regressions covering the ThinkingPart round-trip path and degenerate content shapes. --- renderers/__init__.py | 2 + renderers/base.py | 8 +- renderers/laguna_xs2.py | 380 ++++++++++++++++++++++++++++++++ renderers/parsing.py | 117 ++++++++++ tests/conftest.py | 1 + tests/test_laguna_xs2.py | 114 ++++++++++ tests/test_preserve_thinking.py | 1 + tests/test_roundtrip.py | 1 + 8 files changed, 622 insertions(+), 2 deletions(-) create mode 100644 renderers/laguna_xs2.py create mode 100644 tests/test_laguna_xs2.py diff --git a/renderers/__init__.py b/renderers/__init__.py index a169186..62bc666 100644 --- a/renderers/__init__.py +++ b/renderers/__init__.py @@ -43,6 +43,7 @@ from renderers.gpt_oss import GptOssRenderer from renderers.kimi_k2 import KimiK2Renderer from renderers.kimi_k25 import KimiK25Renderer +from renderers.laguna_xs2 import LagunaXS2Renderer from renderers.minimax_m2 import MiniMaxM2Renderer from renderers.nemotron3 import Nemotron3Renderer from renderers.qwen3 import Qwen3Renderer @@ -61,6 +62,7 @@ "ImagePart", "KimiK2Renderer", "KimiK25Renderer", + "LagunaXS2Renderer", "MULTIMODAL_MODELS", "Message", "MiniMaxM2Renderer", diff --git a/renderers/base.py b/renderers/base.py index 64e9f1c..2b74e0c 100644 --- a/renderers/base.py +++ b/renderers/base.py @@ -603,6 +603,8 @@ def bridge_to_next_turn(self, *args: Any, **kwargs: Any) -> "RenderedTokens | No # Nemotron 3. "nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16": "nemotron-3", "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16": "nemotron-3", + # Poolside Laguna. + "poolside/Laguna-XS.2": "laguna-xs.2", # GPT-OSS. "openai/gpt-oss-20b": "gpt-oss", "openai/gpt-oss-120b": "gpt-oss", @@ -739,6 +741,7 @@ def _populate_registry(): from renderers.gpt_oss import GptOssRenderer from renderers.kimi_k2 import KimiK2Renderer from renderers.kimi_k25 import KimiK25Renderer + from renderers.laguna_xs2 import LagunaXS2Renderer from renderers.minimax_m2 import MiniMaxM2Renderer from renderers.nemotron3 import Nemotron3Renderer from renderers.qwen3 import Qwen3Renderer @@ -760,6 +763,7 @@ def _populate_registry(): "deepseek-v3": DeepSeekV3Renderer, "kimi-k2": KimiK2Renderer, "kimi-k2.5": KimiK25Renderer, + "laguna-xs.2": LagunaXS2Renderer, "nemotron-3": Nemotron3Renderer, "gpt-oss": GptOssRenderer, } @@ -824,8 +828,8 @@ def create_renderer( tokenizer: HuggingFace tokenizer instance. renderer: Renderer name ('qwen3', 'qwen3-vl', 'qwen3.5', 'qwen3.6', 'glm-5', 'glm-5.1', 'glm-4.5', 'minimax-m2', 'deepseek-v3', - 'kimi-k2', 'kimi-k2.5', 'nemotron-3', 'gpt-oss', 'default') - or 'auto' to detect from model name. + 'kimi-k2', 'kimi-k2.5', 'laguna-xs.2', 'nemotron-3', + 'gpt-oss', 'default') or 'auto' to detect from model name. tool_parser: Name of a tool parser registered in ``renderers.parsers``. Only consumed by DefaultRenderer. Model-specific renderers have their own parsing wired in. diff --git a/renderers/laguna_xs2.py b/renderers/laguna_xs2.py new file mode 100644 index 0000000..18eef96 --- /dev/null +++ b/renderers/laguna_xs2.py @@ -0,0 +1,380 @@ +"""Laguna-XS.2 Renderer. + +Main properties: +- Prefix is the single token ``〈|EOS|〉`` (also the EOS / stop token). +- Role markers are block-style: ``...``, ``...``, + ``...``, ``...``. Of + these, only ```` / ```` are single (added) tokens + in the tokenizer; everything else is plain text and BPEs into multiple + subwords. +- Assistant turn has an explicit close: ```` is the canonical + stop token (alongside ``〈|EOS|〉``). +- Tool calls: ```` / ```` ARE single tokens, but the + inner ```` / ```` / ```` / ```` + markers are plain text — parsed via regex on the decoded inner block. +- The template bakes in a default system prompt when ``messages[0]`` is not + a system message. The system block also contains the tools section (under + a ``### Tools`` header with an ```` listing and prose + format instructions that vary on ``enable_thinking``). +- Reasoning is rendered for every assistant message — no last-user-index + gating. ``preserve_all_thinking`` and + ``preserve_thinking_between_tool_calls`` are accepted for protocol + uniformity but are effectively no-ops since past reasoning is preserved + by default. +""" + +from __future__ import annotations + +import json +from typing import Any, assert_never + +from transformers.tokenization_utils import PreTrainedTokenizer + +from renderers.base import ( + Content, + Message, + ParsedResponse, + RenderedTokens, + ToolSpec, + reject_assistant_in_extension, +) +from renderers.parsing import parse_laguna_xs2 + +_DEFAULT_SYSTEM_MESSAGE = ( + "You are a helpful, conversationally-fluent assistant made by Poolside. " + "You are here to be helpful to users through natural language conversations." +) + +_TOOLS_HEADER = ( + "\n\n### Tools\n\n" + "You may call functions to assist with the user query.\n" + "All available function signatures are listed below:\n" + "\n" +) + +_TOOLS_FOOTER_THINKING = ( + "\n\n" + "Wrap your thinking in '', '' tags, followed by a function call. " + "For each function call, return an unescaped XML-like object with function name " + "and arguments within '' and '' tags, like here:\n" + " your thoughts here \n" + "function-name\n" + "argument-key\n" + "value-of-argument-key\n" + "" +) + +_TOOLS_FOOTER_NO_THINKING = ( + "\n\n" + "For each function call, return an unescaped XML-like object with function name " + "and arguments within '' and '' tags, like here:\n" + "function-name\n" + "argument-key\n" + "value-of-argument-key\n" + "" +) + + +class LagunaXS2Renderer: + """Deterministic message → token renderer for Poolside's Laguna-XS.2 model.""" + + def __init__( + self, + tokenizer: PreTrainedTokenizer, + *, + enable_thinking: bool = False, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, + ): + self._tokenizer = tokenizer + self._enable_thinking = enable_thinking + # Accepted for protocol uniformity. The chat template renders + # reasoning on every assistant message regardless, so flipping + # these flags has no effect on the byte-level output. + self._preserve_all_thinking = preserve_all_thinking + self._preserve_thinking_between_tool_calls = ( + preserve_thinking_between_tool_calls + ) + + self._eos = self._token_id("〈|EOS|〉") + self._think = self._token_id("") + self._think_end = self._token_id("") + self._assistant = self._token_id("") + self._assistant_end = self._token_id("") + self._tool_call = self._token_id("") + self._tool_call_end = self._token_id("") + + def _token_id(self, token: str) -> int: + tid = self._tokenizer.convert_tokens_to_ids(token) + assert isinstance(tid, int) and tid != self._tokenizer.unk_token_id, ( + f"Special token {token!r} not found in tokenizer vocabulary" + ) + return tid + + def _encode(self, text: str) -> list[int]: + if not text: + return [] + return self._tokenizer.encode(text, add_special_tokens=False) + + @staticmethod + def _visible_text(content: Content | None) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + parts.append(item.get("text", "")) + return "".join(parts) + return "" + + @staticmethod + def _thinking_text(content: Content | None) -> str: + """Concatenate ``ThinkingPart`` entries from list-form content. + + Used as a reasoning source in ``_render_assistant`` when neither + ``reasoning`` nor ``reasoning_content`` is present on the message. + Returns ``""`` for any non-list input. + """ + if not isinstance(content, list): + return "" + parts: list[str] = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "thinking": + parts.append(item.get("thinking", "")) + return "".join(parts) + + def render( + self, + messages: list[Message], + *, + tools: list[ToolSpec] | None = None, + add_generation_prompt: bool = False, + ) -> RenderedTokens: + if not messages: + raise ValueError("No messages provided.") + + tokens: list[int] = [] + indices: list[int] = [] + + def emit_special(token_id: int, msg_idx: int) -> None: + tokens.append(token_id) + indices.append(msg_idx) + + def emit_text(text: str, msg_idx: int) -> None: + ids = self._encode(text) + tokens.extend(ids) + indices.extend([msg_idx] * len(ids)) + + emit_special(self._eos, -1) + + # ── System header (absorbs messages[0] if it's a system message) ── + system_content = _DEFAULT_SYSTEM_MESSAGE + system_msg_idx = -1 + if messages and messages[0].get("role") == "system": + system_content = self._visible_text(messages[0].get("content")) + system_msg_idx = 0 + + has_sys_content = bool(system_content and system_content.strip()) + # Mirrors the template's ``(system_message and system_message.strip()) or tools`` + # gate: when the caller passes an empty system message and no tools, + # the whole ``...`` block is omitted. + if has_sys_content or tools: + # The template emits ``\n`` then conditionally a second + # ``\n``. Bundle those into one emit so BPE merges ``\n\n`` into + # its single-token form (rather than two ``\n`` atoms). + emit_text("\n\n" if has_sys_content else "\n", -1) + if has_sys_content: + emit_text(system_content.rstrip(), system_msg_idx) + if tools: + tool_text = _TOOLS_HEADER + for tool in tools: + tool_text += json.dumps(tool, ensure_ascii=False) + "\n" + tool_text += ( + _TOOLS_FOOTER_THINKING + if self._enable_thinking + else _TOOLS_FOOTER_NO_THINKING + ) + emit_text(tool_text, -1) + emit_text("\n\n", -1) + + # ── Per-message loop ────────────────────────────────────────── + for i, msg in enumerate(messages): + content = self._visible_text(msg.get("content")) + + match msg["role"]: + case "system": + # Already consumed in the header block. + if i == 0: + continue + emit_text("\n" + content + "\n\n", i) + case "user": + emit_text("\n" + content + "\n\n", i) + case "assistant": + self._render_assistant( + msg, i, content, emit_special=emit_special, emit_text=emit_text + ) + case "tool": + emit_text("\n" + content + "\n\n", i) + case unexpected_role: + assert_never(unexpected_role) + + # ── Generation prompt ───────────────────────────────────────── + if add_generation_prompt: + emit_special(self._assistant, -1) + emit_text("\n", -1) + if self._enable_thinking: + emit_special(self._think, -1) + else: + emit_special(self._think_end, -1) + + return RenderedTokens(token_ids=tokens, message_indices=indices) + + def render_ids( + self, + messages: list[Message], + *, + tools: list[ToolSpec] | None = None, + add_generation_prompt: bool = False, + ) -> list[int]: + return self.render( + messages, + tools=tools, + add_generation_prompt=add_generation_prompt, + ).token_ids + + def parse_response(self, token_ids: list[int]) -> ParsedResponse: + return parse_laguna_xs2( + self._tokenizer, + token_ids, + stop_ids={self._assistant_end, self._eos}, + think_id=self._think, + think_end_id=self._think_end, + tool_call_id=self._tool_call, + tool_call_end_id=self._tool_call_end, + ) + + def get_stop_token_ids(self) -> list[int]: + return [self._assistant_end, self._eos] + + def bridge_to_next_turn( + self, + previous_prompt_ids: list[int], + previous_completion_ids: list[int], + new_messages: list[Message], + *, + tools: list[ToolSpec] | None = None, + ) -> RenderedTokens | None: + if ( + not previous_prompt_ids + or not new_messages + or reject_assistant_in_extension(new_messages) + ): + return None + + # The canonical assistant-turn close is ````. ``〈|EOS|〉`` + # also stops generation; either being the final token means the turn + # ended cleanly. Truncation (no stop token at the tail) synthesises + # ``\n`` — the same scaffold the template emits. + previous_ids = list(previous_prompt_ids) + list(previous_completion_ids) + stop_ids = {self._assistant_end, self._eos} + if ( + not previous_ids[len(previous_prompt_ids) :] + or previous_ids[-1] not in stop_ids + ): + previous_ids.append(self._assistant_end) + previous_ids.extend(self._encode("\n")) + + ext: list[int] = [] + + def emit_special(token_id: int, _msg_idx: int = -1) -> None: + ext.append(token_id) + + def emit_text(text: str, _msg_idx: int = -1) -> None: + ext.extend(self._encode(text)) + + for msg in new_messages: + role = msg.get("role") + content = self._visible_text(msg.get("content")) + if role == "user": + emit_text("\n" + content + "\n\n") + elif role == "system": + emit_text("\n" + content + "\n\n") + elif role == "tool": + emit_text("\n" + content + "\n\n") + else: + return None + + emit_special(self._assistant) + emit_text("\n") + if self._enable_thinking: + emit_special(self._think) + else: + emit_special(self._think_end) + + return RenderedTokens(token_ids=previous_ids + ext) + + def _render_assistant( + self, + msg: Message, + msg_idx: int, + content: str, + *, + emit_special, + emit_text, + ) -> None: + reasoning_content = "" + if isinstance(msg.get("reasoning_content"), str): + reasoning_content = msg["reasoning_content"] + else: + # When the caller stores reasoning as a ``ThinkingPart`` inside + # a list-form ``content`` (e.g. after parse_response → + # reserialize), pull it out here so it survives the re-render. + part_thinking = self._thinking_text(msg.get("content")) + if part_thinking: + reasoning_content = part_thinking + + emit_special(self._assistant, msg_idx) + emit_text("\n", msg_idx) + + if reasoning_content: + emit_special(self._think, msg_idx) + emit_text("\n" + reasoning_content.strip() + "\n", msg_idx) + emit_special(self._think_end, msg_idx) + else: + emit_special(self._think_end, msg_idx) + + # Combined newline-after- with optional content. Bundling + # preserves BPE merges across the boundary. + post_think_text = "\n" + if content.strip(): + post_think_text += content.strip() + "\n" + emit_text(post_think_text, msg_idx) + + tool_calls = msg.get("tool_calls") or [] + for tc in tool_calls: + func = tc.get("function") or tc + name = func.get("name", "") + arguments = func.get("arguments", {}) + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + arguments = {} + + emit_special(self._tool_call, msg_idx) + inner = name + "\n" + if isinstance(arguments, dict): + for k, v in arguments.items(): + inner += "" + k + "\n" + if isinstance(v, str): + val_text = v + else: + val_text = json.dumps(v, ensure_ascii=False) + inner += "" + val_text + "\n" + emit_text(inner, msg_idx) + emit_special(self._tool_call_end, msg_idx) + emit_text("\n", msg_idx) + + emit_special(self._assistant_end, msg_idx) + emit_text("\n", msg_idx) diff --git a/renderers/parsing.py b/renderers/parsing.py index 827cc81..19cf6ec 100644 --- a/renderers/parsing.py +++ b/renderers/parsing.py @@ -424,6 +424,123 @@ def _parse_glm_tool_calls( return tool_calls +# ── Laguna-XS.2: name\nk\nv +# Same outer skeleton as parse_glm, but / are plain text +# (multi-token BPE), not single special tokens — so the inner block is decoded +# to text and the key/value pairs are pulled out by regex. + + +def parse_laguna_xs2( + tokenizer, + token_ids: list[int], + *, + stop_ids: set[int], + think_id: int, + think_end_id: int, + tool_call_id: int, + tool_call_end_id: int, +) -> ParsedResponse: + """Parse Laguna-XS.2 completion tokens. + + Thinking uses single-token ```` / ```` (ids found by + scan). Tool calls are delimited by single-token ```` / + ````, but ```` / ```` inside are + plain text — regex-extracted from the decoded inner block. + """ + ids = _strip_stop_tokens(token_ids, stop_ids) + + # The template wraps reasoning with ``\n`` on both sides + # (``\n{r}\n``) and brackets post-think content with ``\n`` + # too (``\n{c}\n``). Strip exactly those newlines from each + # decoded segment — never a bare ``.strip()``, which would also eat + # whitespace the model emitted intentionally. + reasoning = None + think_end = _find(ids, think_end_id) + if think_end != -1: + reasoning_ids = ids[:think_end] + reasoning_ids = [t for t in reasoning_ids if t != think_id] + reasoning = _decode(tokenizer, reasoning_ids).strip("\n") + ids = ids[think_end + 1 :] + elif (think_start := _find(ids, think_id)) != -1: + reasoning = _decode(tokenizer, ids[think_start + 1 :]).strip("\n") + return ParsedResponse( + content="", reasoning_content=reasoning or None, tool_calls=None + ) + + tc_start = _find(ids, tool_call_id) + if tc_start != -1: + content_text = _decode(tokenizer, ids[:tc_start]).strip("\n") + tool_calls = _parse_laguna_xs2_tool_calls( + tokenizer, ids[tc_start:], tool_call_id, tool_call_end_id + ) + else: + content_text = _decode(tokenizer, ids).strip("\n") + tool_calls = None + + return ParsedResponse( + content=content_text, + reasoning_content=reasoning or None, + tool_calls=tool_calls or None, + ) + + +def _parse_laguna_xs2_tool_calls( + tokenizer, + ids: list[int], + tc_id: int, + tc_end_id: int, +) -> list[dict]: + """Parse Laguna-XS.2 tool calls. + + Inside each ``...`` block, the format is:: + + {name}\\n + {k1}\\n{v1}\\n + ... + {kn}\\n{vn}\\n + + The function name is everything before the first ```` literal + in the decoded block. + """ + import re + + tool_calls: list[dict] = [] + i = 0 + while i < len(ids): + if ids[i] == tc_id: + tc_end = _find(ids, tc_end_id, i + 1) + if tc_end == -1: + break + block_text = _decode(tokenizer, ids[i + 1 : tc_end]) + + ak_pos = block_text.find("") + if ak_pos != -1: + name = block_text[:ak_pos].strip() + args_section = block_text[ak_pos:] + else: + name = block_text.strip() + args_section = "" + + arguments: dict = {} + for m in re.finditer( + r"(.*?)\s*(.*?)", + args_section, + re.DOTALL, + ): + k = m.group(1).strip() + v = m.group(2).strip() + try: + arguments[k] = json.loads(v) + except (json.JSONDecodeError, ValueError): + arguments[k] = v + + tool_calls.append({"function": {"name": name, "arguments": arguments}}) + i = tc_end + 1 + else: + i += 1 + return tool_calls + + # ── DeepSeek V3: <|tool▁calls▁begin|>...<|tool▁calls▁end|> + text tags ── diff --git a/tests/conftest.py b/tests/conftest.py index 5ef65a5..8eea97b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -32,6 +32,7 @@ ("moonshotai/Kimi-K2.6", "auto"), ("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", "auto"), ("nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16", "auto"), + ("poolside/Laguna-XS.2", "auto"), ("openai/gpt-oss-20b", "gpt-oss"), ("Qwen/Qwen2.5-0.5B-Instruct", "default"), ] diff --git a/tests/test_laguna_xs2.py b/tests/test_laguna_xs2.py new file mode 100644 index 0000000..aa90939 --- /dev/null +++ b/tests/test_laguna_xs2.py @@ -0,0 +1,114 @@ +"""Laguna-XS.2-specific renderer behavior. + +The Jinja template never iterates ``message.content`` as a list — it +coerces non-string content to ``""``. The renderer, however, is the seam +where a parse → reserialize → re-render round-trip can hand it +structured ``content=[ThinkingPart, TextPart]`` (e.g. trajectory storage +that preserves typed parts). Dropping those parts would silently lose +reasoning from previous turns on every re-render — the regression this +file guards against. +""" + +from __future__ import annotations + +from functools import lru_cache + +import pytest + + +_MODEL = "poolside/Laguna-XS.2" + + +@lru_cache(maxsize=1) +def _renderer(): + from renderers import create_renderer + from renderers.base import load_tokenizer + + return create_renderer(load_tokenizer(_MODEL)) + + +PROMPT = [{"role": "user", "content": "Hi"}] + + +def test_thinking_part_round_trip_matches_flat_form(): + """``content=[ThinkingPart, TextPart]`` must render identically to + ``reasoning_content=... , content=...`` — otherwise reasoning is + silently dropped on every re-render through the trajectory loop.""" + r = _renderer() + structured = { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "thinking through the problem"}, + {"type": "text", "text": "Hello!"}, + ], + } + flat = { + "role": "assistant", + "reasoning_content": "thinking through the problem", + "content": "Hello!", + } + assert r.render_ids(PROMPT + [structured]) == r.render_ids(PROMPT + [flat]) + + +def test_text_part_only_matches_string_content(): + """List with only ``TextPart`` entries collapses to the equivalent string content.""" + r = _renderer() + listed = {"role": "assistant", "content": [{"type": "text", "text": "Hello!"}]} + string = {"role": "assistant", "content": "Hello!"} + assert r.render_ids(PROMPT + [listed]) == r.render_ids(PROMPT + [string]) + + +def test_reasoning_field_beats_thinking_part(): + """Explicit ``reasoning_content`` wins over ``ThinkingPart`` in the + same message — the field is canonical, the part is the fallback.""" + r = _renderer() + both = { + "role": "assistant", + "reasoning_content": "from field", + "content": [ + {"type": "thinking", "thinking": "from part (should be ignored)"}, + {"type": "text", "text": "Hi"}, + ], + } + field_only = { + "role": "assistant", + "reasoning_content": "from field", + "content": "Hi", + } + assert r.render_ids(PROMPT + [both]) == r.render_ids(PROMPT + [field_only]) + + +def test_multiple_thinking_parts_concatenated(): + """Multiple ``ThinkingPart`` entries concatenate (insertion order).""" + r = _renderer() + multi = { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "first"}, + {"type": "thinking", "thinking": "second"}, + {"type": "text", "text": "out"}, + ], + } + flat = { + "role": "assistant", + "reasoning_content": "firstsecond", + "content": "out", + } + assert r.render_ids(PROMPT + [multi]) == r.render_ids(PROMPT + [flat]) + + +@pytest.mark.parametrize( + "shape", + [ + pytest.param(None, id="none"), + pytest.param([], id="empty-list"), + pytest.param([{"type": "image"}], id="non-text-non-thinking-only"), + ], +) +def test_degenerate_content_collapses_to_empty(shape): + """Content shapes that produce no visible text and no thinking must + render the same as ``content=""`` — no crashes, no extra tokens.""" + r = _renderer() + degen = {"role": "assistant", "content": shape} + empty = {"role": "assistant", "content": ""} + assert r.render_ids(PROMPT + [degen]) == r.render_ids(PROMPT + [empty]) diff --git a/tests/test_preserve_thinking.py b/tests/test_preserve_thinking.py index b0d6a9e..661d577 100644 --- a/tests/test_preserve_thinking.py +++ b/tests/test_preserve_thinking.py @@ -42,6 +42,7 @@ def _make(tokenizer, renderer_name, **flags): "Qwen/Qwen3-VL-4B-Instruct", "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen3-VL-30B-A3B-Instruct", + "poolside/Laguna-XS.2", } diff --git a/tests/test_roundtrip.py b/tests/test_roundtrip.py index 1cdac82..a4577fd 100644 --- a/tests/test_roundtrip.py +++ b/tests/test_roundtrip.py @@ -43,6 +43,7 @@ ("moonshotai/Kimi-K2.6", "auto"), ("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", "auto"), ("nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16", "auto"), + ("poolside/Laguna-XS.2", "auto"), ("openai/gpt-oss-20b", "gpt-oss"), ("Qwen/Qwen2.5-0.5B-Instruct", "default"), ] From 167044c75f543da04db218f225cef3cc722f3f7c Mon Sep 17 00:00:00 2001 From: Robin Nabel Date: Tue, 12 May 2026 17:48:17 +0200 Subject: [PATCH 05/13] style: drop redundant LagunaXS2Renderer class docstring --- renderers/laguna_xs2.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/renderers/laguna_xs2.py b/renderers/laguna_xs2.py index 18eef96..c981265 100644 --- a/renderers/laguna_xs2.py +++ b/renderers/laguna_xs2.py @@ -76,8 +76,6 @@ class LagunaXS2Renderer: - """Deterministic message → token renderer for Poolside's Laguna-XS.2 model.""" - def __init__( self, tokenizer: PreTrainedTokenizer, From 1b7e4692a269fa57d042c263c6891988002556db Mon Sep 17 00:00:00 2001 From: Robin Nabel Date: Tue, 12 May 2026 17:50:50 +0200 Subject: [PATCH 06/13] test: remove Laguna-XS.2-specific test file The list-form content behaviors these tests exercised (TextPart extraction, ThinkingPart routing, reasoning_content precedence, degenerate-shape robustness) are generic Renderer-protocol invariants rather than Laguna-specific quirks. Better suited to the shared conftest matrix with opt-in subsets, see PR description for upstreaming suggestions. --- tests/test_laguna_xs2.py | 114 --------------------------------------- 1 file changed, 114 deletions(-) delete mode 100644 tests/test_laguna_xs2.py diff --git a/tests/test_laguna_xs2.py b/tests/test_laguna_xs2.py deleted file mode 100644 index aa90939..0000000 --- a/tests/test_laguna_xs2.py +++ /dev/null @@ -1,114 +0,0 @@ -"""Laguna-XS.2-specific renderer behavior. - -The Jinja template never iterates ``message.content`` as a list — it -coerces non-string content to ``""``. The renderer, however, is the seam -where a parse → reserialize → re-render round-trip can hand it -structured ``content=[ThinkingPart, TextPart]`` (e.g. trajectory storage -that preserves typed parts). Dropping those parts would silently lose -reasoning from previous turns on every re-render — the regression this -file guards against. -""" - -from __future__ import annotations - -from functools import lru_cache - -import pytest - - -_MODEL = "poolside/Laguna-XS.2" - - -@lru_cache(maxsize=1) -def _renderer(): - from renderers import create_renderer - from renderers.base import load_tokenizer - - return create_renderer(load_tokenizer(_MODEL)) - - -PROMPT = [{"role": "user", "content": "Hi"}] - - -def test_thinking_part_round_trip_matches_flat_form(): - """``content=[ThinkingPart, TextPart]`` must render identically to - ``reasoning_content=... , content=...`` — otherwise reasoning is - silently dropped on every re-render through the trajectory loop.""" - r = _renderer() - structured = { - "role": "assistant", - "content": [ - {"type": "thinking", "thinking": "thinking through the problem"}, - {"type": "text", "text": "Hello!"}, - ], - } - flat = { - "role": "assistant", - "reasoning_content": "thinking through the problem", - "content": "Hello!", - } - assert r.render_ids(PROMPT + [structured]) == r.render_ids(PROMPT + [flat]) - - -def test_text_part_only_matches_string_content(): - """List with only ``TextPart`` entries collapses to the equivalent string content.""" - r = _renderer() - listed = {"role": "assistant", "content": [{"type": "text", "text": "Hello!"}]} - string = {"role": "assistant", "content": "Hello!"} - assert r.render_ids(PROMPT + [listed]) == r.render_ids(PROMPT + [string]) - - -def test_reasoning_field_beats_thinking_part(): - """Explicit ``reasoning_content`` wins over ``ThinkingPart`` in the - same message — the field is canonical, the part is the fallback.""" - r = _renderer() - both = { - "role": "assistant", - "reasoning_content": "from field", - "content": [ - {"type": "thinking", "thinking": "from part (should be ignored)"}, - {"type": "text", "text": "Hi"}, - ], - } - field_only = { - "role": "assistant", - "reasoning_content": "from field", - "content": "Hi", - } - assert r.render_ids(PROMPT + [both]) == r.render_ids(PROMPT + [field_only]) - - -def test_multiple_thinking_parts_concatenated(): - """Multiple ``ThinkingPart`` entries concatenate (insertion order).""" - r = _renderer() - multi = { - "role": "assistant", - "content": [ - {"type": "thinking", "thinking": "first"}, - {"type": "thinking", "thinking": "second"}, - {"type": "text", "text": "out"}, - ], - } - flat = { - "role": "assistant", - "reasoning_content": "firstsecond", - "content": "out", - } - assert r.render_ids(PROMPT + [multi]) == r.render_ids(PROMPT + [flat]) - - -@pytest.mark.parametrize( - "shape", - [ - pytest.param(None, id="none"), - pytest.param([], id="empty-list"), - pytest.param([{"type": "image"}], id="non-text-non-thinking-only"), - ], -) -def test_degenerate_content_collapses_to_empty(shape): - """Content shapes that produce no visible text and no thinking must - render the same as ``content=""`` — no crashes, no extra tokens.""" - r = _renderer() - degen = {"role": "assistant", "content": shape} - empty = {"role": "assistant", "content": ""} - assert r.render_ids(PROMPT + [degen]) == r.render_ids(PROMPT + [empty]) From efd04f16f495058257e291ed3f09437d9a4580a1 Mon Sep 17 00:00:00 2001 From: Konstantin Haller Date: Wed, 13 May 2026 13:19:12 +0000 Subject: [PATCH 07/13] fix(laguna): migrate parser to ParsedToolCall API + drop broken assert_never MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After rebasing onto current main, the Laguna parser broke against ParsedResponse.tool_calls' new list[ParsedToolCall] shape (introduced in #22). Mirror parse_glm's structure: emit ParsedToolCall with status (UNCLOSED_BLOCK / MISSING_NAME / INVALID_JSON / OK) and token_span relative to the stop-stripped stream. Also drop ``assert_never(unexpected_role)``: msg["role"] is plain ``str`` (TypedDict), so ty flags a type-assertion-failure and any unknown role would crash at runtime — every other renderer silently skips unknown roles. Co-Authored-By: Claude Opus 4.7 (1M context) --- renderers/laguna_xs2.py | 3 --- renderers/parsing.py | 48 +++++++++++++++++++++++++++++++++++------ 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/renderers/laguna_xs2.py b/renderers/laguna_xs2.py index c981265..ee8e64d 100644 --- a/renderers/laguna_xs2.py +++ b/renderers/laguna_xs2.py @@ -26,7 +26,6 @@ from __future__ import annotations import json -from typing import Any, assert_never from transformers.tokenization_utils import PreTrainedTokenizer @@ -214,8 +213,6 @@ def emit_text(text: str, msg_idx: int) -> None: ) case "tool": emit_text("\n" + content + "\n\n", i) - case unexpected_role: - assert_never(unexpected_role) # ── Generation prompt ───────────────────────────────────────── if add_generation_prompt: diff --git a/renderers/parsing.py b/renderers/parsing.py index 19cf6ec..e235c76 100644 --- a/renderers/parsing.py +++ b/renderers/parsing.py @@ -455,32 +455,38 @@ def parse_laguna_xs2( # decoded segment — never a bare ``.strip()``, which would also eat # whitespace the model emitted intentionally. reasoning = None + parse_offset = 0 think_end = _find(ids, think_end_id) if think_end != -1: reasoning_ids = ids[:think_end] reasoning_ids = [t for t in reasoning_ids if t != think_id] reasoning = _decode(tokenizer, reasoning_ids).strip("\n") ids = ids[think_end + 1 :] + parse_offset = think_end + 1 elif (think_start := _find(ids, think_id)) != -1: reasoning = _decode(tokenizer, ids[think_start + 1 :]).strip("\n") return ParsedResponse( - content="", reasoning_content=reasoning or None, tool_calls=None + content="", reasoning_content=reasoning or None, tool_calls=[] ) tc_start = _find(ids, tool_call_id) + tool_calls: list[ParsedToolCall] = [] if tc_start != -1: content_text = _decode(tokenizer, ids[:tc_start]).strip("\n") tool_calls = _parse_laguna_xs2_tool_calls( - tokenizer, ids[tc_start:], tool_call_id, tool_call_end_id + tokenizer, + ids[tc_start:], + tool_call_id, + tool_call_end_id, + section_offset=parse_offset + tc_start, ) else: content_text = _decode(tokenizer, ids).strip("\n") - tool_calls = None return ParsedResponse( content=content_text, reasoning_content=reasoning or None, - tool_calls=tool_calls or None, + tool_calls=tool_calls, ) @@ -489,7 +495,9 @@ def _parse_laguna_xs2_tool_calls( ids: list[int], tc_id: int, tc_end_id: int, -) -> list[dict]: + *, + section_offset: int, +) -> list[ParsedToolCall]: """Parse Laguna-XS.2 tool calls. Inside each ``...`` block, the format is:: @@ -504,14 +512,23 @@ def _parse_laguna_xs2_tool_calls( """ import re - tool_calls: list[dict] = [] + tool_calls: list[ParsedToolCall] = [] i = 0 while i < len(ids): if ids[i] == tc_id: tc_end = _find(ids, tc_end_id, i + 1) if tc_end == -1: + raw = _decode(tokenizer, ids[i + 1 :]) + tool_calls.append( + ParsedToolCall( + raw=raw, + token_span=(section_offset + i, section_offset + len(ids)), + status=ToolCallParseStatus.UNCLOSED_BLOCK, + ) + ) break block_text = _decode(tokenizer, ids[i + 1 : tc_end]) + span = (section_offset + i, section_offset + tc_end + 1) ak_pos = block_text.find("") if ak_pos != -1: @@ -522,6 +539,7 @@ def _parse_laguna_xs2_tool_calls( args_section = "" arguments: dict = {} + any_json_fallback = False for m in re.finditer( r"(.*?)\s*(.*?)", args_section, @@ -533,8 +551,24 @@ def _parse_laguna_xs2_tool_calls( arguments[k] = json.loads(v) except (json.JSONDecodeError, ValueError): arguments[k] = v + any_json_fallback = True + + if not name: + status = ToolCallParseStatus.MISSING_NAME + elif any_json_fallback: + status = ToolCallParseStatus.INVALID_JSON + else: + status = ToolCallParseStatus.OK - tool_calls.append({"function": {"name": name, "arguments": arguments}}) + tool_calls.append( + ParsedToolCall( + raw=block_text, + name=name or None, + arguments=arguments, + token_span=span, + status=status, + ) + ) i = tc_end + 1 else: i += 1 From 85320a3a758dd424c604ea229365f9b0066662a2 Mon Sep 17 00:00:00 2001 From: hallerite Date: Wed, 13 May 2026 13:30:37 +0000 Subject: [PATCH 08/13] build: derive version from git tags via hatch-vcs (#20) --- .github/workflows/publish.yml | 19 ++++--------------- .gitignore | 2 ++ pyproject.toml | 27 +++++++++++++++++++++++++-- renderers/__init__.py | 9 +++++++++ 4 files changed, 40 insertions(+), 17 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index f4331e1..681fb0e 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -41,6 +41,10 @@ jobs: TAG="$PUSHED_REF" fi + # The package version is derived from this tag by hatch-vcs + # at build time (see [tool.hatch.version] in pyproject.toml). + # We only need to validate the tag shape — there's no + # ``project.version`` field to cross-check anymore. case "$TAG" in renderers-v*) ;; *) @@ -49,21 +53,6 @@ jobs: ;; esac - VERSION="${TAG#renderers-v}" - FILE_VERSION=$(python - <<'PY' - import tomllib - from pathlib import Path - with Path('pyproject.toml').open('rb') as f: - data = tomllib.load(f) - print(data['project']['version']) - PY - ) - - if [ "$FILE_VERSION" != "$VERSION" ]; then - echo "Version mismatch: tag requests '$VERSION' but pyproject.toml defines '$FILE_VERSION'" >&2 - exit 1 - fi - echo "tag=$TAG" >> "$GITHUB_OUTPUT" - uses: astral-sh/setup-uv@v7 diff --git a/.gitignore b/.gitignore index 2f3dfb0..efbf202 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,8 @@ __pycache__/ *.pyc *.pyo *.pyd +# generated by hatch-vcs at build time (see [tool.hatch.build.hooks.vcs]) +renderers/_version.py # tooling caches .pytest_cache/ diff --git a/pyproject.toml b/pyproject.toml index 3fc98c8..0b4a00c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,14 @@ [build-system] -requires = ["hatchling"] +requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" [project] name = "renderers" -version = "0.1.7" +# Derived from git tags by hatch-vcs (see [tool.hatch.version] below). +# Untagged commits get PEP 440 dev versions like ``0.1.8.dev3+g4c877be4`` +# so any commit is uniquely installable; tagged commits get clean +# release versions like ``0.1.8``. +dynamic = ["version"] description = "Chat template renderers — deterministic message-to-token conversion for LLM training" readme = "README.md" requires-python = ">=3.10,<3.14" @@ -20,6 +24,25 @@ dependencies = [ "openai-harmony>=0.0.8", ] +[tool.hatch.version] +source = "vcs" +# Tags look like ``renderers-v0.1.8`` (prefix matches the publish.yml +# release contract); strip the prefix to get a PEP 440 version. The +# regex accepts any PEP 440-valid suffix after the prefix so we can +# tag pre-releases like ``renderers-v0.2.0rc1`` later if needed. +tag-pattern = '^renderers-v(?P.+)$' +# Used when building from a context without VCS metadata (e.g. an +# sdist consumed by a downstream that doesn't ship .git). Real +# builds from a checkout get the resolved version; this fallback +# only fires when the resolver has nothing to go on. +fallback-version = "0.0.0" + +[tool.hatch.build.hooks.vcs] +# Write the resolved version to a Python file so it can be inspected +# at runtime via ``renderers.__version__`` without re-parsing the +# wheel metadata. +version-file = "renderers/_version.py" + [tool.hatch.build.targets.wheel] packages = ["renderers"] diff --git a/renderers/__init__.py b/renderers/__init__.py index b28a485..24f076c 100644 --- a/renderers/__init__.py +++ b/renderers/__init__.py @@ -1,3 +1,11 @@ +try: + from renderers._version import __version__ +except ImportError: + # Source checkout without a built artifact (e.g. editable install + # before the first ``uv build`` populates ``_version.py``). Real + # installs always have it. + __version__ = "0+unknown" + from renderers.base import ( Content, ContentPart, @@ -73,6 +81,7 @@ "ToolCallFunction", "ToolSpec", "VideoPart", + "__version__", "build_training_sample", "build_trajectory_step", "create_renderer", From 99e28108671b794aa8e09ee21728847f433333a1 Mon Sep 17 00:00:00 2001 From: hallerite Date: Wed, 13 May 2026 13:30:38 +0000 Subject: [PATCH 09/13] =?UTF-8?q?feat:=20renderer=20emits=20numpy,=20not?= =?UTF-8?q?=20torch=20=E2=80=94=20drop=20torch=20from=20renderer=20surface?= =?UTF-8?q?=20(#18)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- renderers/client.py | 11 +++++++++-- renderers/kimi_k25.py | 2 +- renderers/qwen35.py | 2 +- renderers/qwen3_vl.py | 2 +- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/renderers/client.py b/renderers/client.py index d0fa563..7d585f4 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -263,8 +263,15 @@ def _build_qwen_vl_features( image_items = mm_data.mm_items.get("image") or [] if image_items: - pixel_values = torch.cat([it["pixel_values"] for it in image_items], dim=0) - image_grid_thw = torch.cat([it["image_grid_thw"] for it in image_items], dim=0) + # mm_items now ship numpy arrays (the renderer is torch-free); + # convert at this vLLM-glue boundary where torch is already a + # hard dependency. + pixel_values = torch.cat( + [torch.as_tensor(it["pixel_values"]) for it in image_items], dim=0 + ) + image_grid_thw = torch.cat( + [torch.as_tensor(it["image_grid_thw"]) for it in image_items], dim=0 + ) hf_inputs = BatchFeature( data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw} ) diff --git a/renderers/kimi_k25.py b/renderers/kimi_k25.py index a4afa73..0984498 100644 --- a/renderers/kimi_k25.py +++ b/renderers/kimi_k25.py @@ -621,7 +621,7 @@ def _process_image(self, part: dict[str, Any]): img_proc = proc.image_processor # Kimi's vision processor takes a media-dict shape, not raw PIL. media_item = {"type": "image", "image": pil} - out = img_proc.preprocess([media_item], return_tensors="pt") + out = img_proc.preprocess([media_item], return_tensors="np") # Patch count via the processor's own calculator (matches the # model's per-patch attention count); kept for debugging. num_patches = int(img_proc.media_tokens_calculator(media_item)) diff --git a/renderers/qwen35.py b/renderers/qwen35.py index 2ee4a1c..05f155e 100644 --- a/renderers/qwen35.py +++ b/renderers/qwen35.py @@ -182,7 +182,7 @@ def _process_image(self, part: dict[str, Any]): out, num_image_tokens = cached return pil, out, num_image_tokens, h proc = self._get_processor() - out = proc.image_processor(images=[pil], return_tensors="pt") + out = proc.image_processor(images=[pil], return_tensors="np") grid_thw = out["image_grid_thw"][0] merge_size = proc.image_processor.merge_size num_image_tokens = int(grid_thw.prod()) // (merge_size * merge_size) diff --git a/renderers/qwen3_vl.py b/renderers/qwen3_vl.py index eabbdc3..7f8e39f 100644 --- a/renderers/qwen3_vl.py +++ b/renderers/qwen3_vl.py @@ -357,7 +357,7 @@ def _process_image(self, part: dict[str, Any]): out, num_image_tokens = cached return pil, out, num_image_tokens, h proc = self._get_processor() - out = proc.image_processor(images=[pil], return_tensors="pt") + out = proc.image_processor(images=[pil], return_tensors="np") grid_thw = out["image_grid_thw"][0] merge_size = proc.image_processor.merge_size num_image_tokens = int(grid_thw.prod()) // (merge_size * merge_size) From 577a4f650bdb287e9585d0fac23dd782d01f9bd4 Mon Sep 17 00:00:00 2001 From: hallerite Date: Wed, 13 May 2026 13:30:39 +0000 Subject: [PATCH 10/13] examples: add online SGLang HTTP recipe (#23) --- examples/README.md | 18 ++ examples/sglang/online_multiturn_sglang.py | 255 +++++++++++++++++++++ 2 files changed, 273 insertions(+) create mode 100644 examples/sglang/online_multiturn_sglang.py diff --git a/examples/README.md b/examples/README.md index 806a229..08d79a6 100644 --- a/examples/README.md +++ b/examples/README.md @@ -31,6 +31,24 @@ CUDA_VISIBLE_DEVICES=1 uv run --script examples/sglang/multiturn_generate_sglang The SGLang script uses `input_ids`, so SGLang does not apply a chat template. It leaves `openai-harmony` at SGLang's pinned version for dependency resolution. +### SGLang HTTP Recipe (online) + +For a token-in/token-out HTTP path against an already-running SGLang server: + +```bash +sglang serve --model-path Qwen/Qwen3.5-4B \ + --host 0.0.0.0 --port 30000 --tensor-parallel-size 1 --trust-remote-code & + +uv run python examples/sglang/online_multiturn_sglang.py \ + --base-url http://localhost:30000 --model Qwen/Qwen3.5-4B +``` + +The HTTP recipe posts `input_ids` to `/generate`; streaming is intentionally +unsupported because `parse_response`/`bridge_to_next_turn` require the full +completion. The source-checkout command above uses the local `renderers` +package; the PEP 723 `uv run --script` form requires a published package that +satisfies the script header. + ## Transformers Multi-Turn Recipe ```bash diff --git a/examples/sglang/online_multiturn_sglang.py b/examples/sglang/online_multiturn_sglang.py new file mode 100644 index 0000000..60df8b2 --- /dev/null +++ b/examples/sglang/online_multiturn_sglang.py @@ -0,0 +1,255 @@ +#!/usr/bin/env -S uv run --script +# /// script +# requires-python = ">=3.10,<3.14" +# dependencies = [ +# "renderers>=0.1.6", +# "transformers>=5.3.0", +# "httpx>=0.27", +# "openai-harmony==0.0.4", +# "tiktoken", +# "jinja2", +# "numpy", +# ] +# /// +"""SGLang online generation from renderer-owned prompt token IDs. + +Mirrors `multiturn_generate_sglang.py` but talks to an already-running SGLang +HTTP server over `/generate` instead of an in-process `sgl.Engine`. The +renderer owns chat templating and parsing; SGLang only does token-in, +token-out. + +Streaming is intentionally not supported: `parse_response` and +`bridge_to_next_turn` both need the complete `completion_ids`. + +Launch a server first, e.g. + + sglang serve --model-path Qwen/Qwen3.5-4B \\ + --host 0.0.0.0 --port 30000 --tensor-parallel-size 1 --trust-remote-code + +then, from a source checkout, + + uv run python examples/sglang/online_multiturn_sglang.py \\ + --base-url http://localhost:30000 --model Qwen/Qwen3.5-4B + +The PEP 723 `uv run --script` form requires a published `renderers` package +that satisfies the script header. +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +from typing import Any + +import httpx +from renderers.base import Renderer +from renderers.gpt_oss import GptOssRenderer +from renderers.qwen35 import Qwen35Renderer +from transformers import AutoTokenizer + + +TOOLS = [ + { + "type": "function", + "function": { + "name": "multiply", + "description": "Multiply two integers.", + "parameters": { + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "integer"}, + }, + "required": ["a", "b"], + }, + }, + } +] + + +def make_renderer(model: str, enable_thinking: bool | None) -> Renderer: + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=False) + if model.startswith("Qwen/Qwen3.5-"): + return Qwen35Renderer(tokenizer, enable_thinking=enable_thinking) + if model == "openai/gpt-oss-20b": + return GptOssRenderer(tokenizer) + raise ValueError(f"unsupported demo model: {model}") + + +def completion_ids(output: dict, prompt_ids: list[int]) -> list[int]: + ids = list(output.get("output_ids") or output.get("token_ids") or []) + if not ids: + raise RuntimeError("SGLang did not return completion token IDs") + # Match offline recipe: strip prefix only if SGLang echoed the prompt back. + return ids[len(prompt_ids) :] if ids[: len(prompt_ids)] == prompt_ids else ids + + +async def generate_sglang( + *, + client: httpx.AsyncClient, + base_url: str, + renderer: Renderer, + prompt_ids: list[int], + max_new_tokens: int, + extra_key: str | None = None, +) -> dict[str, Any]: + body: dict[str, Any] = { + "input_ids": prompt_ids, + "sampling_params": { + "temperature": 0.0, + "max_new_tokens": max_new_tokens, + "stop_token_ids": renderer.get_stop_token_ids(), + "skip_special_tokens": False, + "no_stop_trim": True, + }, + "stream": False, + } + if extra_key is not None: + body["extra_key"] = extra_key + response = await client.post(f"{base_url.rstrip('/')}/generate", json=body) + response.raise_for_status() + return response.json() + + +def print_parsed(label: str, turn: str, parsed) -> None: + print(f"\n[{label}] {turn}") + if parsed.reasoning_content: + print(f"reasoning: {parsed.reasoning_content[:240]}") + if parsed.tool_calls: + print(f"tool_calls: {json.dumps(parsed.tool_calls, ensure_ascii=False)}") + if parsed.content: + print(f"content: {parsed.content}") + + +async def run_one( + *, + client: httpx.AsyncClient, + base_url: str, + model: str, + enable_thinking: bool | None, + max_new_tokens: int, +) -> None: + label = ( + model + if enable_thinking is None + else f"{model} enable_thinking={enable_thinking}" + ) + print(f"\n=== {label} ===") + + renderer = make_renderer(model, enable_thinking) + + messages: list[dict[str, Any]] = [ + {"role": "system", "content": "You are a concise tool-using assistant."}, + { + "role": "user", + "content": "Use the multiply tool for 17 * 23, then summarize.", + }, + ] + + # Turn 1: render locally, send token IDs. SGLang never sees messages. + prompt_ids = renderer.render_ids(messages, tools=TOOLS, add_generation_prompt=True) + output1 = await generate_sglang( + client=client, + base_url=base_url, + renderer=renderer, + prompt_ids=prompt_ids, + max_new_tokens=max_new_tokens, + ) + completion1 = completion_ids(output1, prompt_ids) + parsed1 = renderer.parse_response(completion1) + print_parsed(label, "turn 1", parsed1) + + assistant: dict[str, Any] = {"role": "assistant", "content": parsed1.content} + if parsed1.reasoning_content: + assistant["reasoning_content"] = parsed1.reasoning_content + if parsed1.tool_calls: + assistant["tool_calls"] = parsed1.tool_calls + messages.append(assistant) + + if parsed1.tool_calls: + new_messages: list[dict[str, Any]] = [] + for idx, tool_call in enumerate(parsed1.tool_calls): + fn = tool_call.get("function") or tool_call + tool_args = fn.get("arguments") or {} + if isinstance(tool_args, str): + tool_args = json.loads(tool_args) + new_messages.append( + { + "role": "tool", + "tool_call_id": tool_call.get("id", f"call_{idx}"), + "name": fn.get("name", "multiply"), + "content": json.dumps( + {"result": int(tool_args["a"]) * int(tool_args["b"])} + ), + } + ) + else: + new_messages = [ + {"role": "user", "content": "Give the final answer in one sentence."} + ] + + # Turn 2: bridge extends prompt_ids + completion1 exactly. + bridged_ids = renderer.bridge_to_next_turn( + prompt_ids, completion1, new_messages, tools=TOOLS + ) + if bridged_ids is None: + raise RuntimeError("bridge_to_next_turn returned None") + assert bridged_ids[: len(prompt_ids) + len(completion1)] == ( + prompt_ids + completion1 + ) + + output2 = await generate_sglang( + client=client, + base_url=base_url, + renderer=renderer, + prompt_ids=bridged_ids, + max_new_tokens=max_new_tokens, + ) + completion2 = completion_ids(output2, bridged_ids) + print_parsed(label, "turn 2", renderer.parse_response(completion2)) + + +async def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--base-url", + default="http://localhost:30000", + help="SGLang HTTP server base URL.", + ) + parser.add_argument( + "--model", + default="Qwen/Qwen3.5-4B", + help="Must match the model the SGLang server is serving.", + ) + parser.add_argument( + "--enable-thinking", + choices=["true", "false", "both"], + default="both", + help="Qwen3.5 thinking mode. Ignored for gpt-oss.", + ) + parser.add_argument("--max-new-tokens", type=int, default=512) + parser.add_argument("--timeout", type=float, default=600.0) + args = parser.parse_args() + + if args.model.startswith("Qwen/Qwen3.5-"): + if args.enable_thinking == "both": + modes: list[bool | None] = [True, False] + else: + modes = [args.enable_thinking == "true"] + else: + modes = [None] + + async with httpx.AsyncClient(timeout=args.timeout) as client: + for mode in modes: + await run_one( + client=client, + base_url=args.base_url, + model=args.model, + enable_thinking=mode, + max_new_tokens=args.max_new_tokens, + ) + + +if __name__ == "__main__": + asyncio.run(main()) From 0e22d2cc202408aa6f5c883d989c951cb0c35039 Mon Sep 17 00:00:00 2001 From: hallerite Date: Wed, 13 May 2026 13:30:39 +0000 Subject: [PATCH 11/13] feat(parsing): per-attempt ParsedToolCall with status + token spans (#22) --- renderers/__init__.py | 4 + renderers/base.py | 66 ++- renderers/client.py | 23 +- renderers/default.py | 2 +- renderers/kimi_k25.py | 157 ++++-- renderers/parsers.py | 218 +++++++-- renderers/parsing.py | 604 +++++++++++++++++------- tests/test_client.py | 92 ++-- tests/test_parse_response.py | 86 +++- tests/test_parse_response_robustness.py | 20 +- tests/test_parsers.py | 54 ++- tests/test_roundtrip.py | 24 +- 12 files changed, 992 insertions(+), 358 deletions(-) diff --git a/renderers/__init__.py b/renderers/__init__.py index 24f076c..a169186 100644 --- a/renderers/__init__.py +++ b/renderers/__init__.py @@ -15,6 +15,7 @@ MultiModalData, MultimodalRenderer, ParsedResponse, + ParsedToolCall, PlaceholderRange, RenderedConversation, RenderedTokens, @@ -24,6 +25,7 @@ ThinkingPart, ToolCall, ToolCallFunction, + ToolCallParseStatus, ToolSpec, VideoPart, build_training_sample, @@ -66,6 +68,7 @@ "MultimodalRenderer", "Nemotron3Renderer", "ParsedResponse", + "ParsedToolCall", "PlaceholderRange", "Qwen3Renderer", "Qwen3VLRenderer", @@ -79,6 +82,7 @@ "ThinkingPart", "ToolCall", "ToolCallFunction", + "ToolCallParseStatus", "ToolSpec", "VideoPart", "__version__", diff --git a/renderers/base.py b/renderers/base.py index a8823fe..64e9f1c 100644 --- a/renderers/base.py +++ b/renderers/base.py @@ -1,5 +1,6 @@ from __future__ import annotations +import enum import logging import queue import threading @@ -160,13 +161,74 @@ class RenderedTokens: multi_modal_data: "MultiModalData | None" = None +class ToolCallParseStatus(str, enum.Enum): + """Per-attempt outcome of parsing a single ```` block. + + The renderer parser's job is JSON-syntax → ``dict`` (the parser-level + contract). Schema validation — required fields, argument types, tool + name lookup — is the *tool*'s job and is intentionally not done here. + See ``ParsedToolCall.status`` for what each value means. + + Diverges from vLLM/SGLang on purpose. Both engines collapse parse + failures into either a single ``tools_called: bool`` (vLLM) or silent + drops (SGLang), with no way to express "the model emitted three + parallel tool calls and the second was malformed." Renderers expose + that information because verifier / RL-loss code needs it for + schema-adherence rubrics and selective token masking — use cases the + inference engines don't serve. + """ + + OK = "ok" + INVALID_JSON = "invalid_json" # body wasn't valid JSON + UNCLOSED_BLOCK = "unclosed_block" # opening delim hit EOS / stop + MISSING_NAME = "missing_name" # parsed structurally, but no function name + MALFORMED_STRUCTURE = "malformed_structure" # format-specific shape error + + +@dataclass +class ParsedToolCall: + """A single ```` block as the renderer parsed it. + + One record per *attempt* — successful and malformed calls both land + here, distinguished by ``status``. Ordering is preserved across the + response, so ``[OK, INVALID_JSON, OK]`` is a faithful record of "the + model emitted three parallel calls; the second was broken." + + ``token_span`` is a half-open ``[start, end)`` slice into the + completion's stripped token id stream (i.e. ``token_ids`` after + ``_strip_stop_tokens``); some text-based parsers can't cheaply + recover token offsets and leave it ``None``. Useful for trainer-side + selective loss masking: zero the mask over the spans of non-OK + entries to avoid reinforcing malformed structures. + + ``raw`` is the decoded text of the block as the model emitted it + (before any JSON normalization). Always populated — for failed + attempts it's the only way to see what actually went wrong. + """ + + raw: str + name: str | None = None + arguments: dict[str, Any] | str | None = None + token_span: tuple[int, int] | None = None + status: ToolCallParseStatus = ToolCallParseStatus.OK + id: str | None = None # native tool-call id when the format carries one (Kimi K2) + + @dataclass class ParsedResponse: - """Result of parsing completion tokens back into a structured message.""" + """Result of parsing completion tokens back into a structured message. + + ``tool_calls`` is a list of every parse attempt — successful and + malformed alike. Filter with ``[tc for tc in r.tool_calls if + tc.status == ToolCallParseStatus.OK]`` to get only the calls that + came out clean. Empty list = the model didn't emit any tool calls + (different from "tried and failed entirely", which produces a list + with non-OK entries). + """ content: str reasoning_content: str | None = None - tool_calls: list[dict[str, Any]] | None = None + tool_calls: list[ParsedToolCall] = field(default_factory=list) @dataclass diff --git a/renderers/client.py b/renderers/client.py index 7d585f4..cff5c7f 100644 --- a/renderers/client.py +++ b/renderers/client.py @@ -19,7 +19,14 @@ import numpy as np from openai import AsyncOpenAI, BadRequestError -from renderers.base import Message, MultiModalData, Renderer, RendererPool, ToolSpec +from renderers.base import ( + Message, + MultiModalData, + Renderer, + RendererPool, + ToolCallParseStatus, + ToolSpec, +) _request_logger = logging.getLogger("renderers.client") @@ -162,11 +169,17 @@ def _prepare(): # /inference/v1/generate returns finish_reason in {"stop","length",...} — # never "tool_calls" (a chat-completions concept). Promote stop→tool_calls - # when we extracted tool calls client-side, so OpenAI-compatible agent - # loops continue past the tool turn instead of treating the response as - # final. + # when we extracted at least one well-formed tool call client-side, so + # OpenAI-compatible agent loops continue past the tool turn instead of + # treating the response as final. Malformed attempts (INVALID_JSON, + # UNCLOSED_BLOCK, ...) don't qualify — those still surface on + # ``parsed.tool_calls`` so verifiers can inspect them, but they don't + # trigger the tool-loop continuation. finish_reason = choice.get("finish_reason") - if parsed.tool_calls and finish_reason == "stop": + ok_tool_calls = [ + tc for tc in parsed.tool_calls if tc.status == ToolCallParseStatus.OK + ] + if ok_tool_calls and finish_reason == "stop": finish_reason = "tool_calls" return { diff --git a/renderers/default.py b/renderers/default.py index dfc4e11..f755cfc 100644 --- a/renderers/default.py +++ b/renderers/default.py @@ -174,7 +174,7 @@ def parse_response(self, token_ids: list[int]) -> ParsedResponse: content_ids, tool_calls = self._tool_parser.extract(list(token_ids)) else: content_ids = list(token_ids) - tool_calls = None + tool_calls = [] # 2. Decode (keep special tokens so a downstream reasoning parser can # still see things like / when they're tokens). diff --git a/renderers/kimi_k25.py b/renderers/kimi_k25.py index 0984498..086a4d5 100644 --- a/renderers/kimi_k25.py +++ b/renderers/kimi_k25.py @@ -31,13 +31,16 @@ Message, MultiModalData, ParsedResponse, + ParsedToolCall, PlaceholderRange, RenderedTokens, + ToolCallParseStatus, ToolSpec, reject_assistant_in_extension, should_preserve_past_thinking, trim_to_turn_close, ) +from renderers.parsing import parse_kimi_k2_section from renderers.qwen3_vl import ( _image_hash, _is_image_part, @@ -420,12 +423,27 @@ def _parse_kimi_k2_response( stop_ids: set[int], think_open_ids: list[int], think_close_ids: list[int], + tool_calls_section_begin_id: int | None, + tool_calls_section_end_id: int | None, + tool_call_begin_id: int | None, + tool_call_argument_begin_id: int | None, + tool_call_end_id: int | None, ) -> ParsedResponse: """Parse Kimi K2/K2.5 completion tokens. - Strips the stop token, decodes to text, then extracts: - - reasoning from ``...`` blocks - - tool calls from ``<|tool_calls_section_begin|>...<|tool_calls_section_end|>`` + Primary path: walk token IDs via :func:`parse_kimi_k2_section`. That gives + every ``ParsedToolCall`` a ``token_span`` pointing back into the + (stop-stripped) input — what the trainer needs for selective loss masking. + + Fallback path: regex on decoded text. Only used when none of the section + delimiters appear as special tokens, which in practice means the model + emitted the literal ``<|tool_call_section_begin|>`` string (the singular + variant is *not* in the K2.5 special-token vocab — confirmed by tokenizer + probe). Spans stay ``None`` here since text positions don't cheaply map + back to token offsets across BPE. + + ``...`` is always text-extracted from the content slice + (K2.5 emits them as plain text, not special tokens). """ # Strip stop token ids = list(token_ids) @@ -434,13 +452,79 @@ def _parse_kimi_k2_response( ids = ids[:i] break - # Decode all tokens (including any special tokens that are text-like) - text = tokenizer.decode(ids, skip_special_tokens=False) if ids else "" + # Token-ID path — produces spans. Only run if every relevant special + # token resolved at init (i.e. is in the tokenizer's vocab). + tool_calls: list[ParsedToolCall] = [] + have_special_tokens = ( + tool_calls_section_begin_id is not None + and tool_calls_section_end_id is not None + and tool_call_begin_id is not None + and tool_call_argument_begin_id is not None + and tool_call_end_id is not None + ) + if have_special_tokens: + content_ids, tool_calls = parse_kimi_k2_section( + tokenizer, + ids, + tool_calls_section_begin_ids={tool_calls_section_begin_id}, + tool_calls_section_end_ids={tool_calls_section_end_id}, + tool_call_begin_id=tool_call_begin_id, + tool_call_argument_begin_id=tool_call_argument_begin_id, + tool_call_end_id=tool_call_end_id, + ) + text = ( + tokenizer.decode(content_ids, skip_special_tokens=False) + if content_ids + else "" + ) + else: + text = tokenizer.decode(ids, skip_special_tokens=False) if ids else "" + + # Fallback path: model emitted literal-text section delimiters (singular + # variant) rather than special tokens. Spans unavailable here. + if not tool_calls: + tc_match = _TOOL_CALLS_SECTION_RE.search(text) + if tc_match: + text = text[: tc_match.start()] + tool_section = ( + tc_match.group(1) + if tc_match.group(1) is not None + else tc_match.group(2) + ) + for m in _TOOL_CALL_RE.finditer(tool_section): + tool_id = m.group(1).strip() + args_str = m.group(2).strip() + name_part = tool_id.split(":", 1)[0] + func_name = ( + name_part.split(".", 1)[1] if "." in name_part else name_part + ) + arguments: dict[str, Any] | str + invalid_json = False + try: + arguments = json.loads(args_str) + except json.JSONDecodeError: + arguments = args_str + invalid_json = True + if not func_name: + status = ToolCallParseStatus.MISSING_NAME + elif invalid_json: + status = ToolCallParseStatus.INVALID_JSON + else: + status = ToolCallParseStatus.OK + tool_calls.append( + ParsedToolCall( + raw=m.group(0), + name=func_name or None, + arguments=arguments, + status=status, + id=tool_id or None, + ) + ) - # Extract reasoning from .... Partition on first so - # any tokens BEFORE the open tag (e.g. the assistant role tag, when the - # caller slices the completion to include the prompt's gen-prompt-equivalent) - # don't leak into reasoning_content. + # Extract reasoning from ... in the content text. Partition + # on first so any tokens BEFORE the open tag (e.g. the assistant + # role tag, when the caller slices the completion to include the prompt's + # gen-prompt-equivalent) don't leak into reasoning_content. reasoning: str | None = None if "" in text: _, _, after_open = text.partition("") @@ -449,11 +533,12 @@ def _parse_kimi_k2_response( reasoning = reasoning_raw.strip("\n") or None text = text.strip("\n") else: - # Truncated reasoning (no closing tag) + # Truncated reasoning (no closing tag) — discard any partial + # tool-call attempts since the model never finished thinking. return ParsedResponse( content="", reasoning_content=after_open.strip() or None, - tool_calls=None, + tool_calls=[], ) elif "" in text: # Sampler stripped the prefilled open tag — see @@ -463,35 +548,6 @@ def _parse_kimi_k2_response( reasoning = before.strip("\n") or None text = after.strip("\n") - # Extract tool calls section - tool_calls: list[dict[str, Any]] | None = None - tc_match = _TOOL_CALLS_SECTION_RE.search(text) - if tc_match: - text = text[: tc_match.start()] - tool_section = ( - tc_match.group(1) if tc_match.group(1) is not None else tc_match.group(2) - ) - parsed_calls = [] - for m in _TOOL_CALL_RE.finditer(tool_section): - tool_id = m.group(1).strip() - args_str = m.group(2).strip() - # Extract function name from "functions.name:index" format - name_part = tool_id.split(":", 1)[0] - func_name = name_part.split(".", 1)[1] if "." in name_part else name_part - try: - arguments = json.loads(args_str) - except json.JSONDecodeError: - arguments = args_str # preserve raw string if invalid JSON - parsed_calls.append( - { - "type": "function", - "id": tool_id, - "function": {"name": func_name, "arguments": arguments}, - } - ) - if parsed_calls: - tool_calls = parsed_calls - return ParsedResponse( content=text.strip(), reasoning_content=reasoning.strip() if reasoning else None, @@ -853,17 +909,34 @@ def parse_response(self, token_ids: list[int]) -> ParsedResponse: if self._endoftext is not None: stop_ids.add(self._endoftext) - # Restore the synthetic prefill if it was stripped by the sampler + # Restore the synthetic prefill if it was stripped by the + # sampler. ``parse`` then walks ``normalized``, so any token_span we + # emit is in the *normalized* frame. We track the prepend offset and + # shift spans back so they refer to the caller's ``token_ids``. normalized = self._normalize_response_tokens(list(token_ids)) + prepend_offset = len(normalized) - len(token_ids) - return _parse_kimi_k2_response( + parsed = _parse_kimi_k2_response( self._tokenizer, normalized, stop_ids=stop_ids, think_open_ids=self._think_open_ids, think_close_ids=self._think_close_ids, + tool_calls_section_begin_id=self._tool_calls_section_begin, + tool_calls_section_end_id=self._tool_calls_section_end, + tool_call_begin_id=self._tool_call_begin, + tool_call_argument_begin_id=self._tool_call_argument_begin, + tool_call_end_id=self._tool_call_end, ) + if prepend_offset: + for tc in parsed.tool_calls: + if tc.token_span is not None: + start, end = tc.token_span + tc.token_span = (start - prepend_offset, end - prepend_offset) + + return parsed + def get_stop_token_ids(self) -> list[int]: stop = [self._im_end] if self._endoftext is not None: diff --git a/renderers/parsers.py b/renderers/parsers.py index 96ec304..77fa850 100644 --- a/renderers/parsers.py +++ b/renderers/parsers.py @@ -19,6 +19,8 @@ import re from typing import Protocol, runtime_checkable +from renderers.base import ParsedToolCall, ToolCallParseStatus + # ── Shared helpers ─────────────────────────────────────────────────── @@ -54,15 +56,17 @@ def _token_id(tokenizer, token: str) -> int | None: class ToolParser(Protocol): """Extracts tool calls from completion token ids. - ``extract`` returns a tuple ``(content_ids, tool_calls)`` where - ``content_ids`` is the remaining content token ids with the tool-call - section removed, and ``tool_calls`` is a list of - ``{"function": {"name": str, "arguments": dict | str}}`` dicts or - ``None`` when no tool call was found. + ``extract`` returns ``(content_ids, tool_calls)`` where ``content_ids`` + is the remaining content token ids with the tool-call section removed, + and ``tool_calls`` is a list of :class:`ParsedToolCall` records — one + per attempted block. Empty list = the model emitted no tool calls; + callers filter by ``status == OK`` for the clean subset. """ def __init__(self, tokenizer): ... - def extract(self, token_ids: list[int]) -> tuple[list[int], list[dict] | None]: ... + def extract( + self, token_ids: list[int] + ) -> tuple[list[int], list[ParsedToolCall]]: ... @runtime_checkable @@ -87,14 +91,14 @@ def __init__(self, tokenizer): self._tc_id = _token_id(tokenizer, "") self._tc_end_id = _token_id(tokenizer, "") - def extract(self, ids: list[int]) -> tuple[list[int], list[dict] | None]: + def extract(self, ids: list[int]) -> tuple[list[int], list[ParsedToolCall]]: if self._tc_id is None: - return ids, None + return ids, [] tc_start = _find(ids, self._tc_id) if tc_start == -1: - return ids, None + return ids, [] content_ids = ids[:tc_start] - tool_calls: list[dict] = [] + tool_calls: list[ParsedToolCall] = [] i = tc_start while i < len(ids): if ids[i] == self._tc_id: @@ -103,25 +107,52 @@ def extract(self, ids: list[int]) -> tuple[list[int], list[dict] | None]: if self._tc_end_id is not None else -1 ) + unclosed = end == -1 if end == -1: end = len(ids) tc_text = _decode(self._tokenizer, ids[i + 1 : end]).strip() + span = (i, end + (0 if unclosed else 1)) + if unclosed: + tool_calls.append( + ParsedToolCall( + raw=tc_text, + token_span=span, + status=ToolCallParseStatus.UNCLOSED_BLOCK, + ) + ) + break try: parsed = json.loads(tc_text) + except json.JSONDecodeError: tool_calls.append( - { - "function": { - "name": parsed.get("name", ""), - "arguments": parsed.get("arguments", {}), - } - } + ParsedToolCall( + raw=tc_text, + token_span=span, + status=ToolCallParseStatus.INVALID_JSON, + ) + ) + else: + name = parsed.get("name", "") if isinstance(parsed, dict) else "" + arguments = ( + parsed.get("arguments", {}) if isinstance(parsed, dict) else {} + ) + tool_calls.append( + ParsedToolCall( + raw=tc_text, + name=name or None, + arguments=arguments, + token_span=span, + status=( + ToolCallParseStatus.MISSING_NAME + if not name + else ToolCallParseStatus.OK + ), + ) ) - except json.JSONDecodeError: - pass i = end + 1 else: i += 1 - return content_ids, (tool_calls or None) + return content_ids, tool_calls class Qwen35ToolParser: @@ -132,13 +163,13 @@ def __init__(self, tokenizer): self._tc_id = _token_id(tokenizer, "") self._tc_end_id = _token_id(tokenizer, "") - def extract(self, ids: list[int]) -> tuple[list[int], list[dict] | None]: + def extract(self, ids: list[int]) -> tuple[list[int], list[ParsedToolCall]]: if self._tc_id is None: - return ids, None + return ids, [] tc_start = _find(ids, self._tc_id) if tc_start == -1: - return ids, None - tool_calls: list[dict] = [] + return ids, [] + tool_calls: list[ParsedToolCall] = [] i = tc_start while i < len(ids): if ids[i] == self._tc_id: @@ -148,30 +179,60 @@ def extract(self, ids: list[int]) -> tuple[list[int], list[dict] | None]: else -1 ) if end == -1: + raw = _decode(self._tokenizer, ids[i + 1 :]) + tool_calls.append( + ParsedToolCall( + raw=raw, + token_span=(i, len(ids)), + status=ToolCallParseStatus.UNCLOSED_BLOCK, + ) + ) break block_text = _decode(self._tokenizer, ids[i + 1 : end]) + span = (i, end + 1) name_match = re.search(r"]+)>", block_text) - if name_match: - name = name_match.group(1) - arguments: dict = {} - for pm in re.finditer( - r"]+)>\n?(.*?)\n?", - block_text, - re.DOTALL, - ): - arg_name = pm.group(1) - arg_value = pm.group(2).strip() - try: - arguments[arg_name] = json.loads(arg_value) - except (json.JSONDecodeError, ValueError): - arguments[arg_name] = arg_value + if not name_match: tool_calls.append( - {"function": {"name": name, "arguments": arguments}} + ParsedToolCall( + raw=block_text, + token_span=span, + status=ToolCallParseStatus.MALFORMED_STRUCTURE, + ) + ) + i = end + 1 + continue + name = name_match.group(1) + arguments: dict = {} + any_json_fallback = False + for pm in re.finditer( + r"]+)>\n?(.*?)\n?", + block_text, + re.DOTALL, + ): + arg_name = pm.group(1) + arg_value = pm.group(2).strip() + try: + arguments[arg_name] = json.loads(arg_value) + except (json.JSONDecodeError, ValueError): + arguments[arg_name] = arg_value + any_json_fallback = True + tool_calls.append( + ParsedToolCall( + raw=block_text, + name=name, + arguments=arguments, + token_span=span, + status=( + ToolCallParseStatus.INVALID_JSON + if any_json_fallback + else ToolCallParseStatus.OK + ), ) + ) i = end + 1 else: i += 1 - return ids[:tc_start], (tool_calls or None) + return ids[:tc_start], tool_calls class GlmToolParser: @@ -186,13 +247,13 @@ def __init__(self, tokenizer): self._av_id = _token_id(tokenizer, "") self._ave_id = _token_id(tokenizer, "") - def extract(self, ids: list[int]) -> tuple[list[int], list[dict] | None]: + def extract(self, ids: list[int]) -> tuple[list[int], list[ParsedToolCall]]: if self._tc_id is None: - return ids, None + return ids, [] tc_start = _find(ids, self._tc_id) if tc_start == -1: - return ids, None - tool_calls: list[dict] = [] + return ids, [] + tool_calls: list[ParsedToolCall] = [] i = tc_start while i < len(ids): if ids[i] == self._tc_id: @@ -202,9 +263,21 @@ def extract(self, ids: list[int]) -> tuple[list[int], list[dict] | None]: else -1 ) if end == -1: + raw = _decode(self._tokenizer, ids[i + 1 :]) + tool_calls.append( + ParsedToolCall( + raw=raw, + token_span=(i, len(ids)), + status=ToolCallParseStatus.UNCLOSED_BLOCK, + ) + ) break block = ids[i + 1 : end] + block_text = _decode(self._tokenizer, block) + span = (i, end + 1) first_ak = _find(block, self._ak_id) if self._ak_id is not None else -1 + any_json_fallback = False + structure_broke = False if first_ak == -1: name = _decode(self._tokenizer, block).strip() arguments: dict = {} @@ -220,6 +293,7 @@ def extract(self, ids: list[int]) -> tuple[list[int], list[dict] | None]: else -1 ) if ake == -1: + structure_broke = True break key = _decode(self._tokenizer, block[j + 1 : ake]).strip() av = ( @@ -228,6 +302,7 @@ def extract(self, ids: list[int]) -> tuple[list[int], list[dict] | None]: else -1 ) if av == -1: + structure_broke = True break ave = ( _find(block, self._ave_id, av + 1) @@ -235,6 +310,7 @@ def extract(self, ids: list[int]) -> tuple[list[int], list[dict] | None]: else -1 ) if ave == -1: + structure_broke = True break val_text = _decode( self._tokenizer, block[av + 1 : ave] @@ -243,14 +319,31 @@ def extract(self, ids: list[int]) -> tuple[list[int], list[dict] | None]: arguments[key] = json.loads(val_text) except (json.JSONDecodeError, ValueError): arguments[key] = val_text + any_json_fallback = True j = ave + 1 else: j += 1 - tool_calls.append({"function": {"name": name, "arguments": arguments}}) + if not name: + status = ToolCallParseStatus.MISSING_NAME + elif structure_broke: + status = ToolCallParseStatus.MALFORMED_STRUCTURE + elif any_json_fallback: + status = ToolCallParseStatus.INVALID_JSON + else: + status = ToolCallParseStatus.OK + tool_calls.append( + ParsedToolCall( + raw=block_text, + name=name or None, + arguments=arguments, + token_span=span, + status=status, + ) + ) i = end + 1 else: i += 1 - return ids[:tc_start], (tool_calls or None) + return ids[:tc_start], tool_calls class DeepSeekV3ToolParser: @@ -264,12 +357,12 @@ def __init__(self, tokenizer): self._tc_end = _token_id(tokenizer, "<|tool▁call▁end|>") self._sep = _token_id(tokenizer, "<|tool▁sep|>") - def extract(self, ids: list[int]) -> tuple[list[int], list[dict] | None]: + def extract(self, ids: list[int]) -> tuple[list[int], list[ParsedToolCall]]: if self._tcs_begin is None: - return ids, None + return ids, [] section_start = _find(ids, self._tcs_begin) if section_start == -1: - return ids, None + return ids, [] content_ids = ids[:section_start] section_end = ( _find(ids, self._tcs_end, section_start + 1) @@ -278,9 +371,10 @@ def extract(self, ids: list[int]) -> tuple[list[int], list[dict] | None]: ) if section_end == -1: section_end = len(ids) + inner_offset = section_start + 1 section_ids = ids[section_start + 1 : section_end] - tool_calls: list[dict] = [] + tool_calls: list[ParsedToolCall] = [] i = 0 while i < len(section_ids): if self._tc_begin is None or section_ids[i] != self._tc_begin: @@ -291,23 +385,45 @@ def extract(self, ids: list[int]) -> tuple[list[int], list[dict] | None]: if self._tc_end is not None else -1 ) + unclosed = end == -1 if end == -1: end = len(section_ids) block_text = _decode(self._tokenizer, section_ids[i + 1 : end]) + span = (inner_offset + i, inner_offset + end + (0 if unclosed else 1)) # Format: "function<|tool▁sep|>{name}\n```json\n{args}\n```" name_match = re.search(r"^\s*\w+.*?([A-Za-z0-9_]+)\s*\n", block_text) name = name_match.group(1) if name_match else "" args: dict | str = {} + invalid_json = False json_match = re.search(r"```json\s*(.*?)\s*```", block_text, re.DOTALL) if json_match: try: args = json.loads(json_match.group(1)) except (json.JSONDecodeError, ValueError): args = json_match.group(1) - tool_calls.append({"function": {"name": name, "arguments": args}}) + invalid_json = True + if unclosed: + status = ToolCallParseStatus.UNCLOSED_BLOCK + elif not name: + status = ToolCallParseStatus.MISSING_NAME + elif invalid_json: + status = ToolCallParseStatus.INVALID_JSON + else: + status = ToolCallParseStatus.OK + tool_calls.append( + ParsedToolCall( + raw=block_text, + name=name or None, + arguments=args, + token_span=span, + status=status, + ) + ) i = end + 1 + if unclosed: + break - return content_ids, (tool_calls or None) + return content_ids, tool_calls # ── Reasoning parsers ──────────────────────────────────────────────── diff --git a/renderers/parsing.py b/renderers/parsing.py index 6644103..827cc81 100644 --- a/renderers/parsing.py +++ b/renderers/parsing.py @@ -3,13 +3,22 @@ Finds special token boundaries by scanning token IDs, then decodes only the text segments between them. No regex on decoded text, no false positives from content that happens to look like special tokens. + +Every parser emits ``list[ParsedToolCall]`` covering every attempt — +successful and malformed alike — with a ``status`` enum classifying the +outcome and a ``token_span`` recording where in the (stop-stripped) +token stream the attempt sat. Callers filter on ``status == OK`` for the +clean subset; verifier and RL-loss code uses the rest. This diverges from +vLLM's ``ExtractedToolCallInformation`` (single ``tools_called`` bool, no +per-call status) and SGLang's ``StreamingParseResult`` (silent drop on +failure) — see ``ToolCallParseStatus`` docstring for the rationale. """ from __future__ import annotations import json -from renderers.base import ParsedResponse +from renderers.base import ParsedResponse, ParsedToolCall, ToolCallParseStatus def _find(ids: list[int], target: int, start: int = 0) -> int: @@ -20,6 +29,14 @@ def _find(ids: list[int], target: int, start: int = 0) -> int: return -1 +def _find_any(ids: list[int], targets: set[int], start: int = 0) -> int: + """Find first index in ids whose value is in targets, or -1.""" + for i in range(start, len(ids)): + if ids[i] in targets: + return i + return -1 + + def _find_all(ids: list[int], target: int) -> list[int]: """Find all indices of target in ids.""" return [i for i, t in enumerate(ids) if t == target] @@ -54,52 +71,67 @@ def parse_qwen3( """Parse Qwen3 completion tokens. Hermes-style JSON tool calls.""" ids = _strip_stop_tokens(token_ids, stop_ids) - # No thinking tokens in Qwen3 gen prompt — model may or may not think - # Parse from decoded text since / may be multi-token in Qwen3 - # Actually in Qwen3, IS a special token (151657) - # So we can find it by token ID - - # Find tool calls by token ID tc_start = _find(ids, tool_call_id) + tool_calls: list[ParsedToolCall] = [] if tc_start != -1: content_ids = ids[:tc_start] - # Extract all tool call blocks - tool_calls = [] i = tc_start while i < len(ids): if ids[i] == tool_call_id: end = _find(ids, tool_call_end_id, i + 1) if end == -1: - end = len(ids) + # No closing delim — block runs to end of stripped ids. + raw = _decode(tokenizer, ids[i + 1 :]).strip() + tool_calls.append( + ParsedToolCall( + raw=raw, + token_span=(i, len(ids)), + status=ToolCallParseStatus.UNCLOSED_BLOCK, + ) + ) + break tc_text = _decode(tokenizer, ids[i + 1 : end]).strip() + span = (i, end + 1) try: parsed = json.loads(tc_text) + except json.JSONDecodeError: tool_calls.append( - { - "function": { - "name": parsed.get("name", ""), - "arguments": parsed.get("arguments", {}), - } - } + ParsedToolCall( + raw=tc_text, + token_span=span, + status=ToolCallParseStatus.INVALID_JSON, + ) ) - except json.JSONDecodeError: - pass + else: + name = parsed.get("name", "") if isinstance(parsed, dict) else "" + arguments = ( + parsed.get("arguments", {}) if isinstance(parsed, dict) else {} + ) + if not name: + tool_calls.append( + ParsedToolCall( + raw=tc_text, + name=None, + arguments=arguments, + token_span=span, + status=ToolCallParseStatus.MISSING_NAME, + ) + ) + else: + tool_calls.append( + ParsedToolCall( + raw=tc_text, + name=name, + arguments=arguments, + token_span=span, + status=ToolCallParseStatus.OK, + ) + ) i = end + 1 else: i += 1 - # Match vLLM hermes_tool_parser: when no tool calls parse successfully, - # preserve the raw tokens as content instead of returning an empty - # response. vLLM/hermes_tool_parser.py::extract_tool_calls catches - # json.JSONDecodeError and falls through with content=model_output. - # Without this, clients raise EmptyModelResponseError on any - # ... block with malformed JSON, which - # wastes inference compute on retries and diverges from main's - # behavior on hermes tool envs. - if not tool_calls: - content_ids = ids else: content_ids = ids - tool_calls = None text = _decode(tokenizer, content_ids) # Extract reasoning from text (Qwen3 doesn't have as special token) @@ -112,7 +144,7 @@ def parse_qwen3( return ParsedResponse( content=text.strip(), reasoning_content=reasoning or None, - tool_calls=tool_calls or None, + tool_calls=tool_calls, ) @@ -134,68 +166,109 @@ def parse_qwen35( # Thinking: find by token ID reasoning = None + parse_offset = 0 # shift to map local indices back to stop-stripped ids think_end = _find(ids, think_end_id) if think_end != -1: - # Everything before is reasoning reasoning_ids = ids[:think_end] - # Strip if present at start reasoning_ids = [t for t in reasoning_ids if t != think_id] reasoning = _decode(tokenizer, reasoning_ids).strip() ids = ids[think_end + 1 :] + parse_offset = think_end + 1 elif think_id in set(ids): # present but no — truncated reasoning think_start = _find(ids, think_id) reasoning = _decode(tokenizer, ids[think_start + 1 :]).strip() return ParsedResponse( - content="", reasoning_content=reasoning or None, tool_calls=None + content="", reasoning_content=reasoning or None, tool_calls=[] ) - # Tool calls by token ID tc_start = _find(ids, tool_call_id) + tool_calls: list[ParsedToolCall] = [] if tc_start != -1: content_text = _decode(tokenizer, ids[:tc_start]).strip() tool_calls = _parse_xml_tool_calls( - tokenizer, ids[tc_start:], tool_call_id, tool_call_end_id + tokenizer, + ids[tc_start:], + tool_call_id, + tool_call_end_id, + section_offset=parse_offset + tc_start, ) else: content_text = _decode(tokenizer, ids).strip() - tool_calls = None return ParsedResponse( content=content_text, reasoning_content=reasoning or None, - tool_calls=tool_calls or None, + tool_calls=tool_calls, ) def _parse_xml_tool_calls( - tokenizer, ids: list[int], tc_id: int, tc_end_id: int -) -> list[dict]: + tokenizer, + ids: list[int], + tc_id: int, + tc_end_id: int, + *, + section_offset: int, +) -> list[ParsedToolCall]: """Parse Qwen3.5-style XML tool calls from token IDs.""" import re - tool_calls = [] + tool_calls: list[ParsedToolCall] = [] i = 0 while i < len(ids): if ids[i] == tc_id: end = _find(ids, tc_end_id, i + 1) if end == -1: + raw = _decode(tokenizer, ids[i + 1 :]) + tool_calls.append( + ParsedToolCall( + raw=raw, + token_span=(section_offset + i, section_offset + len(ids)), + status=ToolCallParseStatus.UNCLOSED_BLOCK, + ) + ) break block_text = _decode(tokenizer, ids[i + 1 : end]) + span = (section_offset + i, section_offset + end + 1) name_match = re.search(r"]+)>", block_text) - if name_match: - name = name_match.group(1) - arguments = {} - for pm in re.finditer( - r"]+)>\n?(.*?)\n?", block_text, re.DOTALL - ): - arg_name = pm.group(1) - arg_value = pm.group(2).strip() - try: - arguments[arg_name] = json.loads(arg_value) - except (json.JSONDecodeError, ValueError): - arguments[arg_name] = arg_value - tool_calls.append({"function": {"name": name, "arguments": arguments}}) + if not name_match: + tool_calls.append( + ParsedToolCall( + raw=block_text, + token_span=span, + status=ToolCallParseStatus.MALFORMED_STRUCTURE, + ) + ) + i = end + 1 + continue + + name = name_match.group(1) + arguments: dict = {} + any_json_fallback = False + for pm in re.finditer( + r"]+)>\n?(.*?)\n?", block_text, re.DOTALL + ): + arg_name = pm.group(1) + arg_value = pm.group(2).strip() + try: + arguments[arg_name] = json.loads(arg_value) + except (json.JSONDecodeError, ValueError): + arguments[arg_name] = arg_value + any_json_fallback = True + tool_calls.append( + ParsedToolCall( + raw=block_text, + name=name, + arguments=arguments, + token_span=span, + status=( + ToolCallParseStatus.INVALID_JSON + if any_json_fallback + else ToolCallParseStatus.OK + ), + ) + ) i = end + 1 else: i += 1 @@ -222,23 +295,24 @@ def parse_glm( """Parse GLM completion tokens. Token-level thinking + arg_key/arg_value tool calls.""" ids = _strip_stop_tokens(token_ids, stop_ids) - # Thinking by token ID reasoning = None + parse_offset = 0 think_end = _find(ids, think_end_id) if think_end != -1: reasoning_ids = ids[:think_end] reasoning_ids = [t for t in reasoning_ids if t != think_id] reasoning = _decode(tokenizer, reasoning_ids).strip() ids = ids[think_end + 1 :] + parse_offset = think_end + 1 elif think_id in set(ids): think_start = _find(ids, think_id) reasoning = _decode(tokenizer, ids[think_start + 1 :]).strip() return ParsedResponse( - content="", reasoning_content=reasoning or None, tool_calls=None + content="", reasoning_content=reasoning or None, tool_calls=[] ) - # Tool calls by token ID tc_start = _find(ids, tool_call_id) + tool_calls: list[ParsedToolCall] = [] if tc_start != -1: content_text = _decode(tokenizer, ids[:tc_start]).strip() tool_calls = _parse_glm_tool_calls( @@ -250,35 +324,55 @@ def parse_glm( arg_key_end_id, arg_value_id, arg_value_end_id, + section_offset=parse_offset + tc_start, ) else: content_text = _decode(tokenizer, ids).strip() - tool_calls = None return ParsedResponse( content=content_text, reasoning_content=reasoning or None, - tool_calls=tool_calls or None, + tool_calls=tool_calls, ) def _parse_glm_tool_calls( - tokenizer, ids, tc_id, tc_end_id, ak_id, ake_id, av_id, ave_id -) -> list[dict]: + tokenizer, + ids, + tc_id, + tc_end_id, + ak_id, + ake_id, + av_id, + ave_id, + *, + section_offset: int, +) -> list[ParsedToolCall]: """Parse GLM-style tool calls: name + arg_key/arg_value pairs, all by token ID.""" - tool_calls = [] + tool_calls: list[ParsedToolCall] = [] i = 0 while i < len(ids): if ids[i] == tc_id: end = _find(ids, tc_end_id, i + 1) if end == -1: + raw = _decode(tokenizer, ids[i + 1 :]) + tool_calls.append( + ParsedToolCall( + raw=raw, + token_span=(section_offset + i, section_offset + len(ids)), + status=ToolCallParseStatus.UNCLOSED_BLOCK, + ) + ) break block = ids[i + 1 : end] - # Name is everything before first + block_text = _decode(tokenizer, block) + span = (section_offset + i, section_offset + end + 1) first_ak = _find(block, ak_id) + any_json_fallback = False + structure_broke = False if first_ak == -1: name = _decode(tokenizer, block).strip() - arguments = {} + arguments: dict = {} else: name = _decode(tokenizer, block[:first_ak]).strip() arguments = {} @@ -287,23 +381,43 @@ def _parse_glm_tool_calls( if block[j] == ak_id: ake = _find(block, ake_id, j + 1) if ake == -1: + structure_broke = True break key = _decode(tokenizer, block[j + 1 : ake]).strip() av = _find(block, av_id, ake + 1) if av == -1: + structure_broke = True break ave = _find(block, ave_id, av + 1) if ave == -1: + structure_broke = True break val_text = _decode(tokenizer, block[av + 1 : ave]).strip() try: arguments[key] = json.loads(val_text) except (json.JSONDecodeError, ValueError): arguments[key] = val_text + any_json_fallback = True j = ave + 1 else: j += 1 - tool_calls.append({"function": {"name": name, "arguments": arguments}}) + if not name: + status = ToolCallParseStatus.MISSING_NAME + elif structure_broke: + status = ToolCallParseStatus.MALFORMED_STRUCTURE + elif any_json_fallback: + status = ToolCallParseStatus.INVALID_JSON + else: + status = ToolCallParseStatus.OK + tool_calls.append( + ParsedToolCall( + raw=block_text, + name=name or None, + arguments=arguments, + token_span=span, + status=status, + ) + ) i = end + 1 else: i += 1 @@ -334,8 +448,8 @@ def parse_deepseek_v3( """ ids = _strip_stop_tokens(token_ids, stop_ids) - # ── Tool calls ────────────────────────────────────────────────── tc_section_start = _find(ids, tool_calls_begin_id) + tool_calls: list[ParsedToolCall] = [] if tc_section_start != -1: content_ids = ids[:tc_section_start] tool_calls = _parse_deepseek_tool_calls( @@ -346,14 +460,13 @@ def parse_deepseek_v3( tool_call_begin_id, tool_call_end_id, tool_sep_id, + section_offset=tc_section_start, ) else: content_ids = ids - tool_calls = None text = _decode(tokenizer, content_ids) - # ── Thinking from text tags ──────────────────────────────────── reasoning = None if "" in text: before, _, after = text.partition("") @@ -363,7 +476,7 @@ def parse_deepseek_v3( return ParsedResponse( content=text.strip(), reasoning_content=reasoning or None, - tool_calls=tool_calls or None, + tool_calls=tool_calls, ) @@ -375,80 +488,106 @@ def _parse_deepseek_tool_calls( call_begin_id: int, call_end_id: int, sep_id: int, -) -> list[dict] | None: - """Parse DeepSeek V3-style tool calls from token IDs. - - Each individual tool call is delimited by <|tool▁call▁begin|> ... <|tool▁call▁end|>. - Inside, <|tool▁sep|> separates the call type (e.g. "function") from the - function name and JSON arguments block. - """ + *, + section_offset: int, +) -> list[ParsedToolCall]: + """Parse DeepSeek V3-style tool calls from token IDs.""" import re - tool_calls: list[dict] = [] + tool_calls: list[ParsedToolCall] = [] - # Find the outer section boundaries. section_start = _find(ids, tc_begin_id) if section_start == -1: - return None + return tool_calls section_end = _find(ids, tc_end_id, section_start + 1) + section_end_clipped = section_end == -1 if section_end == -1: section_end = len(ids) + inner_offset = section_offset + section_start + 1 section_ids = ids[section_start + 1 : section_end] i = 0 while i < len(section_ids): if section_ids[i] == call_begin_id: end = _find(section_ids, call_end_id, i + 1) - if end == -1: + unclosed = end == -1 + if unclosed: end = len(section_ids) - call_ids = section_ids[i + 1 : end] + block_text = _decode(tokenizer, call_ids) + # Span for this call covers its .. range + # within the (stop-stripped) parent token stream. + span = ( + inner_offset + i, + inner_offset + end + (0 if unclosed else 1), + ) - # Find <|tool▁sep|> to split type from name+args. sep_pos = _find(call_ids, sep_id) if sep_pos == -1: - # Malformed — skip. + tool_calls.append( + ParsedToolCall( + raw=block_text, + token_span=span, + status=ToolCallParseStatus.MALFORMED_STRUCTURE, + ) + ) i = end + 1 continue - # Everything after <|tool▁sep|> is the name and args block. after_sep_ids = call_ids[sep_pos + 1 :] after_sep_text = _decode(tokenizer, after_sep_ids).strip() - # Extract function name and JSON arguments. - # Format: "{name}\n```json\n{args}\n```" - # But we also gracefully handle raw JSON without the code fence. name = "" args_str = "" - - # Try to split on first newline to get name, then find JSON. newline_pos = after_sep_text.find("\n") if newline_pos != -1: name = after_sep_text[:newline_pos].strip() rest = after_sep_text[newline_pos + 1 :].strip() - # Strip optional ```json ... ``` fence. fence_match = re.match(r"```(?:json)?\s*([\s\S]*?)\s*```$", rest) - if fence_match: - args_str = fence_match.group(1).strip() - else: - args_str = rest + args_str = fence_match.group(1).strip() if fence_match else rest else: - # No newline — treat entire text as name, no args. name = after_sep_text - # Parse arguments as JSON. + arguments: dict | str + invalid_json = False try: arguments = json.loads(args_str) if args_str else {} except json.JSONDecodeError: - arguments = args_str # preserve raw string on failure + arguments = args_str + invalid_json = True + + if unclosed: + status = ToolCallParseStatus.UNCLOSED_BLOCK + elif not name: + status = ToolCallParseStatus.MISSING_NAME + elif invalid_json: + status = ToolCallParseStatus.INVALID_JSON + else: + status = ToolCallParseStatus.OK - tool_calls.append({"function": {"name": name, "arguments": arguments}}) + tool_calls.append( + ParsedToolCall( + raw=block_text, + name=name or None, + arguments=arguments, + token_span=span, + status=status, + ) + ) i = end + 1 + if unclosed: + break else: i += 1 - return tool_calls if tool_calls else None + # If the outer had no matching , any + # call inside that didn't itself flag UNCLOSED_BLOCK is still nested in + # a truncated section — but we already mark individual unclosed calls, + # so we don't double-flag here. The section_end_clipped variable is + # carried for the (rare) caller that wants section-level UX. + _ = section_end_clipped + return tool_calls # ── MiniMax: ... ──────────── @@ -465,73 +604,155 @@ def parse_minimax( tool_call_end_id: int, ) -> ParsedResponse: """Parse MiniMax M2 completion tokens.""" + import re + ids = _strip_stop_tokens(token_ids, stop_ids) - # Thinking: by token ID. MiniMax doesn't generate start. reasoning = None + parse_offset = 0 think_end = _find(ids, think_end_id) if think_end != -1: reasoning_ids = ids[:think_end] reasoning_ids = [t for t in reasoning_ids if t != think_id] reasoning = _decode(tokenizer, reasoning_ids).strip() ids = ids[think_end + 1 :] + parse_offset = think_end + 1 elif think_id in set(ids): think_start = _find(ids, think_id) reasoning = _decode(tokenizer, ids[think_start + 1 :]).strip() return ParsedResponse( - content="", reasoning_content=reasoning or None, tool_calls=None + content="", reasoning_content=reasoning or None, tool_calls=[] ) - # Tool calls by token ID tc_start = _find(ids, tool_call_id) + tool_calls: list[ParsedToolCall] = [] if tc_start != -1: content_text = _decode(tokenizer, ids[:tc_start]).strip() - # Decode the tool call blocks and parse with regex (invoke/parameter are text, not tokens) - tool_calls = [] i = tc_start while i < len(ids): if ids[i] == tool_call_id: end = _find(ids, tool_call_end_id, i + 1) if end == -1: + raw = _decode(tokenizer, ids[i + 1 :]) + tool_calls.append( + ParsedToolCall( + raw=raw, + token_span=( + parse_offset + i, + parse_offset + len(ids), + ), + status=ToolCallParseStatus.UNCLOSED_BLOCK, + ) + ) break block_text = _decode(tokenizer, ids[i + 1 : end]) - import re - - for invoke_match in re.finditer( - r'(.*?)', block_text, re.DOTALL - ): - name = invoke_match.group(1) - body = invoke_match.group(2) - arguments = {} - for pm in re.finditer( - r'(.*?)', body, re.DOTALL - ): - pname = pm.group(1) - pval = pm.group(2).strip() - try: - arguments[pname] = json.loads(pval) - except (json.JSONDecodeError, ValueError): - arguments[pname] = pval + span = (parse_offset + i, parse_offset + end + 1) + + invokes = list( + re.finditer( + r'(.*?)', + block_text, + re.DOTALL, + ) + ) + if not invokes: + # Block exists but contains no — model emitted + # the wrapper without a usable body. tool_calls.append( - {"function": {"name": name, "arguments": arguments}} + ParsedToolCall( + raw=block_text, + token_span=span, + status=ToolCallParseStatus.MALFORMED_STRUCTURE, + ) ) + else: + for invoke_match in invokes: + name = invoke_match.group(1) + body = invoke_match.group(2) + arguments: dict = {} + any_json_fallback = False + for pm in re.finditer( + r'(.*?)', + body, + re.DOTALL, + ): + pname = pm.group(1) + pval = pm.group(2).strip() + try: + arguments[pname] = json.loads(pval) + except (json.JSONDecodeError, ValueError): + arguments[pname] = pval + any_json_fallback = True + tool_calls.append( + ParsedToolCall( + raw=block_text, + name=name, + arguments=arguments, + # All invokes in a block share the wrapper span. + token_span=span, + status=( + ToolCallParseStatus.INVALID_JSON + if any_json_fallback + else ToolCallParseStatus.OK + ), + ) + ) i = end + 1 else: i += 1 else: content_text = _decode(tokenizer, ids).strip() - tool_calls = None return ParsedResponse( content=content_text, reasoning_content=reasoning or None, - tool_calls=tool_calls or None, + tool_calls=tool_calls, ) # ── Kimi K2: <|tool_calls_section_begin|> ... <|tool_calls_section_end|> ──── +def parse_kimi_k2_section( + tokenizer, + ids: list[int], + *, + tool_calls_section_begin_ids: set[int], + tool_calls_section_end_ids: set[int], + tool_call_begin_id: int, + tool_call_argument_begin_id: int, + tool_call_end_id: int, +) -> tuple[list[int], list[ParsedToolCall]]: + """Split ``ids`` into ``(content_before_section, tool_calls)`` by finding + the Kimi-style tool-call section delimiters. + + Accepts *sets* of begin/end token IDs so callers can express models with + multiple delimiter variants (K2.5 has both plural ``<|tool_calls_section_*|>`` + and singular ``<|tool_call_section_*|>`` forms, though only the plural form + is in the special-token vocab in practice). Returns the content ids ahead + of the section and a list of ``ParsedToolCall`` covering every attempted + block inside it; an unclosed section is still walked to whatever the model + emitted before EOS. Returns ``(ids, [])`` when no section is present. + """ + section_start = _find_any(ids, tool_calls_section_begin_ids) + if section_start == -1: + return list(ids), [] + content_ids = ids[:section_start] + section_end = _find_any(ids, tool_calls_section_end_ids, section_start + 1) + if section_end == -1: + section_end = len(ids) + section_ids = ids[section_start + 1 : section_end] + tool_calls = _parse_kimi_k2_tool_calls( + tokenizer, + section_ids, + tool_call_begin_id, + tool_call_argument_begin_id, + tool_call_end_id, + section_offset=section_start + 1, + ) + return content_ids, tool_calls + + def parse_kimi_k2( tokenizer, token_ids: list[int], @@ -551,26 +772,16 @@ def parse_kimi_k2( """ ids = _strip_stop_tokens(token_ids, stop_ids) - # ── Tool calls ──────────────────────────────────────────────── - section_start = _find(ids, tool_calls_section_begin_id) - if section_start != -1: - content_ids = ids[:section_start] - section_end = _find(ids, tool_calls_section_end_id, section_start + 1) - if section_end == -1: - section_end = len(ids) - section_ids = ids[section_start + 1 : section_end] - tool_calls = _parse_kimi_k2_tool_calls( - tokenizer, - section_ids, - tool_call_begin_id, - tool_call_argument_begin_id, - tool_call_end_id, - ) - else: - content_ids = ids - tool_calls = None + content_ids, tool_calls = parse_kimi_k2_section( + tokenizer, + ids, + tool_calls_section_begin_ids={tool_calls_section_begin_id}, + tool_calls_section_end_ids={tool_calls_section_end_id}, + tool_call_begin_id=tool_call_begin_id, + tool_call_argument_begin_id=tool_call_argument_begin_id, + tool_call_end_id=tool_call_end_id, + ) - # ── Thinking from text tags ─────────────────────────────────── text = _decode(tokenizer, content_ids) reasoning: str | None = None if "" in text: @@ -585,13 +796,13 @@ def parse_kimi_k2( return ParsedResponse( content="", reasoning_content=reasoning, - tool_calls=None, + tool_calls=[], ) return ParsedResponse( content=text.strip(), reasoning_content=reasoning, - tool_calls=tool_calls or None, + tool_calls=tool_calls, ) @@ -601,7 +812,9 @@ def _parse_kimi_k2_tool_calls( tc_begin_id: int, tc_arg_begin_id: int, tc_end_id: int, -) -> list[dict]: + *, + section_offset: int, +) -> list[ParsedToolCall]: """Parse individual Kimi K2 tool calls from the section token IDs. Format per call: @@ -610,45 +823,70 @@ def _parse_kimi_k2_tool_calls( The ``id`` is in format ``functions.name:index``; the function name is extracted by stripping the ``functions.`` prefix and ``:index`` suffix. """ - tool_calls: list[dict] = [] + tool_calls: list[ParsedToolCall] = [] i = 0 while i < len(ids): if ids[i] == tc_begin_id: - # Find <|tool_call_argument_begin|> arg_begin = _find(ids, tc_arg_begin_id, i + 1) if arg_begin == -1: + raw = _decode(tokenizer, ids[i + 1 :]) + tool_calls.append( + ParsedToolCall( + raw=raw, + token_span=(section_offset + i, section_offset + len(ids)), + status=ToolCallParseStatus.MALFORMED_STRUCTURE, + ) + ) break - # Find <|tool_call_end|> tc_end = _find(ids, tc_end_id, arg_begin + 1) + unclosed = tc_end == -1 if tc_end == -1: tc_end = len(ids) raw_id = _decode(tokenizer, ids[i + 1 : arg_begin]).strip() args_str = _decode(tokenizer, ids[arg_begin + 1 : tc_end]).strip() + block_text = _decode(tokenizer, ids[i + 1 : tc_end]) + span = ( + section_offset + i, + section_offset + tc_end + (0 if unclosed else 1), + ) - # Extract function name from "functions.name:index" name_part = raw_id.split(":", 1)[0] if "." in name_part: _, func_name = name_part.split(".", 1) else: func_name = name_part + arguments: dict | str + invalid_json = False try: arguments = json.loads(args_str) except json.JSONDecodeError: arguments = args_str + invalid_json = True + + if unclosed: + status = ToolCallParseStatus.UNCLOSED_BLOCK + elif not func_name: + status = ToolCallParseStatus.MISSING_NAME + elif invalid_json: + status = ToolCallParseStatus.INVALID_JSON + else: + status = ToolCallParseStatus.OK tool_calls.append( - { - "id": raw_id, - "type": "function", - "function": { - "name": func_name, - "arguments": arguments, - }, - } + ParsedToolCall( + raw=block_text, + name=func_name or None, + arguments=arguments, + token_span=span, + status=status, + id=raw_id or None, + ) ) i = tc_end + 1 + if unclosed: + break else: i += 1 return tool_calls @@ -692,7 +930,7 @@ def parse_gpt_oss( reasoning_parts: list[str] = [] content_parts: list[str] = [] - tool_calls: list[dict] = [] + tool_calls: list[ParsedToolCall] = [] i = 0 while i < len(ids): @@ -700,18 +938,14 @@ def parse_gpt_oss( i += 1 continue - # Find <|message|> that terminates this block's header + block_start = i msg_pos = _find(ids, message_id, i + 1) if msg_pos == -1: break - # Header: tokens between <|start|> and <|message|> header_ids = ids[i + 1 : msg_pos] header_text = _decode(tokenizer, header_ids) - # Body: tokens from after <|message|> up to the next block boundary - # (<|start|>, <|end|>, or <|call|> — the last closes a tool-call - # commentary block within the same turn). body_start = msg_pos + 1 candidates = [ pos @@ -723,39 +957,54 @@ def parse_gpt_oss( if pos != -1 ] body_end = min(candidates) if candidates else len(ids) + body_closed = bool(candidates) and ids[body_end] in (end_id, call_id) body_text = _decode(tokenizer, ids[body_start:body_end]) - # Extract channel: token after <|channel|> in header_ids channel = _gptoss_extract_after_token(tokenizer, header_ids, channel_id) - # Extract recipient: "to=..." field in header text recipient_match = re.search(r"to=([^\s<]+)", header_text) recipient = recipient_match.group(1) if recipient_match else None if recipient and recipient.startswith("functions."): tool_name = recipient[len("functions.") :] + block_end = body_end + 1 if body_closed else body_end + span = (block_start, block_end) try: arguments = json.loads(body_text) except json.JSONDecodeError: - arguments = body_text # preserve raw string on failure - tool_calls.append( - { - "function": { - "name": tool_name, - "arguments": arguments, - } - } - ) + tool_calls.append( + ParsedToolCall( + raw=body_text, + name=tool_name or None, + arguments=body_text, + token_span=span, + status=ToolCallParseStatus.INVALID_JSON, + ) + ) + else: + if not body_closed: + status = ToolCallParseStatus.UNCLOSED_BLOCK + elif not tool_name: + status = ToolCallParseStatus.MISSING_NAME + else: + status = ToolCallParseStatus.OK + tool_calls.append( + ParsedToolCall( + raw=body_text, + name=tool_name or None, + arguments=arguments, + token_span=span, + status=status, + ) + ) elif channel == "analysis": reasoning_parts.append(body_text) elif channel == "final": content_parts.append(body_text) elif channel == "commentary": - # Commentary without a tool recipient is a user-visible preamble content_parts.append(body_text) - # Advance: skip body + any trailing <|end|> / <|call|> i = body_end if i < len(ids) and ids[i] in (end_id, call_id): i += 1 @@ -766,7 +1015,7 @@ def parse_gpt_oss( return ParsedResponse( content=content, reasoning_content=reasoning, - tool_calls=tool_calls or None, + tool_calls=tool_calls, ) @@ -780,5 +1029,4 @@ def _gptoss_extract_after_token( if pos == -1: return None after = _decode(tokenizer, header_ids[pos + 1 :]).strip() - # Take first whitespace-delimited word (channel name) return after.split()[0] if after else None diff --git a/tests/test_client.py b/tests/test_client.py index 3f9c189..fbf8742 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -3,7 +3,12 @@ import numpy as np -from renderers.base import ParsedResponse, RenderedTokens +from renderers.base import ( + ParsedResponse, + ParsedToolCall, + RenderedTokens, + ToolCallParseStatus, +) from renderers.client import generate @@ -30,12 +35,12 @@ def parse_response(self, completion_ids: list[int]) -> ParsedResponse: content="done", reasoning_content="think", tool_calls=[ - { - "function": { - "name": "echo", - "arguments": {"text": "hello"}, - } - } + ParsedToolCall( + raw='{"name": "echo", "arguments": {"text": "hello"}}', + name="echo", + arguments={"text": "hello"}, + status=ToolCallParseStatus.OK, + ) ], ) @@ -111,26 +116,59 @@ def test_generate_builds_request_body_and_parses_response(): }, } # finish_reason promoted from "stop" → "tool_calls" because the renderer - # extracted tool calls client-side. - assert result == { - "request_id": "gen-test", - "prompt_ids": [1, 2, 3], - "completion_ids": [7, 8], - "completion_logprobs": [-0.1, -0.2], - "content": "done", - "reasoning_content": "think", - "tool_calls": [ - { - "function": { - "name": "echo", - "arguments": {"text": "hello"}, - } - } - ], - "finish_reason": "tool_calls", - "routed_experts": [[[1]], [[2]]], - "multi_modal_data": None, - } + # extracted at least one well-formed tool call client-side. + assert result["finish_reason"] == "tool_calls" + assert result["content"] == "done" + assert result["reasoning_content"] == "think" + assert result["prompt_ids"] == [1, 2, 3] + assert result["completion_ids"] == [7, 8] + assert result["completion_logprobs"] == [-0.1, -0.2] + assert result["routed_experts"] == [[[1]], [[2]]] + assert result["multi_modal_data"] is None + assert result["request_id"] == "gen-test" + assert len(result["tool_calls"]) == 1 + tc = result["tool_calls"][0] + assert tc.name == "echo" + assert tc.arguments == {"text": "hello"} + assert tc.status == ToolCallParseStatus.OK + + +class _MalformedToolRenderer(_FakeRenderer): + """Returns only a malformed tool-call attempt — finish_reason must stay "stop".""" + + def parse_response(self, completion_ids: list[int]) -> ParsedResponse: + return ParsedResponse( + content="", + reasoning_content=None, + tool_calls=[ + ParsedToolCall( + raw='{"name": "echo", broken', + status=ToolCallParseStatus.INVALID_JSON, + ) + ], + ) + + +def test_generate_does_not_promote_finish_reason_for_malformed_tool_calls(): + """A malformed tool-call attempt must NOT promote finish_reason to + "tool_calls" — only well-formed (status=OK) calls qualify. The + malformed attempt is still preserved in ``tool_calls`` for verifier + inspection, but the agent loop should not treat the turn as a + successful tool invocation. + """ + client = _FakeClient() + result = asyncio.run( + generate( + client=client, + renderer=_MalformedToolRenderer(), + messages=[{"role": "user", "content": "hi"}], + model="test-model", + tools=[{"type": "function", "function": {"name": "echo"}}], + ) + ) + assert result["finish_reason"] == "stop" + assert len(result["tool_calls"]) == 1 + assert result["tool_calls"][0].status == ToolCallParseStatus.INVALID_JSON class _NoRenderRenderer(_FakeRenderer): diff --git a/tests/test_parse_response.py b/tests/test_parse_response.py index 403b609..bc17544 100644 --- a/tests/test_parse_response.py +++ b/tests/test_parse_response.py @@ -7,7 +7,7 @@ from functools import lru_cache from renderers import create_renderer -from renderers.base import load_tokenizer +from renderers.base import ToolCallParseStatus, load_tokenizer @lru_cache @@ -62,18 +62,26 @@ def test_qwen3_vl_parse_json_tool_call(): parsed = renderer.parse_response(tokenizer.encode(text, add_special_tokens=False)) assert parsed.content == "Need a tool." - assert parsed.tool_calls == [ - {"function": {"name": "get_weather", "arguments": {"city": "Paris"}}} - ] - - -def test_qwen3_vl_malformed_tool_call_falls_back_to_content(): - """When ... contains malformed JSON, match - vLLM's hermes_tool_parser behavior: preserve the raw tokens as - content rather than returning empty content + empty tool_calls. - Without this, the orchestrator raises EmptyModelResponseError and - wastes inference compute on retries — diverging from main's - behavior on hermes tool envs (Qwen3, etc.). + assert len(parsed.tool_calls) == 1 + tc = parsed.tool_calls[0] + assert tc.status == ToolCallParseStatus.OK + assert tc.name == "get_weather" + assert tc.arguments == {"city": "Paris"} + + +def test_qwen3_vl_malformed_tool_call_surfaces_as_invalid_json(): + """A malformed ```` block lands as a non-OK ``ParsedToolCall`` + rather than getting silently merged back into ``content``. + + Before the per-call status redesign, the parser mirrored vLLM's + hermes parser and stuffed the raw block into ``content`` to avoid + downstream ``EmptyModelResponseError``. That hid the malformed signal + from verifiers — they couldn't tell "model wrote prose" from "model + tried a tool call and produced broken JSON." Now the failed attempt + is preserved with ``status=INVALID_JSON`` and ``raw`` text, which + also satisfies the EmptyModelResponseError prevention contract: the + response is non-empty (it has a tool-call attempt) without lying + about what kind of output the model produced. """ tokenizer, renderer = _qwen3_vl() # Note the trailing comma — malformed JSON @@ -83,13 +91,47 @@ def test_qwen3_vl_malformed_tool_call_falls_back_to_content(): ) parsed = renderer.parse_response(tokenizer.encode(text, add_special_tokens=False)) - # Parser must not collapse response: either content has the raw - # tokens OR there's at least a tool_call attempt. Concretely, we - # want content to be non-empty so the caller doesn't raise - # EmptyModelResponseError. - assert parsed.tool_calls is None, "Malformed JSON should not parse as a tool call" - assert parsed.content, ( - "Malformed tool_call should fall back to raw content, not empty " - "(else caller raises EmptyModelResponseError)" + assert len(parsed.tool_calls) == 1 + tc = parsed.tool_calls[0] + assert tc.status == ToolCallParseStatus.INVALID_JSON + assert "get_weather" in tc.raw + assert tc.token_span is not None + + +@lru_cache +def _kimi_k25(): + tokenizer = load_tokenizer("moonshotai/Kimi-K2.5") + renderer = create_renderer(tokenizer, renderer="auto") + return tokenizer, renderer + + +def test_kimi_k25_tool_call_carries_token_span(): + """K2.5 was the lone parser without token spans before — its inline + text-walking implementation couldn't cheaply map regex hits back to + token offsets. We now walk token IDs via ``parse_kimi_k2_section`` for + the special-token path; spans must round-trip and point at a sensible + range within the original input token_ids. + """ + tokenizer, renderer = _kimi_k25() + # K2.5 tool-call wire shape: section + per-call special tokens. + text = ( + "<|tool_calls_section_begin|>" + "<|tool_call_begin|>functions.get_weather:0" + "<|tool_call_argument_begin|>" + '{"city": "Tokyo"}' + "<|tool_call_end|>" + "<|tool_calls_section_end|>" + ) + token_ids = tokenizer.encode(text, add_special_tokens=False) + parsed = renderer.parse_response(token_ids) + + assert len(parsed.tool_calls) == 1 + tc = parsed.tool_calls[0] + assert tc.status == ToolCallParseStatus.OK + assert tc.name == "get_weather" + assert tc.arguments == {"city": "Tokyo"} + assert tc.token_span is not None + start, end = tc.token_span + assert 0 <= start < end <= len(token_ids), ( + f"span {tc.token_span} out of range for {len(token_ids)} input tokens" ) - assert "get_weather" in parsed.content diff --git a/tests/test_parse_response_robustness.py b/tests/test_parse_response_robustness.py index 10b36b9..1824da0 100644 --- a/tests/test_parse_response_robustness.py +++ b/tests/test_parse_response_robustness.py @@ -4,7 +4,7 @@ even with adversarial or truncated model output. """ -from renderers.base import ParsedResponse +from renderers.base import ParsedResponse, ParsedToolCall # ── Truncation ─────────────────────────────────────────────────────── @@ -78,7 +78,7 @@ def test_content_only_no_thinking(model_name, tokenizer, renderer): ids = tokenizer.encode(text, add_special_tokens=False) parsed = renderer.parse_response(ids) assert "Hello" in parsed.content - assert parsed.tool_calls is None + assert parsed.tool_calls == [] # ── Tool call edge cases ───────────────────────────────────────────── @@ -128,11 +128,17 @@ def test_reasoning_is_string_or_none(model_name, tokenizer, renderer): assert parsed.reasoning_content is None or isinstance(parsed.reasoning_content, str) -def test_tool_calls_is_list_or_none(model_name, tokenizer, renderer): - """tool_calls must be list or None, never empty list.""" +def test_tool_calls_is_list_of_parsed_tool_call(model_name, tokenizer, renderer): + """tool_calls is always a (possibly empty) list of ParsedToolCall — never None. + + Empty list = "model did not emit any tool calls". A list with non-OK + entries = "model tried and the parser caught the failure"; those are + deliberately preserved so verifier / RL-loss code can see them. This + replaces the older list-or-None convention. + """ text = "Hello!" ids = tokenizer.encode(text, add_special_tokens=False) parsed = renderer.parse_response(ids) - assert parsed.tool_calls is None or ( - isinstance(parsed.tool_calls, list) and len(parsed.tool_calls) > 0 - ) + assert isinstance(parsed.tool_calls, list) + for tc in parsed.tool_calls: + assert isinstance(tc, ParsedToolCall) diff --git a/tests/test_parsers.py b/tests/test_parsers.py index c0ccbf4..3ec5bb5 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -4,7 +4,7 @@ import pytest -from renderers.base import load_tokenizer +from renderers.base import ToolCallParseStatus, load_tokenizer from renderers.parsers import ( REASONING_PARSERS, TOOL_PARSERS, @@ -40,9 +40,12 @@ def test_qwen3_tool_parser_roundtrip(): completion_text = 'hello\n\n{"name": "search", "arguments": {"q": "rain"}}\n' token_ids = tok.encode(completion_text, add_special_tokens=False) content_ids, tool_calls = parser.extract(list(token_ids)) - assert tool_calls is not None - assert tool_calls[0]["function"]["name"] == "search" - assert tool_calls[0]["function"]["arguments"] == {"q": "rain"} + assert len(tool_calls) == 1 + tc = tool_calls[0] + assert tc.status == ToolCallParseStatus.OK + assert tc.name == "search" + assert tc.arguments == {"q": "rain"} + assert tc.token_span is not None and tc.token_span[0] < tc.token_span[1] # content_ids should cover everything up to (but not including) content_text = tok.decode(content_ids, skip_special_tokens=False) assert "hello" in content_text @@ -54,10 +57,44 @@ def test_qwen3_tool_parser_no_tool_call(): parser = get_tool_parser("qwen3", tok) ids = tok.encode("just plain text response", add_special_tokens=False) content_ids, tool_calls = parser.extract(list(ids)) - assert tool_calls is None + assert tool_calls == [] assert content_ids == list(ids) +def test_qwen3_tool_parser_records_invalid_json(): + """Malformed JSON in a block surfaces as INVALID_JSON, not silently dropped.""" + tok = load_tokenizer("Qwen/Qwen3-0.6B") + parser = get_tool_parser("qwen3", tok) + completion = ( + 'hi\n\n{"name": "f", "arguments": {broken json\n' + ) + ids = tok.encode(completion, add_special_tokens=False) + _, tool_calls = parser.extract(list(ids)) + assert len(tool_calls) == 1 + assert tool_calls[0].status == ToolCallParseStatus.INVALID_JSON + assert tool_calls[0].raw # raw block text preserved + + +def test_qwen3_tool_parser_parallel_partial_success(): + """Parallel calls: parser keeps the good ones AND records the broken one.""" + tok = load_tokenizer("Qwen/Qwen3-0.6B") + parser = get_tool_parser("qwen3", tok) + completion = ( + "pre\n" + '\n{"name": "a", "arguments": {}}\n\n' + "\n{broken\n\n" + '\n{"name": "c", "arguments": {"x": 1}}\n' + ) + ids = tok.encode(completion, add_special_tokens=False) + _, tool_calls = parser.extract(list(ids)) + assert [tc.status for tc in tool_calls] == [ + ToolCallParseStatus.OK, + ToolCallParseStatus.INVALID_JSON, + ToolCallParseStatus.OK, + ] + assert [tc.name for tc in tool_calls] == ["a", None, "c"] + + def test_think_reasoning_parser_extracts_block(): tok = load_tokenizer("Qwen/Qwen3-0.6B") parser = get_reasoning_parser("think", tok) @@ -90,8 +127,9 @@ def test_default_renderer_uses_parsers(): parsed = renderer.parse_response(list(ids)) assert parsed.reasoning_content == "think" assert parsed.content.startswith("ok") - assert parsed.tool_calls is not None - assert parsed.tool_calls[0]["function"]["name"] == "f" + assert len(parsed.tool_calls) == 1 + assert parsed.tool_calls[0].name == "f" + assert parsed.tool_calls[0].status == ToolCallParseStatus.OK def test_default_renderer_without_parsers_is_backward_compatible(): @@ -106,4 +144,4 @@ def test_default_renderer_without_parsers_is_backward_compatible(): parsed = renderer.parse_response(list(ids)) assert parsed.reasoning_content == "r" assert parsed.content == "a" - assert parsed.tool_calls is None + assert parsed.tool_calls == [] diff --git a/tests/test_roundtrip.py b/tests/test_roundtrip.py index 8c313ed..1cdac82 100644 --- a/tests/test_roundtrip.py +++ b/tests/test_roundtrip.py @@ -158,8 +158,8 @@ def test_roundtrip_reasoning_and_content(rt_model, rt_tokenizer, rt_renderer): def _maybe_skip_tool_calls(renderer_name: str) -> None: - """DefaultRenderer without a tool_parser configured always returns - tool_calls=None. That's a documented limitation, not a bug — skip.""" + """DefaultRenderer without a tool_parser configured always returns an + empty tool_calls list. That's a documented limitation, not a bug — skip.""" if renderer_name == "default": pytest.skip( "DefaultRenderer requires an explicit tool_parser to parse tool " @@ -194,11 +194,9 @@ def test_roundtrip_single_tool_call( assert parsed.tool_calls, f"{rt_model}: tool_calls lost, got {parsed.tool_calls!r}" assert len(parsed.tool_calls) == 1 tc = parsed.tool_calls[0] - assert tc["function"]["name"] == "get_weather", ( - f"{rt_model}: name mangled, got {tc!r}" - ) - assert _normalize_args(tc["function"]["arguments"]) == {"city": "Tokyo"}, ( - f"{rt_model}: args mangled, got {tc['function']['arguments']!r}" + assert tc.name == "get_weather", f"{rt_model}: name mangled, got {tc!r}" + assert _normalize_args(tc.arguments) == {"city": "Tokyo"}, ( + f"{rt_model}: args mangled, got {tc.arguments!r}" ) @@ -235,19 +233,15 @@ def test_roundtrip_multiple_tool_calls( completion_ids = _extract_assistant_tokens(rt_renderer, PROMPT, msg) parsed = rt_renderer.parse_response(completion_ids) - assert parsed.tool_calls is not None and len(parsed.tool_calls) == 2, ( + assert len(parsed.tool_calls) == 2, ( f"{rt_model}: expected 2 tool_calls, got {parsed.tool_calls!r}" ) - names = [tc["function"]["name"] for tc in parsed.tool_calls] + names = [tc.name for tc in parsed.tool_calls] assert names == ["get_weather", "get_time"], ( f"{rt_model}: names/order wrong, got {names}" ) - assert _normalize_args(parsed.tool_calls[0]["function"]["arguments"]) == { - "city": "Tokyo" - } - assert _normalize_args(parsed.tool_calls[1]["function"]["arguments"]) == { - "zone": "JST" - } + assert _normalize_args(parsed.tool_calls[0].arguments) == {"city": "Tokyo"} + assert _normalize_args(parsed.tool_calls[1].arguments) == {"zone": "JST"} # ── byte-exact re-render invariant ───────────────────────────────────── From 1cefc5c4e4f10251efffbd5a1db7c115217ef8fe Mon Sep 17 00:00:00 2001 From: hallerite Date: Wed, 13 May 2026 13:30:40 +0000 Subject: [PATCH 12/13] chore: license under Apache 2.0 (#27) --- LICENSE | 201 +++++++++++++++++++++++++++++++++++++++++++++++++ README.md | 4 + pyproject.toml | 2 + 3 files changed, 207 insertions(+) create mode 100644 LICENSE diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..307bb07 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for describing the origin of the Work and + reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Support. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or support. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2026 Prime Intellect + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index bff2afe..51e4d19 100644 --- a/README.md +++ b/README.md @@ -141,3 +141,7 @@ uv run pytest ``` Round-trip parity (render → parse → original) and token-level parity against `apply_chat_template` are tested per renderer. End-to-end validation runs against Reverse-Text, Wordle, OpenCode-Math, and RLM-SWE environments. + +## License + +Licensed under the [Apache License, Version 2.0](LICENSE). diff --git a/pyproject.toml b/pyproject.toml index 0b4a00c..f511b65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,8 @@ name = "renderers" dynamic = ["version"] description = "Chat template renderers — deterministic message-to-token conversion for LLM training" readme = "README.md" +license = "Apache-2.0" +license-files = ["LICENSE"] requires-python = ">=3.10,<3.14" dependencies = [ "numpy", From 9e59cc901cc0ff04923200e76e6806eceb41bb34 Mon Sep 17 00:00:00 2001 From: hallerite Date: Wed, 13 May 2026 13:30:42 +0000 Subject: [PATCH 13/13] fix(laguna): migrate parser to ParsedToolCall API + drop broken assert_never MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After merging main, the Laguna parser was still on the old ``list[dict]`` shape and 4 tests failed against ``ParsedResponse.tool_calls``'s new ``list[ParsedToolCall]`` type (introduced in #22). Mirror ``parse_glm``'s structure: emit ``ParsedToolCall`` with a ``status`` enum (UNCLOSED_BLOCK / MISSING_NAME / INVALID_JSON / OK) and a ``token_span`` relative to the stop-stripped stream. Also drop ``assert_never(unexpected_role)``: ``msg["role"]`` is plain ``str`` (TypedDict), so ``ty`` flags a type-assertion-failure and any unknown role would crash at runtime — every other renderer silently skips unknown roles. Co-Authored-By: Claude Opus 4.7 (1M context) --- renderers/laguna_xs2.py | 3 --- renderers/parsing.py | 48 +++++++++++++++++++++++++++++++++++------ 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/renderers/laguna_xs2.py b/renderers/laguna_xs2.py index c981265..ee8e64d 100644 --- a/renderers/laguna_xs2.py +++ b/renderers/laguna_xs2.py @@ -26,7 +26,6 @@ from __future__ import annotations import json -from typing import Any, assert_never from transformers.tokenization_utils import PreTrainedTokenizer @@ -214,8 +213,6 @@ def emit_text(text: str, msg_idx: int) -> None: ) case "tool": emit_text("\n" + content + "\n\n", i) - case unexpected_role: - assert_never(unexpected_role) # ── Generation prompt ───────────────────────────────────────── if add_generation_prompt: diff --git a/renderers/parsing.py b/renderers/parsing.py index 19cf6ec..e235c76 100644 --- a/renderers/parsing.py +++ b/renderers/parsing.py @@ -455,32 +455,38 @@ def parse_laguna_xs2( # decoded segment — never a bare ``.strip()``, which would also eat # whitespace the model emitted intentionally. reasoning = None + parse_offset = 0 think_end = _find(ids, think_end_id) if think_end != -1: reasoning_ids = ids[:think_end] reasoning_ids = [t for t in reasoning_ids if t != think_id] reasoning = _decode(tokenizer, reasoning_ids).strip("\n") ids = ids[think_end + 1 :] + parse_offset = think_end + 1 elif (think_start := _find(ids, think_id)) != -1: reasoning = _decode(tokenizer, ids[think_start + 1 :]).strip("\n") return ParsedResponse( - content="", reasoning_content=reasoning or None, tool_calls=None + content="", reasoning_content=reasoning or None, tool_calls=[] ) tc_start = _find(ids, tool_call_id) + tool_calls: list[ParsedToolCall] = [] if tc_start != -1: content_text = _decode(tokenizer, ids[:tc_start]).strip("\n") tool_calls = _parse_laguna_xs2_tool_calls( - tokenizer, ids[tc_start:], tool_call_id, tool_call_end_id + tokenizer, + ids[tc_start:], + tool_call_id, + tool_call_end_id, + section_offset=parse_offset + tc_start, ) else: content_text = _decode(tokenizer, ids).strip("\n") - tool_calls = None return ParsedResponse( content=content_text, reasoning_content=reasoning or None, - tool_calls=tool_calls or None, + tool_calls=tool_calls, ) @@ -489,7 +495,9 @@ def _parse_laguna_xs2_tool_calls( ids: list[int], tc_id: int, tc_end_id: int, -) -> list[dict]: + *, + section_offset: int, +) -> list[ParsedToolCall]: """Parse Laguna-XS.2 tool calls. Inside each ``...`` block, the format is:: @@ -504,14 +512,23 @@ def _parse_laguna_xs2_tool_calls( """ import re - tool_calls: list[dict] = [] + tool_calls: list[ParsedToolCall] = [] i = 0 while i < len(ids): if ids[i] == tc_id: tc_end = _find(ids, tc_end_id, i + 1) if tc_end == -1: + raw = _decode(tokenizer, ids[i + 1 :]) + tool_calls.append( + ParsedToolCall( + raw=raw, + token_span=(section_offset + i, section_offset + len(ids)), + status=ToolCallParseStatus.UNCLOSED_BLOCK, + ) + ) break block_text = _decode(tokenizer, ids[i + 1 : tc_end]) + span = (section_offset + i, section_offset + tc_end + 1) ak_pos = block_text.find("") if ak_pos != -1: @@ -522,6 +539,7 @@ def _parse_laguna_xs2_tool_calls( args_section = "" arguments: dict = {} + any_json_fallback = False for m in re.finditer( r"(.*?)\s*(.*?)", args_section, @@ -533,8 +551,24 @@ def _parse_laguna_xs2_tool_calls( arguments[k] = json.loads(v) except (json.JSONDecodeError, ValueError): arguments[k] = v + any_json_fallback = True + + if not name: + status = ToolCallParseStatus.MISSING_NAME + elif any_json_fallback: + status = ToolCallParseStatus.INVALID_JSON + else: + status = ToolCallParseStatus.OK - tool_calls.append({"function": {"name": name, "arguments": arguments}}) + tool_calls.append( + ParsedToolCall( + raw=block_text, + name=name or None, + arguments=arguments, + token_span=span, + status=status, + ) + ) i = tc_end + 1 else: i += 1