From c5b65e3508f53bc2e5fe547263a94cdb9fe8a824 Mon Sep 17 00:00:00 2001 From: x1a0y4o Date: Wed, 18 Feb 2026 13:05:07 +0800 Subject: [PATCH] fix: support anthropic stream requests - switch model runner to stream_events parsing for anthropic stream compatibility - only pass parallel_tool_calls=False for anthropic/vertexaianthropic - prioritize stream errors before tool follow-up decisions - add regression coverage for stream event compatibility and error handling --- src/bub/core/model_runner.py | 106 +++++++-- .../skills/telegram/scripts/telegram_send.py | 7 +- src/bub/tape/context.py | 110 ++++++++- tests/test_model_runner.py | 221 +++++++++++++++++- tests/test_tape_context.py | 84 ++++++- 5 files changed, 497 insertions(+), 31 deletions(-) diff --git a/src/bub/core/model_runner.py b/src/bub/core/model_runner.py index 0493fba..248d8e9 100644 --- a/src/bub/core/model_runner.py +++ b/src/bub/core/model_runner.py @@ -6,10 +6,10 @@ import re from collections.abc import Callable from dataclasses import dataclass, field -from typing import ClassVar +from typing import Any, ClassVar from loguru import logger -from republic import Tool, ToolAutoResult +from republic import Tool from bub.core.router import AssistantRouteResult, InputRouter from bub.skills.loader import SkillMetadata @@ -47,6 +47,7 @@ class ModelRunner: """Runs assistant loop over tape with command-aware follow-up handling.""" DEFAULT_HEADERS: ClassVar[dict[str, str]] = {"HTTP-Referer": "https://bub.build/", "X-Title": "Bub"} + SERIAL_TOOL_CALL_PROVIDERS: ClassVar[frozenset[str]] = frozenset({"anthropic", "vertexaianthropic"}) def __init__( self, @@ -166,14 +167,20 @@ async def _chat(self, prompt: str) -> _ChatResult: system_prompt = self._render_system_prompt() try: async with asyncio.timeout(self._model_timeout_seconds): - output = await self._tape.tape.run_tools_async( - prompt=prompt, - system_prompt=system_prompt, - max_tokens=self._max_tokens, - tools=self._tools, - extra_headers=self.DEFAULT_HEADERS, + stream_kwargs: dict[str, Any] = { + "prompt": prompt, + "system_prompt": system_prompt, + "max_tokens": self._max_tokens, + "tools": self._tools, + "extra_headers": self.DEFAULT_HEADERS, + } + if self._needs_serial_tool_calls(): + stream_kwargs["parallel_tool_calls"] = False + + stream = await self._tape.tape.stream_events_async( + **stream_kwargs, ) - return _ChatResult.from_tool_auto(output) + return await self._read_stream_result(stream) except TimeoutError: return _ChatResult( text="", @@ -183,6 +190,31 @@ async def _chat(self, prompt: str) -> _ChatResult: logger.exception("model.call.error") return _ChatResult(text="", error=f"model_call_error: {exc!s}") + def _needs_serial_tool_calls(self) -> bool: + provider, separator, _ = self._model.partition(":") + if not separator: + return False + return provider.casefold() in self.SERIAL_TOOL_CALL_PROVIDERS + + async def _read_stream_result(self, stream: Any) -> _ChatResult: + final_event: dict[str, Any] | None = None + error_event: dict[str, Any] | None = None + async for event in stream: + event_kind = getattr(event, "kind", None) + event_data = getattr(event, "data", None) + if not isinstance(event_data, dict): + continue + if event_kind == "error": + error_event = event_data + elif event_kind == "final": + final_event = event_data + + return _ChatResult.from_stream_events( + final_event=final_event, + stream_error=getattr(stream, "error", None), + error_event=error_event, + ) + def _render_system_prompt(self) -> str: blocks: list[str] = [] if self._base_system_prompt: @@ -223,18 +255,54 @@ class _ChatResult: followup_prompt: str | None = None @classmethod - def from_tool_auto(cls, output: ToolAutoResult) -> _ChatResult: - if output.kind == "text": - return cls(text=output.text or "") - if output.kind == "tools": - return cls(text="", followup_prompt=TOOL_CONTINUE_PROMPT) - - if output.tool_calls or output.tool_results: + def from_stream_events( + cls, + *, + final_event: dict[str, Any] | None, + stream_error: object | None, + error_event: dict[str, Any] | None, + ) -> _ChatResult: + if stream_error is not None: + return cls(text="", error=_format_stream_error(stream_error)) + + if final_event is None: + if error_event is not None: + return cls(text="", error=_format_error_event(error_event)) + return cls(text="", error="stream_events_error: missing final event") + + if final_event.get("ok") is False or error_event is not None: + return cls(text="", error=_format_error_event(error_event)) + + if final_event.get("tool_calls") or final_event.get("tool_results"): return cls(text="", followup_prompt=TOOL_CONTINUE_PROMPT) - if output.error is None: - return cls(text="", error="tool_auto_error: unknown") - return cls(text="", error=f"{output.error.kind.value}: {output.error.message}") + if isinstance(final_text := final_event.get("text"), str): + return cls(text=final_text) + + return cls(text="", error="tool_auto_error: unknown") + + +def _format_stream_error(error: object) -> str: + kind = getattr(error, "kind", None) + message = getattr(error, "message", None) + kind_value = getattr(kind, "value", kind) + if isinstance(kind_value, str) and isinstance(message, str): + return f"{kind_value}: {message}" + if isinstance(message, str): + return message + return str(error) + + +def _format_error_event(error_event: dict[str, Any] | None) -> str: + if error_event is None: + return "tool_auto_error: unknown" + kind = error_event.get("kind") + message = error_event.get("message") + if isinstance(kind, str) and isinstance(message, str): + return f"{kind}: {message}" + if isinstance(message, str): + return message + return "tool_auto_error: unknown" def _runtime_contract() -> str: diff --git a/src/bub/skills/telegram/scripts/telegram_send.py b/src/bub/skills/telegram/scripts/telegram_send.py index bac0f93..371d61c 100755 --- a/src/bub/skills/telegram/scripts/telegram_send.py +++ b/src/bub/skills/telegram/scripts/telegram_send.py @@ -161,15 +161,16 @@ def main(): mention_username = args.source_username # Send messages + # Use ASCII status tags for better PowerShell rendering. try: send_message(bot_token, chat_id, args.message, reply_to, mention_username) - print(f"✅ Message sent successfully to {chat_id} (MarkdownV2)") + print(f"[OK] Message sent successfully to {chat_id} (MarkdownV2)") except requests.HTTPError as e: - print(f"❌ HTTP Error: {e}") + print(f"[ERROR] HTTP Error: {e}") print(f" Response: {e.response.text}") sys.exit(1) except Exception as e: - print(f"❌ Error: {e}") + print(f"[ERROR] {e}") sys.exit(1) diff --git a/src/bub/tape/context.py b/src/bub/tape/context.py index 4bd4538..ac37bc3 100644 --- a/src/bub/tape/context.py +++ b/src/bub/tape/context.py @@ -54,9 +54,15 @@ def _append_tool_result_entry( entry: TapeEntry, ) -> None: results = entry.payload.get("results") - if not isinstance(results, list): + if not isinstance(results, list) or not pending_calls: return - for index, result in enumerate(results): + paired_count = min(len(results), len(pending_calls)) + if paired_count <= 0: + return + if paired_count < len(pending_calls): + _trim_last_tool_call_message(messages, paired_count) + pending_calls = pending_calls[:paired_count] + for index, result in enumerate(results[:paired_count]): messages.append(_build_tool_result_message(result, pending_calls, index)) @@ -82,16 +88,112 @@ def _build_tool_result_message( return message +def _trim_last_tool_call_message(messages: list[dict[str, Any]], count: int) -> None: + if not messages: + return + candidate = messages[-1] + if candidate.get("role") != "assistant": + return + tool_calls = candidate.get("tool_calls") + if not isinstance(tool_calls, list): + return + if count <= 0: + messages.pop() + return + candidate["tool_calls"] = tool_calls[:count] + + def _normalize_tool_calls(value: object) -> list[dict[str, Any]]: if not isinstance(value, list): return [] calls: list[dict[str, Any]] = [] for item in value: - if isinstance(item, dict): - calls.append(dict(item)) + calls.extend(_normalize_tool_call(item)) + return calls + + +def _normalize_tool_call(item: object) -> list[dict[str, Any]]: + if not isinstance(item, dict): + return [] + + normalized = dict(item) + function = normalized.get("function") + if not isinstance(function, dict): + return [] + + name = function.get("name") + if not isinstance(name, str) or not name: + return [] + + raw_arguments = function.get("arguments") + argument_chunks = _normalize_tool_arguments(raw_arguments) + if not argument_chunks: + return [] + + call_id = normalized.get("id") + calls: list[dict[str, Any]] = [] + for index, arguments in enumerate(argument_chunks): + cloned = dict(normalized) + cloned_function = dict(function) + cloned_function["arguments"] = arguments + cloned["function"] = cloned_function + if isinstance(call_id, str) and call_id and index > 0: + cloned["id"] = f"{call_id}__{index + 1}" + calls.append(cloned) return calls +def _normalize_tool_arguments(value: object) -> list[str]: + if isinstance(value, dict): + return [json.dumps(value, ensure_ascii=False)] + if not isinstance(value, str): + return [] + + raw = value.strip() + if not raw: + return [] + + parsed = _parse_json_object(raw) + if parsed is not None: + return [raw] + + chunks = _split_json_objects(raw) + if len(chunks) <= 1: + return [] + return chunks + + +def _parse_json_object(raw: str) -> dict[str, Any] | None: + try: + parsed = json.loads(raw) + except json.JSONDecodeError: + return None + if not isinstance(parsed, dict): + return None + return parsed + + +def _split_json_objects(raw: str) -> list[str]: + decoder = json.JSONDecoder() + chunks: list[str] = [] + position = 0 + total = len(raw) + while position < total: + while position < total and raw[position].isspace(): + position += 1 + if position >= total: + break + try: + parsed, end = decoder.raw_decode(raw, position) + except json.JSONDecodeError: + return [] + if not isinstance(parsed, dict): + return [] + chunks.append(raw[position:end]) + position = end + return chunks + + def _render_tool_result(result: object) -> str: if isinstance(result, str): return result diff --git a/tests/test_model_runner.py b/tests/test_model_runner.py index aebf9df..6f7bd3b 100644 --- a/tests/test_model_runner.py +++ b/tests/test_model_runner.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import Any import pytest from republic import ToolAutoResult @@ -80,22 +81,106 @@ def note_hint(self, hint: str) -> bool: return False +@dataclass(frozen=True) +class FakeStreamEvent: + kind: str + data: dict[str, Any] + + +@dataclass +class FakeAsyncStreamEvents: + events: list[FakeStreamEvent] + error: object | None = None + + def __aiter__(self): + async def _iterator(): + for event in self.events: + yield event + + return _iterator() + + +def _stream_from_tool_auto(output: ToolAutoResult) -> FakeAsyncStreamEvents: + if output.kind == "text": + text = output.text or "" + return FakeAsyncStreamEvents( + events=[ + FakeStreamEvent("text", {"delta": text}), + FakeStreamEvent( + "final", + { + "text": text, + "tool_calls": [], + "tool_results": [], + "usage": None, + "ok": True, + }, + ), + ] + ) + + if output.kind == "tools": + events = [ + FakeStreamEvent("tool_call", {"index": idx, "call": call}) for idx, call in enumerate(output.tool_calls) + ] + events.extend( + [FakeStreamEvent("tool_result", {"index": idx, "result": result}) for idx, result in enumerate(output.tool_results)] + ) + events.append( + FakeStreamEvent( + "final", + { + "text": None, + "tool_calls": output.tool_calls, + "tool_results": output.tool_results, + "usage": None, + "ok": True, + }, + ) + ) + return FakeAsyncStreamEvents(events=events) + + error_kind = output.error.kind.value if output.error is not None else "unknown" + error_message = output.error.message if output.error is not None else "unknown" + return FakeAsyncStreamEvents( + events=[ + FakeStreamEvent("error", {"kind": error_kind, "message": error_message}), + FakeStreamEvent( + "final", + { + "text": None, + "tool_calls": output.tool_calls, + "tool_results": output.tool_results, + "usage": None, + "ok": False, + }, + ), + ] + ) + + @dataclass class FakeTapeImpl: - outputs: list[ToolAutoResult] + outputs: list[ToolAutoResult | FakeAsyncStreamEvents] calls: list[tuple[str, str, int]] = field(default_factory=list) + parallel_tool_calls_values: list[bool | None] = field(default_factory=list) - async def run_tools_async( + async def stream_events_async( self, *, prompt: str, system_prompt: str, max_tokens: int, tools: list[object], + parallel_tool_calls: bool | None = None, extra_headers: dict[str, str] | None = None, - ) -> ToolAutoResult: + ) -> FakeAsyncStreamEvents: self.calls.append((prompt, system_prompt, max_tokens)) - return self.outputs.pop(0) + self.parallel_tool_calls_values.append(parallel_tool_calls) + output = self.outputs.pop(0) + if isinstance(output, FakeAsyncStreamEvents): + return output + return _stream_from_tool_auto(output) @dataclass @@ -376,3 +461,131 @@ async def test_model_runner_refreshes_skills_from_provider_between_runs() -> Non _, second_system_prompt, _ = tape.tape.calls[1] assert "" in second_system_prompt assert "friendly-python" in second_system_prompt + + +@pytest.mark.asyncio +async def test_model_runner_reports_stream_error_event() -> None: + tape = FakeTapeService( + FakeTapeImpl( + outputs=[ + FakeAsyncStreamEvents( + events=[ + FakeStreamEvent( + "error", + { + "kind": "provider", + "message": "non-streaming is not supported", + }, + ), + FakeStreamEvent( + "final", + { + "text": None, + "tool_calls": [], + "tool_results": [], + "usage": None, + "ok": False, + }, + ), + ] + ) + ] + ) + ) + runner = ModelRunner( + tape=tape, # type: ignore[arg-type] + router=AnySingleStepRouter(), # type: ignore[arg-type] + tool_view=FakeToolView(), # type: ignore[arg-type] + tools=[], + list_skills=lambda: [], + load_skill_body=lambda name: None, + model="anthropic:test", + max_steps=1, + max_tokens=512, + model_timeout_seconds=90, + base_system_prompt="base", + get_workspace_system_prompt=lambda: "", + ) + + result = await runner.run("start") + assert result.error == "provider: non-streaming is not supported" + assert tape.tape.parallel_tool_calls_values == [False] + + +@pytest.mark.asyncio +async def test_model_runner_does_not_send_parallel_tool_calls_to_non_anthropic() -> None: + tape = FakeTapeService(FakeTapeImpl(outputs=[ToolAutoResult.text_result("assistant-only")])) + runner = ModelRunner( + tape=tape, # type: ignore[arg-type] + router=AnySingleStepRouter(), # type: ignore[arg-type] + tool_view=FakeToolView(), # type: ignore[arg-type] + tools=[], + list_skills=lambda: [], + load_skill_body=lambda name: None, + model="gemini:test", + max_steps=2, + max_tokens=512, + model_timeout_seconds=90, + base_system_prompt="base", + get_workspace_system_prompt=lambda: "", + ) + + result = await runner.run("start") + assert result.error is None + assert tape.tape.parallel_tool_calls_values == [None] + + +@pytest.mark.asyncio +async def test_model_runner_prefers_stream_error_over_tool_followup() -> None: + tape = FakeTapeService( + FakeTapeImpl( + outputs=[ + FakeAsyncStreamEvents( + events=[ + FakeStreamEvent( + "error", + { + "kind": "tool", + "message": "No runnable tools are available.", + }, + ), + FakeStreamEvent( + "final", + { + "text": None, + "tool_calls": [ + { + "id": "call-1", + "type": "function", + "function": {"name": "fs.read", "arguments": '{"path":"a.txt"}'}, + } + ], + "tool_results": [], + "usage": None, + "ok": False, + }, + ), + ] + ), + ToolAutoResult.text_result("assistant-only"), + ] + ) + ) + runner = ModelRunner( + tape=tape, # type: ignore[arg-type] + router=AnySingleStepRouter(), # type: ignore[arg-type] + tool_view=FakeToolView(), # type: ignore[arg-type] + tools=[], + list_skills=lambda: [], + load_skill_body=lambda name: None, + model="anthropic:test", + max_steps=2, + max_tokens=512, + model_timeout_seconds=90, + base_system_prompt="base", + get_workspace_system_prompt=lambda: "", + ) + + result = await runner.run("start") + assert result.error == "tool: No runnable tools are available." + assert len(tape.tape.calls) == 1 diff --git a/tests/test_tape_context.py b/tests/test_tape_context.py index ec4bbd8..69c3058 100644 --- a/tests/test_tape_context.py +++ b/tests/test_tape_context.py @@ -46,4 +46,86 @@ def test_default_tape_context_handles_result_without_calls() -> None: entries = [TapeEntry.tool_result([{"status": "ok"}])] messages = context.select(entries, context) - assert messages == [{"role": "tool", "content": '{"status": "ok"}'}] + assert messages == [] + + +def test_default_tape_context_splits_concatenated_tool_arguments() -> None: + context = default_tape_context() + assert context.select is not None + + entries = [ + TapeEntry.tool_call( + [ + { + "id": "call-1", + "type": "function", + "function": { + "name": "bash", + "arguments": '{"cmd":"echo 1"}{"cmd":"echo 2"}', + }, + } + ] + ), + TapeEntry.tool_result(["ok-1", "ok-2"]), + ] + + messages = context.select(entries, context) + assert messages[0]["role"] == "assistant" + assert len(messages[0]["tool_calls"]) == 2 + assert messages[0]["tool_calls"][0]["function"]["arguments"] == '{"cmd":"echo 1"}' + assert messages[0]["tool_calls"][1]["function"]["arguments"] == '{"cmd":"echo 2"}' + assert messages[1]["tool_call_id"] == "call-1" + assert messages[2]["tool_call_id"] == "call-1__2" + + +def test_default_tape_context_drops_invalid_tool_call_arguments() -> None: + context = default_tape_context() + assert context.select is not None + + entries = [ + TapeEntry.tool_call( + [ + { + "id": "call-1", + "type": "function", + "function": { + "name": "bash", + "arguments": '{"cmd":"echo 1"}this-is-bad', + }, + } + ] + ), + TapeEntry.tool_result(["ignored"]), + ] + + messages = context.select(entries, context) + assert messages == [] + + +def test_default_tape_context_trims_unmatched_split_tool_calls() -> None: + context = default_tape_context() + assert context.select is not None + + entries = [ + TapeEntry.tool_call( + [ + { + "id": "call-1", + "type": "function", + "function": { + "name": "bash", + "arguments": '{"cmd":"echo 1"}{"cmd":"echo 2"}', + }, + } + ] + ), + TapeEntry.tool_result(["only-one-result"]), + ] + + messages = context.select(entries, context) + assert len(messages) == 2 + assert messages[0]["role"] == "assistant" + assert len(messages[0]["tool_calls"]) == 1 + assert messages[0]["tool_calls"][0]["id"] == "call-1" + assert messages[1]["role"] == "tool" + assert messages[1]["tool_call_id"] == "call-1"