From 26342bd857d44bf62fc649d3b2148342a2bc5f39 Mon Sep 17 00:00:00 2001 From: SiddarthAA Date: Sat, 25 Apr 2026 01:34:06 +0530 Subject: [PATCH] fix: session lifecycle and store.delete params Claim 1 - session lifetime scoped to AsyncJigsawStack (async_request.py, __init__.py): - Add optional 'session' field to AsyncRequestConfig. - _SessionContext wraps a shared session without closing it on exit. - AsyncRequest reads the injected session; falls back to a per-request ClientSession when used standalone. - AsyncJigsawStack gains __aenter__ / __aexit__ / aclose() that open a single aiohttp.ClientSession and share it across all service calls. Claim 3 - Store.delete passes raw string as URL query params (store.py): - params=key sent the file key as a malformed query string. The key is already in the URL path. Fix is params={}. (DELETE has no body, so there is nothing to corrupt.) Tests: - tests/test_session_lifecycle.py 8 tests (claim 1) - tests/test_store_delete.py 3 tests (claim 3) --- jigsawstack/__init__.py | 86 ++++++++++++++--------- jigsawstack/async_request.py | 34 +++++++-- jigsawstack/store.py | 4 +- tests/test_session_lifecycle.py | 119 ++++++++++++++++++++++++++++++++ tests/test_store_delete.py | 109 +++++++++++++++++++++++++++++ 5 files changed, 311 insertions(+), 41 deletions(-) create mode 100644 tests/test_session_lifecycle.py create mode 100644 tests/test_store_delete.py diff --git a/jigsawstack/__init__.py b/jigsawstack/__init__.py index 9218810..cc63af4 100644 --- a/jigsawstack/__init__.py +++ b/jigsawstack/__init__.py @@ -1,6 +1,9 @@ import os -from typing import Dict, Union +from typing import Dict, List, Optional, Union +import aiohttp + +from ._config import ClientConfig from .audio import AsyncAudio, Audio from .classification import AsyncClassification, Classification from .embedding import AsyncEmbedding, Embedding @@ -155,51 +158,70 @@ def __init__( self.base_url = base_url self.headers = headers or {"Content-Type": "application/json"} - self.web = AsyncWeb(api_key=api_key, base_url=base_url + "/v1", headers=headers) + # _async_services holds every async service instance so that + # __aenter__ / aclose() can inject / remove the shared session. + self._async_services: List[ClientConfig] = [] + self._session: Optional[aiohttp.ClientSession] = None - self.validate = AsyncValidate(api_key=api_key, base_url=base_url + "/v1", headers=headers) + def _reg(svc: ClientConfig) -> ClientConfig: + self._async_services.append(svc) + return svc - self.audio = AsyncAudio(api_key=api_key, base_url=base_url + "/v1", headers=headers) + self.web = _reg(AsyncWeb(api_key=api_key, base_url=base_url + "/v1", headers=headers)) - self.vision = AsyncVision(api_key=api_key, base_url=base_url + "/v1", headers=headers) + self.validate = _reg(AsyncValidate(api_key=api_key, base_url=base_url + "/v1", headers=headers)) - self.store = AsyncStore(api_key=api_key, base_url=base_url + "/v1", headers=headers) + self.audio = _reg(AsyncAudio(api_key=api_key, base_url=base_url + "/v1", headers=headers)) - self.summary = AsyncSummary( - api_key=api_key, base_url=base_url + "/v1", headers=headers - ).summarize + self.vision = _reg(AsyncVision(api_key=api_key, base_url=base_url + "/v1", headers=headers)) - self.prediction = AsyncPrediction(api_key=api_key, base_url=base_url + "/v1").predict + self.store = _reg(AsyncStore(api_key=api_key, base_url=base_url + "/v1", headers=headers)) - self.text_to_sql = AsyncSQL( - api_key=api_key, base_url=base_url + "/v1", headers=headers - ).text_to_sql + _summary = _reg(AsyncSummary(api_key=api_key, base_url=base_url + "/v1", headers=headers)) + self.summary = _summary.summarize - self.sentiment = AsyncSentiment( - api_key=api_key, base_url=base_url + "/v1", headers=headers - ).analyze + _prediction = _reg(AsyncPrediction(api_key=api_key, base_url=base_url + "/v1", headers=headers)) + self.prediction = _prediction.predict - self.translate = AsyncTranslate(api_key=api_key, base_url=base_url + "/v1", headers=headers) + _sql = _reg(AsyncSQL(api_key=api_key, base_url=base_url + "/v1", headers=headers)) + self.text_to_sql = _sql.text_to_sql - self.embedding = AsyncEmbedding( - api_key=api_key, base_url=base_url + "/v1", headers=headers - ).execute + _sentiment = _reg(AsyncSentiment(api_key=api_key, base_url=base_url + "/v1", headers=headers)) + self.sentiment = _sentiment.analyze - self.embedding_v2 = AsyncEmbeddingV2( - api_key=api_key, base_url=base_url + "/v2", headers=headers - ).execute + self.translate = _reg(AsyncTranslate(api_key=api_key, base_url=base_url + "/v1", headers=headers)) - self.image_generation = AsyncImageGeneration( - api_key=api_key, base_url=base_url + "/v1", headers=headers - ).image_generation + _embedding = _reg(AsyncEmbedding(api_key=api_key, base_url=base_url + "/v1", headers=headers)) + self.embedding = _embedding.execute - self.classification = AsyncClassification( - api_key=api_key, base_url=base_url + "/v1", headers=headers - ).classify + _embedding_v2 = _reg(AsyncEmbeddingV2(api_key=api_key, base_url=base_url + "/v2", headers=headers)) + self.embedding_v2 = _embedding_v2.execute - self.prompt_engine = AsyncPromptEngine( - api_key=api_key, base_url=base_url + "/v1", headers=headers - ) + _image_gen = _reg(AsyncImageGeneration(api_key=api_key, base_url=base_url + "/v1", headers=headers)) + self.image_generation = _image_gen.image_generation + + _classification = _reg(AsyncClassification(api_key=api_key, base_url=base_url + "/v1", headers=headers)) + self.classification = _classification.classify + + self.prompt_engine = _reg(AsyncPromptEngine(api_key=api_key, base_url=base_url + "/v1", headers=headers)) + + async def __aenter__(self) -> "AsyncJigsawStack": + """Open a shared aiohttp.ClientSession reused across all requests.""" + self._session = aiohttp.ClientSession() + for svc in self._async_services: + svc.config["session"] = self._session + return self + + async def aclose(self) -> None: + """Close the shared session and clear it from all service configs.""" + if self._session is not None: + for svc in self._async_services: + svc.config.pop("session", None) + await self._session.close() + self._session = None + + async def __aexit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: + await self.aclose() # Create a global instance of the Web class diff --git a/jigsawstack/async_request.py b/jigsawstack/async_request.py index d86e6b2..f1ee76b 100644 --- a/jigsawstack/async_request.py +++ b/jigsawstack/async_request.py @@ -1,9 +1,9 @@ import json from io import BytesIO -from typing import Any, AsyncGenerator, Dict, Generic, List, TypedDict, Union, cast +from typing import Any, AsyncGenerator, Dict, Generic, List, Optional, TypedDict, Union, cast import aiohttp -from typing_extensions import Literal, TypeVar +from typing_extensions import Literal, NotRequired, TypeVar from .exceptions import NoContentError, raise_for_code_and_type @@ -16,6 +16,23 @@ class AsyncRequestConfig(TypedDict): base_url: str api_key: str headers: Union[Dict[str, str], None] + session: NotRequired[aiohttp.ClientSession] + + +class _SessionContext: + """Async context manager that wraps an existing ClientSession without closing it. + Used when a shared session is injected from the client (AsyncJigsawStack). + """ + __slots__ = ("_session",) + + def __init__(self, session: aiohttp.ClientSession) -> None: + self._session = session + + async def __aenter__(self) -> aiohttp.ClientSession: + return self._session + + async def __aexit__(self, *_: object) -> None: + pass # session lifetime is managed by the caller class AsyncRequest(Generic[T]): @@ -38,6 +55,8 @@ def __init__( self.headers = config.get("headers", None) or {"Content-Type": "application/json"} self.stream = stream self.files = files # Store files for multipart requests + # Optional shared session injected by AsyncJigsawStack. + self._shared_session: Optional[aiohttp.ClientSession] = config.get("session", None) def __convert_params( self, params: Union[Dict[Any, Any], List[Dict[Any, Any]]] @@ -269,13 +288,14 @@ async def make_request( headers=headers, ) - def __get_session(self) -> aiohttp.ClientSession: + def __get_session(self) -> Union["_SessionContext", aiohttp.ClientSession]: """ - Create and return an async client session. - - Returns: - aiohttp.ClientSession: An async client session + Return an async context manager that provides a ClientSession. + If a shared session was injected via config, reuse it without closing. + Otherwise open a fresh session for this request only. """ + if self._shared_session is not None: + return _SessionContext(self._shared_session) return aiohttp.ClientSession() @staticmethod diff --git a/jigsawstack/store.py b/jigsawstack/store.py index 89facea..47b8500 100644 --- a/jigsawstack/store.py +++ b/jigsawstack/store.py @@ -80,7 +80,7 @@ def delete(self, key: str) -> FileDeleteResponse: resp = Request( config=self.config, path=path, - params=key, + params={}, verb="delete", ).perform_with_content() return resp @@ -140,7 +140,7 @@ async def delete(self, key: str) -> FileDeleteResponse: resp = await AsyncRequest( config=self.config, path=path, - params=key, + params={}, verb="delete", ).perform_with_content() return resp diff --git a/tests/test_session_lifecycle.py b/tests/test_session_lifecycle.py new file mode 100644 index 0000000..6422fcc --- /dev/null +++ b/tests/test_session_lifecycle.py @@ -0,0 +1,119 @@ +""" +tests/test_session_lifecycle.py + +Tests for claim 1: AsyncJigsawStack session management. + +Verifies that: +- Without __aenter__, each request creates its own temporary session. +- With __aenter__, a single shared ClientSession is injected into every + service config and reused across requests. +- __aexit__ / aclose() removes the session from all configs and closes it. +- Re-entering after aclose() works correctly (fresh session). +No real network calls are made. +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import aiohttp +import pytest + +from jigsawstack import AsyncJigsawStack +from jigsawstack.async_request import AsyncRequest, AsyncRequestConfig + + +# --------------------------------------------------------------------------- +# Unit: _SessionContext +# --------------------------------------------------------------------------- + +class TestSessionContext: + def test_reuses_session_without_closing(self): + from jigsawstack.async_request import _SessionContext + + mock_session = MagicMock(spec=aiohttp.ClientSession) + ctx = _SessionContext(mock_session) + + async def run(): + async with ctx as s: + assert s is mock_session + mock_session.close.assert_not_called() + + asyncio.run(run()) + + +# --------------------------------------------------------------------------- +# Unit: AsyncRequest picks up shared session from config +# --------------------------------------------------------------------------- + +class TestAsyncRequestSessionInjection: + def test_no_session_in_config_uses_own_session(self): + config = AsyncRequestConfig(base_url="http://test", api_key="key", headers=None) + r = AsyncRequest(config=config, path="/x", params={}, verb="get") + assert r._shared_session is None + + def test_session_in_config_is_stored(self): + mock_session = MagicMock(spec=aiohttp.ClientSession) + config = AsyncRequestConfig(base_url="http://test", api_key="key", headers=None) + config["session"] = mock_session + r = AsyncRequest(config=config, path="/x", params={}, verb="get") + assert r._shared_session is mock_session + + +# --------------------------------------------------------------------------- +# Integration: AsyncJigsawStack as async context manager +# --------------------------------------------------------------------------- + +class TestAsyncJigsawStackContextManager: + def test_no_session_before_enter(self): + client = AsyncJigsawStack(api_key="test-key") + assert client._session is None + # configs should not have a session key yet + for svc in client._async_services: + assert svc.config.get("session") is None + + def test_enter_injects_session_into_all_services(self): + async def run(): + async with AsyncJigsawStack(api_key="test-key") as client: + assert isinstance(client._session, aiohttp.ClientSession) + for svc in client._async_services: + assert svc.config.get("session") is client._session + await client._session.close() # prevent ResourceWarning in test + + asyncio.run(run()) + + def test_exit_clears_session_from_all_services(self): + async def run(): + client = AsyncJigsawStack(api_key="test-key") + await client.__aenter__() + session = client._session + await client.__aexit__(None, None, None) + + assert client._session is None + assert session.closed + for svc in client._async_services: + assert svc.config.get("session") is None + + asyncio.run(run()) + + def test_aclose_is_idempotent(self): + async def run(): + async with AsyncJigsawStack(api_key="test-key") as client: + pass + await client.aclose() + + asyncio.run(run()) + + def test_reenter_after_aclose_creates_fresh_session(self): + async def run(): + client = AsyncJigsawStack(api_key="test-key") + await client.__aenter__() + first_session = client._session + await client.aclose() + + await client.__aenter__() + second_session = client._session + assert second_session is not first_session + assert isinstance(second_session, aiohttp.ClientSession) + await client.aclose() + + asyncio.run(run()) diff --git a/tests/test_store_delete.py b/tests/test_store_delete.py new file mode 100644 index 0000000..04a782a --- /dev/null +++ b/tests/test_store_delete.py @@ -0,0 +1,109 @@ +""" +tests/test_store_delete.py + +Tests for claim 3: Store.delete and AsyncStore.delete passed params=key +(a raw string) instead of params={}. + +For a DELETE /store/file/read/{key} request the key is already in the URL +path; passing a raw string as params would append it as a malformed query +string. The fix is params={}. +No real network calls are made. +""" + +import asyncio +from unittest.mock import MagicMock, patch + +import pytest + +from jigsawstack.store import AsyncStore, Store +from jigsawstack.request import RequestConfig +from jigsawstack.async_request import AsyncRequestConfig + + +def _sync_store() -> Store: + return Store(api_key="key", base_url="http://test") + + +def _async_store() -> AsyncStore: + return AsyncStore(api_key="key", base_url="http://test") + + +class TestStoreDeleteParams: + def test_sync_delete_passes_empty_params(self): + """Request built by Store.delete must use params={}, not params=key.""" + captured = {} + + original_init = __import__("jigsawstack.request", fromlist=["Request"]).Request.__init__ + + def capturing_init(self, config, path, params, verb, **kwargs): + captured["params"] = params + captured["verb"] = verb + # minimal setup so perform_with_content() doesn't blow up + self.path = path + self.params = params + self.verb = verb + self.base_url = config.get("base_url") + self.api_key = config.get("api_key") + self.data = None + self.headers = {"Content-Type": "application/json"} + self.stream = False + self.files = None + + with patch("jigsawstack.store.Request.__init__", capturing_init), \ + patch("jigsawstack.store.Request.perform_with_content", return_value={"success": True}): + _sync_store().delete("my-key") + + assert captured["verb"] == "delete" + assert captured["params"] == {}, ( + f"Expected params={{}}, got {captured['params']!r}. " + "Passing the key string as params appends a malformed query string." + ) + + def test_async_delete_passes_empty_params(self): + """AsyncRequest built by AsyncStore.delete must use params={}, not params=key.""" + from unittest.mock import AsyncMock + from jigsawstack.async_request import AsyncRequest as AR + + calls = [] + original_init = AR.__init__ + + def spy_init(self, config, path, params, verb, **kw): + calls.append({"params": params, "verb": verb}) + original_init(self, config=config, path=path, params=params, verb=verb, **kw) + + async def run(): + with patch.object(AR, "__init__", spy_init), \ + patch.object(AR, "perform_with_content", new_callable=AsyncMock, + return_value={"success": True}): + store = _async_store() + await store.delete("my-key") + + asyncio.run(run()) + + assert len(calls) == 1 + assert calls[0]["verb"] == "delete" + assert calls[0]["params"] == {}, ( + f"Expected params={{}}, got {calls[0]['params']!r}." + ) + + def test_sync_delete_key_in_url_path(self): + """The key must appear in the request path, not as a parameter.""" + captured_path = {} + + def capturing_init(self, config, path, params, verb, **kwargs): + captured_path["path"] = path + self.path = path + self.params = params + self.verb = verb + self.base_url = config.get("base_url") + self.api_key = config.get("api_key") + self.data = None + self.headers = {"Content-Type": "application/json"} + self.stream = False + self.files = None + + with patch("jigsawstack.store.Request.__init__", capturing_init), \ + patch("jigsawstack.store.Request.perform_with_content", return_value={"success": True}): + _sync_store().delete("my-key") + + assert "my-key" in captured_path["path"]