diff --git a/src/database/api_key_store.py b/src/database/api_key_store.py index 9f017e4..7095750 100644 --- a/src/database/api_key_store.py +++ b/src/database/api_key_store.py @@ -4,6 +4,7 @@ import logging import secrets import string +import time from datetime import datetime from typing import List, Optional, Dict, Any @@ -18,6 +19,9 @@ # In-memory fallback _in_memory_api_keys: Dict[str, Dict[str, Any]] = {} +VALIDATION_CACHE_TTL_SECONDS = 120 +VALIDATION_CACHE_MAX_SIZE = 2048 + class APIKeyStore: """MongoDB-backed storage for API key management with in-memory fallback.""" @@ -34,6 +38,7 @@ def __init__( self.api_keys = None self._connected = False self._in_memory = False + self._validation_cache: Dict[str, tuple[float, Dict[str, Any]]] = {} # Try to connect self._try_connect() @@ -79,6 +84,36 @@ def _hash_key(self, key: str) -> str: """Create SHA-256 hash of the API key.""" return hashlib.sha256(key.encode()).hexdigest() + def _clear_validation_cache(self) -> None: + """Clear cached API key validation results.""" + self._validation_cache.clear() + + def _get_cached_validation(self, key_hash: str) -> Optional[Dict[str, Any]]: + """Return a cached validation result when it is still fresh.""" + cached = self._validation_cache.get(key_hash) + if not cached: + return None + + expires_at, key_doc = cached + if expires_at <= time.monotonic(): + self._validation_cache.pop(key_hash, None) + return None + + result = dict(key_doc) + result["last_used"] = datetime.utcnow() + return result + + def _cache_validation(self, key_hash: str, key_doc: Dict[str, Any]) -> None: + """Cache a sanitized active API key document.""" + if len(self._validation_cache) >= VALIDATION_CACHE_MAX_SIZE: + oldest_key = next(iter(self._validation_cache)) + self._validation_cache.pop(oldest_key, None) + + self._validation_cache[key_hash] = ( + time.monotonic() + VALIDATION_CACHE_TTL_SECONDS, + dict(key_doc), + ) + def create_api_key( self, user_id: str, @@ -122,6 +157,7 @@ def create_api_key( "is_active": True, } result = self.api_keys.insert_one(key_doc) + self._clear_validation_cache() logger.info(f"Created new API key for user {user_id}") return { "key": key, @@ -174,6 +210,10 @@ def validate_api_key(self, key: str) -> Optional[Dict[str, Any]]: return result return None + cached_doc = self._get_cached_validation(key_hash) + if cached_doc: + return cached_doc + try: key_doc = self.api_keys.find_one({ "key_hash": key_hash, @@ -185,10 +225,15 @@ def validate_api_key(self, key: str) -> Optional[Dict[str, Any]]: {"_id": key_doc["_id"]}, {"$set": {"last_used": now}} ) - key_doc["last_used"] = now - key_doc["id"] = str(key_doc.pop("_id")) + key_doc = { + **key_doc, + "id": str(key_doc["_id"]), + "last_used": now, + } + key_doc.pop("_id", None) key_doc.pop("key_hash", None) - return key_doc + self._cache_validation(key_hash, key_doc) + return dict(key_doc) if key_doc else None except Exception as e: logger.error(f"Database error validating API key: {e}") return None @@ -199,6 +244,7 @@ def revoke_api_key(self, user_id: str, key_id: str) -> bool: if key_id in _in_memory_api_keys: if _in_memory_api_keys[key_id].get("user_id") == user_id: _in_memory_api_keys[key_id]["is_active"] = False + self._clear_validation_cache() return True return False @@ -208,7 +254,10 @@ def revoke_api_key(self, user_id: str, key_id: str) -> bool: {"_id": ObjectId(key_id), "user_id": user_id}, {"$set": {"is_active": False}} ) - return result.modified_count > 0 + success = result.modified_count > 0 + if success: + self._clear_validation_cache() + return success except Exception as e: logger.error(f"Failed to revoke API key {key_id}: {e}") return False @@ -224,6 +273,7 @@ def update_api_key_name( if key_id in _in_memory_api_keys: if _in_memory_api_keys[key_id].get("user_id") == user_id: _in_memory_api_keys[key_id]["name"] = new_name + self._clear_validation_cache() return True return False @@ -233,7 +283,10 @@ def update_api_key_name( {"_id": ObjectId(key_id), "user_id": user_id}, {"$set": {"name": new_name}} ) - return result.modified_count > 0 + success = result.modified_count > 0 + if success: + self._clear_validation_cache() + return success except Exception as e: logger.error(f"Failed to update API key name {key_id}: {e}") return False