diff --git a/python/lib/sift_client/_internal/low_level_wrappers/base.py b/python/lib/sift_client/_internal/low_level_wrappers/base.py index d349e51a2..169734d25 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/base.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/base.py @@ -1,7 +1,9 @@ from __future__ import annotations from abc import ABC -from typing import Any, Callable +from typing import Any, Awaitable, Callable + +from sift_py.grpc.cache import with_cache, with_force_refresh class LowLevelClientBase(ABC): @@ -50,3 +52,67 @@ async def _handle_pagination( if max_results and len(results) > max_results: results = results[:max_results] return results + + @staticmethod + async def _call_with_cache( + stub_method: Callable[..., Awaitable[Any]], + request: Any, + *, + use_cache: bool = True, + force_refresh: bool = False, + ttl: int | None = None, + ) -> Any: + """Call a gRPC stub method with cache control. + + This is a convenience method for low-level wrappers to easily enable caching + on their gRPC calls without manually constructing metadata. + + Args: + stub_method: The gRPC stub method to call (e.g., stub.GetData). + request: The protobuf request object. + use_cache: Whether to enable caching for this request. Default: True. + force_refresh: Whether to force refresh the cache. Default: False. + ttl: Optional custom TTL in seconds. If not provided, uses the default TTL. + + Returns: + The response from the gRPC call. + + Example: + # Enable caching + response = await self._call_with_cache( + stub.GetData, + request, + use_cache=True, + ) + + # Force refresh + response = await self._call_with_cache( + stub.GetData, + request, + force_refresh=True, + ) + + # With custom TTL + response = await self._call_with_cache( + stub.GetData, + request, + use_cache=True, + ttl=7200, # 2 hours + ) + + # Ignore cache + response = await self._call_with_cache( + stub.GetData, + request, + use_cache=False, + ) + """ + if not use_cache: + return await stub_method(request) + + if force_refresh: + metadata = with_force_refresh(ttl=ttl) + else: + metadata = with_cache(ttl=ttl) + + return await stub_method(request, metadata=metadata) diff --git a/python/lib/sift_client/_internal/low_level_wrappers/data.py b/python/lib/sift_client/_internal/low_level_wrappers/data.py index e5370bbe7..2f3e4699d 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/data.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/data.py @@ -74,8 +74,6 @@ def _update_name_id_map(self, channels: list[Channel]): ) self.channel_cache.name_id_map[channel.name] = str(channel.id_) - # TODO: Cache calls. Only read cache if end_time is more than 30 min in the past. - # Also, consider manually caching full channel data and evaluating start/end times while ignoring pagination. Do this ful caching at a higher level though to handle case where pagination fails. async def _get_data_impl( self, *, @@ -86,8 +84,27 @@ async def _get_data_impl( page_size: int | None = None, page_token: str | None = None, order_by: str | None = None, + use_cache: bool = False, + force_refresh: bool = False, + cache_ttl: int | None = None, ) -> tuple[list[Any], str | None]: - """Get the data for a channel during a run.""" + """Get the data for a channel during a run. + + Args: + channel_ids: List of channel IDs to fetch data for. + run_id: Optional run ID to filter data. + start_time: Optional start time for the data range. + end_time: End time for the data range. + page_size: Number of results per page. + page_token: Token for pagination. + order_by: Field to order results by. + use_cache: Whether to enable caching for this request. Default: False. + force_refresh: Whether to force refresh the cache. Default: False. + cache_ttl: Optional custom TTL in seconds for cached responses. + + Returns: + Tuple of (data list, next page token). + """ queries = [ Query(channel=ChannelQuery(channel_id=channel_id, run_id=run_id)) for channel_id in channel_ids @@ -102,7 +119,19 @@ async def _get_data_impl( } request = GetDataRequest(**request_kwargs) - response = await self._grpc_client.get_stub(DataServiceStub).GetData(request) + + # Use cache helper if caching is enabled + if use_cache or force_refresh: + response = await self._call_with_cache( + self._grpc_client.get_stub(DataServiceStub).GetData, + request, + use_cache=use_cache, + force_refresh=force_refresh, + ttl=cache_ttl, + ) + else: + response = await self._grpc_client.get_stub(DataServiceStub).GetData(request) + response = cast("GetDataResponse", response) return response.data, response.next_page_token # type: ignore # mypy doesn't know RepeatedCompositeFieldContainer can be treated like a list diff --git a/python/lib/sift_client/_internal/low_level_wrappers/ping.py b/python/lib/sift_client/_internal/low_level_wrappers/ping.py index 650f2d44a..b28b733f3 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/ping.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/ping.py @@ -33,6 +33,9 @@ class PingLowLevelClient(LowLevelClientBase, WithGrpcClient): It handles common concerns like error handling and retries. """ + _cache_results: bool + """Whether to cache the results of the ping request. Used for testing.""" + def __init__(self, grpc_client: GrpcClient): """Initialize the PingLowLevelClient. @@ -40,11 +43,14 @@ def __init__(self, grpc_client: GrpcClient): grpc_client: The gRPC client to use for making API calls. """ super().__init__(grpc_client=grpc_client) + self._cache_results = False - async def ping(self) -> str: + async def ping(self, _force_refresh: bool = False) -> str: """Send a ping request to the server in the current event loop.""" # get stub bound to this loop stub = self._grpc_client.get_stub(PingServiceStub) request = PingRequest() - response = await stub.Ping(request) + response = await self._call_with_cache( + stub.Ping, request, use_cache=self._cache_results, force_refresh=_force_refresh, ttl=1 + ) return cast("PingResponse", response).response diff --git a/python/lib/sift_client/_tests/conftest.py b/python/lib/sift_client/_tests/conftest.py index 397848d7e..c2c9e5c4e 100644 --- a/python/lib/sift_client/_tests/conftest.py +++ b/python/lib/sift_client/_tests/conftest.py @@ -6,6 +6,7 @@ import pytest from sift_client import SiftClient, SiftConnectionConfig +from sift_client.transport import CacheConfig, CacheMode from sift_client.util.util import AsyncAPIs @@ -26,6 +27,7 @@ def sift_client() -> SiftClient: grpc_url=grpc_url, rest_url=rest_url, use_ssl=True, + cache_config=CacheConfig(mode=CacheMode.CLEAR_ON_INIT), ) ) diff --git a/python/lib/sift_client/_tests/resources/test_ping.py b/python/lib/sift_client/_tests/resources/test_ping.py index 587d8a7f0..cc519f0e4 100644 --- a/python/lib/sift_client/_tests/resources/test_ping.py +++ b/python/lib/sift_client/_tests/resources/test_ping.py @@ -3,17 +3,45 @@ These tests demonstrate and validate the usage of the Ping API including: - Basic ping functionality - Connection health checks +- Cache behavior and performance - Error handling and edge cases """ +import asyncio +import os +import time + import pytest -from sift_client import SiftClient +from sift_client import SiftClient, SiftConnectionConfig from sift_client.resources import PingAPI, PingAPIAsync +from sift_client.transport import CacheConfig, CacheMode pytestmark = pytest.mark.integration +# We reimplement this here so that the cache is cleared each time we instantiate +@pytest.fixture +def sift_client() -> SiftClient: + """Create a SiftClient instance for testing. + + This fixture is shared across all test files and is session-scoped + to avoid creating multiple client instances. + """ + grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") + rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") + api_key = os.getenv("SIFT_API_KEY", "") + + return SiftClient( + connection_config=SiftConnectionConfig( + api_key=api_key, + grpc_url=grpc_url, + rest_url=rest_url, + cache_config=CacheConfig(mode=CacheMode.CLEAR_ON_INIT), + ) + ) + + def test_client_binding(sift_client): assert sift_client.ping assert isinstance(sift_client.ping, PingAPI) @@ -60,3 +88,207 @@ def test_basic_ping(self, ping_api_sync): # Verify response is not empty assert len(response) > 0 + + +class TestPingCacheBehavior: + """Test suite for ping cache behavior.""" + + @pytest.mark.asyncio + async def test_cache_enabled(self, ping_api_async): + """Test that caching can be enabled for ping requests.""" + # Enable caching on the low-level client + ping_api_async._low_level_client._cache_results = True + + # Measure time for first ping - should hit the server (slower) + start1 = time.perf_counter() + response1 = await ping_api_async.ping() + duration1 = time.perf_counter() - start1 + assert isinstance(response1, str) + assert len(response1) > 0 + + # Measure time for second ping - should use cache (much faster) + start2 = time.perf_counter() + response2 = await ping_api_async.ping() + duration2 = time.perf_counter() - start2 + assert response2 == response1 + + # Print timing info + print(f"\nFirst ping (server): {duration1 * 1000:.2f}ms") + print(f"Second ping (cache): {duration2 * 1000:.2f}ms") + print(f"Speedup: {duration1 / duration2:.2f}x") + + # Cached call should be significantly faster (at least 5x) + assert duration2 < duration1 / 5, ( + f"Cached ping should be much faster. " + f"First: {duration1 * 1000:.2f}ms, Second: {duration2 * 1000:.2f}ms" + ) + + # Disable caching for cleanup + ping_api_async._low_level_client._cache_results = False + + @pytest.mark.asyncio + async def test_force_refresh_bypasses_cache(self, ping_api_async): + """Test that force_refresh bypasses the cache.""" + # Enable caching + ping_api_async._low_level_client._cache_results = True + + # First ping - populate cache + start1 = time.perf_counter() + response1 = await ping_api_async._low_level_client.ping() + duration1 = time.perf_counter() - start1 + assert isinstance(response1, str) + + # Second ping without force_refresh - should use cache (fast) + start2 = time.perf_counter() + response2 = await ping_api_async._low_level_client.ping(_force_refresh=False) + duration2 = time.perf_counter() - start2 + assert isinstance(response2, str) + + # Third ping with force_refresh - should bypass cache (slow, like first call) + start3 = time.perf_counter() + response3 = await ping_api_async._low_level_client.ping(_force_refresh=True) + duration3 = time.perf_counter() - start3 + assert isinstance(response3, str) + + # Print timing info + print(f"\nFirst ping (server): {duration1 * 1000:.2f}ms") + print(f"Second ping (cache): {duration2 * 1000:.2f}ms") + print(f"Third ping (force_refresh, server): {duration3 * 1000:.2f}ms") + + # Cached call should be much faster than both server calls + assert duration2 < duration1 / 5, ( + f"Cached ping should be much faster than first ping. " + f"First: {duration1 * 1000:.2f}ms, Cached: {duration2 * 1000:.2f}ms" + ) + assert duration2 < duration3 / 5, ( + f"Cached ping should be much faster than force_refresh ping. " + f"Force refresh: {duration3 * 1000:.2f}ms, Cached: {duration2 * 1000:.2f}ms" + ) + + # Disable caching for cleanup + ping_api_async._low_level_client._cache_results = False + + @pytest.mark.asyncio + async def test_cache_ttl_expiration(self, ping_api_async): + """Test that cache entries expire after TTL.""" + # Enable caching with very short TTL (1 second) + ping_api_async._low_level_client._cache_results = True + + # First ping - populate cache with 1 second TTL + response1 = await ping_api_async._low_level_client.ping() + assert isinstance(response1, str) + + # Immediate second ping - should use cache + response2 = await ping_api_async._low_level_client.ping() + assert isinstance(response2, str) + + # Wait for TTL to expire (1 second + buffer) + await asyncio.sleep(1.5) + + # Third ping - cache should have expired, will fetch fresh + response3 = await ping_api_async._low_level_client.ping() + assert isinstance(response3, str) + + # Disable caching for cleanup + ping_api_async._low_level_client._cache_results = False + + @pytest.mark.asyncio + async def test_cache_performance(self, ping_api_async): + """Test that cached ping requests are faster than uncached ones.""" + num_iterations = 10 + + # Enable caching + ping_api_async._low_level_client._cache_results = True + + # Measure uncached performance (force_refresh=True) + start_time = time.perf_counter() + for _ in range(num_iterations): + await ping_api_async._low_level_client.ping(_force_refresh=True) + uncached_duration = time.perf_counter() - start_time + + # Warm up cache + await ping_api_async._low_level_client.ping() + + # Measure cached performance + start_time = time.perf_counter() + for _ in range(num_iterations): + await ping_api_async._low_level_client.ping(_force_refresh=False) + cached_duration = time.perf_counter() - start_time + + # Print performance metrics + print(f"\n{'=' * 60}") + print(f"Ping Cache Performance ({num_iterations} iterations)") + print(f"{'=' * 60}") + print( + f"Cached duration: {cached_duration:.4f}s ({cached_duration / num_iterations * 1000:.2f}ms per call)" + ) + print( + f"Uncached duration: {uncached_duration:.4f}s ({uncached_duration / num_iterations * 1000:.2f}ms per call)" + ) + print(f"Speedup: {uncached_duration / cached_duration:.2f}x") + print(f"Time saved: {uncached_duration - cached_duration:.4f}s") + print(f"{'=' * 60}\n") + + # Assert that cached is faster + assert cached_duration < uncached_duration, ( + f"Cached pings should be faster. " + f"Cached: {cached_duration:.4f}s, Uncached: {uncached_duration:.4f}s" + ) + + # Disable caching for cleanup + ping_api_async._low_level_client._cache_results = False + + @pytest.mark.asyncio + async def test_cache_disabled_by_default(self, ping_api_async): + """Test that caching is disabled by default for ping.""" + # Verify cache is disabled by default + assert ping_api_async._low_level_client._cache_results is False + + # Multiple pings should all hit the server (no caching) + response1 = await ping_api_async.ping() + response2 = await ping_api_async.ping() + response3 = await ping_api_async.ping() + + # All should succeed + assert isinstance(response1, str) + assert isinstance(response2, str) + assert isinstance(response3, str) + + @pytest.mark.asyncio + async def test_ping_without_grpc_cache(self): + """Test that ping works when GrpcCache is not enabled on the SiftClient.""" + import os + + from sift_client import SiftClient, SiftConnectionConfig + + # Create a client with caching explicitly disabled + grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051") + rest_url = os.getenv("SIFT_REST_URI", "localhost:8080") + api_key = os.getenv("SIFT_API_KEY", "") + + client = SiftClient( + connection_config=SiftConnectionConfig( + api_key=api_key, + grpc_url=grpc_url, + rest_url=rest_url, + use_ssl=True, + cache_config=None, + ) + ) + + # Verify cache is not initialized + assert client.grpc_client.cache is None + + # Ping should still work without cache + response1 = await client.async_.ping.ping() + assert isinstance(response1, str) + assert len(response1) > 0 + + # Multiple pings should work + response2 = await client.async_.ping.ping() + assert isinstance(response2, str) + + response3 = await client.async_.ping.ping() + assert isinstance(response3, str) + + print(f"\nPing without cache successful: {response1}") diff --git a/python/lib/sift_client/client.py b/python/lib/sift_client/client.py index 2a2252ef8..c3a75045a 100644 --- a/python/lib/sift_client/client.py +++ b/python/lib/sift_client/client.py @@ -31,6 +31,7 @@ WithGrpcClient, WithRestClient, ) +from sift_client.transport.grpc_transport import DEFAULT_CACHE_CONFIG from sift_client.util.util import AsyncAPIs _sift_client_experimental_warning() @@ -116,7 +117,7 @@ def __init__( api_key: The Sift API key for authentication. grpc_url: The Sift gRPC API URL. rest_url: The Sift REST API URL. - connection_config: A SiftConnectionConfig object to configure the connection behavior of the SiftClient. + connection_config: A SiftConnectionConfig object to configure the connection and cache behavior of the SiftClient. """ if not (api_key and grpc_url and rest_url) and not connection_config: raise ValueError( @@ -124,10 +125,14 @@ def __init__( ) if connection_config: - grpc_client = GrpcClient(connection_config.get_grpc_config()) + grpc_config = connection_config.get_grpc_config() + # Override cache_config if provided directly to SiftClient + grpc_client = GrpcClient(grpc_config) rest_client = RestClient(connection_config.get_rest_config()) elif api_key and grpc_url and rest_url: - grpc_client = GrpcClient(GrpcConfig(grpc_url, api_key)) + grpc_client = GrpcClient( + GrpcConfig(grpc_url, api_key, cache_config=DEFAULT_CACHE_CONFIG) + ) rest_client = RestClient(RestConfig(rest_url, api_key)) else: raise ValueError( diff --git a/python/lib/sift_client/transport/__init__.py b/python/lib/sift_client/transport/__init__.py index 249d9bc7e..a7bb2a9ca 100644 --- a/python/lib/sift_client/transport/__init__.py +++ b/python/lib/sift_client/transport/__init__.py @@ -3,10 +3,12 @@ WithGrpcClient, WithRestClient, ) -from sift_client.transport.grpc_transport import GrpcClient, GrpcConfig +from sift_client.transport.grpc_transport import CacheConfig, CacheMode, GrpcClient, GrpcConfig from sift_client.transport.rest_transport import RestClient, RestConfig __all__ = [ + "CacheConfig", + "CacheMode", "GrpcClient", "GrpcConfig", "RestClient", diff --git a/python/lib/sift_client/transport/base_connection.py b/python/lib/sift_client/transport/base_connection.py index 02f0e096e..c6b764d20 100644 --- a/python/lib/sift_client/transport/base_connection.py +++ b/python/lib/sift_client/transport/base_connection.py @@ -3,7 +3,12 @@ from abc import ABC from typing import TYPE_CHECKING -from sift_client.transport.grpc_transport import GrpcClient, GrpcConfig +from sift_client.transport.grpc_transport import ( + DEFAULT_CACHE_CONFIG, + CacheConfig, + GrpcClient, + GrpcConfig, +) from sift_client.transport.rest_transport import RestClient, RestConfig if TYPE_CHECKING: @@ -24,6 +29,7 @@ def __init__( api_key: str, use_ssl: bool = True, cert_via_openssl: bool = False, + cache_config: CacheConfig | None = DEFAULT_CACHE_CONFIG, ): """Initialize the connection configuration. @@ -33,12 +39,14 @@ def __init__( api_key: The API key for authentication. use_ssl: Whether to use SSL/TLS for secure connections. cert_via_openssl: Whether to use OpenSSL for certificate validation. + cache_config: Optional cache configuration for gRPC responses. """ self.api_key = api_key self.grpc_url = grpc_url self.rest_url = rest_url self.use_ssl = use_ssl self.cert_via_openssl = cert_via_openssl + self.cache_config = cache_config def get_grpc_config(self): """Create and return a GrpcConfig with the current settings. @@ -51,6 +59,7 @@ def get_grpc_config(self): api_key=self.api_key, use_ssl=self.use_ssl, cert_via_openssl=self.cert_via_openssl, + cache_config=self.cache_config, ) def get_rest_config(self): diff --git a/python/lib/sift_client/transport/cache.py b/python/lib/sift_client/transport/cache.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/lib/sift_client/transport/grpc_transport.py b/python/lib/sift_client/transport/grpc_transport.py index b27ce8fc1..cd99d2880 100644 --- a/python/lib/sift_client/transport/grpc_transport.py +++ b/python/lib/sift_client/transport/grpc_transport.py @@ -8,11 +8,16 @@ import asyncio import atexit +import enum import logging import threading +from pathlib import Path from typing import Any +from platformdirs import user_cache_dir +from sift_py.grpc.cache import GrpcCache from sift_py.grpc.transport import ( + SiftCacheConfig, SiftChannelConfig, use_sift_async_channel, ) @@ -34,6 +39,82 @@ def _suppress_blocking_io(loop, context): loop.default_exception_handler(context) +DEFAULT_CACHE_TTL_SECONDS = 7 * 24 * 60 * 60 # 1 week +DEFAULT_CACHE_FOLDER = Path(user_cache_dir("sift_client")) +DEFAULT_CACHE_SIZE_LIMIT_BYTES = 5 * 1024**3 # 5GB + + +class CacheMode(str, enum.Enum): + """Cache behavior modes. + + - ENABLED: Cache is enabled and persists across sessions + - DISABLED: Cache is completely disabled + - CLEAR_ON_INIT: Cache is cleared when client is initialized (useful for notebooks) + """ + + ENABLED = "enabled" + DISABLED = "disabled" + CLEAR_ON_INIT = "clear_on_init" + + +class CacheConfig: + """Configuration for gRPC response caching. + + Attributes: + mode: Cache behavior mode (enabled, disabled, clear_on_init). + ttl: Time-to-live for cached entries in seconds. Default is 1 week. + cache_folder: Path to the cache directory. Default is system temp directory. + size_limit: Maximum size of the cache in bytes. Default is 5GB. + """ + + def __init__( + self, + mode: CacheMode = CacheMode.ENABLED, + ttl: int = DEFAULT_CACHE_TTL_SECONDS, + cache_folder: Path | str = DEFAULT_CACHE_FOLDER, + size_limit: int = DEFAULT_CACHE_SIZE_LIMIT_BYTES, + ): + """Initialize the cache configuration. + + Args: + mode: Cache behavior mode (use CacheMode constants). + ttl: Time-to-live for cached entries in seconds. + cache_folder: Path to the cache directory. + size_limit: Maximum size of the cache in bytes. + """ + self.mode = mode + self.ttl = ttl + self.cache_path = str(Path(cache_folder) / "grpc_cache") + self.size_limit = size_limit + self._should_clear_on_init = mode == CacheMode.CLEAR_ON_INIT + + @property + def is_enabled(self) -> bool: + """Check if caching is enabled.""" + return self.mode != CacheMode.DISABLED + + @property + def should_clear_on_init(self) -> bool: + """Check if cache should be cleared on initialization.""" + return self._should_clear_on_init + + def to_sift_cache_config(self) -> SiftCacheConfig: + """Convert to a SiftCacheConfig for use with sift_py.grpc.transport. + + Returns: + A SiftCacheConfig dictionary. + """ + return { + "ttl": self.ttl, + "cache_path": self.cache_path, + "size_limit": self.size_limit, + "clear_on_init": self.should_clear_on_init, + } + + +DEFAULT_CACHE_CONFIG = CacheConfig() + + class GrpcConfig: """Configuration for gRPC API clients.""" @@ -44,6 +125,7 @@ def __init__( use_ssl: bool = True, cert_via_openssl: bool = False, metadata: dict[str, str] | None = None, + cache_config: CacheConfig | None = None, ): """Initialize the gRPC configuration. @@ -52,14 +134,15 @@ def __init__( api_key: The API key for authentication. use_ssl: Whether to use SSL/TLS. cert_via_openssl: Whether to use OpenSSL for SSL/TLS. - use_async: Whether to use async gRPC client. metadata: Additional metadata to include in all requests. + cache_config: Optional cache configuration. If None, caching is disabled. """ self.uri = url self.api_key = api_key self.use_ssl = use_ssl self.cert_via_openssl = cert_via_openssl self.metadata = metadata or {} + self.cache_config = cache_config def _to_sift_channel_config(self) -> SiftChannelConfig: """Convert to a SiftChannelConfig. @@ -67,13 +150,19 @@ def _to_sift_channel_config(self) -> SiftChannelConfig: Returns: A SiftChannelConfig. """ - return { + config: SiftChannelConfig = { "uri": self.uri, "apikey": self.api_key, "use_ssl": self.use_ssl, "cert_via_openssl": self.cert_via_openssl, } + # Add cache config if enabled + if self.cache_config and self.cache_config.is_enabled: + config["cache_config"] = self.cache_config.to_sift_cache_config() + + return config + class GrpcClient: """A simple wrapper around sift_py/grpc/transport.py for making gRPC API calls. @@ -91,6 +180,10 @@ def __init__(self, config: GrpcConfig): # map each asyncio loop to its async channel and stub dict self._channels_async: dict[asyncio.AbstractEventLoop, Any] = {} self._stubs_async_map: dict[asyncio.AbstractEventLoop, dict[type[Any], Any]] = {} + + # Initialize cache if caching is enabled + self.cache = self._init_cache() + # default loop for sync API self._default_loop = asyncio.new_event_loop() atexit.register(self.close_sync) @@ -116,6 +209,24 @@ def _run_default_loop(): self._channels_async[self._default_loop] = channel self._stubs_async_map[self._default_loop] = {} + def _init_cache(self) -> GrpcCache | None: + """Initialize the GrpcCache instance if caching is enabled.""" + if not self._config.cache_config or not self._config.cache_config.is_enabled: + return None + + try: + cache_config = self._config.cache_config + sift_cache_config: SiftCacheConfig = { + "ttl": cache_config.ttl, + "cache_path": cache_config.cache_path, + "size_limit": cache_config.size_limit, + "clear_on_init": cache_config.mode == CacheMode.CLEAR_ON_INIT, + } + return GrpcCache(sift_cache_config) + except Exception as e: + logger.warning(f"Failed to initialize cache: {e}") + return None + @property def default_loop(self) -> asyncio.AbstractEventLoop: """Return the default event loop used for synchronous API operations. @@ -138,7 +249,7 @@ def get_stub(self, stub_class: type[Any]) -> Any: if loop not in self._channels_async: channel = use_sift_async_channel( - self._config._to_sift_channel_config(), self._config.metadata + self._config._to_sift_channel_config(), self._config.metadata, self.cache ) self._channels_async[loop] = channel self._stubs_async_map[loop] = {} @@ -181,4 +292,4 @@ async def _create_async_channel( self, cfg: SiftChannelConfig, metadata: dict[str, str] | None ) -> Any: """Helper to create async channel on default loop.""" - return use_sift_async_channel(cfg, metadata) + return use_sift_async_channel(cfg, metadata, self.cache) diff --git a/python/lib/sift_py/grpc/_async_interceptors/caching.py b/python/lib/sift_py/grpc/_async_interceptors/caching.py new file mode 100644 index 000000000..2c3a866c7 --- /dev/null +++ b/python/lib/sift_py/grpc/_async_interceptors/caching.py @@ -0,0 +1,154 @@ +"""Async gRPC caching interceptor for transparent local response caching. + +This module provides an async caching interceptor that can be used to cache gRPC +unary-unary responses locally using diskcache. The cache is initialized at the +GrpcClient level and passed to the interceptor. + +Note: Cache initialization is handled by GrpcClient, not by this interceptor. + +Usage: + # Cache is initialized at GrpcClient level + cache = diskcache.Cache(".grpc_cache", size_limit=1024**3) + + # Create interceptor with cache instance + cache_interceptor = CachingAsyncInterceptor(ttl=3600, cache_instance=cache) + + # Use with metadata to control caching: + metadata = [ + ("use-cache", "true"), # Enable caching for this call + # ("force-refresh", "true"), # Bypass cache and store fresh result + # ("ignore-cache", "true"), # Bypass cache without clearing + ] +""" + +from __future__ import annotations + +import logging +from typing import Any + +import diskcache +from google.protobuf import message, symbol_database +from grpc import aio as grpc_aio + +from sift_py.grpc._async_interceptors.base import ClientAsyncInterceptor +from sift_py.grpc.cache import GrpcCache + +logger = logging.getLogger(__name__) + + +class CachingAsyncInterceptor(ClientAsyncInterceptor): + """Async interceptor that caches unary-unary gRPC responses locally. + + This interceptor uses a diskcache instance for persistent storage with TTL support. + The cache instance must be provided during initialization (typically from GrpcClient). + Cache keys are generated deterministically based on the gRPC method name + and serialized request payload. + + Responses are serialized to bytes before caching to avoid pickling issues with + async objects. + + Note: diskcache operations are synchronous, but the overhead is minimal + for most use cases. For high-throughput scenarios, consider using an + async-native cache backend. + + Attributes: + _cache: The GrpcCache instance provided during initialization. + """ + + def __init__( + self, + cache: GrpcCache, + ): + """Initialize the async caching interceptor. + + Args: + cache: Pre-initialized GrpcCache instance (required). + """ + self.cache = cache + self.symbol_db = symbol_database.Default() + + async def intercept( + self, + method: Any, + request_or_iterator: Any, + client_call_details: grpc_aio.ClientCallDetails, + ) -> Any: + """Intercept the async gRPC call and apply caching logic. + + Uses GrpcCache.resolve_cache_metadata() to determine caching behavior. + + Args: + method: The continuation to call for the actual RPC. + request_or_iterator: The request object or iterator. + client_call_details: The call details including method name and metadata. + + Returns: + The response from the cache or the actual RPC call. + """ + # Resolve cache metadata to determine behavior + cache_settings = self.cache.resolve_cache_metadata(client_call_details.metadata) + + # Generate cache key + key = self.cache.key_from_proto_message( + method_name=client_call_details.method, request=request_or_iterator + ) + + # Try to read from cache if allowed + if cache_settings.use_cache and not cache_settings.force_refresh: + try: + cached_data = self.cache.get(key) + if cached_data is not None: + logger.debug(f"Cache hit for `{key}`") + # Reconstruct the response + response = self._deserialize_response(cached_data) + if response is not None: + return response + else: + logger.warning(f"Failed to deserialize cached response for `{key}`") + except diskcache.Timeout as e: + logger.debug(f"Cache read timeout for `{key}`: {e}") + except Exception as e: + logger.warning(f"Failed to deserialize cached response for `{key}`: {e}") + + # Force refresh if requested + if cache_settings.force_refresh: + logger.debug(f"Forcing refresh for `{key}`") + self.cache.delete(key) + + # Make the actual RPC call + call = await method(request_or_iterator, client_call_details) + + # The call is a UnaryUnaryCall object, we need to await it to get the actual response + response = await call + + # Cache the response if allowed + if cache_settings.use_cache and response is not None: + try: + # Serialize the protobuf response to bytes before caching + new_data = self._serialize_response(response) + if new_data is not None: + self.cache.set_with_default_ttl(key, new_data, expire=cache_settings.custom_ttl) + logger.debug(f"Cached response for `{key}`") + except diskcache.Timeout as e: + logger.warning(f"Failed to cache response for `{key}`: {e}") + + return response + + @staticmethod + def _serialize_response(response: message.Message) -> tuple[Any, bytes] | None: + if isinstance(response, message.Message): + return response.DESCRIPTOR.full_name, response.SerializeToString() + else: + logger.warning(f"Response is not a protobuf message: {type(response)}") + return None + + def _deserialize_response(self, response: tuple[Any, bytes]) -> message.Message | None: + response_type, data = response + try: + response_type_cls = self.symbol_db.GetSymbol(response_type) + message = response_type_cls() + message.ParseFromString(data) + return message + except Exception as e: + logger.warning(f"Failed to deserialize response: {e}") + return None diff --git a/python/lib/sift_py/grpc/_async_interceptors/metadata.py b/python/lib/sift_py/grpc/_async_interceptors/metadata.py index 0592c3648..d506266f7 100644 --- a/python/lib/sift_py/grpc/_async_interceptors/metadata.py +++ b/python/lib/sift_py/grpc/_async_interceptors/metadata.py @@ -26,10 +26,17 @@ async def intercept( client_call_details: grpc_aio.ClientCallDetails, ): call_details = cast(grpc_aio.ClientCallDetails, client_call_details) + + # Merge existing metadata with interceptor metadata + # Existing metadata from the call takes precedence + merged_metadata = list(self.metadata) + if call_details.metadata: + merged_metadata.extend(call_details.metadata) + new_details = grpc_aio.ClientCallDetails( call_details.method, call_details.timeout, - self.metadata, + merged_metadata, call_details.credentials, call_details.wait_for_ready, ) diff --git a/python/lib/sift_py/grpc/cache.py b/python/lib/sift_py/grpc/cache.py new file mode 100644 index 000000000..f327a6cbc --- /dev/null +++ b/python/lib/sift_py/grpc/cache.py @@ -0,0 +1,202 @@ +"""Utilities for controlling gRPC response caching. + +This module provides helper functions and constants for working with the gRPC +caching interceptor. Use these utilities to control caching behavior on a +per-request basis via metadata. + +Example: + from sift_py.grpc.cache import with_cache, with_force_refresh, ignore_cache + + # Enable caching for a request + metadata = with_cache() + response = stub.GetData(request, metadata=metadata) + + # Force refresh (bypass cache and store fresh result) + metadata = with_force_refresh() + response = stub.GetData(request, metadata=metadata) + + # Ignore cache without clearing + metadata = ignore_cache() + response = stub.GetData(request, metadata=metadata) +""" + +from __future__ import annotations + +import hashlib +import logging +from pathlib import Path +from typing import TYPE_CHECKING, Any, NamedTuple + +import diskcache +from google.protobuf import json_format, message + +if TYPE_CHECKING: + from sift_py.grpc.transport import SiftCacheConfig + +logger = logging.getLogger(__name__) + + +class CacheSettings(NamedTuple): + """Resolved cache metadata from gRPC request.""" + + use_cache: bool + force_refresh: bool + custom_ttl: float | None + + +# Metadata keys for cache control +METADATA_USE_CACHE = "use-cache" +METADATA_FORCE_REFRESH = "force-refresh" +METADATA_CACHE_TTL = "cache-ttl" + + +class GrpcCache(diskcache.Cache): + """Subclass of diskcache.Cache for gRPC response caching.""" + + def __init__(self, config: SiftCacheConfig): + """Initialize the cache from configuration. + + Args: + config: Cache configuration with ttl, cache_path, size_limit, clear_on_init. + """ + self.default_ttl = config["ttl"] + self.cache_path = Path(config["cache_path"]) + self.size_limit = config["size_limit"] + + # Create cache directory if it doesn't exist + self.cache_path.mkdir(parents=True, exist_ok=True) + + # Initialize parent diskcache.Cache + super().__init__(str(self.cache_path), size_limit=self.size_limit) + + # Clear cache if requested + if config.get("clear_on_init", False): + logger.debug(f"Clearing cache on initialization: {self.cache_path}") + self.clear() + + logger.debug( + f"Cache initialized at {self.cache_path.absolute()!r} " + f"with size {self.volume() / (1024**2):.2f} MB" + ) + + def set_with_default_ttl( + self, key: str, value: Any, expire: float | None = None, **kwargs + ) -> bool: + expire_time = expire if expire is not None else self.default_ttl + return super().set(key, value, expire=expire_time, **kwargs) + + @staticmethod + def key_from_proto_message(method_name: str | bytes, request: message.Message) -> str: + # Serialize the request to bytes + request_json = json_format.MessageToJson(request).encode("utf-8") + + if isinstance(method_name, str): + method_name = method_name.encode("utf-8") + + # Create a hash of method name + request + hasher = hashlib.sha256() + hasher.update(method_name) + hasher.update(request_json) + + return hasher.hexdigest() + + @staticmethod + def resolve_cache_metadata(metadata: tuple[tuple[str, str], ...] | None) -> CacheSettings: + """Extract and resolve cache-related metadata fields. + + Args: + metadata: The gRPC request metadata tuple. + + Returns: + CacheMetadata named tuple with resolved cache control fields: + - use_cache: bool - Whether to use caching + - force_refresh: bool - Whether to force refresh + - ignore_cache: bool - Whether to ignore cache + - custom_ttl: int | None - Custom TTL if specified + - should_read: bool - Whether to read from cache + - should_cache: bool - Whether to cache the response + + Example: + cache_info = cache.resolve_cache_metadata(metadata) + if cache_info.should_read: + cached = cache.get(key) + if cache_info.should_cache: + cache.set_with_default_ttl(key, response, expire=cache_info.custom_ttl) + """ + metadata_dict: dict[str, str] + if not metadata: + metadata_dict = {} + else: + # Handle both tuple and grpc.aio.Metadata types + metadata_dict = {} + for key, value in metadata: + metadata_dict[key] = value + + use_cache = metadata_dict.get(METADATA_USE_CACHE, "false").lower() == "true" + + if not use_cache: + return CacheSettings(use_cache=False, force_refresh=False, custom_ttl=None) + + force_refresh = metadata_dict.get(METADATA_FORCE_REFRESH, "false").lower() == "true" + custom_ttl_str = metadata_dict.get(METADATA_CACHE_TTL) + + # Parse custom TTL if provided + custom_ttl = None + if custom_ttl_str: + try: + custom_ttl = int(custom_ttl_str) + except ValueError: + logger.warning(f"Invalid cache TTL value: {custom_ttl_str}, using default") + + return CacheSettings( + use_cache=use_cache, + force_refresh=force_refresh, + custom_ttl=custom_ttl, + ) + + +def with_cache(ttl: int | None = None) -> tuple[tuple[str, str], ...]: + """Enable caching for a gRPC request. + + Args: + ttl: Optional custom TTL in seconds. If not provided, uses the default TTL. + + Returns: + Metadata tuple to pass to the gRPC stub method. + + Example: + metadata = with_cache() + response = stub.GetData(request, metadata=metadata) + + # With custom TTL + metadata = with_cache(ttl=7200) # 2 hours + response = stub.GetData(request, metadata=metadata) + """ + metadata = [(METADATA_USE_CACHE, "true")] + if ttl is not None: + metadata.append((METADATA_CACHE_TTL, str(ttl))) + return tuple(metadata) + + +def with_force_refresh(ttl: int | None = None) -> tuple[tuple[str, str], ...]: + """Force refresh the cache for a gRPC request. + + Bypasses the cache, fetches fresh data from the server, and stores the result. + + Args: + ttl: Optional custom TTL in seconds. If not provided, uses the default TTL. + + Returns: + Metadata tuple to pass to the gRPC stub method. + + Example: + metadata = with_force_refresh() + response = stub.GetData(request, metadata=metadata) + """ + metadata = [ + (METADATA_USE_CACHE, "true"), + (METADATA_FORCE_REFRESH, "true"), + ] + if ttl is not None: + metadata.append((METADATA_CACHE_TTL, str(ttl))) + return tuple(metadata) diff --git a/python/lib/sift_py/grpc/cache_test.py b/python/lib/sift_py/grpc/cache_test.py new file mode 100644 index 000000000..e060a6ff3 --- /dev/null +++ b/python/lib/sift_py/grpc/cache_test.py @@ -0,0 +1,473 @@ +# ruff: noqa: N802 + +import logging +import tempfile +from concurrent import futures +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Callable, cast + +import grpc +import pytest +from pytest_mock import MockFixture +from sift.data.v2.data_pb2 import GetDataRequest, GetDataResponse +from sift.data.v2.data_pb2_grpc import ( + DataServiceServicer, + DataServiceStub, + add_DataServiceServicer_to_server, +) + +from sift_py._internal.test_util.server_interceptor import ServerInterceptor +from sift_py.grpc.cache import ( + GrpcCache, + with_cache, + with_force_refresh, +) +from sift_py.grpc.transport import SiftChannelConfig, use_sift_async_channel + +# Enable debug logging for cache-related modules +logging.basicConfig( + level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logging.getLogger("sift_py").setLevel(logging.DEBUG) + + +class DataService(DataServiceServicer): + """Mock data service that returns a unique response each time.""" + + call_count: int + + def __init__(self): + self.call_count = 0 + + def GetData(self, request: GetDataRequest, context: grpc.ServicerContext): + self.call_count += 1 + # Return a unique token each time to verify caching + return GetDataResponse(next_page_token=f"token-{self.call_count}") + + +class AuthInterceptor(ServerInterceptor): + """Simple auth interceptor that checks for Bearer token.""" + + def intercept( + self, + method: Callable, + request_or_iterator: Any, + context: grpc.ServicerContext, + method_name: str, + ) -> Any: + authenticated = False + for metadata in context.invocation_metadata(): + if metadata.key == "authorization": + if metadata.value.startswith("Bearer "): + authenticated = True + break + + if authenticated: + return method(request_or_iterator, context) + else: + context.set_code(grpc.StatusCode.UNAUTHENTICATED) + context.set_details("Invalid or missing API key") + raise + + +@contextmanager +def server_with_service(mocker: MockFixture): + """Create a test server with a spy on the DataService. + + Returns: + Tuple of (spy, data_service, port) + """ + server = grpc.server( + thread_pool=futures.ThreadPoolExecutor(max_workers=1), + interceptors=[AuthInterceptor()], + ) + + data_service = DataService() + spy = mocker.spy(data_service, "GetData") + + add_DataServiceServicer_to_server(data_service, server) + # Use port 0 to let the OS assign an available port + port = server.add_insecure_port("[::]:0") + server.start() + try: + yield spy, data_service, port + finally: + server.stop(None) + server.wait_for_termination() + + +def test_cache_helper_functions(): + """Test the cache metadata helper functions.""" + # Test with_cache + metadata = with_cache() + assert metadata == (("use-cache", "true"),) + + # Test with_cache with custom TTL + metadata = with_cache(ttl=7200) + assert metadata == (("use-cache", "true"), ("cache-ttl", "7200")) + + # Test with_force_refresh + metadata = with_force_refresh() + assert metadata == (("use-cache", "true"), ("force-refresh", "true")) + + # Test with_force_refresh with custom TTL + metadata = with_force_refresh(ttl=3600) + assert metadata == (("use-cache", "true"), ("force-refresh", "true"), ("cache-ttl", "3600")) + + +def test_grpc_cache_initialization(): + """Test GrpcCache initialization and configuration.""" + with tempfile.TemporaryDirectory() as tmpdir: + cache_config = { + "ttl": 1800, + "cache_path": str(Path(tmpdir) / "test_cache"), + "size_limit": 1024 * 1024, # 1MB + "clear_on_init": False, + } + + cache = GrpcCache(cache_config) + assert cache.default_ttl == 1800 + assert cache.cache_path == Path(tmpdir) / "test_cache" + assert cache.size_limit == 1024 * 1024 + assert cache.cache_path.exists() + + # Test clear_on_init + cache.set("test-key", "test-value") + assert cache.get("test-key") == "test-value" + + cache_config["clear_on_init"] = True + cache2 = GrpcCache(cache_config) + assert cache2.get("test-key") is None + + +def test_cache_key_generation(): + """Test deterministic cache key generation.""" + request1 = GetDataRequest(page_size=100) + request2 = GetDataRequest(page_size=100) + request3 = GetDataRequest(page_size=200) + + key1 = GrpcCache.key_from_proto_message("/sift.data.v2.DataService/GetData", request1) + key2 = GrpcCache.key_from_proto_message("/sift.data.v2.DataService/GetData", request2) + key3 = GrpcCache.key_from_proto_message("/sift.data.v2.DataService/GetData", request3) + + # Same request should generate same key + assert key1 == key2 + + # Different request should generate different key + assert key1 != key3 + + # Keys should be SHA256 hashes (64 hex characters) + assert len(key1) == 64 + assert all(c in "0123456789abcdef" for c in key1) + + +def test_cache_metadata_resolution(): + """Test cache metadata resolution logic.""" + # No metadata + settings = GrpcCache.resolve_cache_metadata(None) + assert settings.use_cache is False + assert settings.force_refresh is False + assert settings.custom_ttl is None + + # use-cache enabled + settings = GrpcCache.resolve_cache_metadata((("use-cache", "true"),)) + assert settings.use_cache is True + assert settings.force_refresh is False + assert settings.custom_ttl is None + + # force-refresh enabled + settings = GrpcCache.resolve_cache_metadata((("use-cache", "true"), ("force-refresh", "true"))) + assert settings.use_cache is True + assert settings.force_refresh is True + assert settings.custom_ttl is None + + # Custom TTL + settings = GrpcCache.resolve_cache_metadata((("use-cache", "true"), ("cache-ttl", "7200"))) + assert settings.use_cache is True + assert settings.force_refresh is False + assert settings.custom_ttl == 7200 + + # Invalid TTL (should be ignored) + settings = GrpcCache.resolve_cache_metadata((("use-cache", "true"), ("cache-ttl", "invalid"))) + assert settings.use_cache is True + assert settings.custom_ttl is None + + +@pytest.mark.asyncio +async def test_basic_caching(mocker: MockFixture): + """Test basic cache hit and miss scenarios.""" + with tempfile.TemporaryDirectory() as tmpdir: + with server_with_service(mocker) as (get_data_spy, data_service, port): + config: SiftChannelConfig = { + "uri": f"localhost:{port}", + "apikey": "test-token", + "use_ssl": False, + "cache_config": { + "ttl": 3600, + "cache_path": str(Path(tmpdir) / "cache"), + "size_limit": 1024 * 1024, + "clear_on_init": True, + }, + } + cache = GrpcCache(config["cache_config"]) + + async with use_sift_async_channel(config, cache=cache) as channel: + stub = DataServiceStub(channel) + request = GetDataRequest(page_size=100) + + # First call without cache - should hit server + res1 = cast(GetDataResponse, await stub.GetData(request)) + assert res1.next_page_token == "token-1" + assert data_service.call_count == 1 + + # Second call without cache - should hit server again + res2 = cast(GetDataResponse, await stub.GetData(request)) + assert res2.next_page_token == "token-2" + assert data_service.call_count == 2 + + # Third call WITH cache - should hit server + res3 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res3.next_page_token == "token-3" + assert data_service.call_count == 3 + + # Fourth call WITH cache - should use cached response + res4 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res4.next_page_token == "token-3" # Same as res3! + assert data_service.call_count == 3 # No new call + + # Fifth call WITH cache - should still use cached response + res5 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res5.next_page_token == "token-3" + assert data_service.call_count == 3 + + +@pytest.mark.asyncio +async def test_force_refresh(mocker: MockFixture): + """Test force refresh bypasses cache and updates it.""" + with tempfile.TemporaryDirectory() as tmpdir: + with server_with_service(mocker) as (get_data_spy, data_service, port): + config: SiftChannelConfig = { + "uri": f"localhost:{port}", + "apikey": "test-token", + "use_ssl": False, + "cache_config": { + "ttl": 3600, + "cache_path": str(Path(tmpdir) / "cache"), + "size_limit": 1024 * 1024, + "clear_on_init": True, + }, + } + cache = GrpcCache(config["cache_config"]) + + async with use_sift_async_channel(config, cache=cache) as channel: + stub = DataServiceStub(channel) + request = GetDataRequest(page_size=100) + + # First call with cache + res1 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res1.next_page_token == "token-1" + assert data_service.call_count == 1 + + # Second call with cache - should use cached + res2 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res2.next_page_token == "token-1" + assert data_service.call_count == 1 + + # Force refresh - should hit server and update cache + res3 = cast( + GetDataResponse, await stub.GetData(request, metadata=with_force_refresh()) + ) + assert res3.next_page_token == "token-2" + assert data_service.call_count == 2 + + # Next call with cache should use the refreshed value + res4 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res4.next_page_token == "token-2" + assert data_service.call_count == 2 + + +@pytest.mark.asyncio +async def test_ignore_cache(mocker: MockFixture): + """Test ignore_cache bypasses cache without updating it.""" + with tempfile.TemporaryDirectory() as tmpdir: + with server_with_service(mocker) as (get_data_spy, data_service, port): + config: SiftChannelConfig = { + "uri": f"localhost:{port}", + "apikey": "test-token", + "use_ssl": False, + "cache_config": { + "ttl": 3600, + "cache_path": str(Path(tmpdir) / "cache"), + "size_limit": 1024 * 1024, + "clear_on_init": True, + }, + } + cache = GrpcCache(config["cache_config"]) + + async with use_sift_async_channel(config, cache=cache) as channel: + stub = DataServiceStub(channel) + request = GetDataRequest(page_size=100) + + # First call with cache + res1 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res1.next_page_token == "token-1" + assert data_service.call_count == 1 + + # Call with no metadata - should hit server + res2 = cast(GetDataResponse, await stub.GetData(request)) + assert res2.next_page_token == "token-2" + assert data_service.call_count == 2 + + # Call with cache again - should still have original cached value + res3 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res3.next_page_token == "token-1" # Original cached value + assert data_service.call_count == 2 + + +@pytest.mark.asyncio +async def test_different_requests_different_cache_keys(mocker: MockFixture): + """Test that different requests use different cache entries.""" + with tempfile.TemporaryDirectory() as tmpdir: + with server_with_service(mocker) as (get_data_spy, data_service, port): + config: SiftChannelConfig = { + "uri": f"localhost:{port}", + "apikey": "test-token", + "use_ssl": False, + "cache_config": { + "ttl": 3600, + "cache_path": str(Path(tmpdir) / "cache"), + "size_limit": 1024 * 1024, + "clear_on_init": True, + }, + } + cache = GrpcCache(config["cache_config"]) + + async with use_sift_async_channel(config, cache=cache) as channel: + stub = DataServiceStub(channel) + request1 = GetDataRequest(page_size=100) + request2 = GetDataRequest(page_size=200) + + # First request with cache + res1 = cast(GetDataResponse, await stub.GetData(request1, metadata=with_cache())) + assert res1.next_page_token == "token-1" + assert data_service.call_count == 1 + + # Different request with cache - should hit server + res2 = cast(GetDataResponse, await stub.GetData(request2, metadata=with_cache())) + assert res2.next_page_token == "token-2" + assert data_service.call_count == 2 + + # First request again - should use cache + res3 = cast(GetDataResponse, await stub.GetData(request1, metadata=with_cache())) + assert res3.next_page_token == "token-1" + assert data_service.call_count == 2 + + # Second request again - should use cache + res4 = cast(GetDataResponse, await stub.GetData(request2, metadata=with_cache())) + assert res4.next_page_token == "token-2" + assert data_service.call_count == 2 + + +@pytest.mark.asyncio +async def test_cache_persists_across_channels(mocker: MockFixture): + """Test that cache persists across different channel instances.""" + with tempfile.TemporaryDirectory() as tmpdir: + with server_with_service(mocker) as (get_data_spy, data_service, port): + config: SiftChannelConfig = { + "uri": f"localhost:{port}", + "apikey": "test-token", + "use_ssl": False, + "cache_config": { + "ttl": 3600, + "cache_path": str(Path(tmpdir) / "cache"), + "size_limit": 1024 * 1024, + "clear_on_init": False, + }, + } + cache = GrpcCache(config["cache_config"]) + + # First channel - populate cache + async with use_sift_async_channel(config, cache=cache) as channel1: + stub1 = DataServiceStub(channel1) + request = GetDataRequest(page_size=100) + res1 = cast(GetDataResponse, await stub1.GetData(request, metadata=with_cache())) + assert res1.next_page_token == "token-1" + assert data_service.call_count == 1 + + # Second channel - should use cached value + async with use_sift_async_channel(config, cache=cache) as channel2: + stub2 = DataServiceStub(channel2) + request = GetDataRequest(page_size=100) + res2 = cast(GetDataResponse, await stub2.GetData(request, metadata=with_cache())) + assert res2.next_page_token == "token-1" # Same as first call + assert data_service.call_count == 1 # No new server call + + +@pytest.mark.asyncio +async def test_custom_ttl(mocker: MockFixture): + """Test custom TTL parameter.""" + with tempfile.TemporaryDirectory() as tmpdir: + with server_with_service(mocker) as (get_data_spy, data_service, port): + config: SiftChannelConfig = { + "uri": f"localhost:{port}", + "apikey": "test-token", + "use_ssl": False, + "cache_config": { + "ttl": 3600, + "cache_path": str(Path(tmpdir) / "cache"), + "size_limit": 1024 * 1024, + "clear_on_init": True, + }, + } + + cache = GrpcCache(config["cache_config"]) + + async with use_sift_async_channel(config, cache=cache) as channel: + stub = DataServiceStub(channel) + request = GetDataRequest(page_size=100) + + # Call with custom TTL + res1 = cast( + GetDataResponse, await stub.GetData(request, metadata=with_cache(ttl=7200)) + ) + assert res1.next_page_token == "token-1" + assert data_service.call_count == 1 + + # Verify it's cached + res2 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res2.next_page_token == "token-1" + assert data_service.call_count == 1 + + +@pytest.mark.asyncio +async def test_metadata_merging(mocker: MockFixture): + """Test that cache metadata is properly merged with API key metadata.""" + with tempfile.TemporaryDirectory() as tmpdir: + with server_with_service(mocker) as (get_data_spy, data_service, port): + config: SiftChannelConfig = { + "uri": f"localhost:{port}", + "apikey": "test-token", + "use_ssl": False, + "cache_config": { + "ttl": 3600, + "cache_path": str(Path(tmpdir) / "cache"), + "size_limit": 1024 * 1024, + "clear_on_init": True, + }, + } + cache = GrpcCache(config["cache_config"]) + + async with use_sift_async_channel(config, cache=cache) as channel: + stub = DataServiceStub(channel) + request = GetDataRequest(page_size=100) + + # This should work - cache metadata should be merged with auth metadata + res = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res.next_page_token == "token-1" + assert data_service.call_count == 1 + + # Verify cache works + res2 = cast(GetDataResponse, await stub.GetData(request, metadata=with_cache())) + assert res2.next_page_token == "token-1" + assert data_service.call_count == 1 diff --git a/python/lib/sift_py/grpc/transport.py b/python/lib/sift_py/grpc/transport.py index 07d13f667..b8b065130 100644 --- a/python/lib/sift_py/grpc/transport.py +++ b/python/lib/sift_py/grpc/transport.py @@ -15,6 +15,7 @@ from typing_extensions import NotRequired, TypeAlias from sift_py.grpc._async_interceptors.base import ClientAsyncInterceptor +from sift_py.grpc._async_interceptors.caching import CachingAsyncInterceptor from sift_py.grpc._async_interceptors.metadata import MetadataAsyncInterceptor from sift_py.grpc._interceptors.base import ClientInterceptor from sift_py.grpc._interceptors.metadata import Metadata, MetadataInterceptor @@ -78,7 +79,7 @@ def use_sift_channel( def use_sift_async_channel( - config: SiftChannelConfig, metadata: Optional[Dict[str, Any]] = None + config: SiftChannelConfig, metadata: Optional[Dict[str, Any]] = None, cache: Any = None ) -> SiftAsyncChannel: """ Like `use_sift_channel` but returns a channel meant to be used within the context @@ -88,13 +89,13 @@ def use_sift_async_channel( cert_via_openssl = config.get("cert_via_openssl", False) if not use_ssl: - return _use_insecure_sift_async_channel(config, metadata) + return _use_insecure_sift_async_channel(config, metadata, cache) return grpc_aio.secure_channel( target=_clean_uri(config["uri"], use_ssl), credentials=get_ssl_credentials(cert_via_openssl), options=_compute_channel_options(config), - interceptors=_compute_sift_async_interceptors(config, metadata), + interceptors=_compute_sift_async_interceptors(config, metadata, cache), ) @@ -112,7 +113,7 @@ def _use_insecure_sift_channel( def _use_insecure_sift_async_channel( - config: SiftChannelConfig, metadata: Optional[Dict[str, Any]] = None + config: SiftChannelConfig, metadata: Optional[Dict[str, Any]] = None, cache: Any = None ) -> SiftAsyncChannel: """ FOR DEVELOPMENT PURPOSES ONLY @@ -120,7 +121,7 @@ def _use_insecure_sift_async_channel( return grpc_aio.insecure_channel( target=_clean_uri(config["uri"], False), options=_compute_channel_options(config), - interceptors=_compute_sift_async_interceptors(config, metadata), + interceptors=_compute_sift_async_interceptors(config, metadata, cache), ) @@ -130,17 +131,27 @@ def _compute_sift_interceptors( """ Initialized all interceptors here. """ - return [ - _metadata_interceptor(config, metadata), - ] + interceptors: List[ClientInterceptor] = [] + + # Metadata interceptor should be last to ensure metadata is always added + interceptors.append(_metadata_interceptor(config, metadata)) + + return interceptors def _compute_sift_async_interceptors( - config: SiftChannelConfig, metadata: Optional[Dict[str, Any]] = None + config: SiftChannelConfig, metadata: Optional[Dict[str, Any]] = None, cache: Any = None ) -> List[grpc_aio.ClientInterceptor]: - return [ - _metadata_async_interceptor(config, metadata), - ] + interceptors: List[grpc_aio.ClientInterceptor] = [] + + # Add caching interceptor if cache instance is provided + if cache is not None: + interceptors.append(CachingAsyncInterceptor(cache=cache)) + + # Metadata interceptor should be last to ensure metadata is always added + interceptors.append(_metadata_async_interceptor(config, metadata)) + + return interceptors def _compute_channel_options(opts: SiftChannelConfig) -> List[Tuple[str, Any]]: @@ -229,6 +240,21 @@ def _compute_keep_alive_channel_opts(config: KeepaliveConfig) -> List[Tuple[str, ] +class SiftCacheConfig(TypedDict): + """ + Configuration for gRPC response caching. + - `ttl`: Time-to-live for cached entries in seconds. + - `cache_path`: Path to the cache directory. + - `size_limit`: Maximum size of the cache in bytes. + - `clear_on_init`: Whether to clear the cache on initialization. + """ + + ttl: int + cache_path: str + size_limit: int + clear_on_init: bool + + class SiftChannelConfig(TypedDict): """ Config class used to instantiate a `SiftChannel` via `use_sift_channel`. @@ -241,6 +267,8 @@ class SiftChannelConfig(TypedDict): Run `pip install sift-stack-py[openssl]` to install the dependencies required to use this option. This works around this issue with grpc loading SSL certificates: https://github.com/grpc/grpc/issues/29682. Default is False. + - `cache_config`: Optional configuration for response caching. If provided, caching will be enabled. + Use metadata flags to control caching on a per-request basis. """ uri: str @@ -248,3 +276,4 @@ class SiftChannelConfig(TypedDict): enable_keepalive: NotRequired[Union[bool, KeepaliveConfig]] use_ssl: NotRequired[bool] cert_via_openssl: NotRequired[bool] + cache_config: NotRequired[SiftCacheConfig] diff --git a/python/pyproject.toml b/python/pyproject.toml index 993b4236b..bbbaccc71 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -35,6 +35,8 @@ dependencies = [ "types-protobuf>=4.0", "typing-extensions~=4.6", "types-requests~=2.25", + "diskcache~=5.6", + "platformdirs~=4.0" ] [project.urls] @@ -149,6 +151,12 @@ module = "requests_toolbelt" ignore_missing_imports = true ignore_errors = true +[[tool.mypy.overrides]] +module = "diskcache" +ignore_missing_imports = true +ignore_errors = true + + [tool.setuptools.packages.find] where = ["lib"]