Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/engram/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from ._models import (
CommittedOperation,
CommittedOperations,
ConversationContent,
Memory,
MessageContent,
PreExtractedContent,
RetrievalConfig,
Run,
RunStatus,
SearchResults,
StringContent,
ToolCallMetadata,
)
from .async_client import AsyncEngramClient
from .client import EngramClient
Expand All @@ -27,15 +31,19 @@
"CommittedOperation",
"CommittedOperations",
"ConnectionError",
"ConversationContent",
"EngramClient",
"EngramError",
"EngramTimeoutError",
"Memory",
"MessageContent",
"PreExtractedContent",
"RetrievalConfig",
"Run",
"RunStatus",
"SearchResults",
"StringContent",
"ToolCallMetadata",
"ValidationError",
"__version__",
]
16 changes: 15 additions & 1 deletion src/engram/_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
from .memory import AddContent, Memory, PreExtractedContent, RetrievalConfig, SearchResults
from .memory import (
AddContent,
ConversationContent,
Memory,
MessageContent,
PreExtractedContent,
RetrievalConfig,
SearchResults,
StringContent,
ToolCallMetadata,
)
from .run import CommittedOperation, CommittedOperations, Run, RunStatus

__all__ = [
"AddContent",
"CommittedOperation",
"CommittedOperations",
"ConversationContent",
"Memory",
"MessageContent",
"PreExtractedContent",
"RetrievalConfig",
"Run",
"RunStatus",
"SearchResults",
"StringContent",
"ToolCallMetadata",
]
45 changes: 41 additions & 4 deletions src/engram/_models/memory.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,57 @@
from __future__ import annotations

from collections.abc import Iterator, Sequence
from dataclasses import dataclass, field
from typing import Literal, TypeAlias
from dataclasses import dataclass
from typing import Any, Literal, TypeAlias


@dataclass(slots=True)
class PreExtractedContent:
"""Pre-extracted content that bypasses the extraction pipeline."""

content: str
tags: list[str] = field(default_factory=list)
topic: str


@dataclass(slots=True)
class StringContent:
"""String content that bypasses the extraction pipeline."""

content: str


@dataclass(slots=True)
class ToolCallMetadata:
"""Tool call metadata."""

name: str
id: str


@dataclass(slots=True)
class MessageContent:
"""A message in a conversation."""

role: Literal["user", "assistant", "system"]
content: str
created_at: str | None = None
tool_call_metadata: ToolCallMetadata | None = None


@dataclass(slots=True)
class ConversationContent:
"""Conversation content that bypasses the extraction pipeline."""

messages: list[MessageContent]
metadata: dict[str, Any] | None = None
created_at: str | None = None
updated_at: str | None = None


# Type alias for the content argument to memories.add()
AddContent: TypeAlias = str | list[dict[str, str]] | PreExtractedContent
AddContent: TypeAlias = (
str | list[dict[str, str]] | PreExtractedContent | ConversationContent | StringContent
)


@dataclass(slots=True)
Expand Down
43 changes: 38 additions & 5 deletions src/engram/_serialization/_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,59 @@

from typing import Any

from .._models import AddContent, PreExtractedContent, RetrievalConfig
from .._models import (
AddContent,
ConversationContent,
PreExtractedContent,
RetrievalConfig,
StringContent,
)


def _serialize_content(content: AddContent) -> dict[str, Any]:
"""Build the content envelope with the type discriminator."""
if isinstance(content, str):
return {"type": "string", "content": content}
if isinstance(content, StringContent):
return {"type": "string", "content": content.content}
if isinstance(content, PreExtractedContent):
d: dict[str, Any] = {"type": "pre_extracted", "content": content.content}
if content.tags:
d["tags"] = content.tags
return d
return {
"type": "pre_extracted",
"content": content.content,
"topic": content.topic,
}
if isinstance(content, list):
return {
"type": "conversation",
"conversation": {"messages": content},
}
if isinstance(content, ConversationContent):
return _serialize_conversation_content(content)
raise TypeError(f"Unsupported content type: {type(content)}") # pragma: no cover


def _serialize_conversation_content(content: ConversationContent) -> dict[str, Any]:
messages = []
for msg in content.messages:
m: dict[str, Any] = {"role": msg.role, "content": msg.content}
if msg.created_at is not None:
m["created_at"] = msg.created_at
if msg.tool_call_metadata is not None:
m["tool_call_metadata"] = {
"name": msg.tool_call_metadata.name,
"id": msg.tool_call_metadata.id,
}
messages.append(m)
conversation: dict[str, Any] = {"messages": messages}
if content.metadata is not None:
conversation["metadata"] = content.metadata
if content.created_at is not None:
conversation["created_at"] = content.created_at
if content.updated_at is not None:
conversation["updated_at"] = content.updated_at
return {"type": "conversation", "conversation": conversation}


def build_add_body(
content: AddContent,
*,
Expand Down
78 changes: 76 additions & 2 deletions tests/test_client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
import pytest

from engram._http import AsyncHttpTransport
from engram._models import PreExtractedContent, RetrievalConfig
from engram._models import (
ConversationContent,
MessageContent,
PreExtractedContent,
RetrievalConfig,
StringContent,
ToolCallMetadata,
)
from engram.async_client import DEFAULT_BASE_URL, AsyncEngramClient
from engram.errors import APIError, AuthenticationError, ValidationError

Expand Down Expand Up @@ -109,7 +116,7 @@ async def test_add_str() -> None:
async def test_add_pre_extracted() -> None:
client = _make_client(body={"run_id": "r2", "status": "pending"})
result = await client.memories.add(
PreExtractedContent(content="fact", tags=["a"]),
PreExtractedContent(content="fact", topic="topic"),
user_id="u1",
)
assert result.run_id == "r2"
Expand Down Expand Up @@ -144,6 +151,73 @@ def handler(request: httpx.Request) -> httpx.Response:
}


