Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 58 additions & 5 deletions src/database/api_key_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import secrets
import string
import time
from datetime import datetime
from typing import List, Optional, Dict, Any

Expand All @@ -18,6 +19,9 @@
# In-memory fallback
_in_memory_api_keys: Dict[str, Dict[str, Any]] = {}

VALIDATION_CACHE_TTL_SECONDS = 120
VALIDATION_CACHE_MAX_SIZE = 2048


class APIKeyStore:
"""MongoDB-backed storage for API key management with in-memory fallback."""
Expand All @@ -34,6 +38,7 @@ def __init__(
self.api_keys = None
self._connected = False
self._in_memory = False
self._validation_cache: Dict[str, tuple[float, Dict[str, Any]]] = {}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _validation_cache is defined as an instance variable, but APIKeyStore is instantiated multiple times in the application (e.g., in src/api/routes/api_keys.py and src/api/dependencies.py). This means each instance maintains its own independent cache.

When a key is revoked or updated in one instance, the cache in the other instance (used by the authentication middleware) will not be cleared, allowing revoked keys to remain valid for the duration of the TTL. Consider making the cache a module-level variable (similar to _in_memory_api_keys) or ensuring a singleton instance is shared across the application.


# Try to connect
self._try_connect()
Expand Down Expand Up @@ -79,6 +84,36 @@ def _hash_key(self, key: str) -> str:
"""Create SHA-256 hash of the API key."""
return hashlib.sha256(key.encode()).hexdigest()

def _clear_validation_cache(self) -> None:
"""Clear cached API key validation results."""
self._validation_cache.clear()

def _get_cached_validation(self, key_hash: str) -> Optional[Dict[str, Any]]:
"""Return a cached validation result when it is still fresh."""
cached = self._validation_cache.get(key_hash)
if not cached:
return None

expires_at, key_doc = cached
if expires_at <= time.monotonic():
self._validation_cache.pop(key_hash, None)
return None

result = dict(key_doc)
result["last_used"] = datetime.utcnow()
return result

def _cache_validation(self, key_hash: str, key_doc: Dict[str, Any]) -> None:
"""Cache a sanitized active API key document."""
if len(self._validation_cache) >= VALIDATION_CACHE_MAX_SIZE:
oldest_key = next(iter(self._validation_cache))
self._validation_cache.pop(oldest_key, None)
Comment on lines +108 to +110
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The cache eviction logic is not thread-safe. In a multi-threaded environment (like FastAPI with multiple worker threads), next(iter(self._validation_cache)) can raise a StopIteration exception if the dictionary is cleared by another thread between the length check and the next() call. It can also raise a RuntimeError if the dictionary size changes during iteration.

Consider using a threading.Lock to synchronize access to the cache or using a more robust eviction pattern that handles these concurrency edge cases.


self._validation_cache[key_hash] = (
time.monotonic() + VALIDATION_CACHE_TTL_SECONDS,
dict(key_doc),
)

def create_api_key(
self,
user_id: str,
Expand Down Expand Up @@ -122,6 +157,7 @@ def create_api_key(
"is_active": True,
}
result = self.api_keys.insert_one(key_doc)
self._clear_validation_cache()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Clearing the entire validation cache when a new API key is created is unnecessary. A newly generated key cannot have a stale entry in the cache. Removing this call will prevent unnecessary cache misses for other active users during key creation.

logger.info(f"Created new API key for user {user_id}")
return {
"key": key,
Expand Down Expand Up @@ -174,6 +210,10 @@ def validate_api_key(self, key: str) -> Optional[Dict[str, Any]]:
return result
return None

cached_doc = self._get_cached_validation(key_hash)
if cached_doc:
return cached_doc

try:
key_doc = self.api_keys.find_one({
"key_hash": key_hash,
Expand All @@ -185,10 +225,15 @@ def validate_api_key(self, key: str) -> Optional[Dict[str, Any]]:
{"_id": key_doc["_id"]},
{"$set": {"last_used": now}}
)
key_doc["last_used"] = now
key_doc["id"] = str(key_doc.pop("_id"))
key_doc = {
**key_doc,
"id": str(key_doc["_id"]),
"last_used": now,
}
key_doc.pop("_id", None)
key_doc.pop("key_hash", None)
return key_doc
self._cache_validation(key_hash, key_doc)
return dict(key_doc) if key_doc else None
except Exception as e:
logger.error(f"Database error validating API key: {e}")
return None
Expand All @@ -199,6 +244,7 @@ def revoke_api_key(self, user_id: str, key_id: str) -> bool:
if key_id in _in_memory_api_keys:
if _in_memory_api_keys[key_id].get("user_id") == user_id:
_in_memory_api_keys[key_id]["is_active"] = False
self._clear_validation_cache()
return True
return False

Expand All @@ -208,7 +254,10 @@ def revoke_api_key(self, user_id: str, key_id: str) -> bool:
{"_id": ObjectId(key_id), "user_id": user_id},
{"$set": {"is_active": False}}
)
return result.modified_count > 0
success = result.modified_count > 0
if success:
self._clear_validation_cache()
return success
except Exception as e:
logger.error(f"Failed to revoke API key {key_id}: {e}")
return False
Expand All @@ -224,6 +273,7 @@ def update_api_key_name(
if key_id in _in_memory_api_keys:
if _in_memory_api_keys[key_id].get("user_id") == user_id:
_in_memory_api_keys[key_id]["name"] = new_name
self._clear_validation_cache()
return True
return False

Expand All @@ -233,7 +283,10 @@ def update_api_key_name(
{"_id": ObjectId(key_id), "user_id": user_id},
{"$set": {"name": new_name}}
)
return result.modified_count > 0
success = result.modified_count > 0
if success:
self._clear_validation_cache()
return success
except Exception as e:
logger.error(f"Failed to update API key name {key_id}: {e}")
return False
Expand Down