Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/engram/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
RunStatus,
SearchResults,
StringContent,
ToolCallMetadata,
ToolCallCustomInput,
ToolCallFuncInput,
ToolCallInput,
)
from .async_client import AsyncEngramClient
from .client import EngramClient
Expand Down Expand Up @@ -43,7 +45,9 @@
"RunStatus",
"SearchResults",
"StringContent",
"ToolCallMetadata",
"ToolCallCustomInput",
"ToolCallFuncInput",
"ToolCallInput",
"ValidationError",
"__version__",
]
8 changes: 6 additions & 2 deletions src/engram/_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
RetrievalConfig,
SearchResults,
StringContent,
ToolCallMetadata,
ToolCallCustomInput,
ToolCallFuncInput,
ToolCallInput,
)
from .run import CommittedOperation, CommittedOperations, Run, RunStatus

Expand All @@ -24,5 +26,7 @@
"RunStatus",
"SearchResults",
"StringContent",
"ToolCallMetadata",
"ToolCallCustomInput",
"ToolCallFuncInput",
"ToolCallInput",
]
39 changes: 33 additions & 6 deletions src/engram/_models/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,48 @@ class StringContent:


@dataclass(slots=True)
class ToolCallMetadata:
"""Tool call metadata."""
class ToolCallFuncInput:
"""The function details of an OpenAI-format function tool call."""

name: str
arguments: str


@dataclass(slots=True)
class ToolCallCustomInput:
"""The details of an OpenAI-format custom tool call."""

name: str
input: str


@dataclass(slots=True)
class ToolCallInput:
"""A single tool call in OpenAI Chat Completions format.

Set either `function` or `custom` depending on the tool type.
"""

id: str
type: str = "function"
function: ToolCallFuncInput | None = None
custom: ToolCallCustomInput | None = None


@dataclass(slots=True)
class MessageContent:
"""A message in a conversation."""
"""A message in a conversation using the OpenAI Chat Completions format.

role: Literal["user", "assistant", "system"]
content: str
- 'tool' role (tool results) is mapped to 'user' by the server.
- 'developer' role is mapped to 'system' by the server.
"""

role: Literal["user", "assistant", "system", "tool", "developer"]
content: str = ""
created_at: str | None = None
tool_call_metadata: ToolCallMetadata | None = None
tool_call_id: str | None = None
name: str | None = None
tool_calls: list[ToolCallInput] | None = None


@dataclass(slots=True)
Expand Down
21 changes: 16 additions & 5 deletions src/engram/_serialization/_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,19 @@
PreExtractedContent,
RetrievalConfig,
StringContent,
ToolCallInput,
)


def _serialize_tool_call(tc: ToolCallInput) -> dict[str, Any]:
out: dict[str, Any] = {"id": tc.id, "type": tc.type}
if tc.function is not None:
out["function"] = {"name": tc.function.name, "arguments": tc.function.arguments}
if tc.custom is not None:
out["custom"] = {"name": tc.custom.name, "input": tc.custom.input}
return out


def _serialize_content(content: AddContent) -> dict[str, Any]:
"""Build the content envelope with the type discriminator."""
if isinstance(content, str):
Expand Down Expand Up @@ -39,11 +49,12 @@ def _serialize_conversation_content(content: ConversationContent) -> dict[str, A
m: dict[str, Any] = {"role": msg.role, "content": msg.content}
if msg.created_at is not None:
m["created_at"] = msg.created_at
if msg.tool_call_metadata is not None:
m["tool_call_metadata"] = {
"name": msg.tool_call_metadata.name,
"id": msg.tool_call_metadata.id,
}
if msg.tool_call_id is not None:
m["tool_call_id"] = msg.tool_call_id
if msg.name is not None:
m["name"] = msg.name
if msg.tool_calls is not None:
m["tool_calls"] = [_serialize_tool_call(tc) for tc in msg.tool_calls]
messages.append(m)
conversation: dict[str, Any] = {"messages": messages}
if content.metadata is not None:
Expand Down
14 changes: 10 additions & 4 deletions tests/test_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
PreExtractedContent,
RetrievalConfig,
StringContent,
ToolCallMetadata,
ToolCallFuncInput,
ToolCallInput,
)
from engram.async_client import DEFAULT_BASE_URL, AsyncEngramClient
from engram.errors import APIError, AuthenticationError, ValidationError
Expand Down Expand Up @@ -202,8 +203,11 @@ def handler(request: httpx.Request) -> httpx.Response:
MessageContent(role="user", content="hi"),
MessageContent(
role="assistant",
content="using tool",
tool_call_metadata=ToolCallMetadata(name="search", id="tc1"),
tool_calls=[
ToolCallInput(
id="tc1", function=ToolCallFuncInput(name="search", arguments="{}")
)
],
),
],
metadata={"session_id": "s1"},
Expand All @@ -214,7 +218,9 @@ def handler(request: httpx.Request) -> httpx.Response:
assert body["content"]["type"] == "conversation"
conv = body["content"]["conversation"]
assert conv["metadata"] == {"session_id": "s1"}
assert conv["messages"][1]["tool_call_metadata"] == {"name": "search", "id": "tc1"}
assert conv["messages"][1]["tool_calls"] == [
{"id": "tc1", "type": "function", "function": {"name": "search", "arguments": "{}"}}
]
assert body["conversation_id"] == "c1"


