diff --git a/src/engram/__init__.py b/src/engram/__init__.py index d05fdf0..5e6eee2 100644 --- a/src/engram/__init__.py +++ b/src/engram/__init__.py @@ -10,7 +10,9 @@ RunStatus, SearchResults, StringContent, - ToolCallMetadata, + ToolCallCustomInput, + ToolCallFuncInput, + ToolCallInput, ) from .async_client import AsyncEngramClient from .client import EngramClient @@ -43,7 +45,9 @@ "RunStatus", "SearchResults", "StringContent", - "ToolCallMetadata", + "ToolCallCustomInput", + "ToolCallFuncInput", + "ToolCallInput", "ValidationError", "__version__", ] diff --git a/src/engram/_models/__init__.py b/src/engram/_models/__init__.py index 594e196..06656aa 100644 --- a/src/engram/_models/__init__.py +++ b/src/engram/_models/__init__.py @@ -7,7 +7,9 @@ RetrievalConfig, SearchResults, StringContent, - ToolCallMetadata, + ToolCallCustomInput, + ToolCallFuncInput, + ToolCallInput, ) from .run import CommittedOperation, CommittedOperations, Run, RunStatus @@ -24,5 +26,7 @@ "RunStatus", "SearchResults", "StringContent", - "ToolCallMetadata", + "ToolCallCustomInput", + "ToolCallFuncInput", + "ToolCallInput", ] diff --git a/src/engram/_models/memory.py b/src/engram/_models/memory.py index 775865a..6153f1a 100644 --- a/src/engram/_models/memory.py +++ b/src/engram/_models/memory.py @@ -21,21 +21,48 @@ class StringContent: @dataclass(slots=True) -class ToolCallMetadata: - """Tool call metadata.""" +class ToolCallFuncInput: + """The function details of an OpenAI-format function tool call.""" name: str + arguments: str + + +@dataclass(slots=True) +class ToolCallCustomInput: + """The details of an OpenAI-format custom tool call.""" + + name: str + input: str + + +@dataclass(slots=True) +class ToolCallInput: + """A single tool call in OpenAI Chat Completions format. + + Set either `function` or `custom` depending on the tool type. + """ + id: str + type: str = "function" + function: ToolCallFuncInput | None = None + custom: ToolCallCustomInput | None = None @dataclass(slots=True) class MessageContent: - """A message in a conversation.""" + """A message in a conversation using the OpenAI Chat Completions format. - role: Literal["user", "assistant", "system"] - content: str + - 'tool' role (tool results) is mapped to 'user' by the server. + - 'developer' role is mapped to 'system' by the server. + """ + + role: Literal["user", "assistant", "system", "tool", "developer"] + content: str = "" created_at: str | None = None - tool_call_metadata: ToolCallMetadata | None = None + tool_call_id: str | None = None + name: str | None = None + tool_calls: list[ToolCallInput] | None = None @dataclass(slots=True) diff --git a/src/engram/_serialization/_builders.py b/src/engram/_serialization/_builders.py index ce80559..d0fa161 100644 --- a/src/engram/_serialization/_builders.py +++ b/src/engram/_serialization/_builders.py @@ -8,9 +8,19 @@ PreExtractedContent, RetrievalConfig, StringContent, + ToolCallInput, ) +def _serialize_tool_call(tc: ToolCallInput) -> dict[str, Any]: + out: dict[str, Any] = {"id": tc.id, "type": tc.type} + if tc.function is not None: + out["function"] = {"name": tc.function.name, "arguments": tc.function.arguments} + if tc.custom is not None: + out["custom"] = {"name": tc.custom.name, "input": tc.custom.input} + return out + + def _serialize_content(content: AddContent) -> dict[str, Any]: """Build the content envelope with the type discriminator.""" if isinstance(content, str): @@ -39,11 +49,12 @@ def _serialize_conversation_content(content: ConversationContent) -> dict[str, A m: dict[str, Any] = {"role": msg.role, "content": msg.content} if msg.created_at is not None: m["created_at"] = msg.created_at - if msg.tool_call_metadata is not None: - m["tool_call_metadata"] = { - "name": msg.tool_call_metadata.name, - "id": msg.tool_call_metadata.id, - } + if msg.tool_call_id is not None: + m["tool_call_id"] = msg.tool_call_id + if msg.name is not None: + m["name"] = msg.name + if msg.tool_calls is not None: + m["tool_calls"] = [_serialize_tool_call(tc) for tc in msg.tool_calls] messages.append(m) conversation: dict[str, Any] = {"messages": messages} if content.metadata is not None: diff --git a/tests/test_client_async.py b/tests/test_client_async.py index 2228431..bbc036b 100644 --- a/tests/test_client_async.py +++ b/tests/test_client_async.py @@ -11,7 +11,8 @@ PreExtractedContent, RetrievalConfig, StringContent, - ToolCallMetadata, + ToolCallFuncInput, + ToolCallInput, ) from engram.async_client import DEFAULT_BASE_URL, AsyncEngramClient from engram.errors import APIError, AuthenticationError, ValidationError @@ -202,8 +203,11 @@ def handler(request: httpx.Request) -> httpx.Response: MessageContent(role="user", content="hi"), MessageContent( role="assistant", - content="using tool", - tool_call_metadata=ToolCallMetadata(name="search", id="tc1"), + tool_calls=[ + ToolCallInput( + id="tc1", function=ToolCallFuncInput(name="search", arguments="{}") + ) + ], ), ], metadata={"session_id": "s1"}, @@ -214,7 +218,9 @@ def handler(request: httpx.Request) -> httpx.Response: assert body["content"]["type"] == "conversation" conv = body["content"]["conversation"] assert conv["metadata"] == {"session_id": "s1"} - assert conv["messages"][1]["tool_call_metadata"] == {"name": "search", "id": "tc1"} + assert conv["messages"][1]["tool_calls"] == [ + {"id": "tc1", "type": "function", "function": {"name": "search", "arguments": "{}"}} + ] assert body["conversation_id"] == "c1" diff --git a/tests/test_client_sync.py b/tests/test_client_sync.py index 98b4c97..fea7666 100644 --- a/tests/test_client_sync.py +++ b/tests/test_client_sync.py @@ -11,7 +11,8 @@ PreExtractedContent, RetrievalConfig, StringContent, - ToolCallMetadata, + ToolCallFuncInput, + ToolCallInput, ) from engram.client import DEFAULT_BASE_URL, EngramClient from engram.errors import APIError, AuthenticationError, ValidationError @@ -217,8 +218,11 @@ def handler(request: httpx.Request) -> httpx.Response: MessageContent(role="user", content="hi"), MessageContent( role="assistant", - content="using tool", - tool_call_metadata=ToolCallMetadata(name="search", id="tc1"), + tool_calls=[ + ToolCallInput( + id="tc1", function=ToolCallFuncInput(name="search", arguments="{}") + ) + ], ), ], metadata={"session_id": "s1"}, @@ -229,7 +233,9 @@ def handler(request: httpx.Request) -> httpx.Response: assert body["content"]["type"] == "conversation" conv = body["content"]["conversation"] assert conv["metadata"] == {"session_id": "s1"} - assert conv["messages"][1]["tool_call_metadata"] == {"name": "search", "id": "tc1"} + assert conv["messages"][1]["tool_calls"] == [ + {"id": "tc1", "type": "function", "function": {"name": "search", "arguments": "{}"}} + ] assert body["conversation_id"] == "c1" diff --git a/tests/test_imports.py b/tests/test_imports.py index b802205..0a8e8e4 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -19,7 +19,9 @@ def test_public_imports() -> None: RunStatus, SearchResults, StringContent, - ToolCallMetadata, + ToolCallCustomInput, + ToolCallFuncInput, + ToolCallInput, ValidationError, ) @@ -41,7 +43,9 @@ def test_public_imports() -> None: assert isinstance(ConversationContent, type) assert isinstance(MessageContent, type) assert isinstance(StringContent, type) - assert isinstance(ToolCallMetadata, type) + assert isinstance(ToolCallCustomInput, type) + assert isinstance(ToolCallFuncInput, type) + assert isinstance(ToolCallInput, type) expected_exports = { "APIError", @@ -62,7 +66,9 @@ def test_public_imports() -> None: "RunStatus", "SearchResults", "StringContent", - "ToolCallMetadata", + "ToolCallCustomInput", + "ToolCallFuncInput", + "ToolCallInput", "ValidationError", "__version__", } diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 8b4fa2f..1b36ddb 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -4,7 +4,9 @@ PreExtractedContent, RetrievalConfig, StringContent, - ToolCallMetadata, + ToolCallCustomInput, + ToolCallFuncInput, + ToolCallInput, ) from engram._serialization import ( build_add_body, @@ -160,12 +162,15 @@ def test_build_add_body_conversation_content_with_message_timestamps() -> None: assert "tool_call_metadata" not in msg -def test_build_add_body_conversation_content_with_tool_call_metadata() -> None: +def test_build_add_body_conversation_content_with_tool_calls() -> None: messages = [ MessageContent( role="assistant", - content="using tool", - tool_call_metadata=ToolCallMetadata(name="search", id="tc1"), + tool_calls=[ + ToolCallInput( + id="tc1", function=ToolCallFuncInput(name="search", arguments='{"q":"x"}') + ) + ], ) ] body = build_add_body( @@ -175,7 +180,62 @@ def test_build_add_body_conversation_content_with_tool_call_metadata() -> None: group=None, ) msg = body["content"]["conversation"]["messages"][0] - assert msg["tool_call_metadata"] == {"name": "search", "id": "tc1"} + assert msg["tool_calls"] == [ + {"id": "tc1", "type": "function", "function": {"name": "search", "arguments": '{"q":"x"}'}} + ] + + +def test_build_add_body_conversation_content_with_custom_tool_calls() -> None: + messages = [ + MessageContent( + role="assistant", + tool_calls=[ + ToolCallInput( + id="tc2", + type="custom", + custom=ToolCallCustomInput(name="my_tool", input="some input"), + ) + ], + ) + ] + body = build_add_body( + ConversationContent(messages=messages), + user_id=None, + conversation_id=None, + group=None, + ) + msg = body["content"]["conversation"]["messages"][0] + assert msg["tool_calls"] == [ + {"id": "tc2", "type": "custom", "custom": {"name": "my_tool", "input": "some input"}} + ] + + +def test_build_add_body_conversation_content_with_tool_role() -> None: + messages = [MessageContent(role="tool", content="result", tool_call_id="tc1", name="search")] + body = build_add_body( + ConversationContent(messages=messages), + user_id=None, + conversation_id=None, + group=None, + ) + msg = body["content"]["conversation"]["messages"][0] + assert msg["role"] == "tool" + assert msg["tool_call_id"] == "tc1" + assert msg["name"] == "search" + assert msg["content"] == "result" + + +def test_build_add_body_conversation_content_with_developer_role() -> None: + messages = [MessageContent(role="developer", content="You are a helpful assistant.")] + body = build_add_body( + ConversationContent(messages=messages), + user_id=None, + conversation_id=None, + group=None, + ) + msg = body["content"]["conversation"]["messages"][0] + assert msg["role"] == "developer" + assert msg["content"] == "You are a helpful assistant." # ── build_memory_params ─────────────────────────────────────────────────