diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ce4f1c0..3e18b6e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -10,7 +10,6 @@ jobs: name: Build and Publish runs-on: ubuntu-latest permissions: - id-token: write contents: read steps: - name: Checkout @@ -49,11 +48,23 @@ jobs: print(f"Version mismatch: pyproject={project_version} tag={tag_version}") sys.exit(1) + # Also verify version.py is in sync + version_file = Path("src/engram/version.py") + namespace: dict = {} + exec(version_file.read_text(), namespace) + module_version = namespace["__version__"] + + if module_version != project_version: + print(f"Version mismatch: version.py={module_version} pyproject={project_version}") + sys.exit(1) + print(f"Version check passed: {project_version}") PY - name: Build package run: uv build --no-sources - - name: Publish to PyPI (OIDC) - uses: pypa/gh-action-pypi-publish@release/v1 + - name: Publish to PyPI + env: + UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }} + run: uv publish diff --git a/src/engram/__init__.py b/src/engram/__init__.py index 7e481cb..b2d1aeb 100644 --- a/src/engram/__init__.py +++ b/src/engram/__init__.py @@ -1,14 +1,41 @@ +from ._models import ( + CommittedOperation, + CommittedOperations, + Memory, + PreExtractedContent, + RetrievalConfig, + Run, + RunStatus, + SearchResults, +) from .async_client import AsyncEngramClient from .client import EngramClient -from .errors import APIError, AuthError, EngramError, ValidationError +from .errors import ( + APIError, + AuthenticationError, + ConnectionError, + EngramError, + EngramTimeoutError, + ValidationError, +) from .version import __version__ __all__ = [ "APIError", "AsyncEngramClient", - "AuthError", + "AuthenticationError", + "CommittedOperation", + "CommittedOperations", + "ConnectionError", "EngramClient", "EngramError", + "EngramTimeoutError", + "Memory", + "PreExtractedContent", + "RetrievalConfig", + "Run", + "RunStatus", + "SearchResults", "ValidationError", "__version__", ] diff --git a/src/engram/_base_client.py b/src/engram/_base_client.py index a691ae6..93cad73 100644 --- a/src/engram/_base_client.py +++ b/src/engram/_base_client.py @@ -1,9 +1,6 @@ from __future__ import annotations from collections.abc import Mapping -from typing import Any - -import httpx from .errors import ValidationError from .types import ClientConfig @@ -14,16 +11,13 @@ class _BaseClient: - """Shared client behavior for sync and async clients.""" - - _http_client: httpx.Client | httpx.AsyncClient - _owns_http_client: bool + """Shared config and header logic for sync and async clients.""" def __init__( self, *, base_url: str = DEFAULT_BASE_URL, - api_key: str | None = None, + api_key: str, headers: Mapping[str, str] | None = None, timeout: float = DEFAULT_TIMEOUT, ) -> None: @@ -31,8 +25,7 @@ def __init__( raise ValidationError("Timeout must be greater than 0.") normalized_base_url = base_url.rstrip("/") - header_overrides = headers if headers is not None else {} - default_headers = _build_headers(api_key=api_key, header_overrides=header_overrides) + default_headers = _build_headers(api_key=api_key, header_overrides=headers or {}) self._config = ClientConfig( base_url=normalized_base_url, @@ -49,46 +42,17 @@ def config(self) -> ClientConfig: def default_headers(self) -> dict[str, str]: return dict(self._config.headers) - def build_request( - self, - method: str, - path: str, - *, - headers: Mapping[str, str] | None = None, - params: Mapping[str, Any] | None = None, - json: Any | None = None, - ) -> httpx.Request: - merged_headers = self.default_headers - if headers: - merged_headers.update(headers) - - return self._http_client.build_request( - method=method, - url=_build_url(self._config.base_url, path), - headers=merged_headers, - params=params, - json=json, - ) - def _build_headers( *, - api_key: str | None, + api_key: str, header_overrides: Mapping[str, str], ) -> dict[str, str]: headers: dict[str, str] = { "Accept": "application/json", "Content-Type": "application/json", "User-Agent": f"weaviate-engram/{__version__}", + "Authorization": f"Bearer {api_key}", } - if api_key: - headers["Authorization"] = f"Bearer {api_key}" headers.update(header_overrides) return headers - - -def _build_url(base_url: str, path: str) -> str: - clean_path = path.lstrip("/") - if not clean_path: - return base_url - return f"{base_url}/{clean_path}" diff --git a/src/engram/_http.py b/src/engram/_http.py new file mode 100644 index 0000000..886ac4a --- /dev/null +++ b/src/engram/_http.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +import httpx + +from .errors import APIError, AuthenticationError +from .errors import ConnectionError as EngramConnectionError +from .types import ClientConfig + + +class HttpTransport: + """Wraps a sync httpx.Client and handles request building and response processing.""" + + def __init__(self, config: ClientConfig, http_client: httpx.Client | None = None) -> None: + self._config = config + self._owns_http_client = http_client is None + self._http_client = http_client or httpx.Client(timeout=config.timeout) + + def close(self) -> None: + if self._owns_http_client: + self._http_client.close() + + def request( + self, + method: str, + path: str, + *, + params: Mapping[str, Any] | None = None, + json: Any | None = None, + ) -> dict[str, Any]: + req = self.build_request(method, path, params=params, json=json) + try: + response = self._http_client.send(req) + except httpx.ConnectError as exc: + raise EngramConnectionError(str(exc)) from exc + return _process_response(response) + + def build_request( + self, + method: str, + path: str, + *, + headers: Mapping[str, str] | None = None, + params: Mapping[str, Any] | None = None, + json: Any | None = None, + ) -> httpx.Request: + merged_headers = dict(self._config.headers) + if headers: + merged_headers.update(headers) + clean_path = path.lstrip("/") + url = f"{self._config.base_url}/{clean_path}" if clean_path else self._config.base_url + return self._http_client.build_request( + method=method, + url=url, + headers=merged_headers, + params=params, + json=json, + ) + + +class AsyncHttpTransport: + """Wraps an async httpx.AsyncClient and handles request building and response processing.""" + + def __init__(self, config: ClientConfig, http_client: httpx.AsyncClient | None = None) -> None: + self._config = config + self._owns_http_client = http_client is None + self._http_client = http_client or httpx.AsyncClient(timeout=config.timeout) + + async def close(self) -> None: + if self._owns_http_client: + await self._http_client.aclose() + + async def request( + self, + method: str, + path: str, + *, + params: Mapping[str, Any] | None = None, + json: Any | None = None, + ) -> dict[str, Any]: + req = self.build_request(method, path, params=params, json=json) + try: + response = await self._http_client.send(req) + except httpx.ConnectError as exc: + raise EngramConnectionError(str(exc)) from exc + return _process_response(response) + + def build_request( + self, + method: str, + path: str, + *, + headers: Mapping[str, str] | None = None, + params: Mapping[str, Any] | None = None, + json: Any | None = None, + ) -> httpx.Request: + merged_headers = dict(self._config.headers) + if headers: + merged_headers.update(headers) + clean_path = path.lstrip("/") + url = f"{self._config.base_url}/{clean_path}" if clean_path else self._config.base_url + return self._http_client.build_request( + method=method, + url=url, + headers=merged_headers, + params=params, + json=json, + ) + + +def _process_response(response: httpx.Response) -> dict[str, Any]: + data = _safe_json(response) + + if response.status_code == 401: + detail = _extract_detail(data, "Authentication failed") + raise AuthenticationError(detail, body=data) + + if response.status_code >= 400: + detail = _extract_detail(data, response.reason_phrase) + raise APIError(detail, status_code=response.status_code, body=data) + + if data is None: + return {} + if isinstance(data, dict): + return data + return {"data": data} + + +def _extract_detail(data: Any, fallback: str) -> str: + if isinstance(data, dict): + return str(data.get("detail", fallback)) + if data: + return str(data) + return fallback + + +def _safe_json(response: httpx.Response) -> Any: + if not response.content: + return None + try: + return response.json() + except Exception: + return None diff --git a/src/engram/_models/__init__.py b/src/engram/_models/__init__.py new file mode 100644 index 0000000..1335c37 --- /dev/null +++ b/src/engram/_models/__init__.py @@ -0,0 +1,14 @@ +from .memory import AddContent, Memory, PreExtractedContent, RetrievalConfig, SearchResults +from .run import CommittedOperation, CommittedOperations, Run, RunStatus + +__all__ = [ + "AddContent", + "CommittedOperation", + "CommittedOperations", + "Memory", + "PreExtractedContent", + "RetrievalConfig", + "Run", + "RunStatus", + "SearchResults", +] diff --git a/src/engram/_models/memory.py b/src/engram/_models/memory.py new file mode 100644 index 0000000..2a4f6f0 --- /dev/null +++ b/src/engram/_models/memory.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from collections.abc import Iterator, Sequence +from dataclasses import dataclass, field +from typing import Literal, TypeAlias + + +@dataclass(slots=True) +class PreExtractedContent: + """Pre-extracted content that bypasses the extraction pipeline.""" + + content: str + tags: list[str] = field(default_factory=list) + + +# Type alias for the content argument to memories.add() +AddContent: TypeAlias = str | list[dict[str, str]] | PreExtractedContent + + +@dataclass(slots=True) +class RetrievalConfig: + retrieval_type: Literal["vector", "bm25", "hybrid"] = "hybrid" + limit: int = 10 + + +@dataclass(slots=True) +class Memory: + id: str + project_id: str + content: str + topic: str + group: str + created_at: str + updated_at: str + user_id: str | None = None + conversation_id: str | None = None + tags: list[str] | None = None + score: float | None = None + + +class SearchResults(Sequence[Memory]): + """List-like wrapper over search results with a total count.""" + + def __init__(self, memories: list[Memory], total: int) -> None: + self._memories = memories + self.total = total + + def __getitem__(self, index: int) -> Memory: # type: ignore[override] + return self._memories[index] + + def __len__(self) -> int: + return len(self._memories) + + def __iter__(self) -> Iterator[Memory]: + return iter(self._memories) + + def __repr__(self) -> str: + return f"SearchResults(total={self.total}, returned={len(self._memories)})" diff --git a/src/engram/_models/run.py b/src/engram/_models/run.py new file mode 100644 index 0000000..b7eeda7 --- /dev/null +++ b/src/engram/_models/run.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass(slots=True) +class Run: + """Returned from memories.add() — represents a pipeline run.""" + + run_id: str + status: str + error: str | None = None + + +@dataclass(slots=True) +class CommittedOperation: + memory_id: str + committed_at: str + + +@dataclass(slots=True) +class CommittedOperations: + created: list[CommittedOperation] = field(default_factory=list) + updated: list[CommittedOperation] = field(default_factory=list) + deleted: list[CommittedOperation] = field(default_factory=list) + + +@dataclass(slots=True) +class RunStatus: + run_id: str + status: str + group_id: str + starting_step: int + input_type: str + created_at: str + updated_at: str + committed_operations: CommittedOperations | None = None + error: str | None = None + + @property + def memories_created(self) -> list[CommittedOperation]: + if self.committed_operations is None: + return [] + return self.committed_operations.created + + @property + def memories_updated(self) -> list[CommittedOperation]: + if self.committed_operations is None: + return [] + return self.committed_operations.updated + + @property + def memories_deleted(self) -> list[CommittedOperation]: + if self.committed_operations is None: + return [] + return self.committed_operations.deleted diff --git a/src/engram/_resources/__init__.py b/src/engram/_resources/__init__.py new file mode 100644 index 0000000..7ea362c --- /dev/null +++ b/src/engram/_resources/__init__.py @@ -0,0 +1,4 @@ +from .memories import AsyncMemories, Memories +from .runs import AsyncRuns, Runs + +__all__ = ["AsyncMemories", "AsyncRuns", "Memories", "Runs"] diff --git a/src/engram/_resources/memories.py b/src/engram/_resources/memories.py new file mode 100644 index 0000000..7ca41a8 --- /dev/null +++ b/src/engram/_resources/memories.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +from .._http import AsyncHttpTransport, HttpTransport +from .._models import AddContent, Memory, RetrievalConfig, Run, SearchResults +from .._serialization import ( + build_add_body, + build_memory_params, + build_search_body, + parse_memory, + parse_run, + parse_search_results, +) + +_MEMORIES_PATH = "/v1/memories" +_MEMORIES_SEARCH_PATH = "/v1/memories/search" + + +def _memory_path(memory_id: str) -> str: + return f"{_MEMORIES_PATH}/{memory_id}" + + +class Memories: + """Sync sub-resource for memory operations: client.memories.*""" + + def __init__(self, transport: HttpTransport) -> None: + self._transport = transport + + def add( + self, + content: AddContent, + *, + user_id: str | None = None, + conversation_id: str | None = None, + group: str | None = None, + ) -> Run: + body = build_add_body( + content, + user_id=user_id, + conversation_id=conversation_id, + group=group, + ) + data = self._transport.request("POST", _MEMORIES_PATH, json=body) + return parse_run(data) + + def get( + self, + memory_id: str, + *, + topic: str, + user_id: str | None = None, + conversation_id: str | None = None, + group: str | None = None, + ) -> Memory: + params = build_memory_params( + topic=topic, + user_id=user_id, + conversation_id=conversation_id, + group=group, + ) + data = self._transport.request("GET", _memory_path(memory_id), params=params) + return parse_memory(data) + + def delete( + self, + memory_id: str, + *, + topic: str, + user_id: str | None = None, + conversation_id: str | None = None, + group: str | None = None, + ) -> None: + params = build_memory_params( + topic=topic, + user_id=user_id, + conversation_id=conversation_id, + group=group, + ) + self._transport.request("DELETE", _memory_path(memory_id), params=params) + + def search( + self, + *, + query: str, + topics: list[str] | None = None, + user_id: str | None = None, + conversation_id: str | None = None, + group: str | None = None, + retrieval_config: RetrievalConfig | None = None, + ) -> SearchResults: + body = build_search_body( + query=query, + topics=topics, + user_id=user_id, + conversation_id=conversation_id, + group=group, + retrieval_config=retrieval_config or RetrievalConfig(), + ) + data = self._transport.request("POST", _MEMORIES_SEARCH_PATH, json=body) + return parse_search_results(data) + + +class AsyncMemories: + """Async sub-resource for memory operations: client.memories.*""" + + def __init__(self, transport: AsyncHttpTransport) -> None: + self._transport = transport + + async def add( + self, + content: AddContent, + *, + user_id: str | None = None, + conversation_id: str | None = None, + group: str | None = None, + ) -> Run: + body = build_add_body( + content, + user_id=user_id, + conversation_id=conversation_id, + group=group, + ) + data = await self._transport.request("POST", _MEMORIES_PATH, json=body) + return parse_run(data) + + async def get( + self, + memory_id: str, + *, + topic: str, + user_id: str | None = None, + conversation_id: str | None = None, + group: str | None = None, + ) -> Memory: + params = build_memory_params( + topic=topic, + user_id=user_id, + conversation_id=conversation_id, + group=group, + ) + data = await self._transport.request("GET", _memory_path(memory_id), params=params) + return parse_memory(data) + + async def delete( + self, + memory_id: str, + *, + topic: str, + user_id: str | None = None, + conversation_id: str | None = None, + group: str | None = None, + ) -> None: + params = build_memory_params( + topic=topic, + user_id=user_id, + conversation_id=conversation_id, + group=group, + ) + await self._transport.request("DELETE", _memory_path(memory_id), params=params) + + async def search( + self, + *, + query: str, + topics: list[str] | None = None, + user_id: str | None = None, + conversation_id: str | None = None, + group: str | None = None, + retrieval_config: RetrievalConfig | None = None, + ) -> SearchResults: + body = build_search_body( + query=query, + topics=topics, + user_id=user_id, + conversation_id=conversation_id, + group=group, + retrieval_config=retrieval_config or RetrievalConfig(), + ) + data = await self._transport.request("POST", _MEMORIES_SEARCH_PATH, json=body) + return parse_search_results(data) diff --git a/src/engram/_resources/runs.py b/src/engram/_resources/runs.py new file mode 100644 index 0000000..d42e1ca --- /dev/null +++ b/src/engram/_resources/runs.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import asyncio +import time + +from .._http import AsyncHttpTransport, HttpTransport +from .._models import RunStatus +from .._serialization import parse_run_status +from ..errors import EngramTimeoutError + +_RUNS_PATH = "/v1/runs" + +_TERMINAL_STATUSES = frozenset(("completed", "failed")) + + +def _run_path(run_id: str) -> str: + return f"{_RUNS_PATH}/{run_id}" + + +class Runs: + """Sync sub-resource for run operations: client.runs.*""" + + def __init__(self, transport: HttpTransport) -> None: + self._transport = transport + + def get(self, run_id: str) -> RunStatus: + data = self._transport.request("GET", _run_path(run_id)) + return parse_run_status(data) + + def wait( + self, + run_id: str, + *, + timeout: float = 30.0, + interval: float = 0.5, + ) -> RunStatus: + deadline = time.monotonic() + timeout + while True: + status = self.get(run_id) + if status.status in _TERMINAL_STATUSES: + return status + if time.monotonic() + interval > deadline: + raise EngramTimeoutError(run_id, timeout) + time.sleep(interval) + + +class AsyncRuns: + """Async sub-resource for run operations: client.runs.*""" + + def __init__(self, transport: AsyncHttpTransport) -> None: + self._transport = transport + + async def get(self, run_id: str) -> RunStatus: + data = await self._transport.request("GET", _run_path(run_id)) + return parse_run_status(data) + + async def wait( + self, + run_id: str, + *, + timeout: float = 30.0, + interval: float = 0.5, + ) -> RunStatus: + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout + while True: + status = await self.get(run_id) + if status.status in _TERMINAL_STATUSES: + return status + if loop.time() + interval > deadline: + raise EngramTimeoutError(run_id, timeout) + await asyncio.sleep(interval) diff --git a/src/engram/_serialization/__init__.py b/src/engram/_serialization/__init__.py new file mode 100644 index 0000000..ef0a243 --- /dev/null +++ b/src/engram/_serialization/__init__.py @@ -0,0 +1,21 @@ +from ._builders import ( + build_add_body, + build_memory_params, + build_search_body, +) +from ._parsers import ( + parse_memory, + parse_run, + parse_run_status, + parse_search_results, +) + +__all__ = [ + "build_add_body", + "build_memory_params", + "build_search_body", + "parse_memory", + "parse_run", + "parse_run_status", + "parse_search_results", +] diff --git a/src/engram/_serialization/_builders.py b/src/engram/_serialization/_builders.py new file mode 100644 index 0000000..304ccf1 --- /dev/null +++ b/src/engram/_serialization/_builders.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import Any + +from .._models import AddContent, PreExtractedContent, RetrievalConfig + + +def _serialize_content(content: AddContent) -> dict[str, Any]: + """Build the content envelope with the type discriminator.""" + if isinstance(content, str): + return {"type": "string", "content": content} + if isinstance(content, PreExtractedContent): + d: dict[str, Any] = {"type": "pre_extracted", "content": content.content} + if content.tags: + d["tags"] = content.tags + return d + if isinstance(content, list): + return { + "type": "conversation", + "conversation": {"messages": content}, + } + raise TypeError(f"Unsupported content type: {type(content)}") # pragma: no cover + + +def build_add_body( + content: AddContent, + *, + user_id: str | None, + conversation_id: str | None, + group: str | None, +) -> dict[str, Any]: + body: dict[str, Any] = {"content": _serialize_content(content)} + if user_id is not None: + body["user_id"] = user_id + if conversation_id is not None: + body["conversation_id"] = conversation_id + if group is not None: + body["group"] = group + return body + + +def build_memory_params( + *, + topic: str, + user_id: str | None, + conversation_id: str | None, + group: str | None, +) -> dict[str, str]: + params: dict[str, str] = {"topic": topic} + if user_id is not None: + params["user_id"] = user_id + if conversation_id is not None: + params["conversation_id"] = conversation_id + if group is not None: + params["group"] = group + return params + + +def build_search_body( + *, + query: str, + topics: list[str] | None, + user_id: str | None, + conversation_id: str | None, + group: str | None, + retrieval_config: RetrievalConfig, +) -> dict[str, Any]: + body: dict[str, Any] = { + "query": query, + "retrieval_config": { + "retrieval_type": retrieval_config.retrieval_type, + "limit": retrieval_config.limit, + }, + } + if topics is not None: + body["topics"] = topics + if user_id is not None: + body["user_id"] = user_id + if conversation_id is not None: + body["conversation_id"] = conversation_id + if group is not None: + body["group"] = group + return body diff --git a/src/engram/_serialization/_parsers.py b/src/engram/_serialization/_parsers.py new file mode 100644 index 0000000..3e58ecb --- /dev/null +++ b/src/engram/_serialization/_parsers.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from typing import Any + +from .._models import ( + CommittedOperation, + CommittedOperations, + Memory, + Run, + RunStatus, + SearchResults, +) + + +def parse_run(data: dict[str, Any]) -> Run: + return Run( + run_id=data["run_id"], + status=data["status"], + error=data.get("error"), + ) + + +def parse_memory(data: dict[str, Any]) -> Memory: + return Memory( + id=data["id"], + project_id=data["project_id"], + content=data["content"], + topic=data["topic"], + group=data["group"], + created_at=data["created_at"], + updated_at=data["updated_at"], + user_id=data.get("user_id"), + conversation_id=data.get("conversation_id"), + tags=data.get("tags"), + score=data.get("score"), + ) + + +def parse_search_results(data: dict[str, Any]) -> SearchResults: + return SearchResults( + memories=[parse_memory(m["Body"]) for m in data["memories"]], + total=data["total"], + ) + + +def _parse_committed_operation(data: dict[str, Any]) -> CommittedOperation: + return CommittedOperation( + memory_id=data["memory_id"], + committed_at=data["committed_at"], + ) + + +def _parse_committed_operations(data: dict[str, Any]) -> CommittedOperations: + return CommittedOperations( + created=[_parse_committed_operation(op) for op in data.get("created", [])], + updated=[_parse_committed_operation(op) for op in data.get("updated", [])], + deleted=[_parse_committed_operation(op) for op in data.get("deleted", [])], + ) + + +def parse_run_status(data: dict[str, Any]) -> RunStatus: + committed_ops = data.get("committed_operations") + return RunStatus( + run_id=data["run_id"], + status=data["status"], + group_id=data["group_id"], + starting_step=data["starting_step"], + input_type=data["input_type"], + created_at=data["created_at"], + updated_at=data["updated_at"], + committed_operations=_parse_committed_operations(committed_ops) + if committed_ops is not None + else None, + error=data.get("error"), + ) diff --git a/src/engram/async_client.py b/src/engram/async_client.py index 32d55e8..d52f75a 100644 --- a/src/engram/async_client.py +++ b/src/engram/async_client.py @@ -2,26 +2,27 @@ from collections.abc import Mapping -import httpx - from ._base_client import DEFAULT_BASE_URL, DEFAULT_TIMEOUT, _BaseClient +from ._http import AsyncHttpTransport +from ._resources import AsyncMemories, AsyncRuns __all__ = ["DEFAULT_BASE_URL", "DEFAULT_TIMEOUT", "AsyncEngramClient"] class AsyncEngramClient(_BaseClient): - """Asynchronous Engram client""" + """Asynchronous Engram client.""" - _http_client: httpx.AsyncClient + _transport: AsyncHttpTransport + memories: AsyncMemories + runs: AsyncRuns def __init__( self, *, base_url: str = DEFAULT_BASE_URL, - api_key: str | None = None, + api_key: str, headers: Mapping[str, str] | None = None, timeout: float = DEFAULT_TIMEOUT, - http_client: httpx.AsyncClient | None = None, ) -> None: super().__init__( base_url=base_url, @@ -29,12 +30,12 @@ def __init__( headers=headers, timeout=timeout, ) - self._owns_http_client = http_client is None - self._http_client = http_client or httpx.AsyncClient(timeout=timeout) + self._transport = AsyncHttpTransport(self._config) + self.memories = AsyncMemories(self._transport) + self.runs = AsyncRuns(self._transport) async def aclose(self) -> None: - if self._owns_http_client: - await self._http_client.aclose() + await self._transport.close() async def __aenter__(self) -> AsyncEngramClient: return self diff --git a/src/engram/client.py b/src/engram/client.py index e12b0ba..7a6359e 100644 --- a/src/engram/client.py +++ b/src/engram/client.py @@ -2,26 +2,27 @@ from collections.abc import Mapping -import httpx - from ._base_client import DEFAULT_BASE_URL, DEFAULT_TIMEOUT, _BaseClient +from ._http import HttpTransport +from ._resources import Memories, Runs __all__ = ["DEFAULT_BASE_URL", "DEFAULT_TIMEOUT", "EngramClient"] class EngramClient(_BaseClient): - """Synchronous Engram client""" + """Synchronous Engram client.""" - _http_client: httpx.Client + _transport: HttpTransport + memories: Memories + runs: Runs def __init__( self, *, base_url: str = DEFAULT_BASE_URL, - api_key: str | None = None, + api_key: str, headers: Mapping[str, str] | None = None, timeout: float = DEFAULT_TIMEOUT, - http_client: httpx.Client | None = None, ) -> None: super().__init__( base_url=base_url, @@ -29,12 +30,12 @@ def __init__( headers=headers, timeout=timeout, ) - self._owns_http_client = http_client is None - self._http_client = http_client or httpx.Client(timeout=timeout) + self._transport = HttpTransport(self._config) + self.memories = Memories(self._transport) + self.runs = Runs(self._transport) def close(self) -> None: - if self._owns_http_client: - self._http_client.close() + self._transport.close() def __enter__(self) -> EngramClient: return self diff --git a/src/engram/errors.py b/src/engram/errors.py index d071ae5..f00a7b6 100644 --- a/src/engram/errors.py +++ b/src/engram/errors.py @@ -22,9 +22,25 @@ def __init__( self.body = body -class AuthError(APIError): - """Raised when authentication fails.""" +class AuthenticationError(APIError): + """Raised when authentication fails (401).""" + + def __init__(self, message: str, *, body: object = None) -> None: + super().__init__(message, status_code=401, body=body) class ValidationError(EngramError): """Raised when configuration or request input is invalid.""" + + +class ConnectionError(EngramError): + """Raised when a connection to the Engram server fails.""" + + +class EngramTimeoutError(EngramError): + """Raised when a run does not reach a terminal status within the timeout.""" + + def __init__(self, run_id: str, timeout: float) -> None: + super().__init__(f"Run {run_id!r} did not reach a terminal status within {timeout}s") + self.run_id = run_id + self.timeout = timeout diff --git a/src/engram/types.py b/src/engram/types.py index 401d574..2feb49f 100644 --- a/src/engram/types.py +++ b/src/engram/types.py @@ -1,7 +1,6 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any @dataclass(slots=True) @@ -10,17 +9,3 @@ class ClientConfig: timeout: float headers: dict[str, str] = field(default_factory=dict) api_key: str | None = None - - -@dataclass(slots=True) -class APIRequest: - method: str - path: str - params: dict[str, Any] | None = None - json: dict[str, Any] | None = None - - -@dataclass(slots=True) -class APIResponse: - status_code: int - data: dict[str, Any] | list[Any] | str | None = None diff --git a/tests/test_client_async.py b/tests/test_client_async.py index c51d8ee..30aa0ff 100644 --- a/tests/test_client_async.py +++ b/tests/test_client_async.py @@ -1,18 +1,24 @@ +import json +from typing import Any + +import httpx import pytest +from engram._http import AsyncHttpTransport +from engram._models import PreExtractedContent, RetrievalConfig from engram.async_client import DEFAULT_BASE_URL, AsyncEngramClient -from engram.errors import ValidationError +from engram.errors import APIError, AuthenticationError, ValidationError @pytest.mark.asyncio async def test_async_client_defaults() -> None: - client = AsyncEngramClient() + client = AsyncEngramClient(api_key="test-key") try: assert client.config.base_url == DEFAULT_BASE_URL assert client.config.timeout == 30.0 assert client.default_headers["Accept"] == "application/json" assert client.default_headers["Content-Type"] == "application/json" - assert "Authorization" not in client.default_headers + assert client.default_headers["Authorization"] == "Bearer test-key" finally: await client.aclose() @@ -31,7 +37,7 @@ async def test_async_client_custom_config_and_header_merging() -> None: assert client.default_headers["Authorization"] == "Bearer test-key" assert client.default_headers["X-Custom"] == "from-client" - request = client.build_request( + request = client._transport.build_request( "GET", "v1/items", headers={"X-Custom": "from-request", "X-Request": "value"}, @@ -46,4 +52,253 @@ async def test_async_client_custom_config_and_header_merging() -> None: def test_async_client_rejects_non_positive_timeout() -> None: with pytest.raises(ValidationError): - AsyncEngramClient(timeout=-1) + AsyncEngramClient(api_key="test-key", timeout=-1) + + +def test_async_client_has_sub_resources() -> None: + client = AsyncEngramClient(api_key="test-key") + assert hasattr(client, "memories") + assert hasattr(client, "runs") + + +# ── Helpers ───────────────────────────────────────────────────────────── + + +def _make_client( + status_code: int = 200, + body: dict[str, Any] | None = None, +) -> AsyncEngramClient: + mock = httpx.MockTransport( + lambda _: httpx.Response(status_code, json=body if body is not None else {}) + ) + client = AsyncEngramClient(base_url="https://test.example.com", api_key="test-key") + transport = AsyncHttpTransport(client._config, httpx.AsyncClient(transport=mock)) + client._transport = transport + client.memories._transport = transport + client.runs._transport = transport + return client + + +def _make_client_with_handler( + handler: Any, + *, + base_url: str = "https://test.example.com", + api_key: str = "k", +) -> AsyncEngramClient: + http_client = httpx.AsyncClient(transport=httpx.MockTransport(handler)) + client = AsyncEngramClient(base_url=base_url, api_key=api_key) + transport = AsyncHttpTransport(client._config, http_client) + client._transport = transport + client.memories._transport = transport + client.runs._transport = transport + return client + + +# ── memories.add ──────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_add_str() -> None: + client = _make_client(body={"run_id": "r1", "status": "pending"}) + result = await client.memories.add("hello") + assert result.run_id == "r1" + assert result.status == "pending" + + +@pytest.mark.asyncio +async def test_add_pre_extracted() -> None: + client = _make_client(body={"run_id": "r2", "status": "pending"}) + result = await client.memories.add( + PreExtractedContent(content="fact", tags=["a"]), + user_id="u1", + ) + assert result.run_id == "r2" + + +@pytest.mark.asyncio +async def test_add_conversation() -> None: + client = _make_client(body={"run_id": "r3", "status": "pending"}) + result = await client.memories.add( + [{"role": "user", "content": "hi"}], + user_id="u1", + conversation_id="c1", + ) + assert result.run_id == "r3" + + +@pytest.mark.asyncio +async def test_add_sends_content_envelope() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"run_id": "r1", "status": "pending"}) + + client = _make_client_with_handler(handler) + await client.memories.add("hello", user_id="u1", group="g1") + body = json.loads(captured[0].content) + assert body == { + "content": {"type": "string", "content": "hello"}, + "user_id": "u1", + "group": "g1", + } + + +# ── memories.get ──────────────────────────────────────────────────────── + +SAMPLE_MEMORY_RESPONSE: dict[str, Any] = { + "id": "m1", + "project_id": "p1", + "content": "some content", + "topic": "t1", + "group": "g1", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", +} + + +@pytest.mark.asyncio +async def test_get_memory() -> None: + client = _make_client(body=SAMPLE_MEMORY_RESPONSE) + mem = await client.memories.get("m1", topic="t1") + assert mem.id == "m1" + assert mem.content == "some content" + + +@pytest.mark.asyncio +async def test_get_memory_sends_query_params() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json=SAMPLE_MEMORY_RESPONSE) + + client = _make_client_with_handler(handler) + await client.memories.get("m1", topic="t1", user_id="u1", group="g1") + url = captured[0].url + assert url.params["topic"] == "t1" + assert url.params["user_id"] == "u1" + assert url.params["group"] == "g1" + + +# ── memories.delete ───────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_delete_memory() -> None: + client = _make_client_with_handler(lambda _: httpx.Response(204, content=b"")) + await client.memories.delete("m1", topic="t1") + + +# ── memories.search ───────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_search_memories() -> None: + response_body: dict[str, Any] = { + "memories": [{"Body": SAMPLE_MEMORY_RESPONSE}], + "total": 1, + } + client = _make_client(body=response_body) + result = await client.memories.search(query="test") + assert result.total == 1 + assert len(result) == 1 + + +@pytest.mark.asyncio +async def test_search_memories_iterable() -> None: + response_body: dict[str, Any] = { + "memories": [ + {"Body": SAMPLE_MEMORY_RESPONSE}, + {"Body": {**SAMPLE_MEMORY_RESPONSE, "id": "m2", "score": 0.85}}, + ], + "total": 2, + } + client = _make_client(body=response_body) + results = await client.memories.search(query="test") + ids = [m.id for m in results] + assert ids == ["m1", "m2"] + assert results[1].score == 0.85 + + +@pytest.mark.asyncio +async def test_search_sends_correct_body() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"memories": [], "total": 0}) + + client = _make_client_with_handler(handler) + await client.memories.search( + query="find this", + topics=["a"], + retrieval_config=RetrievalConfig(retrieval_type="vector", limit=5), + ) + body = json.loads(captured[0].content) + assert body["query"] == "find this" + assert body["topics"] == ["a"] + assert body["retrieval_config"]["retrieval_type"] == "vector" + assert body["retrieval_config"]["limit"] == 5 + + +# ── runs.get ──────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_get_run() -> None: + response_body: dict[str, Any] = { + "run_id": "r1", + "status": "completed", + "group_id": "g1", + "starting_step": 0, + "input_type": "string", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", + "committed_operations": { + "created": [{"memory_id": "m1", "committed_at": "2024-01-01T00:00:00Z"}], + "updated": [], + "deleted": [], + }, + } + client = _make_client(body=response_body) + result = await client.runs.get("r1") + assert result.run_id == "r1" + assert result.status == "completed" + assert len(result.memories_created) == 1 + + +# ── Error handling ────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_401_raises_authentication_error() -> None: + client = _make_client(status_code=401, body={"detail": "Invalid token"}) + with pytest.raises(AuthenticationError) as exc_info: + await client.memories.add("hello") + assert exc_info.value.status_code == 401 + assert "Invalid token" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_404_raises_api_error() -> None: + client = _make_client(status_code=404, body={"detail": "Not found"}) + with pytest.raises(APIError) as exc_info: + await client.memories.get("missing", topic="t1") + assert exc_info.value.status_code == 404 + + +@pytest.mark.asyncio +async def test_400_raises_api_error() -> None: + client = _make_client(status_code=400, body={"detail": "Bad request"}) + with pytest.raises(APIError) as exc_info: + await client.memories.search(query="test") + assert exc_info.value.status_code == 400 + + +@pytest.mark.asyncio +async def test_500_raises_api_error() -> None: + client = _make_client(status_code=500, body={"detail": "Internal server error"}) + with pytest.raises(APIError) as exc_info: + await client.runs.get("r1") + assert exc_info.value.status_code == 500 diff --git a/tests/test_client_sync.py b/tests/test_client_sync.py index 7a59da5..4f56dec 100644 --- a/tests/test_client_sync.py +++ b/tests/test_client_sync.py @@ -1,17 +1,23 @@ +import json +from typing import Any + +import httpx import pytest +from engram._http import HttpTransport +from engram._models import PreExtractedContent, RetrievalConfig from engram.client import DEFAULT_BASE_URL, EngramClient -from engram.errors import ValidationError +from engram.errors import APIError, AuthenticationError, ValidationError def test_client_defaults() -> None: - client = EngramClient() + client = EngramClient(api_key="test-key") try: assert client.config.base_url == DEFAULT_BASE_URL assert client.config.timeout == 30.0 assert client.default_headers["Accept"] == "application/json" assert client.default_headers["Content-Type"] == "application/json" - assert "Authorization" not in client.default_headers + assert client.default_headers["Authorization"] == "Bearer test-key" finally: client.close() @@ -29,7 +35,7 @@ def test_client_custom_config_and_header_merging() -> None: assert client.default_headers["Authorization"] == "Bearer test-key" assert client.default_headers["X-Custom"] == "from-client" - request = client.build_request( + request = client._transport.build_request( "GET", "/v1/items", headers={"X-Custom": "from-request", "X-Request": "value"}, @@ -44,4 +50,279 @@ def test_client_custom_config_and_header_merging() -> None: def test_client_rejects_non_positive_timeout() -> None: with pytest.raises(ValidationError): - EngramClient(timeout=0) + EngramClient(api_key="test-key", timeout=0) + + +def test_client_has_sub_resources() -> None: + client = EngramClient(api_key="test-key") + try: + assert hasattr(client, "memories") + assert hasattr(client, "runs") + finally: + client.close() + + +# ── Helpers ───────────────────────────────────────────────────────────── + + +def _make_client( + status_code: int = 200, + body: dict[str, Any] | None = None, +) -> EngramClient: + mock = httpx.MockTransport( + lambda _: httpx.Response(status_code, json=body if body is not None else {}) + ) + client = EngramClient(base_url="https://test.example.com", api_key="test-key") + transport = HttpTransport(client._config, httpx.Client(transport=mock)) + client._transport.close() + client._transport = transport + client.memories._transport = transport + client.runs._transport = transport + return client + + +def _make_client_with_handler( + handler: Any, + *, + base_url: str = "https://test.example.com", + api_key: str = "k", +) -> EngramClient: + http_client = httpx.Client(transport=httpx.MockTransport(handler)) + client = EngramClient(base_url=base_url, api_key=api_key) + transport = HttpTransport(client._config, http_client) + client._transport.close() + client._transport = transport + client.memories._transport = transport + client.runs._transport = transport + return client + + +# ── memories.add ──────────────────────────────────────────────────────── + + +def test_add_str() -> None: + client = _make_client(body={"run_id": "r1", "status": "pending"}) + result = client.memories.add("hello") + assert result.run_id == "r1" + assert result.status == "pending" + + +def test_add_pre_extracted() -> None: + client = _make_client(body={"run_id": "r2", "status": "pending"}) + result = client.memories.add( + PreExtractedContent(content="fact", tags=["a"]), + user_id="u1", + ) + assert result.run_id == "r2" + + +def test_add_conversation() -> None: + client = _make_client(body={"run_id": "r3", "status": "pending"}) + result = client.memories.add( + [{"role": "user", "content": "hi"}], + user_id="u1", + conversation_id="c1", + ) + assert result.run_id == "r3" + + +def test_add_sends_content_envelope() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"run_id": "r1", "status": "pending"}) + + client = _make_client_with_handler(handler) + client.memories.add("hello", user_id="u1", group="g1") + body = json.loads(captured[0].content) + assert body == { + "content": {"type": "string", "content": "hello"}, + "user_id": "u1", + "group": "g1", + } + + +def test_add_conversation_sends_correct_envelope() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"run_id": "r1", "status": "pending"}) + + client = _make_client_with_handler(handler) + messages = [{"role": "user", "content": "hi"}] + client.memories.add(messages, conversation_id="c1") + body = json.loads(captured[0].content) + assert body == { + "content": { + "type": "conversation", + "conversation": {"messages": messages}, + }, + "conversation_id": "c1", + } + + +# ── memories.get ──────────────────────────────────────────────────────── + +SAMPLE_MEMORY_RESPONSE: dict[str, Any] = { + "id": "m1", + "project_id": "p1", + "content": "some content", + "topic": "t1", + "group": "g1", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", +} + + +def test_get_memory() -> None: + client = _make_client(body=SAMPLE_MEMORY_RESPONSE) + mem = client.memories.get("m1", topic="t1") + assert mem.id == "m1" + assert mem.content == "some content" + + +def test_get_memory_sends_query_params() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json=SAMPLE_MEMORY_RESPONSE) + + client = _make_client_with_handler(handler) + client.memories.get("m1", topic="t1", user_id="u1", group="g1") + url = captured[0].url + assert url.params["topic"] == "t1" + assert url.params["user_id"] == "u1" + assert url.params["group"] == "g1" + + +# ── memories.delete ───────────────────────────────────────────────────── + + +def test_delete_memory() -> None: + client = _make_client_with_handler(lambda _: httpx.Response(204, content=b"")) + client.memories.delete("m1", topic="t1") + + +# ── memories.search ───────────────────────────────────────────────────── + + +def test_search_memories() -> None: + response_body: dict[str, Any] = { + "memories": [{"Body": SAMPLE_MEMORY_RESPONSE}], + "total": 1, + } + client = _make_client(body=response_body) + result = client.memories.search(query="test") + assert result.total == 1 + assert len(result) == 1 + + +def test_search_memories_iterable() -> None: + response_body: dict[str, Any] = { + "memories": [ + {"Body": SAMPLE_MEMORY_RESPONSE}, + {"Body": {**SAMPLE_MEMORY_RESPONSE, "id": "m2", "score": 0.85}}, + ], + "total": 2, + } + client = _make_client(body=response_body) + results = client.memories.search(query="test") + ids = [m.id for m in results] + assert ids == ["m1", "m2"] + assert results[1].score == 0.85 + + +def test_search_sends_correct_body() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"memories": [], "total": 0}) + + client = _make_client_with_handler(handler) + client.memories.search( + query="find this", + topics=["a"], + retrieval_config=RetrievalConfig(retrieval_type="vector", limit=5), + ) + body = json.loads(captured[0].content) + assert body["query"] == "find this" + assert body["topics"] == ["a"] + assert body["retrieval_config"]["retrieval_type"] == "vector" + assert body["retrieval_config"]["limit"] == 5 + + +def test_search_default_retrieval_config() -> None: + captured: list[httpx.Request] = [] + + def handler(request: httpx.Request) -> httpx.Response: + captured.append(request) + return httpx.Response(200, json={"memories": [], "total": 0}) + + client = _make_client_with_handler(handler) + client.memories.search(query="test") + body = json.loads(captured[0].content) + assert body["retrieval_config"] == { + "retrieval_type": "hybrid", + "limit": 10, + } + + +# ── runs.get ──────────────────────────────────────────────────────────── + + +def test_get_run() -> None: + response_body: dict[str, Any] = { + "run_id": "r1", + "status": "completed", + "group_id": "g1", + "starting_step": 0, + "input_type": "string", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", + "committed_operations": { + "created": [{"memory_id": "m1", "committed_at": "2024-01-01T00:00:00Z"}], + "updated": [], + "deleted": [], + }, + } + client = _make_client(body=response_body) + result = client.runs.get("r1") + assert result.run_id == "r1" + assert result.status == "completed" + assert len(result.memories_created) == 1 + + +# ── Error handling ────────────────────────────────────────────────────── + + +def test_401_raises_authentication_error() -> None: + client = _make_client(status_code=401, body={"detail": "Invalid token"}) + with pytest.raises(AuthenticationError) as exc_info: + client.memories.add("hello") + assert exc_info.value.status_code == 401 + assert "Invalid token" in str(exc_info.value) + + +def test_404_raises_api_error() -> None: + client = _make_client(status_code=404, body={"detail": "Not found"}) + with pytest.raises(APIError) as exc_info: + client.memories.get("missing", topic="t1") + assert exc_info.value.status_code == 404 + + +def test_400_raises_api_error() -> None: + client = _make_client(status_code=400, body={"detail": "Bad request"}) + with pytest.raises(APIError) as exc_info: + client.memories.search(query="test") + assert exc_info.value.status_code == 400 + + +def test_500_raises_api_error() -> None: + client = _make_client(status_code=500, body={"detail": "Internal server error"}) + with pytest.raises(APIError) as exc_info: + client.runs.get("r1") + assert exc_info.value.status_code == 500 diff --git a/tests/test_imports.py b/tests/test_imports.py index c0e46ca..3580309 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -1,11 +1,21 @@ def test_public_imports() -> None: import engram - from engram import ( + from engram import ( # noqa: F401 APIError, AsyncEngramClient, - AuthError, + AuthenticationError, + CommittedOperation, + CommittedOperations, + ConnectionError, EngramClient, EngramError, + EngramTimeoutError, + Memory, + PreExtractedContent, + RetrievalConfig, + Run, + RunStatus, + SearchResults, ValidationError, ) @@ -13,15 +23,34 @@ def test_public_imports() -> None: assert isinstance(AsyncEngramClient, type) assert isinstance(EngramError, type) assert isinstance(APIError, type) - assert isinstance(AuthError, type) + assert isinstance(AuthenticationError, type) assert isinstance(ValidationError, type) + assert isinstance(EngramTimeoutError, type) + assert isinstance(Memory, type) + assert isinstance(Run, type) + assert isinstance(RunStatus, type) + assert isinstance(SearchResults, type) + assert isinstance(PreExtractedContent, type) + assert isinstance(RetrievalConfig, type) + assert isinstance(CommittedOperation, type) + assert isinstance(CommittedOperations, type) expected_exports = { "APIError", "AsyncEngramClient", - "AuthError", + "AuthenticationError", + "CommittedOperation", + "CommittedOperations", + "ConnectionError", "EngramClient", "EngramError", + "EngramTimeoutError", + "Memory", + "PreExtractedContent", + "RetrievalConfig", + "Run", + "RunStatus", + "SearchResults", "ValidationError", "__version__", } diff --git a/tests/test_serialization.py b/tests/test_serialization.py new file mode 100644 index 0000000..8bf664e --- /dev/null +++ b/tests/test_serialization.py @@ -0,0 +1,256 @@ +from engram._models import PreExtractedContent, RetrievalConfig +from engram._serialization import ( + build_add_body, + build_memory_params, + build_search_body, + parse_memory, + parse_run, + parse_run_status, + parse_search_results, +) + +# ── build_add_body ────────────────────────────────────────────────────── + + +def test_build_add_body_str() -> None: + body = build_add_body( + "hello world", + user_id=None, + conversation_id=None, + group=None, + ) + assert body == {"content": {"type": "string", "content": "hello world"}} + + +def test_build_add_body_str_with_options() -> None: + body = build_add_body( + "hello", + user_id="u1", + conversation_id="c1", + group="g1", + ) + assert body == { + "content": {"type": "string", "content": "hello"}, + "user_id": "u1", + "conversation_id": "c1", + "group": "g1", + } + + +def test_build_add_body_pre_extracted() -> None: + body = build_add_body( + PreExtractedContent(content="fact", tags=["a", "b"]), + user_id=None, + conversation_id=None, + group=None, + ) + assert body == { + "content": {"type": "pre_extracted", "content": "fact", "tags": ["a", "b"]}, + } + + +def test_build_add_body_pre_extracted_no_tags() -> None: + body = build_add_body( + PreExtractedContent(content="fact"), + user_id=None, + conversation_id=None, + group=None, + ) + assert body == { + "content": {"type": "pre_extracted", "content": "fact"}, + } + + +def test_build_add_body_conversation() -> None: + messages = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + body = build_add_body( + messages, + user_id="u1", + conversation_id="c1", + group=None, + ) + assert body == { + "content": { + "type": "conversation", + "conversation": {"messages": messages}, + }, + "user_id": "u1", + "conversation_id": "c1", + } + + +# ── build_memory_params ───────────────────────────────────────────────── + + +def test_build_memory_params_minimal() -> None: + params = build_memory_params(topic="t1", user_id=None, conversation_id=None, group=None) + assert params == {"topic": "t1"} + + +def test_build_memory_params_full() -> None: + params = build_memory_params(topic="t1", user_id="u1", conversation_id="c1", group="g1") + assert params == { + "topic": "t1", + "user_id": "u1", + "conversation_id": "c1", + "group": "g1", + } + + +# ── build_search_body ─────────────────────────────────────────────────── + + +def test_build_search_body_defaults() -> None: + body = build_search_body( + query="test", + topics=None, + user_id=None, + conversation_id=None, + group=None, + retrieval_config=RetrievalConfig(), + ) + assert body == { + "query": "test", + "retrieval_config": {"retrieval_type": "hybrid", "limit": 10}, + } + + +def test_build_search_body_full() -> None: + body = build_search_body( + query="test", + topics=["a", "b"], + user_id="u1", + conversation_id="c1", + group="g1", + retrieval_config=RetrievalConfig(retrieval_type="vector", limit=5), + ) + assert body["topics"] == ["a", "b"] + assert body["user_id"] == "u1" + assert body["retrieval_config"]["retrieval_type"] == "vector" + assert body["retrieval_config"]["limit"] == 5 + + +# ── parse_run ─────────────────────────────────────────────────────────── + + +def test_parse_run() -> None: + result = parse_run({"run_id": "r1", "status": "pending"}) + assert result.run_id == "r1" + assert result.status == "pending" + assert result.error is None + + +def test_parse_run_with_error() -> None: + result = parse_run({"run_id": "r1", "status": "failed", "error": "boom"}) + assert result.error == "boom" + + +# ── parse_memory ──────────────────────────────────────────────────────── + + +SAMPLE_MEMORY = { + "id": "m1", + "project_id": "p1", + "content": "some content", + "topic": "t1", + "group": "g1", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", +} + + +def test_parse_memory_minimal() -> None: + mem = parse_memory(SAMPLE_MEMORY) + assert mem.id == "m1" + assert mem.project_id == "p1" + assert mem.user_id is None + assert mem.score is None + + +def test_parse_memory_with_optional_fields() -> None: + data = { + **SAMPLE_MEMORY, + "user_id": "u1", + "conversation_id": "c1", + "tags": ["x"], + "score": 0.95, + } + mem = parse_memory(data) + assert mem.user_id == "u1" + assert mem.tags == ["x"] + assert mem.score == 0.95 + + +# ── parse_search_results ──────────────────────────────────────────────── + + +def test_parse_search_results() -> None: + data = {"memories": [{"Body": SAMPLE_MEMORY}], "total": 1} + result = parse_search_results(data) + assert result.total == 1 + assert len(result) == 1 + assert result[0].id == "m1" + + +def test_parse_search_results_empty() -> None: + result = parse_search_results({"memories": [], "total": 0}) + assert result.total == 0 + assert len(result) == 0 + + +def test_search_results_iterable() -> None: + data = { + "memories": [{"Body": SAMPLE_MEMORY}, {"Body": {**SAMPLE_MEMORY, "id": "m2"}}], + "total": 2, + } + result = parse_search_results(data) + ids = [m.id for m in result] + assert ids == ["m1", "m2"] + + +# ── parse_run_status ──────────────────────────────────────────────────── + + +SAMPLE_RUN_STATUS = { + "run_id": "r1", + "status": "completed", + "group_id": "g1", + "starting_step": 0, + "input_type": "string", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-02T00:00:00Z", +} + + +def test_parse_run_status_minimal() -> None: + result = parse_run_status(SAMPLE_RUN_STATUS) + assert result.run_id == "r1" + assert result.starting_step == 0 + assert result.committed_operations is None + assert result.error is None + assert result.memories_created == [] + + +def test_parse_run_status_with_committed_operations() -> None: + data = { + **SAMPLE_RUN_STATUS, + "committed_operations": { + "created": [{"memory_id": "m1", "committed_at": "2024-01-01T00:00:00Z"}], + "updated": [], + "deleted": [], + }, + } + result = parse_run_status(data) + assert result.committed_operations is not None + assert len(result.memories_created) == 1 + assert result.memories_created[0].memory_id == "m1" + assert result.memories_updated == [] + + +def test_parse_run_status_with_error() -> None: + data = {**SAMPLE_RUN_STATUS, "status": "failed", "error": "boom"} + result = parse_run_status(data) + assert result.error == "boom"