diff --git a/src/engram/__init__.py b/src/engram/__init__.py index b2d1aeb..d05fdf0 100644 --- a/src/engram/__init__.py +++ b/src/engram/__init__.py @@ -1,12 +1,16 @@ from ._models import ( CommittedOperation, CommittedOperations, + ConversationContent, Memory, + MessageContent, PreExtractedContent, RetrievalConfig, Run, RunStatus, SearchResults, + StringContent, + ToolCallMetadata, ) from .async_client import AsyncEngramClient from .client import EngramClient @@ -27,15 +31,19 @@ "CommittedOperation", "CommittedOperations", "ConnectionError", + "ConversationContent", "EngramClient", "EngramError", "EngramTimeoutError", "Memory", + "MessageContent", "PreExtractedContent", "RetrievalConfig", "Run", "RunStatus", "SearchResults", + "StringContent", + "ToolCallMetadata", "ValidationError", "__version__", ] diff --git a/src/engram/_models/__init__.py b/src/engram/_models/__init__.py index 1335c37..594e196 100644 --- a/src/engram/_models/__init__.py +++ b/src/engram/_models/__init__.py @@ -1,14 +1,28 @@ -from .memory import AddContent, Memory, PreExtractedContent, RetrievalConfig, SearchResults +from .memory import ( + AddContent, + ConversationContent, + Memory, + MessageContent, + PreExtractedContent, + RetrievalConfig, + SearchResults, + StringContent, + ToolCallMetadata, +) from .run import CommittedOperation, CommittedOperations, Run, RunStatus __all__ = [ "AddContent", "CommittedOperation", "CommittedOperations", + "ConversationContent", "Memory", + "MessageContent", "PreExtractedContent", "RetrievalConfig", "Run", "RunStatus", "SearchResults", + "StringContent", + "ToolCallMetadata", ] diff --git a/src/engram/_models/memory.py b/src/engram/_models/memory.py index 2a4f6f0..775865a 100644 --- a/src/engram/_models/memory.py +++ b/src/engram/_models/memory.py @@ -1,8 +1,8 @@ from __future__ import annotations from collections.abc import Iterator, Sequence -from dataclasses import dataclass, field -from typing import Literal, TypeAlias +from dataclasses import dataclass +from typing import Any, Literal, TypeAlias @dataclass(slots=True) @@ -10,11 +10,48 @@ class PreExtractedContent: """Pre-extracted content that bypasses the extraction pipeline.""" content: str - tags: list[str] = field(default_factory=list) + topic: str + + +@dataclass(slots=True) +class StringContent: + """String content that bypasses the extraction pipeline.""" + + content: str + + +@dataclass(slots=True) +class ToolCallMetadata: + """Tool call metadata.""" + + name: str + id: str + + +@dataclass(slots=True) +class MessageContent: + """A message in a conversation.""" + + role: Literal["user", "assistant", "system"] + content: str + created_at: str | None = None + tool_call_metadata: ToolCallMetadata | None = None + + +@dataclass(slots=True) +class ConversationContent: + """Conversation content that bypasses the extraction pipeline.""" + + messages: list[MessageContent] + metadata: dict[str, Any] | None = None + created_at: str | None = None + updated_at: str | None = None # Type alias for the content argument to memories.add() -AddContent: TypeAlias = str | list[dict[str, str]] | PreExtractedContent +AddContent: TypeAlias = ( + str | list[dict[str, str]] | PreExtractedContent | ConversationContent | StringContent +) @dataclass(slots=True) diff --git a/src/engram/_serialization/_builders.py b/src/engram/_serialization/_builders.py index b54fb75..ce80559 100644 --- a/src/engram/_serialization/_builders.py +++ b/src/engram/_serialization/_builders.py @@ -2,26 +2,59 @@ from typing import Any -from .._models import AddContent, PreExtractedContent, RetrievalConfig +from .._models import ( + AddContent, + ConversationContent, + PreExtractedContent, + RetrievalConfig, + StringContent, +) def _serialize_content(content: AddContent) -> dict[str, Any]: """Build the content envelope with the type discriminator.""" if isinstance(content, str): return {"type": "string", "content": content} + if isinstance(content, StringContent): + return {"type": "string", "content": content.content} if isinstance(content, PreExtractedContent): - d: dict[str, Any] = {"type": "pre_extracted", "content": content.content} - if content.tags: - d["tags"] = content.tags - return d + return { + "type": "pre_extracted", + "content": content.content, + "topic": content.topic, + } if isinstance(content, list): return { "type": "conversation", "conversation": {"messages": content}, } + if isinstance(content, ConversationContent): + return _serialize_conversation_content(content) raise TypeError(f"Unsupported content type: {type(content)}") # pragma: no cover +def _serialize_conversation_content(content: ConversationContent) -> dict[str, Any]: + messages = [] + for msg in content.messages: + m: dict[str, Any] = {"role": msg.role, "content": msg.content} + if msg.created_at is not None: + m["created_at"] = msg.created_at + if msg.tool_call_metadata is not None: + m["tool_call_metadata"] = { + "name": msg.tool_call_metadata.name, + "id": msg.tool_call_metadata.id, + } + messages.append(m) + conversation: dict[str, Any] = {"messages": messages} + if content.metadata is not None: + conversation["metadata"] = content.metadata + if content.created_at is not None: + conversation["created_at"] = content.created_at + if content.updated_at is not None: + conversation["updated_at"] = content.updated_at + return {"type": "conversation", "conversation": conversation} + + def build_add_body( content: AddContent, *, diff --git a/tests/test_client_async.py b/tests/test_client_async.py index 0952f9e..2228431 100644 --- a/tests/test_client_async.py +++ b/tests/test_client_async.py @@ -5,7 +5,14 @@ import pytest from engram._http import AsyncHttpTransport -from engram._models import PreExtractedContent, RetrievalConfig +from engram._models import ( + ConversationContent, + MessageContent, + PreExtractedContent, + RetrievalConfig, + StringContent, + ToolCallMetadata, +) from engram.async_client import DEFAULT_BASE_URL, AsyncEngramClient from engram.errors import APIError, AuthenticationError, ValidationError @@ -109,7 +116,7 @@ async def test_add_str() -> None: async def test_add_pre_extracted() -> None: client = _make_client(body={"run_id": "r2", "status": "pending"}) result = await client.memories.add( - PreExtractedContent(content="fact", tags=["a"]), + PreExtractedContent(content="fact", topic="topic"), user_id="u1", ) assert result.run_id == "r2" @@ -144,6 +151,73 @@ def handler(request: httpx.Request) -> httpx.Response: } +@pytest.mark.asyncio +async def test_add_string_content() -> None: + client = _make_client(body={"run_id": "r4", "status": "pending"}) + result = await client.memories.add(StringContent(content="hello"), user_id="u1") + assert result.run_id == "r4" + + +@pytest.mark.asyncio +async def test_add_string_content_sends_correct_envelope() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"run_id": "r1", "status": "pending"}) + + client = _make_client_with_handler(handler) + await client.memories.add(StringContent(content="hello"), user_id="u1", group="g1") + body = json.loads(captured[0].content) + assert body == { + "content": {"type": "string", "content": "hello"}, + "user_id": "u1", + "group": "g1", + } + + +@pytest.mark.asyncio +async def test_add_conversation_content() -> None: + client = _make_client(body={"run_id": "r5", "status": "pending"}) + result = await client.memories.add( + ConversationContent(messages=[MessageContent(role="user", content="hi")]), + user_id="u1", + conversation_id="c1", + ) + assert result.run_id == "r5" + + +@pytest.mark.asyncio +async def test_add_conversation_content_sends_correct_envelope() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"run_id": "r1", "status": "pending"}) + + client = _make_client_with_handler(handler) + await client.memories.add( + ConversationContent( + messages=[ + MessageContent(role="user", content="hi"), + MessageContent( + role="assistant", + content="using tool", + tool_call_metadata=ToolCallMetadata(name="search", id="tc1"), + ), + ], + metadata={"session_id": "s1"}, + ), + conversation_id="c1", + ) + body = json.loads(captured[0].content) + assert body["content"]["type"] == "conversation" + conv = body["content"]["conversation"] + assert conv["metadata"] == {"session_id": "s1"} + assert conv["messages"][1]["tool_call_metadata"] == {"name": "search", "id": "tc1"} + assert body["conversation_id"] == "c1" + + # ── memories.get ──────────────────────────────────────────────────────── SAMPLE_MEMORY_RESPONSE: dict[str, Any] = { diff --git a/tests/test_client_sync.py b/tests/test_client_sync.py index 73d7140..98b4c97 100644 --- a/tests/test_client_sync.py +++ b/tests/test_client_sync.py @@ -5,7 +5,14 @@ import pytest from engram._http import HttpTransport -from engram._models import PreExtractedContent, RetrievalConfig +from engram._models import ( + ConversationContent, + MessageContent, + PreExtractedContent, + RetrievalConfig, + StringContent, + ToolCallMetadata, +) from engram.client import DEFAULT_BASE_URL, EngramClient from engram.errors import APIError, AuthenticationError, ValidationError @@ -110,7 +117,7 @@ def test_add_str() -> None: def test_add_pre_extracted() -> None: client = _make_client(body={"run_id": "r2", "status": "pending"}) result = client.memories.add( - PreExtractedContent(content="fact", tags=["a"]), + PreExtractedContent(content="fact", topic="topic"), user_id="u1", ) assert result.run_id == "r2" @@ -163,6 +170,69 @@ def handler(request: httpx.Request) -> httpx.Response: } +def test_add_string_content() -> None: + client = _make_client(body={"run_id": "r4", "status": "pending"}) + result = client.memories.add(StringContent(content="hello"), user_id="u1") + assert result.run_id == "r4" + + +def test_add_string_content_sends_correct_envelope() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"run_id": "r1", "status": "pending"}) + + client = _make_client_with_handler(handler) + client.memories.add(StringContent(content="hello"), user_id="u1", group="g1") + body = json.loads(captured[0].content) + assert body == { + "content": {"type": "string", "content": "hello"}, + "user_id": "u1", + "group": "g1", + } + + +def test_add_conversation_content() -> None: + client = _make_client(body={"run_id": "r5", "status": "pending"}) + result = client.memories.add( + ConversationContent(messages=[MessageContent(role="user", content="hi")]), + user_id="u1", + conversation_id="c1", + ) + assert result.run_id == "r5" + + +def test_add_conversation_content_sends_correct_envelope() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"run_id": "r1", "status": "pending"}) + + client = _make_client_with_handler(handler) + client.memories.add( + ConversationContent( + messages=[ + MessageContent(role="user", content="hi"), + MessageContent( + role="assistant", + content="using tool", + tool_call_metadata=ToolCallMetadata(name="search", id="tc1"), + ), + ], + metadata={"session_id": "s1"}, + ), + conversation_id="c1", + ) + body = json.loads(captured[0].content) + assert body["content"]["type"] == "conversation" + conv = body["content"]["conversation"] + assert conv["metadata"] == {"session_id": "s1"} + assert conv["messages"][1]["tool_call_metadata"] == {"name": "search", "id": "tc1"} + assert body["conversation_id"] == "c1" + + # ── memories.get ──────────────────────────────────────────────────────── SAMPLE_MEMORY_RESPONSE: dict[str, Any] = { diff --git a/tests/test_imports.py b/tests/test_imports.py index 3580309..b802205 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -7,15 +7,19 @@ def test_public_imports() -> None: CommittedOperation, CommittedOperations, ConnectionError, + ConversationContent, EngramClient, EngramError, EngramTimeoutError, Memory, + MessageContent, PreExtractedContent, RetrievalConfig, Run, RunStatus, SearchResults, + StringContent, + ToolCallMetadata, ValidationError, ) @@ -34,6 +38,10 @@ def test_public_imports() -> None: assert isinstance(RetrievalConfig, type) assert isinstance(CommittedOperation, type) assert isinstance(CommittedOperations, type) + assert isinstance(ConversationContent, type) + assert isinstance(MessageContent, type) + assert isinstance(StringContent, type) + assert isinstance(ToolCallMetadata, type) expected_exports = { "APIError", @@ -42,15 +50,19 @@ def test_public_imports() -> None: "CommittedOperation", "CommittedOperations", "ConnectionError", + "ConversationContent", "EngramClient", "EngramError", "EngramTimeoutError", "Memory", + "MessageContent", "PreExtractedContent", "RetrievalConfig", "Run", "RunStatus", "SearchResults", + "StringContent", + "ToolCallMetadata", "ValidationError", "__version__", } diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 2925ff0..8b4fa2f 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -1,4 +1,11 @@ -from engram._models import PreExtractedContent, RetrievalConfig +from engram._models import ( + ConversationContent, + MessageContent, + PreExtractedContent, + RetrievalConfig, + StringContent, + ToolCallMetadata, +) from engram._serialization import ( build_add_body, build_memory_params, @@ -39,35 +46,69 @@ def test_build_add_body_str_with_options() -> None: def test_build_add_body_pre_extracted() -> None: body = build_add_body( - PreExtractedContent(content="fact", tags=["a", "b"]), + PreExtractedContent(content="fact", topic="topic"), user_id=None, conversation_id=None, group=None, ) assert body == { - "content": {"type": "pre_extracted", "content": "fact", "tags": ["a", "b"]}, + "content": {"type": "pre_extracted", "content": "fact", "topic": "topic"}, } -def test_build_add_body_pre_extracted_no_tags() -> None: +def test_build_add_body_conversation() -> None: + messages = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] body = build_add_body( - PreExtractedContent(content="fact"), + messages, + user_id="u1", + conversation_id="c1", + group=None, + ) + assert body == { + "content": { + "type": "conversation", + "conversation": {"messages": messages}, + }, + "user_id": "u1", + "conversation_id": "c1", + } + + +def test_build_add_body_string_content() -> None: + body = build_add_body( + StringContent(content="hello world"), user_id=None, conversation_id=None, group=None, ) + assert body == {"content": {"type": "string", "content": "hello world"}} + + +def test_build_add_body_string_content_with_options() -> None: + body = build_add_body( + StringContent(content="hello"), + user_id="u1", + conversation_id="c1", + group="g1", + ) assert body == { - "content": {"type": "pre_extracted", "content": "fact"}, + "content": {"type": "string", "content": "hello"}, + "user_id": "u1", + "conversation_id": "c1", + "group": "g1", } -def test_build_add_body_conversation() -> None: +def test_build_add_body_conversation_content() -> None: messages = [ - {"role": "user", "content": "hi"}, - {"role": "assistant", "content": "hello"}, + MessageContent(role="user", content="hi"), + MessageContent(role="assistant", content="hello"), ] body = build_add_body( - messages, + ConversationContent(messages=messages), user_id="u1", conversation_id="c1", group=None, @@ -75,13 +116,68 @@ def test_build_add_body_conversation() -> None: assert body == { "content": { "type": "conversation", - "conversation": {"messages": messages}, + "conversation": { + "messages": [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ], + }, }, "user_id": "u1", "conversation_id": "c1", } +def test_build_add_body_conversation_content_with_metadata() -> None: + messages = [MessageContent(role="user", content="hi")] + body = build_add_body( + ConversationContent( + messages=messages, + metadata={"session_id": "s1"}, + created_at="2024-01-01T00:00:00Z", + updated_at="2024-01-02T00:00:00Z", + ), + user_id=None, + conversation_id=None, + group=None, + ) + conv = body["content"]["conversation"] + assert conv["metadata"] == {"session_id": "s1"} + assert conv["created_at"] == "2024-01-01T00:00:00Z" + assert conv["updated_at"] == "2024-01-02T00:00:00Z" + + +def test_build_add_body_conversation_content_with_message_timestamps() -> None: + messages = [MessageContent(role="user", content="hi", created_at="2024-01-01T00:00:00Z")] + body = build_add_body( + ConversationContent(messages=messages), + user_id=None, + conversation_id=None, + group=None, + ) + msg = body["content"]["conversation"]["messages"][0] + assert msg["created_at"] == "2024-01-01T00:00:00Z" + assert "tool_call_metadata" not in msg + + +def test_build_add_body_conversation_content_with_tool_call_metadata() -> None: + messages = [ + MessageContent( + role="assistant", + content="using tool", + tool_call_metadata=ToolCallMetadata(name="search", id="tc1"), + ) + ] + body = build_add_body( + ConversationContent(messages=messages), + user_id=None, + conversation_id=None, + group=None, + ) + msg = body["content"]["conversation"]["messages"][0] + assert msg["tool_call_metadata"] == {"name": "search", "id": "tc1"} + + # ── build_memory_params ─────────────────────────────────────────────────