From 7876d2d58a965c4819cb5fd1d1ecd93c39a92645 Mon Sep 17 00:00:00 2001 From: Charles Pierse Date: Mon, 16 Feb 2026 17:12:34 +0000 Subject: [PATCH 1/6] Add basic memory crud functionality --- src/engram/__init__.py | 31 +++- src/engram/_base_client.py | 55 ++++-- src/engram/_models.py | 115 ++++++++++++ src/engram/_resources/__init__.py | 4 + src/engram/_resources/memories.py | 188 +++++++++++++++++++ src/engram/_resources/runs.py | 79 ++++++++ src/engram/_serialization.py | 161 ++++++++++++++++ src/engram/async_client.py | 24 ++- src/engram/client.py | 24 ++- src/engram/errors.py | 12 +- src/engram/types.py | 15 -- tests/test_client_async.py | 261 +++++++++++++++++++++++++- tests/test_client_sync.py | 295 +++++++++++++++++++++++++++++- tests/test_imports.py | 37 +++- tests/test_serialization.py | 256 ++++++++++++++++++++++++++ 15 files changed, 1519 insertions(+), 38 deletions(-) create mode 100644 src/engram/_models.py create mode 100644 src/engram/_resources/__init__.py create mode 100644 src/engram/_resources/memories.py create mode 100644 src/engram/_resources/runs.py create mode 100644 src/engram/_serialization.py create mode 100644 tests/test_serialization.py diff --git a/src/engram/__init__.py b/src/engram/__init__.py index 7e481cb..1806fa9 100644 --- a/src/engram/__init__.py +++ b/src/engram/__init__.py @@ -1,14 +1,41 @@ +from ._models import ( + CommittedOperation, + CommittedOperations, + Memory, + PreExtractedContent, + RetrievalConfig, + Run, + RunStatus, + SearchResults, +) from .async_client import AsyncEngramClient from .client import EngramClient -from .errors import APIError, AuthError, EngramError, ValidationError +from .errors import ( + APIError, + AuthenticationError, + ConnectionError, + EngramError, + NotFoundError, + ValidationError, +) from .version import __version__ __all__ = [ "APIError", "AsyncEngramClient", - "AuthError", + "AuthenticationError", + "CommittedOperation", + "CommittedOperations", + "ConnectionError", "EngramClient", "EngramError", + "Memory", + "NotFoundError", + "PreExtractedContent", + "RetrievalConfig", + "Run", + "RunStatus", + "SearchResults", "ValidationError", "__version__", ] diff --git a/src/engram/_base_client.py b/src/engram/_base_client.py index a691ae6..7c7372a 100644 --- a/src/engram/_base_client.py +++ b/src/engram/_base_client.py @@ -5,7 +5,7 @@ import httpx -from .errors import ValidationError +from .errors import APIError, AuthenticationError, NotFoundError, ValidationError from .types import ClientConfig from .version import __version__ @@ -31,8 +31,7 @@ def __init__( raise ValidationError("Timeout must be greater than 0.") normalized_base_url = base_url.rstrip("/") - header_overrides = headers if headers is not None else {} - default_headers = _build_headers(api_key=api_key, header_overrides=header_overrides) + default_headers = _build_headers(api_key=api_key, header_overrides=headers or {}) self._config = ClientConfig( base_url=normalized_base_url, @@ -49,6 +48,27 @@ def config(self) -> ClientConfig: def default_headers(self) -> dict[str, str]: return dict(self._config.headers) + def _process_response(self, response: httpx.Response) -> dict[str, Any]: + data = _safe_json(response) + + if response.status_code == 401: + detail = _extract_detail(data, "Authentication failed") + raise AuthenticationError(detail, status_code=401, body=data) + + if response.status_code == 404: + detail = _extract_detail(data, "Not found") + raise NotFoundError(detail, status_code=404, body=data) + + if response.status_code >= 400: + detail = _extract_detail(data, response.reason_phrase) + raise APIError(detail, status_code=response.status_code, body=data) + + if data is None: + return {} + if isinstance(data, dict): + return data + return {"data": data} + def build_request( self, method: str, @@ -62,15 +82,35 @@ def build_request( if headers: merged_headers.update(headers) + clean_path = path.lstrip("/") + url = f"{self._config.base_url}/{clean_path}" if clean_path else self._config.base_url + return self._http_client.build_request( method=method, - url=_build_url(self._config.base_url, path), + url=url, headers=merged_headers, params=params, json=json, ) +def _extract_detail(data: Any, fallback: str) -> str: + if isinstance(data, dict): + return str(data.get("detail", fallback)) + if data: + return str(data) + return fallback + + +def _safe_json(response: httpx.Response) -> Any: + if not response.content: + return None + try: + return response.json() + except Exception: + return None + + def _build_headers( *, api_key: str | None, @@ -85,10 +125,3 @@ def _build_headers( headers["Authorization"] = f"Bearer {api_key}" headers.update(header_overrides) return headers - - -def _build_url(base_url: str, path: str) -> str: - clean_path = path.lstrip("/") - if not clean_path: - return base_url - return f"{base_url}/{clean_path}" diff --git a/src/engram/_models.py b/src/engram/_models.py new file mode 100644 index 0000000..7ee087a --- /dev/null +++ b/src/engram/_models.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from collections.abc import Iterator, Sequence +from dataclasses import dataclass, field +from typing import Literal + +# ── Request models ────────────────────────────────────────────────────── + + +@dataclass(slots=True) +class PreExtractedContent: + """Pre-extracted content that bypasses the extraction pipeline.""" + + content: str + tags: list[str] = field(default_factory=list) + + +@dataclass(slots=True) +class RetrievalConfig: + retrieval_type: Literal["vector", "bm25", "hybrid"] = "hybrid" + limit: int = 10 + + +# ── Response models ───────────────────────────────────────────────────── + +# Type alias for the content argument to memories.add() +AddContent = str | list[dict[str, str]] | PreExtractedContent + + +@dataclass(slots=True) +class Run: + """Returned from memories.add() — represents a pipeline run.""" + + run_id: str + status: str + error: str | None = None + + +@dataclass(slots=True) +class Memory: + id: str + project_id: str + content: str + topic: str + group: str + created_at: str + updated_at: str + user_id: str | None = None + conversation_id: str | None = None + tags: list[str] | None = None + score: float | None = None + + +class SearchResults(Sequence[Memory]): + """List-like wrapper over search results with a total count.""" + + def __init__(self, memories: list[Memory], total: int) -> None: + self._memories = memories + self.total = total + + def __getitem__(self, index: int) -> Memory: # type: ignore[override] + return self._memories[index] + + def __len__(self) -> int: + return len(self._memories) + + def __iter__(self) -> Iterator[Memory]: + return iter(self._memories) + + def __repr__(self) -> str: + return f"SearchResults(total={self.total}, returned={len(self._memories)})" + + +@dataclass(slots=True) +class CommittedOperation: + memory_id: str + committed_at: str + + +@dataclass(slots=True) +class CommittedOperations: + created: list[CommittedOperation] = field(default_factory=list) + updated: list[CommittedOperation] = field(default_factory=list) + deleted: list[CommittedOperation] = field(default_factory=list) + + +@dataclass(slots=True) +class RunStatus: + run_id: str + status: str + group_id: str + starting_step: int + input_type: str + created_at: str + updated_at: str + committed_operations: CommittedOperations | None = None + error: str | None = None + + @property + def memories_created(self) -> list[CommittedOperation]: + if self.committed_operations is None: + return [] + return self.committed_operations.created + + @property + def memories_updated(self) -> list[CommittedOperation]: + if self.committed_operations is None: + return [] + return self.committed_operations.updated + + @property + def memories_deleted(self) -> list[CommittedOperation]: + if self.committed_operations is None: + return [] + return self.committed_operations.deleted diff --git a/src/engram/_resources/__init__.py b/src/engram/_resources/__init__.py new file mode 100644 index 0000000..7ea362c --- /dev/null +++ b/src/engram/_resources/__init__.py @@ -0,0 +1,4 @@ +from .memories import AsyncMemories, Memories +from .runs import AsyncRuns, Runs + +__all__ = ["AsyncMemories", "AsyncRuns", "Memories", "Runs"] diff --git a/src/engram/_resources/memories.py b/src/engram/_resources/memories.py new file mode 100644 index 0000000..85126bd --- /dev/null +++ b/src/engram/_resources/memories.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from .._models import AddContent, Memory, RetrievalConfig, Run, SearchResults +from .._serialization import ( + build_add_body, + build_memory_params, + build_search_body, + parse_memory, + parse_run, + parse_search_results, +) + +if TYPE_CHECKING: + from ..async_client import AsyncEngramClient + from ..client import EngramClient + +_MEMORIES_PATH = "/v1/memories" +_MEMORIES_SEARCH_PATH = "/v1/memories/search" + + +def _memory_path(memory_id: str) -> str: + return f"{_MEMORIES_PATH}/{memory_id}" + + +class Memories: + """Sync sub-resource for memory operations: client.memories.*""" + + _client: EngramClient + + def __init__(self, client: EngramClient) -> None: + self._client = client + + def add( + self, + content: AddContent, + *, + user_id: str | None = None, + conversation_id: str | None = None, + group: str | None = None, + ) -> Run: + body = build_add_body( + content, + user_id=user_id, + conversation_id=conversation_id, + group=group, + ) + data = self._client._request("POST", _MEMORIES_PATH, json=body) + return parse_run(data) + + def get( + self, + memory_id: str, + *, + topic: str, + user_id: str | None = None, + conversation_id: str | None = None, + group: str | None = None, + ) -> Memory: + params = build_memory_params( + topic=topic, + user_id=user_id, + conversation_id=conversation_id, + group=group, + ) + data = self._client._request("GET", _memory_path(memory_id), params=params) + return parse_memory(data) + + def delete( + self, + memory_id: str, + *, + topic: str, + user_id: str | None = None, + conversation_id: str | None = None, + group: str | None = None, + ) -> None: + params = build_memory_params( + topic=topic, + user_id=user_id, + conversation_id=conversation_id, + group=group, + ) + self._client._request("DELETE", _memory_path(memory_id), params=params) + + def search( + self, + *, + query: str, + topics: list[str] | None = None, + user_id: str | None = None, + conversation_id: str | None = None, + group: str | None = None, + retrieval_config: RetrievalConfig | None = None, + ) -> SearchResults: + body = build_search_body( + query=query, + topics=topics, + user_id=user_id, + conversation_id=conversation_id, + group=group, + retrieval_config=retrieval_config or RetrievalConfig(), + ) + data = self._client._request("POST", _MEMORIES_SEARCH_PATH, json=body) + return parse_search_results(data) + + +class AsyncMemories: + """Async sub-resource for memory operations: client.memories.*""" + + _client: AsyncEngramClient + + def __init__(self, client: AsyncEngramClient) -> None: + self._client = client + + async def add( + self, + content: AddContent, + *, + user_id: str | None = None, + conversation_id: str | None = None, + group: str | None = None, + ) -> Run: + body = build_add_body( + content, + user_id=user_id, + conversation_id=conversation_id, + group=group, + ) + data = await self._client._request("POST", _MEMORIES_PATH, json=body) + return parse_run(data) + + async def get( + self, + memory_id: str, + *, + topic: str, + user_id: str | None = None, + conversation_id: str | None = None, + group: str | None = None, + ) -> Memory: + params = build_memory_params( + topic=topic, + user_id=user_id, + conversation_id=conversation_id, + group=group, + ) + data = await self._client._request("GET", _memory_path(memory_id), params=params) + return parse_memory(data) + + async def delete( + self, + memory_id: str, + *, + topic: str, + user_id: str | None = None, + conversation_id: str | None = None, + group: str | None = None, + ) -> None: + params = build_memory_params( + topic=topic, + user_id=user_id, + conversation_id=conversation_id, + group=group, + ) + await self._client._request("DELETE", _memory_path(memory_id), params=params) + + async def search( + self, + *, + query: str, + topics: list[str] | None = None, + user_id: str | None = None, + conversation_id: str | None = None, + group: str | None = None, + retrieval_config: RetrievalConfig | None = None, + ) -> SearchResults: + body = build_search_body( + query=query, + topics=topics, + user_id=user_id, + conversation_id=conversation_id, + group=group, + retrieval_config=retrieval_config or RetrievalConfig(), + ) + data = await self._client._request("POST", _MEMORIES_SEARCH_PATH, json=body) + return parse_search_results(data) diff --git a/src/engram/_resources/runs.py b/src/engram/_resources/runs.py new file mode 100644 index 0000000..e224ee4 --- /dev/null +++ b/src/engram/_resources/runs.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import asyncio +import time +from typing import TYPE_CHECKING + +from .._models import RunStatus +from .._serialization import parse_run_status + +if TYPE_CHECKING: + from ..async_client import AsyncEngramClient + from ..client import EngramClient + +_RUNS_PATH = "/v1/runs" + +_TERMINAL_STATUSES = frozenset(("completed", "failed")) + + +def _run_path(run_id: str) -> str: + return f"{_RUNS_PATH}/{run_id}" + + +class Runs: + """Sync sub-resource for run operations: client.runs.*""" + + _client: EngramClient + + def __init__(self, client: EngramClient) -> None: + self._client = client + + def get(self, run_id: str) -> RunStatus: + data = self._client._request("GET", _run_path(run_id)) + return parse_run_status(data) + + def wait( + self, + run_id: str, + *, + timeout: float = 30.0, + interval: float = 0.5, + ) -> RunStatus: + deadline = time.monotonic() + timeout + while True: + status = self.get(run_id) + if status.status in _TERMINAL_STATUSES: + return status + if time.monotonic() + interval > deadline: + return status + time.sleep(interval) + + +class AsyncRuns: + """Async sub-resource for run operations: client.runs.*""" + + _client: AsyncEngramClient + + def __init__(self, client: AsyncEngramClient) -> None: + self._client = client + + async def get(self, run_id: str) -> RunStatus: + data = await self._client._request("GET", _run_path(run_id)) + return parse_run_status(data) + + async def wait( + self, + run_id: str, + *, + timeout: float = 30.0, + interval: float = 0.5, + ) -> RunStatus: + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout + while True: + status = await self.get(run_id) + if status.status in _TERMINAL_STATUSES: + return status + if loop.time() + interval > deadline: + return status + await asyncio.sleep(interval) diff --git a/src/engram/_serialization.py b/src/engram/_serialization.py new file mode 100644 index 0000000..812c0d7 --- /dev/null +++ b/src/engram/_serialization.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from typing import Any + +from ._models import ( + AddContent, + CommittedOperation, + CommittedOperations, + Memory, + PreExtractedContent, + RetrievalConfig, + Run, + RunStatus, + SearchResults, +) + +# ── Body builders ─────────────────────────────────────────────────────── + + +def _serialize_content(content: AddContent) -> dict[str, Any]: + """Build the content envelope with the type discriminator.""" + if isinstance(content, str): + return {"type": "string", "content": content} + if isinstance(content, PreExtractedContent): + d: dict[str, Any] = {"type": "pre_extracted", "content": content.content} + if content.tags: + d["tags"] = content.tags + return d + if isinstance(content, list): + return { + "type": "conversation", + "conversation": {"messages": content}, + } + raise TypeError(f"Unsupported content type: {type(content)}") # pragma: no cover + + +def build_add_body( + content: AddContent, + *, + user_id: str | None, + conversation_id: str | None, + group: str | None, +) -> dict[str, Any]: + body: dict[str, Any] = {"content": _serialize_content(content)} + if user_id is not None: + body["user_id"] = user_id + if conversation_id is not None: + body["conversation_id"] = conversation_id + if group is not None: + body["group"] = group + return body + + +def build_memory_params( + *, + topic: str, + user_id: str | None, + conversation_id: str | None, + group: str | None, +) -> dict[str, str]: + params: dict[str, str] = {"topic": topic} + if user_id is not None: + params["user_id"] = user_id + if conversation_id is not None: + params["conversation_id"] = conversation_id + if group is not None: + params["group"] = group + return params + + +def build_search_body( + *, + query: str, + topics: list[str] | None, + user_id: str | None, + conversation_id: str | None, + group: str | None, + retrieval_config: RetrievalConfig, +) -> dict[str, Any]: + body: dict[str, Any] = { + "query": query, + "retrieval_config": { + "retrieval_type": retrieval_config.retrieval_type, + "limit": retrieval_config.limit, + }, + } + if topics is not None: + body["topics"] = topics + if user_id is not None: + body["user_id"] = user_id + if conversation_id is not None: + body["conversation_id"] = conversation_id + if group is not None: + body["group"] = group + return body + + +# ── Response parsers ──────────────────────────────────────────────────── + + +def parse_run(data: dict[str, Any]) -> Run: + return Run( + run_id=data["run_id"], + status=data["status"], + error=data.get("error"), + ) + + +def parse_memory(data: dict[str, Any]) -> Memory: + return Memory( + id=data["id"], + project_id=data["project_id"], + content=data["content"], + topic=data["topic"], + group=data["group"], + created_at=data["created_at"], + updated_at=data["updated_at"], + user_id=data.get("user_id"), + conversation_id=data.get("conversation_id"), + tags=data.get("tags"), + score=data.get("score"), + ) + + +def parse_search_results(data: dict[str, Any]) -> SearchResults: + return SearchResults( + memories=[parse_memory(m) for m in data["memories"]], + total=data["total"], + ) + + +def _parse_committed_operation(data: dict[str, Any]) -> CommittedOperation: + return CommittedOperation( + memory_id=data["memory_id"], + committed_at=data["committed_at"], + ) + + +def _parse_committed_operations(data: dict[str, Any]) -> CommittedOperations: + return CommittedOperations( + created=[_parse_committed_operation(op) for op in data.get("created", [])], + updated=[_parse_committed_operation(op) for op in data.get("updated", [])], + deleted=[_parse_committed_operation(op) for op in data.get("deleted", [])], + ) + + +def parse_run_status(data: dict[str, Any]) -> RunStatus: + committed_ops = data.get("committed_operations") + return RunStatus( + run_id=data["run_id"], + status=data["status"], + group_id=data["group_id"], + starting_step=data["starting_step"], + input_type=data["input_type"], + created_at=data["created_at"], + updated_at=data["updated_at"], + committed_operations=_parse_committed_operations(committed_ops) + if committed_ops is not None + else None, + error=data.get("error"), + ) diff --git a/src/engram/async_client.py b/src/engram/async_client.py index 32d55e8..a585564 100644 --- a/src/engram/async_client.py +++ b/src/engram/async_client.py @@ -1,18 +1,23 @@ from __future__ import annotations from collections.abc import Mapping +from typing import Any import httpx from ._base_client import DEFAULT_BASE_URL, DEFAULT_TIMEOUT, _BaseClient +from ._resources import AsyncMemories, AsyncRuns +from .errors import ConnectionError as EngramConnectionError __all__ = ["DEFAULT_BASE_URL", "DEFAULT_TIMEOUT", "AsyncEngramClient"] class AsyncEngramClient(_BaseClient): - """Asynchronous Engram client""" + """Asynchronous Engram client.""" _http_client: httpx.AsyncClient + memories: AsyncMemories + runs: AsyncRuns def __init__( self, @@ -31,6 +36,23 @@ def __init__( ) self._owns_http_client = http_client is None self._http_client = http_client or httpx.AsyncClient(timeout=timeout) + self.memories = AsyncMemories(self) + self.runs = AsyncRuns(self) + + async def _request( + self, + method: str, + path: str, + *, + params: Mapping[str, Any] | None = None, + json: Any | None = None, + ) -> dict[str, Any]: + request = self.build_request(method, path, params=params, json=json) + try: + response = await self._http_client.send(request) + except httpx.ConnectError as exc: + raise EngramConnectionError(str(exc)) from exc + return self._process_response(response) async def aclose(self) -> None: if self._owns_http_client: diff --git a/src/engram/client.py b/src/engram/client.py index e12b0ba..1aab55c 100644 --- a/src/engram/client.py +++ b/src/engram/client.py @@ -1,18 +1,23 @@ from __future__ import annotations from collections.abc import Mapping +from typing import Any import httpx from ._base_client import DEFAULT_BASE_URL, DEFAULT_TIMEOUT, _BaseClient +from ._resources import Memories, Runs +from .errors import ConnectionError as EngramConnectionError __all__ = ["DEFAULT_BASE_URL", "DEFAULT_TIMEOUT", "EngramClient"] class EngramClient(_BaseClient): - """Synchronous Engram client""" + """Synchronous Engram client.""" _http_client: httpx.Client + memories: Memories + runs: Runs def __init__( self, @@ -31,6 +36,23 @@ def __init__( ) self._owns_http_client = http_client is None self._http_client = http_client or httpx.Client(timeout=timeout) + self.memories = Memories(self) + self.runs = Runs(self) + + def _request( + self, + method: str, + path: str, + *, + params: Mapping[str, Any] | None = None, + json: Any | None = None, + ) -> dict[str, Any]: + request = self.build_request(method, path, params=params, json=json) + try: + response = self._http_client.send(request) + except httpx.ConnectError as exc: + raise EngramConnectionError(str(exc)) from exc + return self._process_response(response) def close(self) -> None: if self._owns_http_client: diff --git a/src/engram/errors.py b/src/engram/errors.py index d071ae5..6ac07ba 100644 --- a/src/engram/errors.py +++ b/src/engram/errors.py @@ -22,9 +22,17 @@ def __init__( self.body = body -class AuthError(APIError): - """Raised when authentication fails.""" +class AuthenticationError(APIError): + """Raised when authentication fails (401).""" + + +class NotFoundError(APIError): + """Raised when a resource is not found (404).""" class ValidationError(EngramError): """Raised when configuration or request input is invalid.""" + + +class ConnectionError(EngramError): + """Raised when a connection to the Engram server fails.""" diff --git a/src/engram/types.py b/src/engram/types.py index 401d574..2feb49f 100644 --- a/src/engram/types.py +++ b/src/engram/types.py @@ -1,7 +1,6 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any @dataclass(slots=True) @@ -10,17 +9,3 @@ class ClientConfig: timeout: float headers: dict[str, str] = field(default_factory=dict) api_key: str | None = None - - -@dataclass(slots=True) -class APIRequest: - method: str - path: str - params: dict[str, Any] | None = None - json: dict[str, Any] | None = None - - -@dataclass(slots=True) -class APIResponse: - status_code: int - data: dict[str, Any] | list[Any] | str | None = None diff --git a/tests/test_client_async.py b/tests/test_client_async.py index c51d8ee..f034035 100644 --- a/tests/test_client_async.py +++ b/tests/test_client_async.py @@ -1,7 +1,12 @@ +import json +from typing import Any + +import httpx import pytest +from engram._models import PreExtractedContent, RetrievalConfig from engram.async_client import DEFAULT_BASE_URL, AsyncEngramClient -from engram.errors import ValidationError +from engram.errors import APIError, AuthenticationError, NotFoundError, ValidationError @pytest.mark.asyncio @@ -47,3 +52,257 @@ async def test_async_client_custom_config_and_header_merging() -> None: def test_async_client_rejects_non_positive_timeout() -> None: with pytest.raises(ValidationError): AsyncEngramClient(timeout=-1) + + +def test_async_client_has_sub_resources() -> None: + client = AsyncEngramClient() + assert hasattr(client, "memories") + assert hasattr(client, "runs") + + +# ── Helpers ───────────────────────────────────────────────────────────── + + +def _make_client( + status_code: int = 200, + body: dict[str, Any] | None = None, +) -> AsyncEngramClient: + transport = httpx.MockTransport( + lambda _: httpx.Response(status_code, json=body if body is not None else {}) + ) + return AsyncEngramClient( + base_url="https://test.example.com", + api_key="test-key", + http_client=httpx.AsyncClient(transport=transport), + ) + + +# ── memories.add ──────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_add_str() -> None: + client = _make_client(body={"run_id": "r1", "status": "pending"}) + result = await client.memories.add("hello") + assert result.run_id == "r1" + assert result.status == "pending" + + +@pytest.mark.asyncio +async def test_add_pre_extracted() -> None: + client = _make_client(body={"run_id": "r2", "status": "pending"}) + result = await client.memories.add( + PreExtractedContent(content="fact", tags=["a"]), + user_id="u1", + ) + assert result.run_id == "r2" + + +@pytest.mark.asyncio +async def test_add_conversation() -> None: + client = _make_client(body={"run_id": "r3", "status": "pending"}) + result = await client.memories.add( + [{"role": "user", "content": "hi"}], + user_id="u1", + conversation_id="c1", + ) + assert result.run_id == "r3" + + +@pytest.mark.asyncio +async def test_add_sends_content_envelope() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"run_id": "r1", "status": "pending"}) + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + client = AsyncEngramClient( + base_url="https://test.example.com", + api_key="k", + http_client=http_client, + ) + await client.memories.add("hello", user_id="u1", group="g1") + body = json.loads(captured[0].content) + assert body == { + "content": {"type": "string", "content": "hello"}, + "user_id": "u1", + "group": "g1", + } + + +# ── memories.get ──────────────────────────────────────────────────────── + +SAMPLE_MEMORY_RESPONSE: dict[str, Any] = { + "id": "m1", + "project_id": "p1", + "content": "some content", + "topic": "t1", + "group": "g1", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", +} + + +@pytest.mark.asyncio +async def test_get_memory() -> None: + client = _make_client(body=SAMPLE_MEMORY_RESPONSE) + mem = await client.memories.get("m1", topic="t1") + assert mem.id == "m1" + assert mem.content == "some content" + + +@pytest.mark.asyncio +async def test_get_memory_sends_query_params() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json=SAMPLE_MEMORY_RESPONSE) + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + client = AsyncEngramClient( + base_url="https://test.example.com", + api_key="k", + http_client=http_client, + ) + await client.memories.get("m1", topic="t1", user_id="u1", group="g1") + url = captured[0].url + assert url.params["topic"] == "t1" + assert url.params["user_id"] == "u1" + assert url.params["group"] == "g1" + + +# ── memories.delete ───────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_delete_memory() -> None: + transport = httpx.MockTransport(lambda _: httpx.Response(204, content=b"")) + http_client = httpx.AsyncClient(transport=transport) + client = AsyncEngramClient( + base_url="https://test.example.com", + api_key="k", + http_client=http_client, + ) + await client.memories.delete("m1", topic="t1") + + +# ── memories.search ───────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_search_memories() -> None: + response_body: dict[str, Any] = { + "memories": [SAMPLE_MEMORY_RESPONSE], + "total": 1, + } + client = _make_client(body=response_body) + result = await client.memories.search(query="test") + assert result.total == 1 + assert len(result) == 1 + + +@pytest.mark.asyncio +async def test_search_memories_iterable() -> None: + response_body: dict[str, Any] = { + "memories": [ + SAMPLE_MEMORY_RESPONSE, + {**SAMPLE_MEMORY_RESPONSE, "id": "m2", "score": 0.85}, + ], + "total": 2, + } + client = _make_client(body=response_body) + results = await client.memories.search(query="test") + ids = [m.id for m in results] + assert ids == ["m1", "m2"] + assert results[1].score == 0.85 + + +@pytest.mark.asyncio +async def test_search_sends_correct_body() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"memories": [], "total": 0}) + + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + client = AsyncEngramClient( + base_url="https://test.example.com", + api_key="k", + http_client=http_client, + ) + await client.memories.search( + query="find this", + topics=["a"], + retrieval_config=RetrievalConfig(retrieval_type="vector", limit=5), + ) + body = json.loads(captured[0].content) + assert body["query"] == "find this" + assert body["topics"] == ["a"] + assert body["retrieval_config"]["retrieval_type"] == "vector" + assert body["retrieval_config"]["limit"] == 5 + + +# ── runs.get ──────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_get_run() -> None: + response_body: dict[str, Any] = { + "run_id": "r1", + "status": "completed", + "group_id": "g1", + "starting_step": 0, + "input_type": "string", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", + "committed_operations": { + "created": [{"memory_id": "m1", "committed_at": "2024-01-01T00:00:00Z"}], + "updated": [], + "deleted": [], + }, + } + client = _make_client(body=response_body) + result = await client.runs.get("r1") + assert result.run_id == "r1" + assert result.status == "completed" + assert len(result.memories_created) == 1 + + +# ── Error handling ────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_401_raises_authentication_error() -> None: + client = _make_client(status_code=401, body={"detail": "Invalid token"}) + with pytest.raises(AuthenticationError) as exc_info: + await client.memories.add("hello") + assert exc_info.value.status_code == 401 + assert "Invalid token" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_404_raises_not_found_error() -> None: + client = _make_client(status_code=404, body={"detail": "Not found"}) + with pytest.raises(NotFoundError) as exc_info: + await client.memories.get("missing", topic="t1") + assert exc_info.value.status_code == 404 + + +@pytest.mark.asyncio +async def test_400_raises_api_error() -> None: + client = _make_client(status_code=400, body={"detail": "Bad request"}) + with pytest.raises(APIError) as exc_info: + await client.memories.search(query="test") + assert exc_info.value.status_code == 400 + + +@pytest.mark.asyncio +async def test_500_raises_api_error() -> None: + client = _make_client(status_code=500, body={"detail": "Internal server error"}) + with pytest.raises(APIError) as exc_info: + await client.runs.get("r1") + assert exc_info.value.status_code == 500 diff --git a/tests/test_client_sync.py b/tests/test_client_sync.py index 7a59da5..e6533af 100644 --- a/tests/test_client_sync.py +++ b/tests/test_client_sync.py @@ -1,7 +1,12 @@ +import json +from typing import Any + +import httpx import pytest +from engram._models import PreExtractedContent, RetrievalConfig from engram.client import DEFAULT_BASE_URL, EngramClient -from engram.errors import ValidationError +from engram.errors import APIError, AuthenticationError, NotFoundError, ValidationError def test_client_defaults() -> None: @@ -45,3 +50,291 @@ def test_client_custom_config_and_header_merging() -> None: def test_client_rejects_non_positive_timeout() -> None: with pytest.raises(ValidationError): EngramClient(timeout=0) + + +def test_client_has_sub_resources() -> None: + client = EngramClient() + try: + assert hasattr(client, "memories") + assert hasattr(client, "runs") + finally: + client.close() + + +# ── Helpers ───────────────────────────────────────────────────────────── + + +def _make_client( + status_code: int = 200, + body: dict[str, Any] | None = None, +) -> EngramClient: + transport = httpx.MockTransport( + lambda _: httpx.Response(status_code, json=body if body is not None else {}) + ) + return EngramClient( + base_url="https://test.example.com", + api_key="test-key", + http_client=httpx.Client(transport=transport), + ) + + +# ── memories.add ──────────────────────────────────────────────────────── + + +def test_add_str() -> None: + client = _make_client(body={"run_id": "r1", "status": "pending"}) + result = client.memories.add("hello") + assert result.run_id == "r1" + assert result.status == "pending" + + +def test_add_pre_extracted() -> None: + client = _make_client(body={"run_id": "r2", "status": "pending"}) + result = client.memories.add( + PreExtractedContent(content="fact", tags=["a"]), + user_id="u1", + ) + assert result.run_id == "r2" + + +def test_add_conversation() -> None: + client = _make_client(body={"run_id": "r3", "status": "pending"}) + result = client.memories.add( + [{"role": "user", "content": "hi"}], + user_id="u1", + conversation_id="c1", + ) + assert result.run_id == "r3" + + +def test_add_sends_content_envelope() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"run_id": "r1", "status": "pending"}) + + http_client = httpx.Client(transport=httpx.MockTransport(handler)) + client = EngramClient( + base_url="https://test.example.com", + api_key="k", + http_client=http_client, + ) + client.memories.add("hello", user_id="u1", group="g1") + body = json.loads(captured[0].content) + assert body == { + "content": {"type": "string", "content": "hello"}, + "user_id": "u1", + "group": "g1", + } + + +def test_add_conversation_sends_correct_envelope() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"run_id": "r1", "status": "pending"}) + + http_client = httpx.Client(transport=httpx.MockTransport(handler)) + client = EngramClient( + base_url="https://test.example.com", + api_key="k", + http_client=http_client, + ) + messages = [{"role": "user", "content": "hi"}] + client.memories.add(messages, conversation_id="c1") + body = json.loads(captured[0].content) + assert body == { + "content": { + "type": "conversation", + "conversation": {"messages": messages}, + }, + "conversation_id": "c1", + } + + +# ── memories.get ──────────────────────────────────────────────────────── + +SAMPLE_MEMORY_RESPONSE: dict[str, Any] = { + "id": "m1", + "project_id": "p1", + "content": "some content", + "topic": "t1", + "group": "g1", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", +} + + +def test_get_memory() -> None: + client = _make_client(body=SAMPLE_MEMORY_RESPONSE) + mem = client.memories.get("m1", topic="t1") + assert mem.id == "m1" + assert mem.content == "some content" + + +def test_get_memory_sends_query_params() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json=SAMPLE_MEMORY_RESPONSE) + + http_client = httpx.Client(transport=httpx.MockTransport(handler)) + client = EngramClient( + base_url="https://test.example.com", + api_key="k", + http_client=http_client, + ) + client.memories.get("m1", topic="t1", user_id="u1", group="g1") + url = captured[0].url + assert url.params["topic"] == "t1" + assert url.params["user_id"] == "u1" + assert url.params["group"] == "g1" + + +# ── memories.delete ───────────────────────────────────────────────────── + + +def test_delete_memory() -> None: + transport = httpx.MockTransport(lambda _: httpx.Response(204, content=b"")) + http_client = httpx.Client(transport=transport) + client = EngramClient( + base_url="https://test.example.com", + api_key="k", + http_client=http_client, + ) + client.memories.delete("m1", topic="t1") + + +# ── memories.search ───────────────────────────────────────────────────── + + +def test_search_memories() -> None: + response_body: dict[str, Any] = { + "memories": [SAMPLE_MEMORY_RESPONSE], + "total": 1, + } + client = _make_client(body=response_body) + result = client.memories.search(query="test") + assert result.total == 1 + assert len(result) == 1 + + +def test_search_memories_iterable() -> None: + response_body: dict[str, Any] = { + "memories": [ + SAMPLE_MEMORY_RESPONSE, + {**SAMPLE_MEMORY_RESPONSE, "id": "m2", "score": 0.85}, + ], + "total": 2, + } + client = _make_client(body=response_body) + results = client.memories.search(query="test") + ids = [m.id for m in results] + assert ids == ["m1", "m2"] + assert results[1].score == 0.85 + + +def test_search_sends_correct_body() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"memories": [], "total": 0}) + + http_client = httpx.Client(transport=httpx.MockTransport(handler)) + client = EngramClient( + base_url="https://test.example.com", + api_key="k", + http_client=http_client, + ) + client.memories.search( + query="find this", + topics=["a"], + retrieval_config=RetrievalConfig(retrieval_type="vector", limit=5), + ) + body = json.loads(captured[0].content) + assert body["query"] == "find this" + assert body["topics"] == ["a"] + assert body["retrieval_config"]["retrieval_type"] == "vector" + assert body["retrieval_config"]["limit"] == 5 + + +def test_search_default_retrieval_config() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"memories": [], "total": 0}) + + http_client = httpx.Client(transport=httpx.MockTransport(handler)) + client = EngramClient( + base_url="https://test.example.com", + api_key="k", + http_client=http_client, + ) + client.memories.search(query="test") + body = json.loads(captured[0].content) + assert body["retrieval_config"] == { + "retrieval_type": "hybrid", + "limit": 10, + } + + +# ── runs.get ──────────────────────────────────────────────────────────── + + +def test_get_run() -> None: + response_body: dict[str, Any] = { + "run_id": "r1", + "status": "completed", + "group_id": "g1", + "starting_step": 0, + "input_type": "string", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", + "committed_operations": { + "created": [{"memory_id": "m1", "committed_at": "2024-01-01T00:00:00Z"}], + "updated": [], + "deleted": [], + }, + } + client = _make_client(body=response_body) + result = client.runs.get("r1") + assert result.run_id == "r1" + assert result.status == "completed" + assert len(result.memories_created) == 1 + + +# ── Error handling ────────────────────────────────────────────────────── + + +def test_401_raises_authentication_error() -> None: + client = _make_client(status_code=401, body={"detail": "Invalid token"}) + with pytest.raises(AuthenticationError) as exc_info: + client.memories.add("hello") + assert exc_info.value.status_code == 401 + assert "Invalid token" in str(exc_info.value) + + +def test_404_raises_not_found_error() -> None: + client = _make_client(status_code=404, body={"detail": "Not found"}) + with pytest.raises(NotFoundError) as exc_info: + client.memories.get("missing", topic="t1") + assert exc_info.value.status_code == 404 + + +def test_400_raises_api_error() -> None: + client = _make_client(status_code=400, body={"detail": "Bad request"}) + with pytest.raises(APIError) as exc_info: + client.memories.search(query="test") + assert exc_info.value.status_code == 400 + + +def test_500_raises_api_error() -> None: + client = _make_client(status_code=500, body={"detail": "Internal server error"}) + with pytest.raises(APIError) as exc_info: + client.runs.get("r1") + assert exc_info.value.status_code == 500 diff --git a/tests/test_imports.py b/tests/test_imports.py index c0e46ca..e489a15 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -1,11 +1,21 @@ def test_public_imports() -> None: import engram - from engram import ( + from engram import ( # noqa: F401 APIError, AsyncEngramClient, - AuthError, + AuthenticationError, + CommittedOperation, + CommittedOperations, + ConnectionError, EngramClient, EngramError, + Memory, + NotFoundError, + PreExtractedContent, + RetrievalConfig, + Run, + RunStatus, + SearchResults, ValidationError, ) @@ -13,15 +23,34 @@ def test_public_imports() -> None: assert isinstance(AsyncEngramClient, type) assert isinstance(EngramError, type) assert isinstance(APIError, type) - assert isinstance(AuthError, type) + assert isinstance(AuthenticationError, type) + assert isinstance(NotFoundError, type) assert isinstance(ValidationError, type) + assert isinstance(Memory, type) + assert isinstance(Run, type) + assert isinstance(RunStatus, type) + assert isinstance(SearchResults, type) + assert isinstance(PreExtractedContent, type) + assert isinstance(RetrievalConfig, type) + assert isinstance(CommittedOperation, type) + assert isinstance(CommittedOperations, type) expected_exports = { "APIError", "AsyncEngramClient", - "AuthError", + "AuthenticationError", + "CommittedOperation", + "CommittedOperations", + "ConnectionError", "EngramClient", "EngramError", + "Memory", + "NotFoundError", + "PreExtractedContent", + "RetrievalConfig", + "Run", + "RunStatus", + "SearchResults", "ValidationError", "__version__", } diff --git a/tests/test_serialization.py b/tests/test_serialization.py new file mode 100644 index 0000000..62bbe0b --- /dev/null +++ b/tests/test_serialization.py @@ -0,0 +1,256 @@ +from engram._models import PreExtractedContent, RetrievalConfig +from engram._serialization import ( + build_add_body, + build_memory_params, + build_search_body, + parse_memory, + parse_run, + parse_run_status, + parse_search_results, +) + +# ── build_add_body ────────────────────────────────────────────────────── + + +def test_build_add_body_str() -> None: + body = build_add_body( + "hello world", + user_id=None, + conversation_id=None, + group=None, + ) + assert body == {"content": {"type": "string", "content": "hello world"}} + + +def test_build_add_body_str_with_options() -> None: + body = build_add_body( + "hello", + user_id="u1", + conversation_id="c1", + group="g1", + ) + assert body == { + "content": {"type": "string", "content": "hello"}, + "user_id": "u1", + "conversation_id": "c1", + "group": "g1", + } + + +def test_build_add_body_pre_extracted() -> None: + body = build_add_body( + PreExtractedContent(content="fact", tags=["a", "b"]), + user_id=None, + conversation_id=None, + group=None, + ) + assert body == { + "content": {"type": "pre_extracted", "content": "fact", "tags": ["a", "b"]}, + } + + +def test_build_add_body_pre_extracted_no_tags() -> None: + body = build_add_body( + PreExtractedContent(content="fact"), + user_id=None, + conversation_id=None, + group=None, + ) + assert body == { + "content": {"type": "pre_extracted", "content": "fact"}, + } + + +def test_build_add_body_conversation() -> None: + messages = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + body = build_add_body( + messages, + user_id="u1", + conversation_id="c1", + group=None, + ) + assert body == { + "content": { + "type": "conversation", + "conversation": {"messages": messages}, + }, + "user_id": "u1", + "conversation_id": "c1", + } + + +# ── build_memory_params ───────────────────────────────────────────────── + + +def test_build_memory_params_minimal() -> None: + params = build_memory_params(topic="t1", user_id=None, conversation_id=None, group=None) + assert params == {"topic": "t1"} + + +def test_build_memory_params_full() -> None: + params = build_memory_params(topic="t1", user_id="u1", conversation_id="c1", group="g1") + assert params == { + "topic": "t1", + "user_id": "u1", + "conversation_id": "c1", + "group": "g1", + } + + +# ── build_search_body ─────────────────────────────────────────────────── + + +def test_build_search_body_defaults() -> None: + body = build_search_body( + query="test", + topics=None, + user_id=None, + conversation_id=None, + group=None, + retrieval_config=RetrievalConfig(), + ) + assert body == { + "query": "test", + "retrieval_config": {"retrieval_type": "hybrid", "limit": 10}, + } + + +def test_build_search_body_full() -> None: + body = build_search_body( + query="test", + topics=["a", "b"], + user_id="u1", + conversation_id="c1", + group="g1", + retrieval_config=RetrievalConfig(retrieval_type="vector", limit=5), + ) + assert body["topics"] == ["a", "b"] + assert body["user_id"] == "u1" + assert body["retrieval_config"]["retrieval_type"] == "vector" + assert body["retrieval_config"]["limit"] == 5 + + +# ── parse_run ─────────────────────────────────────────────────────────── + + +def test_parse_run() -> None: + result = parse_run({"run_id": "r1", "status": "pending"}) + assert result.run_id == "r1" + assert result.status == "pending" + assert result.error is None + + +def test_parse_run_with_error() -> None: + result = parse_run({"run_id": "r1", "status": "failed", "error": "boom"}) + assert result.error == "boom" + + +# ── parse_memory ──────────────────────────────────────────────────────── + + +SAMPLE_MEMORY = { + "id": "m1", + "project_id": "p1", + "content": "some content", + "topic": "t1", + "group": "g1", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", +} + + +def test_parse_memory_minimal() -> None: + mem = parse_memory(SAMPLE_MEMORY) + assert mem.id == "m1" + assert mem.project_id == "p1" + assert mem.user_id is None + assert mem.score is None + + +def test_parse_memory_with_optional_fields() -> None: + data = { + **SAMPLE_MEMORY, + "user_id": "u1", + "conversation_id": "c1", + "tags": ["x"], + "score": 0.95, + } + mem = parse_memory(data) + assert mem.user_id == "u1" + assert mem.tags == ["x"] + assert mem.score == 0.95 + + +# ── parse_search_results ──────────────────────────────────────────────── + + +def test_parse_search_results() -> None: + data = {"memories": [SAMPLE_MEMORY], "total": 1} + result = parse_search_results(data) + assert result.total == 1 + assert len(result) == 1 + assert result[0].id == "m1" + + +def test_parse_search_results_empty() -> None: + result = parse_search_results({"memories": [], "total": 0}) + assert result.total == 0 + assert len(result) == 0 + + +def test_search_results_iterable() -> None: + data = { + "memories": [SAMPLE_MEMORY, {**SAMPLE_MEMORY, "id": "m2"}], + "total": 2, + } + result = parse_search_results(data) + ids = [m.id for m in result] + assert ids == ["m1", "m2"] + + +# ── parse_run_status ──────────────────────────────────────────────────── + + +SAMPLE_RUN_STATUS = { + "run_id": "r1", + "status": "completed", + "group_id": "g1", + "starting_step": 0, + "input_type": "string", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", +} + + +def test_parse_run_status_minimal() -> None: + result = parse_run_status(SAMPLE_RUN_STATUS) + assert result.run_id == "r1" + assert result.starting_step == 0 + assert result.committed_operations is None + assert result.error is None + assert result.memories_created == [] + + +def test_parse_run_status_with_committed_operations() -> None: + data = { + **SAMPLE_RUN_STATUS, + "committed_operations": { + "created": [{"memory_id": "m1", "committed_at": "2024-01-01T00:00:00Z"}], + "updated": [], + "deleted": [], + }, + } + result = parse_run_status(data) + assert result.committed_operations is not None + assert len(result.memories_created) == 1 + assert result.memories_created[0].memory_id == "m1" + assert result.memories_updated == [] + + +def test_parse_run_status_with_error() -> None: + data = {**SAMPLE_RUN_STATUS, "status": "failed", "error": "boom"} + result = parse_run_status(data) + assert result.error == "boom" From ca897fc61241ed655b1fc914214940b52a41347c Mon Sep 17 00:00:00 2001 From: Charles Pierse Date: Tue, 17 Feb 2026 19:48:38 +0000 Subject: [PATCH 2/6] PR changes --- src/engram/__init__.py | 2 + src/engram/_base_client.py | 72 +--------------- src/engram/_http.py | 139 ++++++++++++++++++++++++++++++ src/engram/_models.py | 115 ------------------------ src/engram/_models/__init__.py | 14 +++ src/engram/_models/memory.py | 58 +++++++++++++ src/engram/_models/run.py | 56 ++++++++++++ src/engram/_resources/memories.py | 35 +++----- src/engram/_resources/runs.py | 27 +++--- src/engram/async_client.py | 27 +++--- src/engram/client.py | 27 +++--- src/engram/errors.py | 15 ++++ tests/test_imports.py | 3 + 13 files changed, 338 insertions(+), 252 deletions(-) create mode 100644 src/engram/_http.py delete mode 100644 src/engram/_models.py create mode 100644 src/engram/_models/__init__.py create mode 100644 src/engram/_models/memory.py create mode 100644 src/engram/_models/run.py diff --git a/src/engram/__init__.py b/src/engram/__init__.py index 1806fa9..851be08 100644 --- a/src/engram/__init__.py +++ b/src/engram/__init__.py @@ -15,6 +15,7 @@ AuthenticationError, ConnectionError, EngramError, + EngramTimeoutError, NotFoundError, ValidationError, ) @@ -29,6 +30,7 @@ "ConnectionError", "EngramClient", "EngramError", + "EngramTimeoutError", "Memory", "NotFoundError", "PreExtractedContent", diff --git a/src/engram/_base_client.py b/src/engram/_base_client.py index 7c7372a..4608c79 100644 --- a/src/engram/_base_client.py +++ b/src/engram/_base_client.py @@ -1,11 +1,8 @@ from __future__ import annotations from collections.abc import Mapping -from typing import Any -import httpx - -from .errors import APIError, AuthenticationError, NotFoundError, ValidationError +from .errors import ValidationError from .types import ClientConfig from .version import __version__ @@ -14,10 +11,7 @@ class _BaseClient: - """Shared client behavior for sync and async clients.""" - - _http_client: httpx.Client | httpx.AsyncClient - _owns_http_client: bool + """Shared config and header logic for sync and async clients.""" def __init__( self, @@ -48,68 +42,6 @@ def config(self) -> ClientConfig: def default_headers(self) -> dict[str, str]: return dict(self._config.headers) - def _process_response(self, response: httpx.Response) -> dict[str, Any]: - data = _safe_json(response) - - if response.status_code == 401: - detail = _extract_detail(data, "Authentication failed") - raise AuthenticationError(detail, status_code=401, body=data) - - if response.status_code == 404: - detail = _extract_detail(data, "Not found") - raise NotFoundError(detail, status_code=404, body=data) - - if response.status_code >= 400: - detail = _extract_detail(data, response.reason_phrase) - raise APIError(detail, status_code=response.status_code, body=data) - - if data is None: - return {} - if isinstance(data, dict): - return data - return {"data": data} - - def build_request( - self, - method: str, - path: str, - *, - headers: Mapping[str, str] | None = None, - params: Mapping[str, Any] | None = None, - json: Any | None = None, - ) -> httpx.Request: - merged_headers = self.default_headers - if headers: - merged_headers.update(headers) - - clean_path = path.lstrip("/") - url = f"{self._config.base_url}/{clean_path}" if clean_path else self._config.base_url - - return self._http_client.build_request( - method=method, - url=url, - headers=merged_headers, - params=params, - json=json, - ) - - -def _extract_detail(data: Any, fallback: str) -> str: - if isinstance(data, dict): - return str(data.get("detail", fallback)) - if data: - return str(data) - return fallback - - -def _safe_json(response: httpx.Response) -> Any: - if not response.content: - return None - try: - return response.json() - except Exception: - return None - def _build_headers( *, diff --git a/src/engram/_http.py b/src/engram/_http.py new file mode 100644 index 0000000..7cb3266 --- /dev/null +++ b/src/engram/_http.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +import httpx + +from .errors import APIError, AuthenticationError, NotFoundError +from .errors import ConnectionError as EngramConnectionError +from .types import ClientConfig + + +class HttpTransport: + """Wraps a sync httpx.Client and handles request building and response processing.""" + + def __init__(self, config: ClientConfig, http_client: httpx.Client) -> None: + self._config = config + self._http_client = http_client + + def request( + self, + method: str, + path: str, + *, + params: Mapping[str, Any] | None = None, + json: Any | None = None, + ) -> dict[str, Any]: + req = self.build_request(method, path, params=params, json=json) + try: + response = self._http_client.send(req) + except httpx.ConnectError as exc: + raise EngramConnectionError(str(exc)) from exc + return _process_response(response) + + def build_request( + self, + method: str, + path: str, + *, + headers: Mapping[str, str] | None = None, + params: Mapping[str, Any] | None = None, + json: Any | None = None, + ) -> httpx.Request: + merged_headers = dict(self._config.headers) + if headers: + merged_headers.update(headers) + clean_path = path.lstrip("/") + url = f"{self._config.base_url}/{clean_path}" if clean_path else self._config.base_url + return self._http_client.build_request( + method=method, + url=url, + headers=merged_headers, + params=params, + json=json, + ) + + +class AsyncHttpTransport: + """Wraps an async httpx.AsyncClient and handles request building and response processing.""" + + def __init__(self, config: ClientConfig, http_client: httpx.AsyncClient) -> None: + self._config = config + self._http_client = http_client + + async def request( + self, + method: str, + path: str, + *, + params: Mapping[str, Any] | None = None, + json: Any | None = None, + ) -> dict[str, Any]: + req = self.build_request(method, path, params=params, json=json) + try: + response = await self._http_client.send(req) + except httpx.ConnectError as exc: + raise EngramConnectionError(str(exc)) from exc + return _process_response(response) + + def build_request( + self, + method: str, + path: str, + *, + headers: Mapping[str, str] | None = None, + params: Mapping[str, Any] | None = None, + json: Any | None = None, + ) -> httpx.Request: + merged_headers = dict(self._config.headers) + if headers: + merged_headers.update(headers) + clean_path = path.lstrip("/") + url = f"{self._config.base_url}/{clean_path}" if clean_path else self._config.base_url + return self._http_client.build_request( + method=method, + url=url, + headers=merged_headers, + params=params, + json=json, + ) + + +def _process_response(response: httpx.Response) -> dict[str, Any]: + data = _safe_json(response) + + if response.status_code == 401: + detail = _extract_detail(data, "Authentication failed") + raise AuthenticationError(detail, body=data) + + if response.status_code == 404: + detail = _extract_detail(data, "Not found") + raise NotFoundError(detail, body=data) + + if response.status_code >= 400: + detail = _extract_detail(data, response.reason_phrase) + raise APIError(detail, status_code=response.status_code, body=data) + + if data is None: + return {} + if isinstance(data, dict): + return data + return {"data": data} + + +def _extract_detail(data: Any, fallback: str) -> str: + if isinstance(data, dict): + return str(data.get("detail", fallback)) + if data: + return str(data) + return fallback + + +def _safe_json(response: httpx.Response) -> Any: + if not response.content: + return None + try: + return response.json() + except Exception: + return None diff --git a/src/engram/_models.py b/src/engram/_models.py deleted file mode 100644 index 7ee087a..0000000 --- a/src/engram/_models.py +++ /dev/null @@ -1,115 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterator, Sequence -from dataclasses import dataclass, field -from typing import Literal - -# ── Request models ────────────────────────────────────────────────────── - - -@dataclass(slots=True) -class PreExtractedContent: - """Pre-extracted content that bypasses the extraction pipeline.""" - - content: str - tags: list[str] = field(default_factory=list) - - -@dataclass(slots=True) -class RetrievalConfig: - retrieval_type: Literal["vector", "bm25", "hybrid"] = "hybrid" - limit: int = 10 - - -# ── Response models ───────────────────────────────────────────────────── - -# Type alias for the content argument to memories.add() -AddContent = str | list[dict[str, str]] | PreExtractedContent - - -@dataclass(slots=True) -class Run: - """Returned from memories.add() — represents a pipeline run.""" - - run_id: str - status: str - error: str | None = None - - -@dataclass(slots=True) -class Memory: - id: str - project_id: str - content: str - topic: str - group: str - created_at: str - updated_at: str - user_id: str | None = None - conversation_id: str | None = None - tags: list[str] | None = None - score: float | None = None - - -class SearchResults(Sequence[Memory]): - """List-like wrapper over search results with a total count.""" - - def __init__(self, memories: list[Memory], total: int) -> None: - self._memories = memories - self.total = total - - def __getitem__(self, index: int) -> Memory: # type: ignore[override] - return self._memories[index] - - def __len__(self) -> int: - return len(self._memories) - - def __iter__(self) -> Iterator[Memory]: - return iter(self._memories) - - def __repr__(self) -> str: - return f"SearchResults(total={self.total}, returned={len(self._memories)})" - - -@dataclass(slots=True) -class CommittedOperation: - memory_id: str - committed_at: str - - -@dataclass(slots=True) -class CommittedOperations: - created: list[CommittedOperation] = field(default_factory=list) - updated: list[CommittedOperation] = field(default_factory=list) - deleted: list[CommittedOperation] = field(default_factory=list) - - -@dataclass(slots=True) -class RunStatus: - run_id: str - status: str - group_id: str - starting_step: int - input_type: str - created_at: str - updated_at: str - committed_operations: CommittedOperations | None = None - error: str | None = None - - @property - def memories_created(self) -> list[CommittedOperation]: - if self.committed_operations is None: - return [] - return self.committed_operations.created - - @property - def memories_updated(self) -> list[CommittedOperation]: - if self.committed_operations is None: - return [] - return self.committed_operations.updated - - @property - def memories_deleted(self) -> list[CommittedOperation]: - if self.committed_operations is None: - return [] - return self.committed_operations.deleted diff --git a/src/engram/_models/__init__.py b/src/engram/_models/__init__.py new file mode 100644 index 0000000..1335c37 --- /dev/null +++ b/src/engram/_models/__init__.py @@ -0,0 +1,14 @@ +from .memory import AddContent, Memory, PreExtractedContent, RetrievalConfig, SearchResults +from .run import CommittedOperation, CommittedOperations, Run, RunStatus + +__all__ = [ + "AddContent", + "CommittedOperation", + "CommittedOperations", + "Memory", + "PreExtractedContent", + "RetrievalConfig", + "Run", + "RunStatus", + "SearchResults", +] diff --git a/src/engram/_models/memory.py b/src/engram/_models/memory.py new file mode 100644 index 0000000..2a4f6f0 --- /dev/null +++ b/src/engram/_models/memory.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from collections.abc import Iterator, Sequence +from dataclasses import dataclass, field +from typing import Literal, TypeAlias + + +@dataclass(slots=True) +class PreExtractedContent: + """Pre-extracted content that bypasses the extraction pipeline.""" + + content: str + tags: list[str] = field(default_factory=list) + + +# Type alias for the content argument to memories.add() +AddContent: TypeAlias = str | list[dict[str, str]] | PreExtractedContent + + +@dataclass(slots=True) +class RetrievalConfig: + retrieval_type: Literal["vector", "bm25", "hybrid"] = "hybrid" + limit: int = 10 + + +@dataclass(slots=True) +class Memory: + id: str + project_id: str + content: str + topic: str + group: str + created_at: str + updated_at: str + user_id: str | None = None + conversation_id: str | None = None + tags: list[str] | None = None + score: float | None = None + + +class SearchResults(Sequence[Memory]): + """List-like wrapper over search results with a total count.""" + + def __init__(self, memories: list[Memory], total: int) -> None: + self._memories = memories + self.total = total + + def __getitem__(self, index: int) -> Memory: # type: ignore[override] + return self._memories[index] + + def __len__(self) -> int: + return len(self._memories) + + def __iter__(self) -> Iterator[Memory]: + return iter(self._memories) + + def __repr__(self) -> str: + return f"SearchResults(total={self.total}, returned={len(self._memories)})" diff --git a/src/engram/_models/run.py b/src/engram/_models/run.py new file mode 100644 index 0000000..b7eeda7 --- /dev/null +++ b/src/engram/_models/run.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass(slots=True) +class Run: + """Returned from memories.add() — represents a pipeline run.""" + + run_id: str + status: str + error: str | None = None + + +@dataclass(slots=True) +class CommittedOperation: + memory_id: str + committed_at: str + + +@dataclass(slots=True) +class CommittedOperations: + created: list[CommittedOperation] = field(default_factory=list) + updated: list[CommittedOperation] = field(default_factory=list) + deleted: list[CommittedOperation] = field(default_factory=list) + + +@dataclass(slots=True) +class RunStatus: + run_id: str + status: str + group_id: str + starting_step: int + input_type: str + created_at: str + updated_at: str + committed_operations: CommittedOperations | None = None + error: str | None = None + + @property + def memories_created(self) -> list[CommittedOperation]: + if self.committed_operations is None: + return [] + return self.committed_operations.created + + @property + def memories_updated(self) -> list[CommittedOperation]: + if self.committed_operations is None: + return [] + return self.committed_operations.updated + + @property + def memories_deleted(self) -> list[CommittedOperation]: + if self.committed_operations is None: + return [] + return self.committed_operations.deleted diff --git a/src/engram/_resources/memories.py b/src/engram/_resources/memories.py index 85126bd..7ca41a8 100644 --- a/src/engram/_resources/memories.py +++ b/src/engram/_resources/memories.py @@ -1,7 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING - +from .._http import AsyncHttpTransport, HttpTransport from .._models import AddContent, Memory, RetrievalConfig, Run, SearchResults from .._serialization import ( build_add_body, @@ -12,10 +11,6 @@ parse_search_results, ) -if TYPE_CHECKING: - from ..async_client import AsyncEngramClient - from ..client import EngramClient - _MEMORIES_PATH = "/v1/memories" _MEMORIES_SEARCH_PATH = "/v1/memories/search" @@ -27,10 +22,8 @@ def _memory_path(memory_id: str) -> str: class Memories: """Sync sub-resource for memory operations: client.memories.*""" - _client: EngramClient - - def __init__(self, client: EngramClient) -> None: - self._client = client + def __init__(self, transport: HttpTransport) -> None: + self._transport = transport def add( self, @@ -46,7 +39,7 @@ def add( conversation_id=conversation_id, group=group, ) - data = self._client._request("POST", _MEMORIES_PATH, json=body) + data = self._transport.request("POST", _MEMORIES_PATH, json=body) return parse_run(data) def get( @@ -64,7 +57,7 @@ def get( conversation_id=conversation_id, group=group, ) - data = self._client._request("GET", _memory_path(memory_id), params=params) + data = self._transport.request("GET", _memory_path(memory_id), params=params) return parse_memory(data) def delete( @@ -82,7 +75,7 @@ def delete( conversation_id=conversation_id, group=group, ) - self._client._request("DELETE", _memory_path(memory_id), params=params) + self._transport.request("DELETE", _memory_path(memory_id), params=params) def search( self, @@ -102,17 +95,15 @@ def search( group=group, retrieval_config=retrieval_config or RetrievalConfig(), ) - data = self._client._request("POST", _MEMORIES_SEARCH_PATH, json=body) + data = self._transport.request("POST", _MEMORIES_SEARCH_PATH, json=body) return parse_search_results(data) class AsyncMemories: """Async sub-resource for memory operations: client.memories.*""" - _client: AsyncEngramClient - - def __init__(self, client: AsyncEngramClient) -> None: - self._client = client + def __init__(self, transport: AsyncHttpTransport) -> None: + self._transport = transport async def add( self, @@ -128,7 +119,7 @@ async def add( conversation_id=conversation_id, group=group, ) - data = await self._client._request("POST", _MEMORIES_PATH, json=body) + data = await self._transport.request("POST", _MEMORIES_PATH, json=body) return parse_run(data) async def get( @@ -146,7 +137,7 @@ async def get( conversation_id=conversation_id, group=group, ) - data = await self._client._request("GET", _memory_path(memory_id), params=params) + data = await self._transport.request("GET", _memory_path(memory_id), params=params) return parse_memory(data) async def delete( @@ -164,7 +155,7 @@ async def delete( conversation_id=conversation_id, group=group, ) - await self._client._request("DELETE", _memory_path(memory_id), params=params) + await self._transport.request("DELETE", _memory_path(memory_id), params=params) async def search( self, @@ -184,5 +175,5 @@ async def search( group=group, retrieval_config=retrieval_config or RetrievalConfig(), ) - data = await self._client._request("POST", _MEMORIES_SEARCH_PATH, json=body) + data = await self._transport.request("POST", _MEMORIES_SEARCH_PATH, json=body) return parse_search_results(data) diff --git a/src/engram/_resources/runs.py b/src/engram/_resources/runs.py index e224ee4..d42e1ca 100644 --- a/src/engram/_resources/runs.py +++ b/src/engram/_resources/runs.py @@ -2,14 +2,11 @@ import asyncio import time -from typing import TYPE_CHECKING +from .._http import AsyncHttpTransport, HttpTransport from .._models import RunStatus from .._serialization import parse_run_status - -if TYPE_CHECKING: - from ..async_client import AsyncEngramClient - from ..client import EngramClient +from ..errors import EngramTimeoutError _RUNS_PATH = "/v1/runs" @@ -23,13 +20,11 @@ def _run_path(run_id: str) -> str: class Runs: """Sync sub-resource for run operations: client.runs.*""" - _client: EngramClient - - def __init__(self, client: EngramClient) -> None: - self._client = client + def __init__(self, transport: HttpTransport) -> None: + self._transport = transport def get(self, run_id: str) -> RunStatus: - data = self._client._request("GET", _run_path(run_id)) + data = self._transport.request("GET", _run_path(run_id)) return parse_run_status(data) def wait( @@ -45,20 +40,18 @@ def wait( if status.status in _TERMINAL_STATUSES: return status if time.monotonic() + interval > deadline: - return status + raise EngramTimeoutError(run_id, timeout) time.sleep(interval) class AsyncRuns: """Async sub-resource for run operations: client.runs.*""" - _client: AsyncEngramClient - - def __init__(self, client: AsyncEngramClient) -> None: - self._client = client + def __init__(self, transport: AsyncHttpTransport) -> None: + self._transport = transport async def get(self, run_id: str) -> RunStatus: - data = await self._client._request("GET", _run_path(run_id)) + data = await self._transport.request("GET", _run_path(run_id)) return parse_run_status(data) async def wait( @@ -75,5 +68,5 @@ async def wait( if status.status in _TERMINAL_STATUSES: return status if loop.time() + interval > deadline: - return status + raise EngramTimeoutError(run_id, timeout) await asyncio.sleep(interval) diff --git a/src/engram/async_client.py b/src/engram/async_client.py index a585564..a2af845 100644 --- a/src/engram/async_client.py +++ b/src/engram/async_client.py @@ -6,8 +6,8 @@ import httpx from ._base_client import DEFAULT_BASE_URL, DEFAULT_TIMEOUT, _BaseClient +from ._http import AsyncHttpTransport from ._resources import AsyncMemories, AsyncRuns -from .errors import ConnectionError as EngramConnectionError __all__ = ["DEFAULT_BASE_URL", "DEFAULT_TIMEOUT", "AsyncEngramClient"] @@ -15,7 +15,7 @@ class AsyncEngramClient(_BaseClient): """Asynchronous Engram client.""" - _http_client: httpx.AsyncClient + _transport: AsyncHttpTransport memories: AsyncMemories runs: AsyncRuns @@ -35,28 +35,27 @@ def __init__( timeout=timeout, ) self._owns_http_client = http_client is None - self._http_client = http_client or httpx.AsyncClient(timeout=timeout) - self.memories = AsyncMemories(self) - self.runs = AsyncRuns(self) + http = http_client or httpx.AsyncClient(timeout=timeout) + self._transport = AsyncHttpTransport(self._config, http) + self.memories = AsyncMemories(self._transport) + self.runs = AsyncRuns(self._transport) - async def _request( + def build_request( self, method: str, path: str, *, + headers: Mapping[str, str] | None = None, params: Mapping[str, Any] | None = None, json: Any | None = None, - ) -> dict[str, Any]: - request = self.build_request(method, path, params=params, json=json) - try: - response = await self._http_client.send(request) - except httpx.ConnectError as exc: - raise EngramConnectionError(str(exc)) from exc - return self._process_response(response) + ) -> httpx.Request: + return self._transport.build_request( + method, path, headers=headers, params=params, json=json + ) async def aclose(self) -> None: if self._owns_http_client: - await self._http_client.aclose() + await self._transport._http_client.aclose() async def __aenter__(self) -> AsyncEngramClient: return self diff --git a/src/engram/client.py b/src/engram/client.py index 1aab55c..80c41a0 100644 --- a/src/engram/client.py +++ b/src/engram/client.py @@ -6,8 +6,8 @@ import httpx from ._base_client import DEFAULT_BASE_URL, DEFAULT_TIMEOUT, _BaseClient +from ._http import HttpTransport from ._resources import Memories, Runs -from .errors import ConnectionError as EngramConnectionError __all__ = ["DEFAULT_BASE_URL", "DEFAULT_TIMEOUT", "EngramClient"] @@ -15,7 +15,7 @@ class EngramClient(_BaseClient): """Synchronous Engram client.""" - _http_client: httpx.Client + _transport: HttpTransport memories: Memories runs: Runs @@ -35,28 +35,27 @@ def __init__( timeout=timeout, ) self._owns_http_client = http_client is None - self._http_client = http_client or httpx.Client(timeout=timeout) - self.memories = Memories(self) - self.runs = Runs(self) + http = http_client or httpx.Client(timeout=timeout) + self._transport = HttpTransport(self._config, http) + self.memories = Memories(self._transport) + self.runs = Runs(self._transport) - def _request( + def build_request( self, method: str, path: str, *, + headers: Mapping[str, str] | None = None, params: Mapping[str, Any] | None = None, json: Any | None = None, - ) -> dict[str, Any]: - request = self.build_request(method, path, params=params, json=json) - try: - response = self._http_client.send(request) - except httpx.ConnectError as exc: - raise EngramConnectionError(str(exc)) from exc - return self._process_response(response) + ) -> httpx.Request: + return self._transport.build_request( + method, path, headers=headers, params=params, json=json + ) def close(self) -> None: if self._owns_http_client: - self._http_client.close() + self._transport._http_client.close() def __enter__(self) -> EngramClient: return self diff --git a/src/engram/errors.py b/src/engram/errors.py index 6ac07ba..66e6233 100644 --- a/src/engram/errors.py +++ b/src/engram/errors.py @@ -25,10 +25,16 @@ def __init__( class AuthenticationError(APIError): """Raised when authentication fails (401).""" + def __init__(self, message: str, *, body: object = None) -> None: + super().__init__(message, status_code=401, body=body) + class NotFoundError(APIError): """Raised when a resource is not found (404).""" + def __init__(self, message: str, *, body: object = None) -> None: + super().__init__(message, status_code=404, body=body) + class ValidationError(EngramError): """Raised when configuration or request input is invalid.""" @@ -36,3 +42,12 @@ class ValidationError(EngramError): class ConnectionError(EngramError): """Raised when a connection to the Engram server fails.""" + + +class EngramTimeoutError(EngramError): + """Raised when a run does not reach a terminal status within the timeout.""" + + def __init__(self, run_id: str, timeout: float) -> None: + super().__init__(f"Run {run_id!r} did not reach a terminal status within {timeout}s") + self.run_id = run_id + self.timeout = timeout diff --git a/tests/test_imports.py b/tests/test_imports.py index e489a15..fb8236e 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -9,6 +9,7 @@ def test_public_imports() -> None: ConnectionError, EngramClient, EngramError, + EngramTimeoutError, Memory, NotFoundError, PreExtractedContent, @@ -26,6 +27,7 @@ def test_public_imports() -> None: assert isinstance(AuthenticationError, type) assert isinstance(NotFoundError, type) assert isinstance(ValidationError, type) + assert isinstance(EngramTimeoutError, type) assert isinstance(Memory, type) assert isinstance(Run, type) assert isinstance(RunStatus, type) @@ -44,6 +46,7 @@ def test_public_imports() -> None: "ConnectionError", "EngramClient", "EngramError", + "EngramTimeoutError", "Memory", "NotFoundError", "PreExtractedContent", From 80e274f187b2b180ddc3d31d7ed76ae4af09f9af Mon Sep 17 00:00:00 2001 From: Charles Pierse Date: Wed, 18 Feb 2026 16:06:46 +0000 Subject: [PATCH 3/6] Fix serialization bug --- src/engram/_serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/engram/_serialization.py b/src/engram/_serialization.py index 812c0d7..f1afde9 100644 --- a/src/engram/_serialization.py +++ b/src/engram/_serialization.py @@ -124,7 +124,7 @@ def parse_memory(data: dict[str, Any]) -> Memory: def parse_search_results(data: dict[str, Any]) -> SearchResults: return SearchResults( - memories=[parse_memory(m) for m in data["memories"]], + memories=[parse_memory(m["Body"]) for m in data["memories"]], total=data["total"], ) From 062db4b54d039615fe7eb9c6e0b2566332bd379b Mon Sep 17 00:00:00 2001 From: Charles Pierse Date: Thu, 19 Feb 2026 16:17:15 +0000 Subject: [PATCH 4/6] Better organization of serialization --- src/engram/_http.py | 18 ++- src/engram/_serialization.py | 161 ------------------------- src/engram/_serialization/__init__.py | 21 ++++ src/engram/_serialization/_builders.py | 83 +++++++++++++ src/engram/_serialization/_parsers.py | 75 ++++++++++++ src/engram/async_client.py | 21 +--- src/engram/client.py | 21 +--- tests/test_client_async.py | 8 +- tests/test_client_sync.py | 8 +- tests/test_serialization.py | 4 +- 10 files changed, 207 insertions(+), 213 deletions(-) delete mode 100644 src/engram/_serialization.py create mode 100644 src/engram/_serialization/__init__.py create mode 100644 src/engram/_serialization/_builders.py create mode 100644 src/engram/_serialization/_parsers.py diff --git a/src/engram/_http.py b/src/engram/_http.py index 7cb3266..c79b0c9 100644 --- a/src/engram/_http.py +++ b/src/engram/_http.py @@ -13,9 +13,14 @@ class HttpTransport: """Wraps a sync httpx.Client and handles request building and response processing.""" - def __init__(self, config: ClientConfig, http_client: httpx.Client) -> None: + def __init__(self, config: ClientConfig, http_client: httpx.Client | None = None) -> None: self._config = config - self._http_client = http_client + self._owns_http_client = http_client is None + self._http_client = http_client or httpx.Client(timeout=config.timeout) + + def close(self) -> None: + if self._owns_http_client: + self._http_client.close() def request( self, @@ -58,9 +63,14 @@ def build_request( class AsyncHttpTransport: """Wraps an async httpx.AsyncClient and handles request building and response processing.""" - def __init__(self, config: ClientConfig, http_client: httpx.AsyncClient) -> None: + def __init__(self, config: ClientConfig, http_client: httpx.AsyncClient | None = None) -> None: self._config = config - self._http_client = http_client + self._owns_http_client = http_client is None + self._http_client = http_client or httpx.AsyncClient(timeout=config.timeout) + + async def close(self) -> None: + if self._owns_http_client: + await self._http_client.aclose() async def request( self, diff --git a/src/engram/_serialization.py b/src/engram/_serialization.py deleted file mode 100644 index f1afde9..0000000 --- a/src/engram/_serialization.py +++ /dev/null @@ -1,161 +0,0 @@ -from __future__ import annotations - -from typing import Any - -from ._models import ( - AddContent, - CommittedOperation, - CommittedOperations, - Memory, - PreExtractedContent, - RetrievalConfig, - Run, - RunStatus, - SearchResults, -) - -# ── Body builders ─────────────────────────────────────────────────────── - - -def _serialize_content(content: AddContent) -> dict[str, Any]: - """Build the content envelope with the type discriminator.""" - if isinstance(content, str): - return {"type": "string", "content": content} - if isinstance(content, PreExtractedContent): - d: dict[str, Any] = {"type": "pre_extracted", "content": content.content} - if content.tags: - d["tags"] = content.tags - return d - if isinstance(content, list): - return { - "type": "conversation", - "conversation": {"messages": content}, - } - raise TypeError(f"Unsupported content type: {type(content)}") # pragma: no cover - - -def build_add_body( - content: AddContent, - *, - user_id: str | None, - conversation_id: str | None, - group: str | None, -) -> dict[str, Any]: - body: dict[str, Any] = {"content": _serialize_content(content)} - if user_id is not None: - body["user_id"] = user_id - if conversation_id is not None: - body["conversation_id"] = conversation_id - if group is not None: - body["group"] = group - return body - - -def build_memory_params( - *, - topic: str, - user_id: str | None, - conversation_id: str | None, - group: str | None, -) -> dict[str, str]: - params: dict[str, str] = {"topic": topic} - if user_id is not None: - params["user_id"] = user_id - if conversation_id is not None: - params["conversation_id"] = conversation_id - if group is not None: - params["group"] = group - return params - - -def build_search_body( - *, - query: str, - topics: list[str] | None, - user_id: str | None, - conversation_id: str | None, - group: str | None, - retrieval_config: RetrievalConfig, -) -> dict[str, Any]: - body: dict[str, Any] = { - "query": query, - "retrieval_config": { - "retrieval_type": retrieval_config.retrieval_type, - "limit": retrieval_config.limit, - }, - } - if topics is not None: - body["topics"] = topics - if user_id is not None: - body["user_id"] = user_id - if conversation_id is not None: - body["conversation_id"] = conversation_id - if group is not None: - body["group"] = group - return body - - -# ── Response parsers ──────────────────────────────────────────────────── - - -def parse_run(data: dict[str, Any]) -> Run: - return Run( - run_id=data["run_id"], - status=data["status"], - error=data.get("error"), - ) - - -def parse_memory(data: dict[str, Any]) -> Memory: - return Memory( - id=data["id"], - project_id=data["project_id"], - content=data["content"], - topic=data["topic"], - group=data["group"], - created_at=data["created_at"], - updated_at=data["updated_at"], - user_id=data.get("user_id"), - conversation_id=data.get("conversation_id"), - tags=data.get("tags"), - score=data.get("score"), - ) - - -def parse_search_results(data: dict[str, Any]) -> SearchResults: - return SearchResults( - memories=[parse_memory(m["Body"]) for m in data["memories"]], - total=data["total"], - ) - - -def _parse_committed_operation(data: dict[str, Any]) -> CommittedOperation: - return CommittedOperation( - memory_id=data["memory_id"], - committed_at=data["committed_at"], - ) - - -def _parse_committed_operations(data: dict[str, Any]) -> CommittedOperations: - return CommittedOperations( - created=[_parse_committed_operation(op) for op in data.get("created", [])], - updated=[_parse_committed_operation(op) for op in data.get("updated", [])], - deleted=[_parse_committed_operation(op) for op in data.get("deleted", [])], - ) - - -def parse_run_status(data: dict[str, Any]) -> RunStatus: - committed_ops = data.get("committed_operations") - return RunStatus( - run_id=data["run_id"], - status=data["status"], - group_id=data["group_id"], - starting_step=data["starting_step"], - input_type=data["input_type"], - created_at=data["created_at"], - updated_at=data["updated_at"], - committed_operations=_parse_committed_operations(committed_ops) - if committed_ops is not None - else None, - error=data.get("error"), - ) diff --git a/src/engram/_serialization/__init__.py b/src/engram/_serialization/__init__.py new file mode 100644 index 0000000..ef0a243 --- /dev/null +++ b/src/engram/_serialization/__init__.py @@ -0,0 +1,21 @@ +from ._builders import ( + build_add_body, + build_memory_params, + build_search_body, +) +from ._parsers import ( + parse_memory, + parse_run, + parse_run_status, + parse_search_results, +) + +__all__ = [ + "build_add_body", + "build_memory_params", + "build_search_body", + "parse_memory", + "parse_run", + "parse_run_status", + "parse_search_results", +] diff --git a/src/engram/_serialization/_builders.py b/src/engram/_serialization/_builders.py new file mode 100644 index 0000000..304ccf1 --- /dev/null +++ b/src/engram/_serialization/_builders.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import Any + +from .._models import AddContent, PreExtractedContent, RetrievalConfig + + +def _serialize_content(content: AddContent) -> dict[str, Any]: + """Build the content envelope with the type discriminator.""" + if isinstance(content, str): + return {"type": "string", "content": content} + if isinstance(content, PreExtractedContent): + d: dict[str, Any] = {"type": "pre_extracted", "content": content.content} + if content.tags: + d["tags"] = content.tags + return d + if isinstance(content, list): + return { + "type": "conversation", + "conversation": {"messages": content}, + } + raise TypeError(f"Unsupported content type: {type(content)}") # pragma: no cover + + +def build_add_body( + content: AddContent, + *, + user_id: str | None, + conversation_id: str | None, + group: str | None, +) -> dict[str, Any]: + body: dict[str, Any] = {"content": _serialize_content(content)} + if user_id is not None: + body["user_id"] = user_id + if conversation_id is not None: + body["conversation_id"] = conversation_id + if group is not None: + body["group"] = group + return body + + +def build_memory_params( + *, + topic: str, + user_id: str | None, + conversation_id: str | None, + group: str | None, +) -> dict[str, str]: + params: dict[str, str] = {"topic": topic} + if user_id is not None: + params["user_id"] = user_id + if conversation_id is not None: + params["conversation_id"] = conversation_id + if group is not None: + params["group"] = group + return params + + +def build_search_body( + *, + query: str, + topics: list[str] | None, + user_id: str | None, + conversation_id: str | None, + group: str | None, + retrieval_config: RetrievalConfig, +) -> dict[str, Any]: + body: dict[str, Any] = { + "query": query, + "retrieval_config": { + "retrieval_type": retrieval_config.retrieval_type, + "limit": retrieval_config.limit, + }, + } + if topics is not None: + body["topics"] = topics + if user_id is not None: + body["user_id"] = user_id + if conversation_id is not None: + body["conversation_id"] = conversation_id + if group is not None: + body["group"] = group + return body diff --git a/src/engram/_serialization/_parsers.py b/src/engram/_serialization/_parsers.py new file mode 100644 index 0000000..3e58ecb --- /dev/null +++ b/src/engram/_serialization/_parsers.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from typing import Any + +from .._models import ( + CommittedOperation, + CommittedOperations, + Memory, + Run, + RunStatus, + SearchResults, +) + + +def parse_run(data: dict[str, Any]) -> Run: + return Run( + run_id=data["run_id"], + status=data["status"], + error=data.get("error"), + ) + + +def parse_memory(data: dict[str, Any]) -> Memory: + return Memory( + id=data["id"], + project_id=data["project_id"], + content=data["content"], + topic=data["topic"], + group=data["group"], + created_at=data["created_at"], + updated_at=data["updated_at"], + user_id=data.get("user_id"), + conversation_id=data.get("conversation_id"), + tags=data.get("tags"), + score=data.get("score"), + ) + + +def parse_search_results(data: dict[str, Any]) -> SearchResults: + return SearchResults( + memories=[parse_memory(m["Body"]) for m in data["memories"]], + total=data["total"], + ) + + +def _parse_committed_operation(data: dict[str, Any]) -> CommittedOperation: + return CommittedOperation( + memory_id=data["memory_id"], + committed_at=data["committed_at"], + ) + + +def _parse_committed_operations(data: dict[str, Any]) -> CommittedOperations: + return CommittedOperations( + created=[_parse_committed_operation(op) for op in data.get("created", [])], + updated=[_parse_committed_operation(op) for op in data.get("updated", [])], + deleted=[_parse_committed_operation(op) for op in data.get("deleted", [])], + ) + + +def parse_run_status(data: dict[str, Any]) -> RunStatus: + committed_ops = data.get("committed_operations") + return RunStatus( + run_id=data["run_id"], + status=data["status"], + group_id=data["group_id"], + starting_step=data["starting_step"], + input_type=data["input_type"], + created_at=data["created_at"], + updated_at=data["updated_at"], + committed_operations=_parse_committed_operations(committed_ops) + if committed_ops is not None + else None, + error=data.get("error"), + ) diff --git a/src/engram/async_client.py b/src/engram/async_client.py index a2af845..7ccbb5a 100644 --- a/src/engram/async_client.py +++ b/src/engram/async_client.py @@ -1,7 +1,6 @@ from __future__ import annotations from collections.abc import Mapping -from typing import Any import httpx @@ -34,28 +33,12 @@ def __init__( headers=headers, timeout=timeout, ) - self._owns_http_client = http_client is None - http = http_client or httpx.AsyncClient(timeout=timeout) - self._transport = AsyncHttpTransport(self._config, http) + self._transport = AsyncHttpTransport(self._config, http_client) self.memories = AsyncMemories(self._transport) self.runs = AsyncRuns(self._transport) - def build_request( - self, - method: str, - path: str, - *, - headers: Mapping[str, str] | None = None, - params: Mapping[str, Any] | None = None, - json: Any | None = None, - ) -> httpx.Request: - return self._transport.build_request( - method, path, headers=headers, params=params, json=json - ) - async def aclose(self) -> None: - if self._owns_http_client: - await self._transport._http_client.aclose() + await self._transport.close() async def __aenter__(self) -> AsyncEngramClient: return self diff --git a/src/engram/client.py b/src/engram/client.py index 80c41a0..feca706 100644 --- a/src/engram/client.py +++ b/src/engram/client.py @@ -1,7 +1,6 @@ from __future__ import annotations from collections.abc import Mapping -from typing import Any import httpx @@ -34,28 +33,12 @@ def __init__( headers=headers, timeout=timeout, ) - self._owns_http_client = http_client is None - http = http_client or httpx.Client(timeout=timeout) - self._transport = HttpTransport(self._config, http) + self._transport = HttpTransport(self._config, http_client) self.memories = Memories(self._transport) self.runs = Runs(self._transport) - def build_request( - self, - method: str, - path: str, - *, - headers: Mapping[str, str] | None = None, - params: Mapping[str, Any] | None = None, - json: Any | None = None, - ) -> httpx.Request: - return self._transport.build_request( - method, path, headers=headers, params=params, json=json - ) - def close(self) -> None: - if self._owns_http_client: - self._transport._http_client.close() + self._transport.close() def __enter__(self) -> EngramClient: return self diff --git a/tests/test_client_async.py b/tests/test_client_async.py index f034035..1b829a1 100644 --- a/tests/test_client_async.py +++ b/tests/test_client_async.py @@ -36,7 +36,7 @@ async def test_async_client_custom_config_and_header_merging() -> None: assert client.default_headers["Authorization"] == "Bearer test-key" assert client.default_headers["X-Custom"] == "from-client" - request = client.build_request( + request = client._transport.build_request( "GET", "v1/items", headers={"X-Custom": "from-request", "X-Request": "value"}, @@ -195,7 +195,7 @@ async def test_delete_memory() -> None: @pytest.mark.asyncio async def test_search_memories() -> None: response_body: dict[str, Any] = { - "memories": [SAMPLE_MEMORY_RESPONSE], + "memories": [{"Body": SAMPLE_MEMORY_RESPONSE}], "total": 1, } client = _make_client(body=response_body) @@ -208,8 +208,8 @@ async def test_search_memories() -> None: async def test_search_memories_iterable() -> None: response_body: dict[str, Any] = { "memories": [ - SAMPLE_MEMORY_RESPONSE, - {**SAMPLE_MEMORY_RESPONSE, "id": "m2", "score": 0.85}, + {"Body": SAMPLE_MEMORY_RESPONSE}, + {"Body": {**SAMPLE_MEMORY_RESPONSE, "id": "m2", "score": 0.85}}, ], "total": 2, } diff --git a/tests/test_client_sync.py b/tests/test_client_sync.py index e6533af..0ecd695 100644 --- a/tests/test_client_sync.py +++ b/tests/test_client_sync.py @@ -34,7 +34,7 @@ def test_client_custom_config_and_header_merging() -> None: assert client.default_headers["Authorization"] == "Bearer test-key" assert client.default_headers["X-Custom"] == "from-client" - request = client.build_request( + request = client._transport.build_request( "GET", "/v1/items", headers={"X-Custom": "from-request", "X-Request": "value"}, @@ -213,7 +213,7 @@ def test_delete_memory() -> None: def test_search_memories() -> None: response_body: dict[str, Any] = { - "memories": [SAMPLE_MEMORY_RESPONSE], + "memories": [{"Body": SAMPLE_MEMORY_RESPONSE}], "total": 1, } client = _make_client(body=response_body) @@ -225,8 +225,8 @@ def test_search_memories() -> None: def test_search_memories_iterable() -> None: response_body: dict[str, Any] = { "memories": [ - SAMPLE_MEMORY_RESPONSE, - {**SAMPLE_MEMORY_RESPONSE, "id": "m2", "score": 0.85}, + {"Body": SAMPLE_MEMORY_RESPONSE}, + {"Body": {**SAMPLE_MEMORY_RESPONSE, "id": "m2", "score": 0.85}}, ], "total": 2, } diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 62bbe0b..8bf664e 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -188,7 +188,7 @@ def test_parse_memory_with_optional_fields() -> None: def test_parse_search_results() -> None: - data = {"memories": [SAMPLE_MEMORY], "total": 1} + data = {"memories": [{"Body": SAMPLE_MEMORY}], "total": 1} result = parse_search_results(data) assert result.total == 1 assert len(result) == 1 @@ -203,7 +203,7 @@ def test_parse_search_results_empty() -> None: def test_search_results_iterable() -> None: data = { - "memories": [SAMPLE_MEMORY, {**SAMPLE_MEMORY, "id": "m2"}], + "memories": [{"Body": SAMPLE_MEMORY}, {"Body": {**SAMPLE_MEMORY, "id": "m2"}}], "total": 2, } result = parse_search_results(data) From 60a8ac7477a90f4a27d3194a0e179d4582ec2145 Mon Sep 17 00:00:00 2001 From: Charles Pierse Date: Fri, 20 Feb 2026 11:51:52 +0000 Subject: [PATCH 5/6] Address most recent PR comments --- src/engram/__init__.py | 2 - src/engram/_base_client.py | 7 ++- src/engram/_http.py | 6 +-- src/engram/async_client.py | 7 +-- src/engram/client.py | 7 +-- src/engram/errors.py | 7 --- tests/test_client_async.py | 72 +++++++++++++++---------------- tests/test_client_sync.py | 88 ++++++++++++++++---------------------- tests/test_imports.py | 3 -- 9 files changed, 80 insertions(+), 119 deletions(-) diff --git a/src/engram/__init__.py b/src/engram/__init__.py index 851be08..b2d1aeb 100644 --- a/src/engram/__init__.py +++ b/src/engram/__init__.py @@ -16,7 +16,6 @@ ConnectionError, EngramError, EngramTimeoutError, - NotFoundError, ValidationError, ) from .version import __version__ @@ -32,7 +31,6 @@ "EngramError", "EngramTimeoutError", "Memory", - "NotFoundError", "PreExtractedContent", "RetrievalConfig", "Run", diff --git a/src/engram/_base_client.py b/src/engram/_base_client.py index 4608c79..93cad73 100644 --- a/src/engram/_base_client.py +++ b/src/engram/_base_client.py @@ -17,7 +17,7 @@ def __init__( self, *, base_url: str = DEFAULT_BASE_URL, - api_key: str | None = None, + api_key: str, headers: Mapping[str, str] | None = None, timeout: float = DEFAULT_TIMEOUT, ) -> None: @@ -45,15 +45,14 @@ def default_headers(self) -> dict[str, str]: def _build_headers( *, - api_key: str | None, + api_key: str, header_overrides: Mapping[str, str], ) -> dict[str, str]: headers: dict[str, str] = { "Accept": "application/json", "Content-Type": "application/json", "User-Agent": f"weaviate-engram/{__version__}", + "Authorization": f"Bearer {api_key}", } - if api_key: - headers["Authorization"] = f"Bearer {api_key}" headers.update(header_overrides) return headers diff --git a/src/engram/_http.py b/src/engram/_http.py index c79b0c9..886ac4a 100644 --- a/src/engram/_http.py +++ b/src/engram/_http.py @@ -5,7 +5,7 @@ import httpx -from .errors import APIError, AuthenticationError, NotFoundError +from .errors import APIError, AuthenticationError from .errors import ConnectionError as EngramConnectionError from .types import ClientConfig @@ -117,10 +117,6 @@ def _process_response(response: httpx.Response) -> dict[str, Any]: detail = _extract_detail(data, "Authentication failed") raise AuthenticationError(detail, body=data) - if response.status_code == 404: - detail = _extract_detail(data, "Not found") - raise NotFoundError(detail, body=data) - if response.status_code >= 400: detail = _extract_detail(data, response.reason_phrase) raise APIError(detail, status_code=response.status_code, body=data) diff --git a/src/engram/async_client.py b/src/engram/async_client.py index 7ccbb5a..d52f75a 100644 --- a/src/engram/async_client.py +++ b/src/engram/async_client.py @@ -2,8 +2,6 @@ from collections.abc import Mapping -import httpx - from ._base_client import DEFAULT_BASE_URL, DEFAULT_TIMEOUT, _BaseClient from ._http import AsyncHttpTransport from ._resources import AsyncMemories, AsyncRuns @@ -22,10 +20,9 @@ def __init__( self, *, base_url: str = DEFAULT_BASE_URL, - api_key: str | None = None, + api_key: str, headers: Mapping[str, str] | None = None, timeout: float = DEFAULT_TIMEOUT, - http_client: httpx.AsyncClient | None = None, ) -> None: super().__init__( base_url=base_url, @@ -33,7 +30,7 @@ def __init__( headers=headers, timeout=timeout, ) - self._transport = AsyncHttpTransport(self._config, http_client) + self._transport = AsyncHttpTransport(self._config) self.memories = AsyncMemories(self._transport) self.runs = AsyncRuns(self._transport) diff --git a/src/engram/client.py b/src/engram/client.py index feca706..7a6359e 100644 --- a/src/engram/client.py +++ b/src/engram/client.py @@ -2,8 +2,6 @@ from collections.abc import Mapping -import httpx - from ._base_client import DEFAULT_BASE_URL, DEFAULT_TIMEOUT, _BaseClient from ._http import HttpTransport from ._resources import Memories, Runs @@ -22,10 +20,9 @@ def __init__( self, *, base_url: str = DEFAULT_BASE_URL, - api_key: str | None = None, + api_key: str, headers: Mapping[str, str] | None = None, timeout: float = DEFAULT_TIMEOUT, - http_client: httpx.Client | None = None, ) -> None: super().__init__( base_url=base_url, @@ -33,7 +30,7 @@ def __init__( headers=headers, timeout=timeout, ) - self._transport = HttpTransport(self._config, http_client) + self._transport = HttpTransport(self._config) self.memories = Memories(self._transport) self.runs = Runs(self._transport) diff --git a/src/engram/errors.py b/src/engram/errors.py index 66e6233..f00a7b6 100644 --- a/src/engram/errors.py +++ b/src/engram/errors.py @@ -29,13 +29,6 @@ def __init__(self, message: str, *, body: object = None) -> None: super().__init__(message, status_code=401, body=body) -class NotFoundError(APIError): - """Raised when a resource is not found (404).""" - - def __init__(self, message: str, *, body: object = None) -> None: - super().__init__(message, status_code=404, body=body) - - class ValidationError(EngramError): """Raised when configuration or request input is invalid.""" diff --git a/tests/test_client_async.py b/tests/test_client_async.py index 1b829a1..30aa0ff 100644 --- a/tests/test_client_async.py +++ b/tests/test_client_async.py @@ -4,20 +4,21 @@ import httpx import pytest +from engram._http import AsyncHttpTransport from engram._models import PreExtractedContent, RetrievalConfig from engram.async_client import DEFAULT_BASE_URL, AsyncEngramClient -from engram.errors import APIError, AuthenticationError, NotFoundError, ValidationError +from engram.errors import APIError, AuthenticationError, ValidationError @pytest.mark.asyncio async def test_async_client_defaults() -> None: - client = AsyncEngramClient() + client = AsyncEngramClient(api_key="test-key") try: assert client.config.base_url == DEFAULT_BASE_URL assert client.config.timeout == 30.0 assert client.default_headers["Accept"] == "application/json" assert client.default_headers["Content-Type"] == "application/json" - assert "Authorization" not in client.default_headers + assert client.default_headers["Authorization"] == "Bearer test-key" finally: await client.aclose() @@ -51,11 +52,11 @@ async def test_async_client_custom_config_and_header_merging() -> None: def test_async_client_rejects_non_positive_timeout() -> None: with pytest.raises(ValidationError): - AsyncEngramClient(timeout=-1) + AsyncEngramClient(api_key="test-key", timeout=-1) def test_async_client_has_sub_resources() -> None: - client = AsyncEngramClient() + client = AsyncEngramClient(api_key="test-key") assert hasattr(client, "memories") assert hasattr(client, "runs") @@ -67,14 +68,30 @@ def _make_client( status_code: int = 200, body: dict[str, Any] | None = None, ) -> AsyncEngramClient: - transport = httpx.MockTransport( + mock = httpx.MockTransport( lambda _: httpx.Response(status_code, json=body if body is not None else {}) ) - return AsyncEngramClient( - base_url="https://test.example.com", - api_key="test-key", - http_client=httpx.AsyncClient(transport=transport), - ) + client = AsyncEngramClient(base_url="https://test.example.com", api_key="test-key") + transport = AsyncHttpTransport(client._config, httpx.AsyncClient(transport=mock)) + client._transport = transport + client.memories._transport = transport + client.runs._transport = transport + return client + + +def _make_client_with_handler( + handler: Any, + *, + base_url: str = "https://test.example.com", + api_key: str = "k", +) -> AsyncEngramClient: + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + client = AsyncEngramClient(base_url=base_url, api_key=api_key) + transport = AsyncHttpTransport(client._config, http_client) + client._transport = transport + client.memories._transport = transport + client.runs._transport = transport + return client # ── memories.add ──────────────────────────────────────────────────────── @@ -117,12 +134,7 @@ def handler(request: httpx.Request) -> httpx.Response: captured.append(request) return httpx.Response(200, json={"run_id": "r1", "status": "pending"}) - http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) - client = AsyncEngramClient( - base_url="https://test.example.com", - api_key="k", - http_client=http_client, - ) + client = _make_client_with_handler(handler) await client.memories.add("hello", user_id="u1", group="g1") body = json.loads(captured[0].content) assert body == { @@ -161,12 +173,7 @@ def handler(request: httpx.Request) -> httpx.Response: captured.append(request) return httpx.Response(200, json=SAMPLE_MEMORY_RESPONSE) - http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) - client = AsyncEngramClient( - base_url="https://test.example.com", - api_key="k", - http_client=http_client, - ) + client = _make_client_with_handler(handler) await client.memories.get("m1", topic="t1", user_id="u1", group="g1") url = captured[0].url assert url.params["topic"] == "t1" @@ -179,13 +186,7 @@ def handler(request: httpx.Request) -> httpx.Response: @pytest.mark.asyncio async def test_delete_memory() -> None: - transport = httpx.MockTransport(lambda _: httpx.Response(204, content=b"")) - http_client = httpx.AsyncClient(transport=transport) - client = AsyncEngramClient( - base_url="https://test.example.com", - api_key="k", - http_client=http_client, - ) + client = _make_client_with_handler(lambda _: httpx.Response(204, content=b"")) await client.memories.delete("m1", topic="t1") @@ -228,12 +229,7 @@ def handler(request: httpx.Request) -> httpx.Response: captured.append(request) return httpx.Response(200, json={"memories": [], "total": 0}) - http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) - client = AsyncEngramClient( - base_url="https://test.example.com", - api_key="k", - http_client=http_client, - ) + client = _make_client_with_handler(handler) await client.memories.search( query="find this", topics=["a"], @@ -285,9 +281,9 @@ async def test_401_raises_authentication_error() -> None: @pytest.mark.asyncio -async def test_404_raises_not_found_error() -> None: +async def test_404_raises_api_error() -> None: client = _make_client(status_code=404, body={"detail": "Not found"}) - with pytest.raises(NotFoundError) as exc_info: + with pytest.raises(APIError) as exc_info: await client.memories.get("missing", topic="t1") assert exc_info.value.status_code == 404 diff --git a/tests/test_client_sync.py b/tests/test_client_sync.py index 0ecd695..4f56dec 100644 --- a/tests/test_client_sync.py +++ b/tests/test_client_sync.py @@ -4,19 +4,20 @@ import httpx import pytest +from engram._http import HttpTransport from engram._models import PreExtractedContent, RetrievalConfig from engram.client import DEFAULT_BASE_URL, EngramClient -from engram.errors import APIError, AuthenticationError, NotFoundError, ValidationError +from engram.errors import APIError, AuthenticationError, ValidationError def test_client_defaults() -> None: - client = EngramClient() + client = EngramClient(api_key="test-key") try: assert client.config.base_url == DEFAULT_BASE_URL assert client.config.timeout == 30.0 assert client.default_headers["Accept"] == "application/json" assert client.default_headers["Content-Type"] == "application/json" - assert "Authorization" not in client.default_headers + assert client.default_headers["Authorization"] == "Bearer test-key" finally: client.close() @@ -49,11 +50,11 @@ def test_client_custom_config_and_header_merging() -> None: def test_client_rejects_non_positive_timeout() -> None: with pytest.raises(ValidationError): - EngramClient(timeout=0) + EngramClient(api_key="test-key", timeout=0) def test_client_has_sub_resources() -> None: - client = EngramClient() + client = EngramClient(api_key="test-key") try: assert hasattr(client, "memories") assert hasattr(client, "runs") @@ -68,14 +69,32 @@ def _make_client( status_code: int = 200, body: dict[str, Any] | None = None, ) -> EngramClient: - transport = httpx.MockTransport( + mock = httpx.MockTransport( lambda _: httpx.Response(status_code, json=body if body is not None else {}) ) - return EngramClient( - base_url="https://test.example.com", - api_key="test-key", - http_client=httpx.Client(transport=transport), - ) + client = EngramClient(base_url="https://test.example.com", api_key="test-key") + transport = HttpTransport(client._config, httpx.Client(transport=mock)) + client._transport.close() + client._transport = transport + client.memories._transport = transport + client.runs._transport = transport + return client + + +def _make_client_with_handler( + handler: Any, + *, + base_url: str = "https://test.example.com", + api_key: str = "k", +) -> EngramClient: + http_client = httpx.Client(transport=httpx.MockTransport(handler)) + client = EngramClient(base_url=base_url, api_key=api_key) + transport = HttpTransport(client._config, http_client) + client._transport.close() + client._transport = transport + client.memories._transport = transport + client.runs._transport = transport + return client # ── memories.add ──────────────────────────────────────────────────────── @@ -114,12 +133,7 @@ def handler(request: httpx.Request) -> httpx.Response: captured.append(request) return httpx.Response(200, json={"run_id": "r1", "status": "pending"}) - http_client = httpx.Client(transport=httpx.MockTransport(handler)) - client = EngramClient( - base_url="https://test.example.com", - api_key="k", - http_client=http_client, - ) + client = _make_client_with_handler(handler) client.memories.add("hello", user_id="u1", group="g1") body = json.loads(captured[0].content) assert body == { @@ -136,12 +150,7 @@ def handler(request: httpx.Request) -> httpx.Response: captured.append(request) return httpx.Response(200, json={"run_id": "r1", "status": "pending"}) - http_client = httpx.Client(transport=httpx.MockTransport(handler)) - client = EngramClient( - base_url="https://test.example.com", - api_key="k", - http_client=http_client, - ) + client = _make_client_with_handler(handler) messages = [{"role": "user", "content": "hi"}] client.memories.add(messages, conversation_id="c1") body = json.loads(captured[0].content) @@ -181,12 +190,7 @@ def handler(request: httpx.Request) -> httpx.Response: captured.append(request) return httpx.Response(200, json=SAMPLE_MEMORY_RESPONSE) - http_client = httpx.Client(transport=httpx.MockTransport(handler)) - client = EngramClient( - base_url="https://test.example.com", - api_key="k", - http_client=http_client, - ) + client = _make_client_with_handler(handler) client.memories.get("m1", topic="t1", user_id="u1", group="g1") url = captured[0].url assert url.params["topic"] == "t1" @@ -198,13 +202,7 @@ def handler(request: httpx.Request) -> httpx.Response: def test_delete_memory() -> None: - transport = httpx.MockTransport(lambda _: httpx.Response(204, content=b"")) - http_client = httpx.Client(transport=transport) - client = EngramClient( - base_url="https://test.example.com", - api_key="k", - http_client=http_client, - ) + client = _make_client_with_handler(lambda _: httpx.Response(204, content=b"")) client.memories.delete("m1", topic="t1") @@ -244,12 +242,7 @@ def handler(request: httpx.Request) -> httpx.Response: captured.append(request) return httpx.Response(200, json={"memories": [], "total": 0}) - http_client = httpx.Client(transport=httpx.MockTransport(handler)) - client = EngramClient( - base_url="https://test.example.com", - api_key="k", - http_client=http_client, - ) + client = _make_client_with_handler(handler) client.memories.search( query="find this", topics=["a"], @@ -269,12 +262,7 @@ def handler(request: httpx.Request) -> httpx.Response: captured.append(request) return httpx.Response(200, json={"memories": [], "total": 0}) - http_client = httpx.Client(transport=httpx.MockTransport(handler)) - client = EngramClient( - base_url="https://test.example.com", - api_key="k", - http_client=http_client, - ) + client = _make_client_with_handler(handler) client.memories.search(query="test") body = json.loads(captured[0].content) assert body["retrieval_config"] == { @@ -319,9 +307,9 @@ def test_401_raises_authentication_error() -> None: assert "Invalid token" in str(exc_info.value) -def test_404_raises_not_found_error() -> None: +def test_404_raises_api_error() -> None: client = _make_client(status_code=404, body={"detail": "Not found"}) - with pytest.raises(NotFoundError) as exc_info: + with pytest.raises(APIError) as exc_info: client.memories.get("missing", topic="t1") assert exc_info.value.status_code == 404 diff --git a/tests/test_imports.py b/tests/test_imports.py index fb8236e..3580309 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -11,7 +11,6 @@ def test_public_imports() -> None: EngramError, EngramTimeoutError, Memory, - NotFoundError, PreExtractedContent, RetrievalConfig, Run, @@ -25,7 +24,6 @@ def test_public_imports() -> None: assert isinstance(EngramError, type) assert isinstance(APIError, type) assert isinstance(AuthenticationError, type) - assert isinstance(NotFoundError, type) assert isinstance(ValidationError, type) assert isinstance(EngramTimeoutError, type) assert isinstance(Memory, type) @@ -48,7 +46,6 @@ def test_public_imports() -> None: "EngramError", "EngramTimeoutError", "Memory", - "NotFoundError", "PreExtractedContent", "RetrievalConfig", "Run", From 2f8017e9550e9e1f1eb7504e3f5b897920f3b1b6 Mon Sep 17 00:00:00 2001 From: Charles Pierse Date: Fri, 20 Feb 2026 11:59:26 +0000 Subject: [PATCH 6/6] Update release script --- .github/workflows/release.yml | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ce4f1c0..3e18b6e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -10,7 +10,6 @@ jobs: name: Build and Publish runs-on: ubuntu-latest permissions: - id-token: write contents: read steps: - name: Checkout @@ -49,11 +48,23 @@ jobs: print(f"Version mismatch: pyproject={project_version} tag={tag_version}") sys.exit(1) + # Also verify version.py is in sync + version_file = Path("src/engram/version.py") + namespace: dict = {} + exec(version_file.read_text(), namespace) + module_version = namespace["__version__"] + + if module_version != project_version: + print(f"Version mismatch: version.py={module_version} pyproject={project_version}") + sys.exit(1) + print(f"Version check passed: {project_version}") PY - name: Build package run: uv build --no-sources - - name: Publish to PyPI (OIDC) - uses: pypa/gh-action-pypi-publish@release/v1 + - name: Publish to PyPI + env: + UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }} + run: uv publish