From d998446eadf7e8c813249972b29c03dcec68e641 Mon Sep 17 00:00:00 2001 From: Genmin Date: Fri, 1 May 2026 11:51:04 -0700 Subject: [PATCH 1/2] perf: cache API key validation --- src/database/api_key_store.py | 63 ++++++++++++++-- tests/test_api_key_store.py | 131 ++++++++++++++++++++++++++++++++++ 2 files changed, 189 insertions(+), 5 deletions(-) create mode 100644 tests/test_api_key_store.py diff --git a/src/database/api_key_store.py b/src/database/api_key_store.py index 9f017e4..7095750 100644 --- a/src/database/api_key_store.py +++ b/src/database/api_key_store.py @@ -4,6 +4,7 @@ import logging import secrets import string +import time from datetime import datetime from typing import List, Optional, Dict, Any @@ -18,6 +19,9 @@ # In-memory fallback _in_memory_api_keys: Dict[str, Dict[str, Any]] = {} +VALIDATION_CACHE_TTL_SECONDS = 120 +VALIDATION_CACHE_MAX_SIZE = 2048 + class APIKeyStore: """MongoDB-backed storage for API key management with in-memory fallback.""" @@ -34,6 +38,7 @@ def __init__( self.api_keys = None self._connected = False self._in_memory = False + self._validation_cache: Dict[str, tuple[float, Dict[str, Any]]] = {} # Try to connect self._try_connect() @@ -79,6 +84,36 @@ def _hash_key(self, key: str) -> str: """Create SHA-256 hash of the API key.""" return hashlib.sha256(key.encode()).hexdigest() + def _clear_validation_cache(self) -> None: + """Clear cached API key validation results.""" + self._validation_cache.clear() + + def _get_cached_validation(self, key_hash: str) -> Optional[Dict[str, Any]]: + """Return a cached validation result when it is still fresh.""" + cached = self._validation_cache.get(key_hash) + if not cached: + return None + + expires_at, key_doc = cached + if expires_at <= time.monotonic(): + self._validation_cache.pop(key_hash, None) + return None + + result = dict(key_doc) + result["last_used"] = datetime.utcnow() + return result + + def _cache_validation(self, key_hash: str, key_doc: Dict[str, Any]) -> None: + """Cache a sanitized active API key document.""" + if len(self._validation_cache) >= VALIDATION_CACHE_MAX_SIZE: + oldest_key = next(iter(self._validation_cache)) + self._validation_cache.pop(oldest_key, None) + + self._validation_cache[key_hash] = ( + time.monotonic() + VALIDATION_CACHE_TTL_SECONDS, + dict(key_doc), + ) + def create_api_key( self, user_id: str, @@ -122,6 +157,7 @@ def create_api_key( "is_active": True, } result = self.api_keys.insert_one(key_doc) + self._clear_validation_cache() logger.info(f"Created new API key for user {user_id}") return { "key": key, @@ -174,6 +210,10 @@ def validate_api_key(self, key: str) -> Optional[Dict[str, Any]]: return result return None + cached_doc = self._get_cached_validation(key_hash) + if cached_doc: + return cached_doc + try: key_doc = self.api_keys.find_one({ "key_hash": key_hash, @@ -185,10 +225,15 @@ def validate_api_key(self, key: str) -> Optional[Dict[str, Any]]: {"_id": key_doc["_id"]}, {"$set": {"last_used": now}} ) - key_doc["last_used"] = now - key_doc["id"] = str(key_doc.pop("_id")) + key_doc = { + **key_doc, + "id": str(key_doc["_id"]), + "last_used": now, + } + key_doc.pop("_id", None) key_doc.pop("key_hash", None) - return key_doc + self._cache_validation(key_hash, key_doc) + return dict(key_doc) if key_doc else None except Exception as e: logger.error(f"Database error validating API key: {e}") return None @@ -199,6 +244,7 @@ def revoke_api_key(self, user_id: str, key_id: str) -> bool: if key_id in _in_memory_api_keys: if _in_memory_api_keys[key_id].get("user_id") == user_id: _in_memory_api_keys[key_id]["is_active"] = False + self._clear_validation_cache() return True return False @@ -208,7 +254,10 @@ def revoke_api_key(self, user_id: str, key_id: str) -> bool: {"_id": ObjectId(key_id), "user_id": user_id}, {"$set": {"is_active": False}} ) - return result.modified_count > 0 + success = result.modified_count > 0 + if success: + self._clear_validation_cache() + return success except Exception as e: logger.error(f"Failed to revoke API key {key_id}: {e}") return False @@ -224,6 +273,7 @@ def update_api_key_name( if key_id in _in_memory_api_keys: if _in_memory_api_keys[key_id].get("user_id") == user_id: _in_memory_api_keys[key_id]["name"] = new_name + self._clear_validation_cache() return True return False @@ -233,7 +283,10 @@ def update_api_key_name( {"_id": ObjectId(key_id), "user_id": user_id}, {"$set": {"name": new_name}} ) - return result.modified_count > 0 + success = result.modified_count > 0 + if success: + self._clear_validation_cache() + return success except Exception as e: logger.error(f"Failed to update API key name {key_id}: {e}") return False diff --git a/tests/test_api_key_store.py b/tests/test_api_key_store.py new file mode 100644 index 0000000..7d31073 --- /dev/null +++ b/tests/test_api_key_store.py @@ -0,0 +1,131 @@ +import importlib.util +import os +import sys +import types +from datetime import datetime +from pathlib import Path + +os.environ.setdefault("PINECONE_API_KEY", "test-pinecone-key") +os.environ.setdefault("NEO4J_PASSWORD", "test-neo4j-password") +os.environ.setdefault("GEMINI_API_KEY", "test-gemini-key") + +config_module = types.ModuleType("src.config") +config_module.settings = types.SimpleNamespace( + mongodb_uri="mongodb://localhost:27017", + mongodb_database="xmem-test", +) +sys.modules.setdefault("src.config", config_module) + +API_KEY_STORE_PATH = ( + Path(__file__).resolve().parents[1] / "src" / "database" / "api_key_store.py" +) +api_key_store_spec = importlib.util.spec_from_file_location( + "api_key_store_under_test", + API_KEY_STORE_PATH, +) +api_key_store_module = importlib.util.module_from_spec(api_key_store_spec) +api_key_store_spec.loader.exec_module(api_key_store_module) +APIKeyStore = api_key_store_module.APIKeyStore + + +class FakeUpdateResult: + def __init__(self, modified_count=1): + self.modified_count = modified_count + + +class FakeAPIKeysCollection: + def __init__(self, doc): + self.doc = doc + self.find_one_calls = 0 + self.update_one_calls = [] + + def find_one(self, query): + self.find_one_calls += 1 + if ( + self.doc + and self.doc.get("key_hash") == query.get("key_hash") + and self.doc.get("is_active") == query.get("is_active") + ): + return dict(self.doc) + return None + + def update_one(self, query, update): + self.update_one_calls.append((query, update)) + if self.doc and self.doc.get("_id") == query.get("_id"): + self.doc.update(update.get("$set", {})) + return FakeUpdateResult() + return FakeUpdateResult(modified_count=0) + + +def build_store(collection): + store = object.__new__(APIKeyStore) + store._in_memory = False + store._connected = True + store.api_keys = collection + store._validation_cache = {} + return store + + +def key_doc_for(store, key, **overrides): + doc = { + "_id": "key-1", + "user_id": "user-1", + "key_hash": store._hash_key(key), + "key_prefix": key[:8], + "name": "Primary", + "created_at": datetime(2026, 1, 1), + "last_used": None, + "is_active": True, + } + doc.update(overrides) + return doc + + +def test_validate_api_key_caches_active_database_key(monkeypatch): + monkeypatch.setattr(api_key_store_module, "VALIDATION_CACHE_TTL_SECONDS", 60) + store = build_store(None) + key = "xmem_test-key" + collection = FakeAPIKeysCollection(key_doc_for(store, key)) + store.api_keys = collection + + first = store.validate_api_key(key) + second = store.validate_api_key(key) + + assert collection.find_one_calls == 1 + assert len(collection.update_one_calls) == 1 + assert first["id"] == "key-1" + assert second["id"] == "key-1" + assert "key_hash" not in first + assert "key_hash" not in second + + second["name"] = "mutated caller copy" + third = store.validate_api_key(key) + assert third["name"] == "Primary" + + +def test_validate_api_key_requeries_after_cache_expiry(monkeypatch): + monkeypatch.setattr(api_key_store_module, "VALIDATION_CACHE_TTL_SECONDS", -1) + store = build_store(None) + key = "xmem_test-key" + collection = FakeAPIKeysCollection(key_doc_for(store, key)) + store.api_keys = collection + + store.validate_api_key(key) + store.validate_api_key(key) + + assert collection.find_one_calls == 2 + assert len(collection.update_one_calls) == 2 + + +def test_validate_api_key_does_not_cache_missing_or_inactive_key(monkeypatch): + monkeypatch.setattr(api_key_store_module, "VALIDATION_CACHE_TTL_SECONDS", 60) + store = build_store(None) + key = "xmem_test-key" + collection = FakeAPIKeysCollection(key_doc_for(store, key, is_active=False)) + store.api_keys = collection + + assert store.validate_api_key(key) is None + assert store.validate_api_key(key) is None + + assert collection.find_one_calls == 2 + assert collection.update_one_calls == [] From 18856d02043ad28f57bd1f9f3ff936afb2048380 Mon Sep 17 00:00:00 2001 From: Vedant Mahajan Date: Sat, 2 May 2026 13:39:41 +0530 Subject: [PATCH 2/2] Delete tests/test_api_key_store.py --- tests/test_api_key_store.py | 131 ------------------------------------ 1 file changed, 131 deletions(-) delete mode 100644 tests/test_api_key_store.py diff --git a/tests/test_api_key_store.py b/tests/test_api_key_store.py deleted file mode 100644 index 7d31073..0000000 --- a/tests/test_api_key_store.py +++ /dev/null @@ -1,131 +0,0 @@ -import importlib.util -import os -import sys -import types -from datetime import datetime -from pathlib import Path - -os.environ.setdefault("PINECONE_API_KEY", "test-pinecone-key") -os.environ.setdefault("NEO4J_PASSWORD", "test-neo4j-password") -os.environ.setdefault("GEMINI_API_KEY", "test-gemini-key") - -config_module = types.ModuleType("src.config") -config_module.settings = types.SimpleNamespace( - mongodb_uri="mongodb://localhost:27017", - mongodb_database="xmem-test", -) -sys.modules.setdefault("src.config", config_module) - -API_KEY_STORE_PATH = ( - Path(__file__).resolve().parents[1] / "src" / "database" / "api_key_store.py" -) -api_key_store_spec = importlib.util.spec_from_file_location( - "api_key_store_under_test", - API_KEY_STORE_PATH, -) -api_key_store_module = importlib.util.module_from_spec(api_key_store_spec) -api_key_store_spec.loader.exec_module(api_key_store_module) -APIKeyStore = api_key_store_module.APIKeyStore - - -class FakeUpdateResult: - def __init__(self, modified_count=1): - self.modified_count = modified_count - - -class FakeAPIKeysCollection: - def __init__(self, doc): - self.doc = doc - self.find_one_calls = 0 - self.update_one_calls = [] - - def find_one(self, query): - self.find_one_calls += 1 - if ( - self.doc - and self.doc.get("key_hash") == query.get("key_hash") - and self.doc.get("is_active") == query.get("is_active") - ): - return dict(self.doc) - return None - - def update_one(self, query, update): - self.update_one_calls.append((query, update)) - if self.doc and self.doc.get("_id") == query.get("_id"): - self.doc.update(update.get("$set", {})) - return FakeUpdateResult() - return FakeUpdateResult(modified_count=0) - - -def build_store(collection): - store = object.__new__(APIKeyStore) - store._in_memory = False - store._connected = True - store.api_keys = collection - store._validation_cache = {} - return store - - -def key_doc_for(store, key, **overrides): - doc = { - "_id": "key-1", - "user_id": "user-1", - "key_hash": store._hash_key(key), - "key_prefix": key[:8], - "name": "Primary", - "created_at": datetime(2026, 1, 1), - "last_used": None, - "is_active": True, - } - doc.update(overrides) - return doc - - -def test_validate_api_key_caches_active_database_key(monkeypatch): - monkeypatch.setattr(api_key_store_module, "VALIDATION_CACHE_TTL_SECONDS", 60) - store = build_store(None) - key = "xmem_test-key" - collection = FakeAPIKeysCollection(key_doc_for(store, key)) - store.api_keys = collection - - first = store.validate_api_key(key) - second = store.validate_api_key(key) - - assert collection.find_one_calls == 1 - assert len(collection.update_one_calls) == 1 - assert first["id"] == "key-1" - assert second["id"] == "key-1" - assert "key_hash" not in first - assert "key_hash" not in second - - second["name"] = "mutated caller copy" - third = store.validate_api_key(key) - assert third["name"] == "Primary" - - -def test_validate_api_key_requeries_after_cache_expiry(monkeypatch): - monkeypatch.setattr(api_key_store_module, "VALIDATION_CACHE_TTL_SECONDS", -1) - store = build_store(None) - key = "xmem_test-key" - collection = FakeAPIKeysCollection(key_doc_for(store, key)) - store.api_keys = collection - - store.validate_api_key(key) - store.validate_api_key(key) - - assert collection.find_one_calls == 2 - assert len(collection.update_one_calls) == 2 - - -def test_validate_api_key_does_not_cache_missing_or_inactive_key(monkeypatch): - monkeypatch.setattr(api_key_store_module, "VALIDATION_CACHE_TTL_SECONDS", 60) - store = build_store(None) - key = "xmem_test-key" - collection = FakeAPIKeysCollection(key_doc_for(store, key, is_active=False)) - store.api_keys = collection - - assert store.validate_api_key(key) is None - assert store.validate_api_key(key) is None - - assert collection.find_one_calls == 2 - assert collection.update_one_calls == []