Expand Down
14 changes: 10 additions & 4 deletions tests/test_client_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
PreExtractedContent,
RetrievalConfig,
StringContent,
ToolCallMetadata,
ToolCallFuncInput,
ToolCallInput,
)
from engram.client import DEFAULT_BASE_URL, EngramClient
from engram.errors import APIError, AuthenticationError, ValidationError
Expand Down Expand Up @@ -217,8 +218,11 @@ def handler(request: httpx.Request) -> httpx.Response:
MessageContent(role="user", content="hi"),
MessageContent(
role="assistant",
content="using tool",
tool_call_metadata=ToolCallMetadata(name="search", id="tc1"),
tool_calls=[
ToolCallInput(
id="tc1", function=ToolCallFuncInput(name="search", arguments="{}")
)
],
),
],
metadata={"session_id": "s1"},
Expand All @@ -229,7 +233,9 @@ def handler(request: httpx.Request) -> httpx.Response:
assert body["content"]["type"] == "conversation"
conv = body["content"]["conversation"]
assert conv["metadata"] == {"session_id": "s1"}
assert conv["messages"][1]["tool_call_metadata"] == {"name": "search", "id": "tc1"}
assert conv["messages"][1]["tool_calls"] == [
{"id": "tc1", "type": "function", "function": {"name": "search", "arguments": "{}"}}
]
assert body["conversation_id"] == "c1"


Expand Down
12 changes: 9 additions & 3 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def test_public_imports() -> None:
RunStatus,
SearchResults,
StringContent,
ToolCallMetadata,
ToolCallCustomInput,
ToolCallFuncInput,
ToolCallInput,
ValidationError,
)

Expand All @@ -41,7 +43,9 @@ def test_public_imports() -> None:
assert isinstance(ConversationContent, type)
assert isinstance(MessageContent, type)
assert isinstance(StringContent, type)
assert isinstance(ToolCallMetadata, type)
assert isinstance(ToolCallCustomInput, type)
assert isinstance(ToolCallFuncInput, type)
assert isinstance(ToolCallInput, type)

expected_exports = {
"APIError",
Expand All @@ -62,7 +66,9 @@ def test_public_imports() -> None:
"RunStatus",
"SearchResults",
"StringContent",
"ToolCallMetadata",
"ToolCallCustomInput",
"ToolCallFuncInput",
"ToolCallInput",
"ValidationError",
"__version__",
}
Expand Down
70 changes: 65 additions & 5 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
PreExtractedContent,
RetrievalConfig,
StringContent,
ToolCallMetadata,
ToolCallCustomInput,
ToolCallFuncInput,
ToolCallInput,
)
from engram._serialization import (
build_add_body,
Expand Down Expand Up @@ -160,12 +162,15 @@ def test_build_add_body_conversation_content_with_message_timestamps() -> None:
assert "tool_call_metadata" not in msg


def test_build_add_body_conversation_content_with_tool_call_metadata() -> None:
def test_build_add_body_conversation_content_with_tool_calls() -> None:
messages = [
MessageContent(
role="assistant",
content="using tool",
tool_call_metadata=ToolCallMetadata(name="search", id="tc1"),
tool_calls=[
ToolCallInput(
id="tc1", function=ToolCallFuncInput(name="search", arguments='{"q":"x"}')
)
],
)
]
body = build_add_body(
Expand All @@ -175,7 +180,62 @@ def test_build_add_body_conversation_content_with_tool_call_metadata() -> None:
group=None,
)
msg = body["content"]["conversation"]["messages"][0]
assert msg["tool_call_metadata"] == {"name": "search", "id": "tc1"}
assert msg["tool_calls"] == [
{"id": "tc1", "type": "function", "function": {"name": "search", "arguments": '{"q":"x"}'}}
]


def test_build_add_body_conversation_content_with_custom_tool_calls() -> None:
messages = [
MessageContent(
role="assistant",
tool_calls=[
ToolCallInput(
id="tc2",
type="custom",
custom=ToolCallCustomInput(name="my_tool", input="some input"),
)
],
)
]
body = build_add_body(
ConversationContent(messages=messages),
user_id=None,
conversation_id=None,
group=None,
)
msg = body["content"]["conversation"]["messages"][0]
assert msg["tool_calls"] == [
{"id": "tc2", "type": "custom", "custom": {"name": "my_tool", "input": "some input"}}
]


def test_build_add_body_conversation_content_with_tool_role() -> None:
messages = [MessageContent(role="tool", content="result", tool_call_id="tc1", name="search")]
body = build_add_body(
ConversationContent(messages=messages),
user_id=None,
conversation_id=None,
group=None,
)
msg = body["content"]["conversation"]["messages"][0]
assert msg["role"] == "tool"
assert msg["tool_call_id"] == "tc1"
assert msg["name"] == "search"
assert msg["content"] == "result"


def test_build_add_body_conversation_content_with_developer_role() -> None:
messages = [MessageContent(role="developer", content="You are a helpful assistant.")]
body = build_add_body(
ConversationContent(messages=messages),
user_id=None,
conversation_id=None,
group=None,
)
msg = body["content"]["conversation"]["messages"][0]
assert msg["role"] == "developer"
assert msg["content"] == "You are a helpful assistant."


# ── build_memory_params ─────────────────────────────────────────────────
Expand Down