diff --git a/renderers/__init__.py b/renderers/__init__.py
index a169186..62bc666 100644
--- a/renderers/__init__.py
+++ b/renderers/__init__.py
@@ -43,6 +43,7 @@
from renderers.gpt_oss import GptOssRenderer
from renderers.kimi_k2 import KimiK2Renderer
from renderers.kimi_k25 import KimiK25Renderer
+from renderers.laguna_xs2 import LagunaXS2Renderer
from renderers.minimax_m2 import MiniMaxM2Renderer
from renderers.nemotron3 import Nemotron3Renderer
from renderers.qwen3 import Qwen3Renderer
@@ -61,6 +62,7 @@
"ImagePart",
"KimiK2Renderer",
"KimiK25Renderer",
+ "LagunaXS2Renderer",
"MULTIMODAL_MODELS",
"Message",
"MiniMaxM2Renderer",
diff --git a/renderers/base.py b/renderers/base.py
index 64e9f1c..2b74e0c 100644
--- a/renderers/base.py
+++ b/renderers/base.py
@@ -603,6 +603,8 @@ def bridge_to_next_turn(self, *args: Any, **kwargs: Any) -> "RenderedTokens | No
# Nemotron 3.
"nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16": "nemotron-3",
"nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16": "nemotron-3",
+ # Poolside Laguna.
+ "poolside/Laguna-XS.2": "laguna-xs.2",
# GPT-OSS.
"openai/gpt-oss-20b": "gpt-oss",
"openai/gpt-oss-120b": "gpt-oss",
@@ -739,6 +741,7 @@ def _populate_registry():
from renderers.gpt_oss import GptOssRenderer
from renderers.kimi_k2 import KimiK2Renderer
from renderers.kimi_k25 import KimiK25Renderer
+ from renderers.laguna_xs2 import LagunaXS2Renderer
from renderers.minimax_m2 import MiniMaxM2Renderer
from renderers.nemotron3 import Nemotron3Renderer
from renderers.qwen3 import Qwen3Renderer
@@ -760,6 +763,7 @@ def _populate_registry():
"deepseek-v3": DeepSeekV3Renderer,
"kimi-k2": KimiK2Renderer,
"kimi-k2.5": KimiK25Renderer,
+ "laguna-xs.2": LagunaXS2Renderer,
"nemotron-3": Nemotron3Renderer,
"gpt-oss": GptOssRenderer,
}
@@ -824,8 +828,8 @@ def create_renderer(
tokenizer: HuggingFace tokenizer instance.
renderer: Renderer name ('qwen3', 'qwen3-vl', 'qwen3.5', 'qwen3.6',
'glm-5', 'glm-5.1', 'glm-4.5', 'minimax-m2', 'deepseek-v3',
- 'kimi-k2', 'kimi-k2.5', 'nemotron-3', 'gpt-oss', 'default')
- or 'auto' to detect from model name.
+ 'kimi-k2', 'kimi-k2.5', 'laguna-xs.2', 'nemotron-3',
+ 'gpt-oss', 'default') or 'auto' to detect from model name.
tool_parser: Name of a tool parser registered in ``renderers.parsers``.
Only consumed by DefaultRenderer. Model-specific renderers
have their own parsing wired in.
diff --git a/renderers/laguna_xs2.py b/renderers/laguna_xs2.py
new file mode 100644
index 0000000..ee8e64d
--- /dev/null
+++ b/renderers/laguna_xs2.py
@@ -0,0 +1,375 @@
+"""Laguna-XS.2 Renderer.
+
+Main properties:
+- Prefix is the single token ``〈|EOS|〉`` (also the EOS / stop token).
+- Role markers are block-style: ``...``, ``...``,
+ ``...``, ``...``. Of
+ these, only ```` / ```` are single (added) tokens
+ in the tokenizer; everything else is plain text and BPEs into multiple
+ subwords.
+- Assistant turn has an explicit close: ```` is the canonical
+ stop token (alongside ``〈|EOS|〉``).
+- Tool calls: ```` / ```` ARE single tokens, but the
+ inner ```` / ```` / ```` / ````
+ markers are plain text — parsed via regex on the decoded inner block.
+- The template bakes in a default system prompt when ``messages[0]`` is not
+ a system message. The system block also contains the tools section (under
+ a ``### Tools`` header with an ```` listing and prose
+ format instructions that vary on ``enable_thinking``).
+- Reasoning is rendered for every assistant message — no last-user-index
+ gating. ``preserve_all_thinking`` and
+ ``preserve_thinking_between_tool_calls`` are accepted for protocol
+ uniformity but are effectively no-ops since past reasoning is preserved
+ by default.
+"""
+
+from __future__ import annotations
+
+import json
+
+from transformers.tokenization_utils import PreTrainedTokenizer
+
+from renderers.base import (
+ Content,
+ Message,
+ ParsedResponse,
+ RenderedTokens,
+ ToolSpec,
+ reject_assistant_in_extension,
+)
+from renderers.parsing import parse_laguna_xs2
+
+_DEFAULT_SYSTEM_MESSAGE = (
+ "You are a helpful, conversationally-fluent assistant made by Poolside. "
+ "You are here to be helpful to users through natural language conversations."
+)
+
+_TOOLS_HEADER = (
+ "\n\n### Tools\n\n"
+ "You may call functions to assist with the user query.\n"
+ "All available function signatures are listed below:\n"
+ "\n"
+)
+
+_TOOLS_FOOTER_THINKING = (
+ "\n\n"
+ "Wrap your thinking in '', '' tags, followed by a function call. "
+ "For each function call, return an unescaped XML-like object with function name "
+ "and arguments within '' and '' tags, like here:\n"
+ " your thoughts here \n"
+ "function-name\n"
+ "argument-key\n"
+ "value-of-argument-key\n"
+ ""
+)
+
+_TOOLS_FOOTER_NO_THINKING = (
+ "\n\n"
+ "For each function call, return an unescaped XML-like object with function name "
+ "and arguments within '' and '' tags, like here:\n"
+ "function-name\n"
+ "argument-key\n"
+ "value-of-argument-key\n"
+ ""
+)
+
+
+class LagunaXS2Renderer:
+ def __init__(
+ self,
+ tokenizer: PreTrainedTokenizer,
+ *,
+ enable_thinking: bool = False,
+ preserve_all_thinking: bool = False,
+ preserve_thinking_between_tool_calls: bool = False,
+ ):
+ self._tokenizer = tokenizer
+ self._enable_thinking = enable_thinking
+ # Accepted for protocol uniformity. The chat template renders
+ # reasoning on every assistant message regardless, so flipping
+ # these flags has no effect on the byte-level output.
+ self._preserve_all_thinking = preserve_all_thinking
+ self._preserve_thinking_between_tool_calls = (
+ preserve_thinking_between_tool_calls
+ )
+
+ self._eos = self._token_id("〈|EOS|〉")
+ self._think = self._token_id("")
+ self._think_end = self._token_id("")
+ self._assistant = self._token_id("")
+ self._assistant_end = self._token_id("")
+ self._tool_call = self._token_id("")
+ self._tool_call_end = self._token_id("")
+
+ def _token_id(self, token: str) -> int:
+ tid = self._tokenizer.convert_tokens_to_ids(token)
+ assert isinstance(tid, int) and tid != self._tokenizer.unk_token_id, (
+ f"Special token {token!r} not found in tokenizer vocabulary"
+ )
+ return tid
+
+ def _encode(self, text: str) -> list[int]:
+ if not text:
+ return []
+ return self._tokenizer.encode(text, add_special_tokens=False)
+
+ @staticmethod
+ def _visible_text(content: Content | None) -> str:
+ if isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ parts: list[str] = []
+ for item in content:
+ if isinstance(item, dict) and item.get("type") == "text":
+ parts.append(item.get("text", ""))
+ return "".join(parts)
+ return ""
+
+ @staticmethod
+ def _thinking_text(content: Content | None) -> str:
+ """Concatenate ``ThinkingPart`` entries from list-form content.
+
+ Used as a reasoning source in ``_render_assistant`` when neither
+ ``reasoning`` nor ``reasoning_content`` is present on the message.
+ Returns ``""`` for any non-list input.
+ """
+ if not isinstance(content, list):
+ return ""
+ parts: list[str] = []
+ for item in content:
+ if isinstance(item, dict) and item.get("type") == "thinking":
+ parts.append(item.get("thinking", ""))
+ return "".join(parts)
+
+ def render(
+ self,
+ messages: list[Message],
+ *,
+ tools: list[ToolSpec] | None = None,
+ add_generation_prompt: bool = False,
+ ) -> RenderedTokens:
+ if not messages:
+ raise ValueError("No messages provided.")
+
+ tokens: list[int] = []
+ indices: list[int] = []
+
+ def emit_special(token_id: int, msg_idx: int) -> None:
+ tokens.append(token_id)
+ indices.append(msg_idx)
+
+ def emit_text(text: str, msg_idx: int) -> None:
+ ids = self._encode(text)
+ tokens.extend(ids)
+ indices.extend([msg_idx] * len(ids))
+
+ emit_special(self._eos, -1)
+
+ # ── System header (absorbs messages[0] if it's a system message) ──
+ system_content = _DEFAULT_SYSTEM_MESSAGE
+ system_msg_idx = -1
+ if messages and messages[0].get("role") == "system":
+ system_content = self._visible_text(messages[0].get("content"))
+ system_msg_idx = 0
+
+ has_sys_content = bool(system_content and system_content.strip())
+ # Mirrors the template's ``(system_message and system_message.strip()) or tools``
+ # gate: when the caller passes an empty system message and no tools,
+ # the whole ``...`` block is omitted.
+ if has_sys_content or tools:
+ # The template emits ``\n`` then conditionally a second
+ # ``\n``. Bundle those into one emit so BPE merges ``\n\n`` into
+ # its single-token form (rather than two ``\n`` atoms).
+ emit_text("\n\n" if has_sys_content else "\n", -1)
+ if has_sys_content:
+ emit_text(system_content.rstrip(), system_msg_idx)
+ if tools:
+ tool_text = _TOOLS_HEADER
+ for tool in tools:
+ tool_text += json.dumps(tool, ensure_ascii=False) + "\n"
+ tool_text += (
+ _TOOLS_FOOTER_THINKING
+ if self._enable_thinking
+ else _TOOLS_FOOTER_NO_THINKING
+ )
+ emit_text(tool_text, -1)
+ emit_text("\n\n", -1)
+
+ # ── Per-message loop ──────────────────────────────────────────
+ for i, msg in enumerate(messages):
+ content = self._visible_text(msg.get("content"))
+
+ match msg["role"]:
+ case "system":
+ # Already consumed in the header block.
+ if i == 0:
+ continue
+ emit_text("\n" + content + "\n\n", i)
+ case "user":
+ emit_text("\n" + content + "\n\n", i)
+ case "assistant":
+ self._render_assistant(
+ msg, i, content, emit_special=emit_special, emit_text=emit_text
+ )
+ case "tool":
+ emit_text("\n" + content + "\n\n", i)
+
+ # ── Generation prompt ─────────────────────────────────────────
+ if add_generation_prompt:
+ emit_special(self._assistant, -1)
+ emit_text("\n", -1)
+ if self._enable_thinking:
+ emit_special(self._think, -1)
+ else:
+ emit_special(self._think_end, -1)
+
+ return RenderedTokens(token_ids=tokens, message_indices=indices)
+
+ def render_ids(
+ self,
+ messages: list[Message],
+ *,
+ tools: list[ToolSpec] | None = None,
+ add_generation_prompt: bool = False,
+ ) -> list[int]:
+ return self.render(
+ messages,
+ tools=tools,
+ add_generation_prompt=add_generation_prompt,
+ ).token_ids
+
+ def parse_response(self, token_ids: list[int]) -> ParsedResponse:
+ return parse_laguna_xs2(
+ self._tokenizer,
+ token_ids,
+ stop_ids={self._assistant_end, self._eos},
+ think_id=self._think,
+ think_end_id=self._think_end,
+ tool_call_id=self._tool_call,
+ tool_call_end_id=self._tool_call_end,
+ )
+
+ def get_stop_token_ids(self) -> list[int]:
+ return [self._assistant_end, self._eos]
+
+ def bridge_to_next_turn(
+ self,
+ previous_prompt_ids: list[int],
+ previous_completion_ids: list[int],
+ new_messages: list[Message],
+ *,
+ tools: list[ToolSpec] | None = None,
+ ) -> RenderedTokens | None:
+ if (
+ not previous_prompt_ids
+ or not new_messages
+ or reject_assistant_in_extension(new_messages)
+ ):
+ return None
+
+ # The canonical assistant-turn close is ````. ``〈|EOS|〉``
+ # also stops generation; either being the final token means the turn
+ # ended cleanly. Truncation (no stop token at the tail) synthesises
+ # ``\n`` — the same scaffold the template emits.
+ previous_ids = list(previous_prompt_ids) + list(previous_completion_ids)
+ stop_ids = {self._assistant_end, self._eos}
+ if (
+ not previous_ids[len(previous_prompt_ids) :]
+ or previous_ids[-1] not in stop_ids
+ ):
+ previous_ids.append(self._assistant_end)
+ previous_ids.extend(self._encode("\n"))
+
+ ext: list[int] = []
+
+ def emit_special(token_id: int, _msg_idx: int = -1) -> None:
+ ext.append(token_id)
+
+ def emit_text(text: str, _msg_idx: int = -1) -> None:
+ ext.extend(self._encode(text))
+
+ for msg in new_messages:
+ role = msg.get("role")
+ content = self._visible_text(msg.get("content"))
+ if role == "user":
+ emit_text("\n" + content + "\n\n")
+ elif role == "system":
+ emit_text("\n" + content + "\n\n")
+ elif role == "tool":
+ emit_text("\n" + content + "\n\n")
+ else:
+ return None
+
+ emit_special(self._assistant)
+ emit_text("\n")
+ if self._enable_thinking:
+ emit_special(self._think)
+ else:
+ emit_special(self._think_end)
+
+ return RenderedTokens(token_ids=previous_ids + ext)
+
+ def _render_assistant(
+ self,
+ msg: Message,
+ msg_idx: int,
+ content: str,
+ *,
+ emit_special,
+ emit_text,
+ ) -> None:
+ reasoning_content = ""
+ if isinstance(msg.get("reasoning_content"), str):
+ reasoning_content = msg["reasoning_content"]
+ else:
+ # When the caller stores reasoning as a ``ThinkingPart`` inside
+ # a list-form ``content`` (e.g. after parse_response →
+ # reserialize), pull it out here so it survives the re-render.
+ part_thinking = self._thinking_text(msg.get("content"))
+ if part_thinking:
+ reasoning_content = part_thinking
+
+ emit_special(self._assistant, msg_idx)
+ emit_text("\n", msg_idx)
+
+ if reasoning_content:
+ emit_special(self._think, msg_idx)
+ emit_text("\n" + reasoning_content.strip() + "\n", msg_idx)
+ emit_special(self._think_end, msg_idx)
+ else:
+ emit_special(self._think_end, msg_idx)
+
+ # Combined newline-after- with optional content. Bundling
+ # preserves BPE merges across the boundary.
+ post_think_text = "\n"
+ if content.strip():
+ post_think_text += content.strip() + "\n"
+ emit_text(post_think_text, msg_idx)
+
+ tool_calls = msg.get("tool_calls") or []
+ for tc in tool_calls:
+ func = tc.get("function") or tc
+ name = func.get("name", "")
+ arguments = func.get("arguments", {})
+ if isinstance(arguments, str):
+ try:
+ arguments = json.loads(arguments)
+ except json.JSONDecodeError:
+ arguments = {}
+
+ emit_special(self._tool_call, msg_idx)
+ inner = name + "\n"
+ if isinstance(arguments, dict):
+ for k, v in arguments.items():
+ inner += "" + k + "\n"
+ if isinstance(v, str):
+ val_text = v
+ else:
+ val_text = json.dumps(v, ensure_ascii=False)
+ inner += "" + val_text + "\n"
+ emit_text(inner, msg_idx)
+ emit_special(self._tool_call_end, msg_idx)
+ emit_text("\n", msg_idx)
+
+ emit_special(self._assistant_end, msg_idx)
+ emit_text("\n", msg_idx)
diff --git a/renderers/parsing.py b/renderers/parsing.py
index 827cc81..e235c76 100644
--- a/renderers/parsing.py
+++ b/renderers/parsing.py
@@ -424,6 +424,157 @@ def _parse_glm_tool_calls(
return tool_calls
+# ── Laguna-XS.2: name\nk\nv
+# Same outer skeleton as parse_glm, but / are plain text
+# (multi-token BPE), not single special tokens — so the inner block is decoded
+# to text and the key/value pairs are pulled out by regex.
+
+
+def parse_laguna_xs2(
+ tokenizer,
+ token_ids: list[int],
+ *,
+ stop_ids: set[int],
+ think_id: int,
+ think_end_id: int,
+ tool_call_id: int,
+ tool_call_end_id: int,
+) -> ParsedResponse:
+ """Parse Laguna-XS.2 completion tokens.
+
+ Thinking uses single-token ```` / ```` (ids found by
+ scan). Tool calls are delimited by single-token ```` /
+ ````, but ```` / ```` inside are
+ plain text — regex-extracted from the decoded inner block.
+ """
+ ids = _strip_stop_tokens(token_ids, stop_ids)
+
+ # The template wraps reasoning with ``\n`` on both sides
+ # (``\n{r}\n``) and brackets post-think content with ``\n``
+ # too (``\n{c}\n``). Strip exactly those newlines from each
+ # decoded segment — never a bare ``.strip()``, which would also eat
+ # whitespace the model emitted intentionally.
+ reasoning = None
+ parse_offset = 0
+ think_end = _find(ids, think_end_id)
+ if think_end != -1:
+ reasoning_ids = ids[:think_end]
+ reasoning_ids = [t for t in reasoning_ids if t != think_id]
+ reasoning = _decode(tokenizer, reasoning_ids).strip("\n")
+ ids = ids[think_end + 1 :]
+ parse_offset = think_end + 1
+ elif (think_start := _find(ids, think_id)) != -1:
+ reasoning = _decode(tokenizer, ids[think_start + 1 :]).strip("\n")
+ return ParsedResponse(
+ content="", reasoning_content=reasoning or None, tool_calls=[]
+ )
+
+ tc_start = _find(ids, tool_call_id)
+ tool_calls: list[ParsedToolCall] = []
+ if tc_start != -1:
+ content_text = _decode(tokenizer, ids[:tc_start]).strip("\n")
+ tool_calls = _parse_laguna_xs2_tool_calls(
+ tokenizer,
+ ids[tc_start:],
+ tool_call_id,
+ tool_call_end_id,
+ section_offset=parse_offset + tc_start,
+ )
+ else:
+ content_text = _decode(tokenizer, ids).strip("\n")
+
+ return ParsedResponse(
+ content=content_text,
+ reasoning_content=reasoning or None,
+ tool_calls=tool_calls,
+ )
+
+
+def _parse_laguna_xs2_tool_calls(
+ tokenizer,
+ ids: list[int],
+ tc_id: int,
+ tc_end_id: int,
+ *,
+ section_offset: int,
+) -> list[ParsedToolCall]:
+ """Parse Laguna-XS.2 tool calls.
+
+ Inside each ``...`` block, the format is::
+
+ {name}\\n
+ {k1}\\n{v1}\\n
+ ...
+ {kn}\\n{vn}\\n
+
+ The function name is everything before the first ```` literal
+ in the decoded block.
+ """
+ import re
+
+ tool_calls: list[ParsedToolCall] = []
+ i = 0
+ while i < len(ids):
+ if ids[i] == tc_id:
+ tc_end = _find(ids, tc_end_id, i + 1)
+ if tc_end == -1:
+ raw = _decode(tokenizer, ids[i + 1 :])
+ tool_calls.append(
+ ParsedToolCall(
+ raw=raw,
+ token_span=(section_offset + i, section_offset + len(ids)),
+ status=ToolCallParseStatus.UNCLOSED_BLOCK,
+ )
+ )
+ break
+ block_text = _decode(tokenizer, ids[i + 1 : tc_end])
+ span = (section_offset + i, section_offset + tc_end + 1)
+
+ ak_pos = block_text.find("")
+ if ak_pos != -1:
+ name = block_text[:ak_pos].strip()
+ args_section = block_text[ak_pos:]
+ else:
+ name = block_text.strip()
+ args_section = ""
+
+ arguments: dict = {}
+ any_json_fallback = False
+ for m in re.finditer(
+ r"(.*?)\s*(.*?)",
+ args_section,
+ re.DOTALL,
+ ):
+ k = m.group(1).strip()
+ v = m.group(2).strip()
+ try:
+ arguments[k] = json.loads(v)
+ except (json.JSONDecodeError, ValueError):
+ arguments[k] = v
+ any_json_fallback = True
+
+ if not name:
+ status = ToolCallParseStatus.MISSING_NAME
+ elif any_json_fallback:
+ status = ToolCallParseStatus.INVALID_JSON
+ else:
+ status = ToolCallParseStatus.OK
+
+ tool_calls.append(
+ ParsedToolCall(
+ raw=block_text,
+ name=name or None,
+ arguments=arguments,
+ token_span=span,
+ status=status,
+ )
+ )
+ i = tc_end + 1
+ else:
+ i += 1
+ return tool_calls
+
+
# ── DeepSeek V3: <|tool▁calls▁begin|>...<|tool▁calls▁end|> + text tags ──
diff --git a/tests/conftest.py b/tests/conftest.py
index 5ef65a5..8eea97b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -32,6 +32,7 @@
("moonshotai/Kimi-K2.6", "auto"),
("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", "auto"),
("nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16", "auto"),
+ ("poolside/Laguna-XS.2", "auto"),
("openai/gpt-oss-20b", "gpt-oss"),
("Qwen/Qwen2.5-0.5B-Instruct", "default"),
]
diff --git a/tests/test_preserve_thinking.py b/tests/test_preserve_thinking.py
index b0d6a9e..661d577 100644
--- a/tests/test_preserve_thinking.py
+++ b/tests/test_preserve_thinking.py
@@ -42,6 +42,7 @@ def _make(tokenizer, renderer_name, **flags):
"Qwen/Qwen3-VL-4B-Instruct",
"Qwen/Qwen3-VL-8B-Instruct",
"Qwen/Qwen3-VL-30B-A3B-Instruct",
+ "poolside/Laguna-XS.2",
}
diff --git a/tests/test_roundtrip.py b/tests/test_roundtrip.py
index 1cdac82..a4577fd 100644
--- a/tests/test_roundtrip.py
+++ b/tests/test_roundtrip.py
@@ -43,6 +43,7 @@
("moonshotai/Kimi-K2.6", "auto"),
("nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16", "auto"),
("nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16", "auto"),
+ ("poolside/Laguna-XS.2", "auto"),
("openai/gpt-oss-20b", "gpt-oss"),
("Qwen/Qwen2.5-0.5B-Instruct", "default"),
]