diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 50c30450..9780f940 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -996,6 +996,7 @@ async def generate_gen( max_rq_tokens=self.max_rq_tokens, filters=grammar_handler.filters, ) + self.active_job_ids[request_id] = job generated_tokens = 0 full_response = "" @@ -1013,8 +1014,21 @@ async def generate_gen( if chunk: chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk)) full_response += chunk + + # Extract token IDs as a plain list for downstream consumers if isinstance(chunk_tokens, torch.Tensor): + token_id_list = chunk_tokens.flatten().tolist() generated_tokens += chunk_tokens.size(dim=0) + elif isinstance(chunk_tokens, tuple): + first = chunk_tokens[0] + if isinstance(first, torch.Tensor): + token_id_list = first.flatten().tolist() + else: + token_id_list = list(first) + generated_tokens += len(token_id_list) + else: + token_id_list = list(chunk_tokens) + generated_tokens += len(token_id_list) # Increase penalty range to generated token amount # TODO: @@ -1024,6 +1038,7 @@ async def generate_gen( generation = { "request_id": request_id, "text": chunk, + "token_ids": token_id_list, "prompt_tokens": context_len, "generated_tokens": generated_tokens, "offset": len(full_response), @@ -1044,8 +1059,6 @@ async def generate_gen( yield finish_chunk break - # Assign the active job to the request ID - self.active_job_ids[request_id] = job except asyncio.CancelledError: await job.cancel() diff --git a/common/templating.py b/common/templating.py index cc0cceb1..dda06d85 100644 --- a/common/templating.py +++ b/common/templating.py @@ -12,6 +12,7 @@ from jinja2.ext import loopcontrols from jinja2.sandbox import ImmutableSandboxedEnvironment from loguru import logger +from markupsafe import Markup from packaging import version @@ -24,12 +25,17 @@ class TemplateLoadError(Exception): pass +VALID_TOOL_CALL_FORMATS = {"json", "xml", "auto"} + + @dataclass class TemplateMetadata: """Represents the parsed metadata from a template.""" stop_strings: List[str] = field(default_factory=list) tool_start: Optional[str] = None + tool_end: Optional[str] = None + tool_call_format: str = "json" class PromptTemplate: @@ -46,6 +52,22 @@ class PromptTemplate: ) metadata: Optional[TemplateMetadata] = None + @staticmethod + def _tojson_compat(value, indent=None, ensure_ascii=True): + """Compatibility JSON filter for chat templates. + + Some model templates call ``tojson(ensure_ascii=False)`` while the + bundled Jinja filter may not accept that keyword in sandboxed mode. + """ + return Markup( + json.dumps( + value, + indent=indent, + ensure_ascii=ensure_ascii, + separators=(",", ": "), + ) + ) + async def extract_metadata(self, template_vars: dict): """ Returns deserialized template metadata from a chat template. @@ -76,6 +98,22 @@ async def extract_metadata(self, template_vars: dict): if isinstance(template_module.tool_start, str): template_metadata.tool_start = template_module.tool_start + if hasattr(template_module, "tool_end"): + if isinstance(template_module.tool_end, str): + template_metadata.tool_end = template_module.tool_end + + if hasattr(template_module, "tool_call_format"): + fmt = template_module.tool_call_format + if isinstance(fmt, str) and fmt in VALID_TOOL_CALL_FORMATS: + template_metadata.tool_call_format = fmt + logger.debug(f"Template tool_call_format: {fmt}") + else: + logger.warning( + f"Invalid tool_call_format '{fmt}' in template, " + f"defaulting to 'json'. " + f"Valid values: {VALID_TOOL_CALL_FORMATS}" + ) + self.metadata = template_metadata return template_metadata @@ -107,6 +145,7 @@ def raise_exception(message): self.environment.globals["strftime_now"] = strftime_now self.environment.globals["raise_exception"] = raise_exception + self.environment.filters["tojson"] = self._tojson_compat return self.environment.from_string(template_str) diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 52523149..df3a26aa 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -4,7 +4,7 @@ from uuid import uuid4 from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest -from endpoints.OAI.types.tools import ToolSpec, ToolCall +from endpoints.OAI.types.tools import NamedToolChoice, ToolSpec, ToolCall class ChatCompletionLogprob(BaseModel): @@ -71,6 +71,10 @@ class ChatCompletionRequest(CommonCompletionRequest): tools: Optional[List[ToolSpec]] = None functions: Optional[List[Dict]] = None + tool_choice: Optional[ + Union[Literal["none", "auto", "required"], NamedToolChoice] + ] = None + parallel_tool_calls: Optional[bool] = True # Chat completions requests do not have a BOS token preference. Backend # respects the tokenization config for the individual model. diff --git a/endpoints/OAI/types/tools.py b/endpoints/OAI/types/tools.py index b5b9611f..1e572663 100644 --- a/endpoints/OAI/types/tools.py +++ b/endpoints/OAI/types/tools.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from typing import Dict, Literal +from typing import Dict, Literal, Optional from uuid import uuid4 @@ -28,8 +28,28 @@ class Tool(BaseModel): class ToolCall(BaseModel): - """Represents an OAI tool description.""" + """Represents an OAI tool call. + + The ``index`` field is optional so it can be omitted in non-streaming + responses (where OpenAI does not include it) via ``exclude_none=True``, + while being set explicitly for streaming deltas where it is required + by strict validators like the Vercel AI SDK. + """ - id: str = Field(default_factory=lambda: str(uuid4()).replace("-", "")[:9]) + id: str = Field(default_factory=lambda: f"call_{uuid4().hex[:24]}") function: Tool type: Literal["function"] = "function" + index: Optional[int] = None + + +class NamedToolFunction(BaseModel): + """Represents a named function reference for tool_choice.""" + + name: str + + +class NamedToolChoice(BaseModel): + """Represents a named tool choice (forces a specific function call).""" + + function: NamedToolFunction + type: Literal["function"] = "function" diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index b559bb2b..6b82e7fc 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -1,6 +1,7 @@ """Chat completion utilities for OAI server.""" import asyncio +import json import pathlib from asyncio import CancelledError from typing import List, Optional @@ -28,12 +29,32 @@ ChatCompletionStreamChoice, ) from endpoints.OAI.types.common import UsageStats +from endpoints.OAI.types.tools import NamedToolChoice, ToolCall from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA +def _serialize_stream_chunk(chunk) -> str: + """Serialize a streaming chunk with OpenAI-compatible field handling. + + Uses exclude_none=True to strip irrelevant null fields (tool_calls, + tool_call_id, logprobs, usage) while ensuring finish_reason is always + present on each choice (as null when not set), matching OpenAI's + observed streaming behavior. + """ + d = chunk.model_dump(exclude_none=True) + for choice in d.get("choices", []): + if "finish_reason" not in choice: + choice["finish_reason"] = None + return json.dumps(d, ensure_ascii=False) + + def _create_response( - request_id: str, generations: List[dict], model_name: Optional[str] + request_id: str, + generations: List[dict], + model_name: Optional[str], + tool_call_format: str = "json", + tool_choice=None, ): """Create a chat completion response from the provided text.""" @@ -43,9 +64,39 @@ def _create_response( role="assistant", content=unwrap(generation.get("text"), "") ) - tool_calls = generation["tool_calls"] - if tool_calls: - message.tool_calls = ToolCallProcessor.from_json(tool_calls) + tool_calls_raw = generation.get("tool_calls") + if tool_calls_raw: + parsed = ToolCallProcessor.parse(tool_calls_raw, format=tool_call_format) + if parsed and isinstance(tool_choice, NamedToolChoice): + parsed = ToolCallProcessor.filter_by_name( + parsed, tool_choice.function.name + ) + if parsed: + message.tool_calls = parsed + else: + logger.warning( + "Tool call text present but parsing returned no results " + f"(format={tool_call_format})" + ) + + # Fallback: detect bare XML tool calls in content that were not + # caught by the two-pass system (model never emitted tool_start) + if ( + tool_call_format in ("xml", "auto") + and not message.tool_calls + and message.content + and " List[ChatCompletionStreamChunk]: + """Build the OpenAI-standard streaming sequence for tool calls. + + Emits two chunks: + 1. Tool-call chunk: role="assistant", complete tool_calls with + index/id/type/name/arguments (all data in one chunk). + 2. Finish chunk: empty delta, finish_reason="tool_calls". + + Complete arguments are sent in a single chunk rather than streamed + incrementally, which is valid per OpenAI's spec (clients concatenate + argument strings across deltas) and maximizes compatibility with + clients that may not implement multi-chunk tool-call assembly. + + The tool_calls are placed directly into a ChatCompletionMessage + (not a raw dict) so Pydantic validates them as ToolCall objects + with the index field preserved (ToolCall declares index as Optional[int]). + """ + chunk_id = f"chatcmpl-{request_id}" + + # Set index on each tool call for streaming + for idx, tc in enumerate(tool_calls): + tc.index = idx + + # Chunk 1: Complete tool call data + tool_call_message = ChatCompletionMessage( + role="assistant", + tool_calls=tool_calls, + ) + tool_chunk = ChatCompletionStreamChunk( + id=chunk_id, + choices=[ + ChatCompletionStreamChoice( + index=0, + delta=tool_call_message, + finish_reason=None, + ) + ], + model=model_name, + ) + + # Chunk 2: Finish signal + # Use model_construct to prevent Pydantic's smart Union from + # coercing the empty dict {} into ChatCompletionMessage(role="user") + finish_choice = ChatCompletionStreamChoice.model_construct( + index=0, + delta={}, + finish_reason="tool_calls", + logprobs=None, + ) + finish_chunk = ChatCompletionStreamChunk( + id=chunk_id, + choices=[finish_choice], + model=model_name, + ) + + return [tool_chunk, finish_chunk] + + async def _append_template_metadata(data: ChatCompletionRequest, template_vars: dict): """Adding metadata is a one-time process.""" @@ -237,6 +345,24 @@ async def format_messages_with_template( message_dicts.append(message.model_dump(exclude_none=True)) + # Pre-template: convert tool_call arguments from JSON strings to dicts. + # OpenAI-compatible clients (Kilo, Roo, etc.) send arguments as JSON + # strings per the OAI spec, but Qwen3-Coder's template calls + # .items() on arguments which requires a dict/mapping. + for msg in message_dicts: + if msg.get("tool_calls"): + for tc in msg["tool_calls"]: + func = tc.get("function", {}) + args = func.get("arguments") + if isinstance(args, str): + try: + func["arguments"] = json.loads(args) + except (json.JSONDecodeError, ValueError): + logger.warning( + "Failed to parse tool_call arguments JSON " + "string to dict, keeping as string" + ) + # Get all special tokens special_tokens_dict = model.container.get_special_tokens() @@ -319,6 +445,7 @@ async def stream_generate_chat_completion( gen_queue = asyncio.Queue() gen_tasks: List[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start + tool_call_format = model.container.prompt_template.metadata.tool_call_format disconnect_task = asyncio.create_task(request_disconnect_loop(request)) try: @@ -347,13 +474,26 @@ async def stream_generate_chat_completion( # Consumer loop while True: + # Fast path: items already queued — no task overhead + if not gen_queue.empty(): + generation = gen_queue.get_nowait() + else: + # Slow path: queue empty — race get against disconnect + get_task = asyncio.create_task(gen_queue.get()) + done, _ = await asyncio.wait( + [get_task, disconnect_task], + return_when=asyncio.FIRST_COMPLETED, + ) + if disconnect_task in done: + get_task.cancel() + raise CancelledError() + generation = get_task.result() + if disconnect_task.done(): raise CancelledError() - generation = await gen_queue.get() - # Handle options if a tool model is present - if tool_start: + if tool_start and data.tool_choice != "none": if "stop_str" in generation: generations = await generate_tool_calls( prompt, @@ -365,6 +505,50 @@ async def stream_generate_chat_completion( # Only one generation present in this case generation = generations[0] + + # Emit proper three-phase tool-call streaming sequence + if "tool_calls" in generation: + tool_calls_raw = generation["tool_calls"] + parsed = ToolCallProcessor.parse( + tool_calls_raw, format=tool_call_format + ) + if parsed and isinstance(data.tool_choice, NamedToolChoice): + parsed = ToolCallProcessor.filter_by_name( + parsed, data.tool_choice.function.name + ) + if parsed: + for tc_chunk in _build_tool_call_chunks( + parsed, + request.state.id, + model_path.name, + ): + yield _serialize_stream_chunk(tc_chunk) + + # Handle completion and usage after tool calls + if ( + all(task.done() for task in gen_tasks) + and gen_queue.empty() + ): + if ( + data.stream_options + and data.stream_options.include_usage + ): + usage_chunk = _create_stream_chunk( + request.state.id, + generation, + model_path.name, + is_usage_chunk=True, + ) + yield _serialize_stream_chunk(usage_chunk) + + logger.info( + "Finished chat completion streaming " + f"request {request.state.id}" + ) + yield "[DONE]" + break + continue + elif "text" in generation: current_generation_text += generation["text"] @@ -373,9 +557,11 @@ async def stream_generate_chat_completion( raise generation response = _create_stream_chunk( - request.state.id, generation, model_path.name + request.state.id, + generation, + model_path.name, ) - yield response.model_dump_json() + yield _serialize_stream_chunk(response) # Check if all tasks are completed if all(task.done() for task in gen_tasks) and gen_queue.empty(): @@ -387,7 +573,7 @@ async def stream_generate_chat_completion( model_path.name, is_usage_chunk=True, ) - yield usage_chunk.model_dump_json() + yield _serialize_stream_chunk(usage_chunk) logger.info( f"Finished chat completion streaming request {request.state.id}" @@ -398,13 +584,14 @@ async def stream_generate_chat_completion( except CancelledError: # Get out if the request gets disconnected - if not abort_event.is_set(): - abort_event.set() - handle_request_disconnect("Chat completion generation cancelled by user.") + handle_request_disconnect("Chat completion generation cancelled by user.") except Exception: yield get_generator_error( "Chat completion aborted. Please check the server console." ) + finally: + abort_event.set() + disconnect_task.cancel() async def generate_chat_completion( @@ -416,6 +603,7 @@ async def generate_chat_completion( ): gen_tasks: List[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start + tool_call_format = model.container.prompt_template.metadata.tool_call_format try: logger.info(f"Received chat completion request {request.state.id}") @@ -437,12 +625,21 @@ async def generate_chat_completion( generations = await asyncio.gather(*gen_tasks) # Check all the generations and see if a tool call is required - if tool_start: + force_tool_pass = data.tool_choice == "required" or isinstance( + data.tool_choice, NamedToolChoice + ) + if tool_start or force_tool_pass: generations = await generate_tool_calls( prompt, embeddings, data, generations, request ) - response = _create_response(request.state.id, generations, model_path.name) + response = _create_response( + request.state.id, + generations, + model_path.name, + tool_call_format=tool_call_format, + tool_choice=data.tool_choice, + ) logger.info(f"Finished chat completion request {request.state.id}") @@ -467,24 +664,72 @@ async def generate_tool_calls( ): gen_tasks: List[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start + tool_call_format = model.container.prompt_template.metadata.tool_call_format + tool_choice = data.tool_choice + + if tool_choice == "none": + return generations # Tracks which generations asked for a tool call tool_idx: List[int] = [] # Copy to make sure the parent JSON schema doesn't get modified tool_data = data.model_copy(deep=True) - tool_data.json_schema = TOOL_CALL_SCHEMA + + if tool_call_format in ("xml", "auto"): + # XML / auto mode: let the model generate its natural output + # without JSON schema constraint + logger.debug( + f"generate_tool_calls: Using '{tool_call_format}' mode " + f"(no JSON schema constraint)" + ) + + # Remove tool_start from stop strings so the model can emit + # multiple sequential blocks without stopping early + if ( + tool_start + and isinstance(tool_data.stop, list) + and tool_start in tool_data.stop + ): + tool_data.stop = [s for s in tool_data.stop if s != tool_start] + logger.debug( + f"generate_tool_calls: Removed '{tool_start}' from " + f"second-pass stop strings" + ) + else: + # JSON mode: constrained generation (existing behavior) + tool_data.json_schema = TOOL_CALL_SCHEMA for idx, gen in enumerate(generations): - if gen["stop_str"] != tool_start: + stop_str = gen.get("stop_str") + should_generate = stop_str == tool_start + + # Force tool generation if tool_choice requires it + if not should_generate and ( + tool_choice == "required" or isinstance(tool_choice, NamedToolChoice) + ): + should_generate = True + + if not should_generate: continue - logger.info(f"Detected tool call in chat completion request {request.state.id}") + logger.info( + f"Detected tool call in chat completion request " + f"{request.state.id} (format={tool_call_format})" + ) - # Append the existing generation text if present + # Build per-generation prompt (avoid mutating shared prompt) + tool_prompt = prompt precursor_text = gen.get("full_text") if precursor_text: - prompt = prompt + precursor_text + tool_prompt = tool_prompt + precursor_text + + # For XML/auto mode: append tool_start back to prompt. + # The stop string was consumed by the first pass and not included + # in full_text, but the model expects to continue after . + # Include a trailing newline to match the canonical template format. + if tool_call_format in ("xml", "auto") and tool_start: + tool_prompt = tool_prompt + tool_start + "\n" gen_request_id = gen.get("request_id") tool_request_id = f"{gen_request_id}-tool" @@ -493,7 +738,7 @@ async def generate_tool_calls( asyncio.create_task( model.container.generate( tool_request_id, - prompt, + tool_prompt, tool_data, mm_embeddings=embeddings, ) @@ -507,6 +752,12 @@ async def generate_tool_calls( # Map tool calls to their appropriate generation for gen_idx, tool_call in zip(tool_idx, tool_calls, strict=True): - generations[gen_idx]["tool_calls"] = tool_call["text"] + raw_text = tool_call["text"] + + if tool_call_format in ("xml", "auto"): + # Prepend tool_start to reconstruct complete XML for parser + raw_text = tool_start + "\n" + raw_text + + generations[gen_idx]["tool_calls"] = raw_text return generations diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index f66d381d..c11a25bf 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -225,11 +225,24 @@ async def stream_generate_completion( # Consumer loop while True: + # Fast path: items already queued — no task overhead + if not gen_queue.empty(): + generation = gen_queue.get_nowait() + else: + # Slow path: queue empty — race get against disconnect + get_task = asyncio.create_task(gen_queue.get()) + done, _ = await asyncio.wait( + [get_task, disconnect_task], + return_when=asyncio.FIRST_COMPLETED, + ) + if disconnect_task in done: + get_task.cancel() + raise CancelledError() + generation = get_task.result() + if disconnect_task.done(): raise CancelledError() - generation = await gen_queue.get() - # Stream collector will push an exception to the queue if it fails if isinstance(generation, Exception): raise generation @@ -245,15 +258,16 @@ async def stream_generate_completion( except CancelledError: # Get out if the request gets disconnected - if not abort_event.is_set(): - abort_event.set() - handle_request_disconnect( - f"Completion generation {request.state.id} cancelled by user." - ) + handle_request_disconnect( + f"Completion generation {request.state.id} cancelled by user." + ) except Exception: yield get_generator_error( f"Completion {request.state.id} aborted. Please check the server console." ) + finally: + abort_event.set() + disconnect_task.cancel() async def generate_completion( diff --git a/endpoints/OAI/utils/tools.py b/endpoints/OAI/utils/tools.py index c1ebdedf..05eaf143 100644 --- a/endpoints/OAI/utils/tools.py +++ b/endpoints/OAI/utils/tools.py @@ -1,8 +1,11 @@ +"""Tool call processing utilities for OAI server.""" + import json +import re from loguru import logger -from typing import List +from typing import Any, List, Tuple -from endpoints.OAI.types.tools import ToolCall +from endpoints.OAI.types.tools import ToolCall, Tool TOOL_CALL_SCHEMA = { @@ -27,24 +30,480 @@ }, } +# --------------------------------------------------------------------------- +# XML parsing regex patterns +# Derived from vLLM's Qwen3CoderToolParser and the official Qwen parser. +# These handle both complete and partially-closed tags. +# --------------------------------------------------------------------------- + +# Matches complete ... blocks +TOOL_CALL_BLOCK_RE = re.compile( + r"(.*?)", + re.DOTALL, +) + +# Matches BODY blocks +FUNCTION_RE = re.compile( + r"(.*?)", + re.DOTALL, +) + +# Matches VALUE +# Terminates on: , next , or +PARAMETER_RE = re.compile( + r"(.*?)" + r"(?:|(?=)|(?=))", + re.DOTALL, +) + +# Think block patterns +THINK_BLOCK_RE = re.compile(r".*?\s*", re.DOTALL) +THINK_UNCLOSED_RE = re.compile(r"(?!.*).*$", re.DOTALL) + +# Markdown code fence patterns +CODE_FENCE_RE = re.compile(r"^```(?:json)?\s*", re.MULTILINE) +CODE_FENCE_END_RE = re.compile(r"\s*```\s*$", re.MULTILINE) + + +def _strip_think_blocks(text: str) -> str: + """Strip ... blocks from text. + + Handles both complete and unclosed blocks (quantization can cause + the model to never close a think tag). + """ + original = text + + # Complete blocks first + text = THINK_BLOCK_RE.sub("", text) + + # Unclosed block (think started but never closed — strip to end) + text = THINK_UNCLOSED_RE.sub("", text) + + if text != original: + if THINK_UNCLOSED_RE.search(original): + logger.warning( + "XML Parser: Stripped unclosed block " + "(possible quantization degradation)" + ) + else: + logger.debug("XML Parser: Stripped block(s) from output") + + return text + + +def _coerce_param_value(raw: str) -> Any: + """Coerce a raw parameter value string to the appropriate Python type. + + Strategy (safe, no eval()): + 1. Strip leading/trailing newlines (official template emits \\n + after opening tag and before closing tag). + 2. Try json.loads — handles objects, arrays, numbers, bools, null. + 3. Fall back to plain string. + """ + # Strip template-inserted newlines around values + if raw.startswith("\n"): + raw = raw[1:] + if raw.endswith("\n"): + raw = raw[:-1] + + stripped = raw.strip() + + # Empty string + if not stripped: + return "" + + # Try JSON parse (handles objects, arrays, numbers, booleans, null) + try: + return json.loads(stripped) + except (json.JSONDecodeError, ValueError): + pass + + # Fall back to string — never eval() + return stripped + class ToolCallProcessor: + + # ------------------------------------------------------------------ + # JSON normalization helpers + # ------------------------------------------------------------------ + + @staticmethod + def _normalize_tool_calls(raw) -> list: + """Normalize model-emitted tool call payloads into OAI-like objects. + + Accepted forms: + - [{"type":"function","function":{"name":...,"arguments":{...}}}] + - [{"name":...,"arguments":{...}}] + - {"name":...,"arguments":{...}} + """ + if isinstance(raw, dict): + raw = [raw] + if not isinstance(raw, list): + raise ValueError("tool_calls payload is not list/dict") + + normalized: list = [] + for item in raw: + if not isinstance(item, dict): + continue + + if "function" in item and isinstance(item["function"], dict): + fn = item["function"] + name = fn.get("name") + arguments = fn.get("arguments", {}) + else: + name = item.get("name") + arguments = item.get("arguments", {}) + + if name is None: + continue + + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + arguments = {"input": arguments} + + normalized.append( + { + "type": "function", + "function": { + "name": name, + "arguments": arguments if isinstance(arguments, dict) else {}, + }, + } + ) + return normalized + + @staticmethod + def _safe_json_loads(payload: str) -> list: + """Best-effort JSON parse for model-emitted tool payloads. + + Handles: clean JSON, markdown-fenced JSON, JSON substrings in + surrounding text, flat {name, arguments} dicts, and single objects. + """ + # Direct parse + try: + return ToolCallProcessor._normalize_tool_calls(json.loads(payload)) + except (json.JSONDecodeError, ValueError): + pass + + # Clean up common model artifacts (markdown fences, whitespace) + cleaned = payload.strip() + cleaned = CODE_FENCE_RE.sub("", cleaned) + cleaned = CODE_FENCE_END_RE.sub("", cleaned) + cleaned = cleaned.strip() + + # Try cleaned + try: + return ToolCallProcessor._normalize_tool_calls(json.loads(cleaned)) + except (json.JSONDecodeError, ValueError): + pass + + # Find JSON array substring + start = cleaned.find("[") + end = cleaned.rfind("]") + if start != -1 and end != -1 and end > start: + try: + return ToolCallProcessor._normalize_tool_calls( + json.loads(cleaned[start : end + 1]) + ) + except (json.JSONDecodeError, ValueError): + pass + + # Find JSON object substring + obj_start = cleaned.find("{") + obj_end = cleaned.rfind("}") + if obj_start != -1 and obj_end != -1 and obj_end > obj_start: + try: + return ToolCallProcessor._normalize_tool_calls( + json.loads(cleaned[obj_start : obj_end + 1]) + ) + except (json.JSONDecodeError, ValueError): + pass + + raise json.JSONDecodeError( + "Could not extract valid JSON from payload", payload, 0 + ) + + # ------------------------------------------------------------------ + # JSON parsing + # ------------------------------------------------------------------ + @staticmethod def from_json(tool_calls_str: str) -> List[ToolCall]: - """Postprocess tool call JSON to a parseable class""" + """Postprocess tool call JSON to a parseable class. - tool_calls = json.loads(tool_calls_str) + Handles clean JSON arrays, markdown-fenced output, flat dicts, + and other common model output variations via _safe_json_loads. + """ + logger.debug(f"JSON Parser: Parsing tool calls ({len(tool_calls_str)} chars)") + + tool_calls = ToolCallProcessor._safe_json_loads(tool_calls_str) for tool_call in tool_calls: tool_call["function"]["arguments"] = json.dumps( tool_call["function"]["arguments"] ) - return [ToolCall(**tool_call) for tool_call in tool_calls] + result = [ToolCall(**tool_call) for tool_call in tool_calls] + logger.debug(f"JSON Parser: Successfully parsed {len(result)} tool call(s)") + return result + + # ------------------------------------------------------------------ + # XML parsing (Qwen3-Coder / GLM-4.5 style) + # ------------------------------------------------------------------ @staticmethod - def dump(tool_calls: List[ToolCall]) -> List[dict]: + def from_xml(raw_text: str) -> List[ToolCall]: + """Parse Qwen3-Coder XML-format tool calls into ToolCall objects. + + Handles: + - Wrapped: ... + - Bare: ... (missing wrapper) + - Multiple sequential tool call blocks + - blocks (stripped) + - Multi-line parameter values + - Missing closing tags + """ + logger.debug(f"XML Parser: Parsing tool calls ({len(raw_text)} chars)") + + # Stage 1: Strip think blocks + text = _strip_think_blocks(raw_text) + + # Stage 2: Check for incomplete XML at end (generation cutoff) + stripped_end = text.rstrip() + if stripped_end.endswith(("<", "]*$", "", text) + + # Stage 3: Extract function blocks + # First, find all wrapped ... blocks + wrapped_positions = [ + (m.start(), m.end()) for m in TOOL_CALL_BLOCK_RE.finditer(text) + ] + + # Collect function blocks from inside wrapped regions + function_blocks = [] + for match in TOOL_CALL_BLOCK_RE.finditer(text): + inner = match.group(1) + for func_match in FUNCTION_RE.finditer(inner): + function_blocks.append((func_match.group(1), func_match.group(2))) + + # Find bare blocks NOT inside any wrapped region + for func_match in FUNCTION_RE.finditer(text): + pos = func_match.start() + is_wrapped = any(start <= pos < end for start, end in wrapped_positions) + if not is_wrapped: + logger.debug( + "XML Parser: Found bare block without " + " wrapper" + ) + function_blocks.append((func_match.group(1), func_match.group(2))) + + if not function_blocks: + logger.warning("XML Parser: No blocks found") + return [] + + # Stage 4: Parse each function block into a ToolCall + tool_calls = [] + for func_name_raw, func_body in function_blocks: + func_name = func_name_raw.strip() + + # Extract parameters + params = {} + for param_match in PARAMETER_RE.finditer(func_body): + key = param_match.group(1).strip() + value_raw = param_match.group(2) + value = _coerce_param_value(value_raw) + params[key] = value + + arguments_json = json.dumps(params, ensure_ascii=False) + + tool_call = ToolCall( + function=Tool(name=func_name, arguments=arguments_json) + ) + tool_calls.append(tool_call) + + logger.debug(f"XML Parser: Successfully parsed {len(tool_calls)} tool call(s)") + return tool_calls + + # ------------------------------------------------------------------ + # Auto-detect parsing (JSON → JSON-in-tool_call → XML) + # ------------------------------------------------------------------ + + @staticmethod + def from_auto(raw_text: str) -> List[ToolCall]: + """Auto-detect format and parse. + + Tries in order: + 1. Pure JSON (standard TabbyAPI / Llama) + 2. JSON inside wrappers (Qwen3-Instruct style) + 3. XML with tags (Qwen3-Coder style) """ - Convert ToolCall objects to a list of dictionaries. + logger.debug("Auto Parser: Attempting format auto-detection") + + # Attempt 1: Pure JSON array + try: + result = ToolCallProcessor.from_json(raw_text) + logger.debug("Auto Parser: Detected JSON format") + return result + except (json.JSONDecodeError, ValueError, KeyError) as e: + logger.debug(f"Auto Parser: Not JSON ({e}), trying next format") + + # Attempt 2: JSON inside wrappers (Qwen3-Instruct) + try: + all_tool_calls = [] + for match in TOOL_CALL_BLOCK_RE.finditer(raw_text): + inner = match.group(1).strip() + if inner.startswith("{") or inner.startswith("["): + parsed = json.loads(inner) + if isinstance(parsed, dict): + parsed = [parsed] + if isinstance(parsed, list): + for tc in parsed: + name = tc.get("name", "") + arguments = tc.get("arguments", {}) + if isinstance(arguments, dict): + arguments = json.dumps(arguments) + elif not isinstance(arguments, str): + arguments = json.dumps(arguments) + all_tool_calls.append( + ToolCall(function=Tool(name=name, arguments=arguments)) + ) + if all_tool_calls: + logger.debug( + "Auto Parser: Detected JSON-inside-tool_call " + f"format ({len(all_tool_calls)} call(s))" + ) + return all_tool_calls + except (json.JSONDecodeError, ValueError, KeyError) as e: + logger.debug(f"Auto Parser: Not JSON-in-tool_call ({e}), trying XML") + + # Attempt 3: XML format (Qwen3-Coder style) + result = ToolCallProcessor.from_xml(raw_text) + if result: + logger.debug("Auto Parser: Detected XML format") + else: + logger.warning("Auto Parser: All format detection attempts failed") + return result + + # ------------------------------------------------------------------ + # Dispatcher + # ------------------------------------------------------------------ + + @staticmethod + def parse(tool_calls_str: str, format: str = "json") -> List[ToolCall]: + """Dispatch tool call parsing to the appropriate format handler. + + Args: + tool_calls_str: Raw tool call text from model generation. + format: One of ``"json"``, ``"xml"``, ``"auto"``. + + Returns: + List of parsed ToolCall objects. Empty list on parse failure + (never raises). + """ + try: + if format == "xml": + return ToolCallProcessor.from_xml(tool_calls_str) + elif format == "auto": + return ToolCallProcessor.from_auto(tool_calls_str) + else: + return ToolCallProcessor.from_json(tool_calls_str) + except Exception as e: + logger.error( + f"ToolCallProcessor.parse: Failed to parse tool calls " + f"(format={format}): {e}" + ) + return [] + + # ------------------------------------------------------------------ + # Filtering + # ------------------------------------------------------------------ + + @staticmethod + def filter_by_name( + tool_calls: List[ToolCall], function_name: str + ) -> List[ToolCall]: + """Filter parsed tool calls to only those matching a function name.""" + filtered = [tc for tc in tool_calls if tc.function.name == function_name] + if not filtered: + logger.warning( + f"filter_by_name: No tool calls matched '{function_name}' " + f"(had {len(tool_calls)} call(s))" + ) + return filtered + + # ------------------------------------------------------------------ + # Content / tool-call separation + # ------------------------------------------------------------------ + + @staticmethod + def extract_content_and_tools( + raw_text: str, + ) -> Tuple[str, List[ToolCall]]: + """Separate plain text content from XML tool call blocks. + + Used when the model mixes reasoning text with tool calls, e.g.: + ``"I'll help with that: ...`` + + Returns: + Tuple of (remaining_content, tool_calls). + """ + text = _strip_think_blocks(raw_text) + + # Collect all XML regions to exclude from content + xml_regions = [] + + # Wrapped tool call blocks + for match in TOOL_CALL_BLOCK_RE.finditer(text): + xml_regions.append((match.start(), match.end())) + + # Bare function blocks not inside wrappers + for match in FUNCTION_RE.finditer(text): + pos = match.start() + is_wrapped = any(start <= pos < end for start, end in xml_regions) + if not is_wrapped: + xml_regions.append((match.start(), match.end())) + + # Sort and extract content (everything outside XML regions) + xml_regions.sort() + content_parts = [] + last_end = 0 + for start, end in xml_regions: + if start > last_end: + part = text[last_end:start].strip() + if part: + content_parts.append(part) + last_end = end + if last_end < len(text): + part = text[last_end:].strip() + if part: + content_parts.append(part) + + content = " ".join(content_parts).strip() + + # Parse tool calls from the full text + tool_calls = ToolCallProcessor.from_xml(text) + + logger.debug( + f"extract_content_and_tools: Found {len(tool_calls)} tool " + f"call(s), content={'yes' if content else 'no'} " + f"({len(content)} chars)" + ) + + return content, tool_calls + + # ------------------------------------------------------------------ + # Serialisation helpers (unchanged from original) + # ------------------------------------------------------------------ + + @staticmethod + def dump(tool_calls: List[ToolCall]) -> List[dict]: + """Convert ToolCall objects to a list of dictionaries. Args: tool_calls (List[ToolCall]): List of ToolCall objects to convert @@ -65,8 +524,7 @@ def dump(tool_calls: List[ToolCall]) -> List[dict]: @staticmethod def to_json(tool_calls: List[ToolCall]) -> str: - """ - Convert ToolCall objects to JSON string representation. + """Convert ToolCall objects to JSON string representation. Args: tool_calls (List[ToolCall]): List of ToolCall objects to convert diff --git a/templates/tool_calls/qwen3_coder.jinja b/templates/tool_calls/qwen3_coder.jinja new file mode 100644 index 00000000..15272747 --- /dev/null +++ b/templates/tool_calls/qwen3_coder.jinja @@ -0,0 +1,123 @@ +{# TabbyAPI Metadata #} +{%- set tool_call_format = "xml" -%} +{%- set tool_start = "" -%} +{%- set tool_end = "" -%} +{%- set stop_strings = ["<|im_start|>", "<|im_end|>"] -%} + +{% macro render_extra_keys(json_dict, handled_keys) %} + {%- if json_dict is mapping %} + {%- for json_key in json_dict if json_key not in handled_keys %} + {%- if json_dict[json_key] is string %} + {{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '' }} + {%- else %} + {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '' }} + {%- endif %} + {%- endfor %} + {%- endif %} +{%- endmacro %} + +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{%- if not tools is defined %} + {%- set tools = [] %} +{%- endif %} + +{%- if system_message is defined %} + {{- "<|im_start|>system\n" + system_message }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }} + {%- endif %} +{%- endif %} +{%- if tools is iterable and tools | length > 0 %} + {{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }} + {{- "" }} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- "\n\n" ~ tool.name ~ "" }} + {%- if tool.description is defined %} + {{- '\n' ~ (tool.description | trim) ~ '' }} + {%- endif %} + {{- '\n' }} + {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- '\n' }} + {{- '\n' ~ param_name ~ '' }} + {%- if param_fields.type is defined %} + {{- '\n' ~ (param_fields.type | string) ~ '' }} + {%- endif %} + {%- if param_fields.description is defined %} + {{- '\n' ~ (param_fields.description | trim) ~ '' }} + {%- endif %} + {%- set handled_keys = ['name', 'type', 'description'] %} + {{- render_extra_keys(param_fields, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {%- set handled_keys = ['type', 'properties'] %} + {{- render_extra_keys(tool.parameters, handled_keys) }} + {{- '\n' }} + {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %} + {{- render_extra_keys(tool, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} +{%- endif %} +{%- if system_message is defined %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in loop_messages %} + {%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %} + {{- '<|im_start|>' + message.role }} + {%- if message.content is defined and message.content is string and message.content | trim | length > 0 %} + {{- '\n' + message.content | trim + '\n' }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n\n' }} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value if args_value is string else args_value | tojson | safe %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} \ No newline at end of file