@pytest.mark.asyncio
async def test_add_string_content() -> None:
client = _make_client(body={"run_id": "r4", "status": "pending"})
result = await client.memories.add(StringContent(content="hello"), user_id="u1")
assert result.run_id == "r4"


@pytest.mark.asyncio
async def test_add_string_content_sends_correct_envelope() -> None:
captured: list[httpx.Request] = []

def handler(request: httpx.Request) -> httpx.Response:
captured.append(request)
return httpx.Response(200, json={"run_id": "r1", "status": "pending"})

client = _make_client_with_handler(handler)
await client.memories.add(StringContent(content="hello"), user_id="u1", group="g1")
body = json.loads(captured[0].content)
assert body == {
"content": {"type": "string", "content": "hello"},
"user_id": "u1",
"group": "g1",
}


@pytest.mark.asyncio
async def test_add_conversation_content() -> None:
client = _make_client(body={"run_id": "r5", "status": "pending"})
result = await client.memories.add(
ConversationContent(messages=[MessageContent(role="user", content="hi")]),
user_id="u1",
conversation_id="c1",
)
assert result.run_id == "r5"


@pytest.mark.asyncio
async def test_add_conversation_content_sends_correct_envelope() -> None:
captured: list[httpx.Request] = []

def handler(request: httpx.Request) -> httpx.Response:
captured.append(request)
return httpx.Response(200, json={"run_id": "r1", "status": "pending"})

client = _make_client_with_handler(handler)
await client.memories.add(
ConversationContent(
messages=[
MessageContent(role="user", content="hi"),
MessageContent(
role="assistant",
content="using tool",
tool_call_metadata=ToolCallMetadata(name="search", id="tc1"),
),
],
metadata={"session_id": "s1"},
),
conversation_id="c1",
)
body = json.loads(captured[0].content)
assert body["content"]["type"] == "conversation"
conv = body["content"]["conversation"]
assert conv["metadata"] == {"session_id": "s1"}
assert conv["messages"][1]["tool_call_metadata"] == {"name": "search", "id": "tc1"}
assert body["conversation_id"] == "c1"


# ── memories.get ────────────────────────────────────────────────────────

SAMPLE_MEMORY_RESPONSE: dict[str, Any] = {
Expand Down
74 changes: 72 additions & 2 deletions tests/test_client_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
import pytest

from engram._http import HttpTransport
from engram._models import PreExtractedContent, RetrievalConfig
from engram._models import (
ConversationContent,
MessageContent,
PreExtractedContent,
RetrievalConfig,
StringContent,
ToolCallMetadata,
)
from engram.client import DEFAULT_BASE_URL, EngramClient
from engram.errors import APIError, AuthenticationError, ValidationError

Expand Down Expand Up @@ -110,7 +117,7 @@ def test_add_str() -> None:
def test_add_pre_extracted() -> None:
client = _make_client(body={"run_id": "r2", "status": "pending"})
result = client.memories.add(
PreExtractedContent(content="fact", tags=["a"]),
PreExtractedContent(content="fact", topic="topic"),
user_id="u1",
)
assert result.run_id == "r2"
Expand Down Expand Up @@ -163,6 +170,69 @@ def handler(request: httpx.Request) -> httpx.Response:
}


def test_add_string_content() -> None:
client = _make_client(body={"run_id": "r4", "status": "pending"})
result = client.memories.add(StringContent(content="hello"), user_id="u1")
assert result.run_id == "r4"


def test_add_string_content_sends_correct_envelope() -> None:
captured: list[httpx.Request] = []

def handler(request: httpx.Request) -> httpx.Response:
captured.append(request)
return httpx.Response(200, json={"run_id": "r1", "status": "pending"})

client = _make_client_with_handler(handler)
client.memories.add(StringContent(content="hello"), user_id="u1", group="g1")
body = json.loads(captured[0].content)
assert body == {
"content": {"type": "string", "content": "hello"},
"user_id": "u1",
"group": "g1",
}


def test_add_conversation_content() -> None:
client = _make_client(body={"run_id": "r5", "status": "pending"})
result = client.memories.add(
ConversationContent(messages=[MessageContent(role="user", content="hi")]),
user_id="u1",
conversation_id="c1",
)
assert result.run_id == "r5"


def test_add_conversation_content_sends_correct_envelope() -> None:
captured: list[httpx.Request] = []

def handler(request: httpx.Request) -> httpx.Response:
captured.append(request)
return httpx.Response(200, json={"run_id": "r1", "status": "pending"})

client = _make_client_with_handler(handler)
client.memories.add(
ConversationContent(
messages=[
MessageContent(role="user", content="hi"),
MessageContent(
role="assistant",
content="using tool",
tool_call_metadata=ToolCallMetadata(name="search", id="tc1"),
),
],
metadata={"session_id": "s1"},
),
conversation_id="c1",
)
body = json.loads(captured[0].content)
assert body["content"]["type"] == "conversation"
conv = body["content"]["conversation"]
assert conv["metadata"] == {"session_id": "s1"}
assert conv["messages"][1]["tool_call_metadata"] == {"name": "search", "id": "tc1"}
assert body["conversation_id"] == "c1"


# ── memories.get ────────────────────────────────────────────────────────

SAMPLE_MEMORY_RESPONSE: dict[str, Any] = {
Expand Down
Loading