Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion app/core/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
class RedisKey(str, Enum):
UserSession = "user_session"
UserSessionByUser = "user_session:{user_id}"
INVALID_TOKEN_SET_KEY= "notifications:invalid_tokens"
INVALID_TOKEN_SET_KEY = "notifications:invalid_tokens"
MobileSessionCache = "session:{session_id}"


NOTIFICATION_EVENT_SUBJECT = "notification_event"
Expand Down
41 changes: 38 additions & 3 deletions app/deps/token_auth.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
from datetime import datetime, timezone
from typing import Annotated
import uuid

from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from pydantic import BaseModel

from app.container import get_container, Container
from app.core.config import settings
from app.core.securite import decode_access_mobile_token
from app.infra.redis import RedisClient
from app.service.session import MobileSessionCache, SessionService

security = HTTPBearer()


class MobileUserSchema(BaseModel):
user_id: uuid.UUID
email: str
Expand All @@ -20,7 +27,8 @@ async def get_current_mobile_user(
) -> MobileUserSchema:
"""
Dependency to get the current logged-in mobile user.
Returns a strict Pydantic model.
Fast path: Redis cache (0 DB queries).
Slow path: Postgres fallback (2 DB queries) with cache re-population.
"""
token = credentials.credentials
payload = decode_access_mobile_token(token)
Expand All @@ -31,7 +39,23 @@ async def get_current_mobile_user(

session_id = uuid.UUID(session_id_str)

# Validate session via SessionService
# --- Fast path: Redis cache ---
redis = RedisClient.get_instance()
cached: MobileSessionCache | None = await SessionService.get_cached_session(
redis, session_id
)
if cached is not None:
if cached.expires_at < datetime.now(timezone.utc):
raise HTTPException(status_code=401, detail="Session expired")
if cached.blocked:
raise HTTPException(status_code=403, detail="User is blocked")
return MobileUserSchema(
user_id=cached.user_id,
email=cached.email,
session_id=cached.session_id,
)

# --- Slow path: Postgres fallback ---
session = await container.session_service.session_querier.get_session_by_id(id=session_id)
if not session:
raise HTTPException(status_code=401, detail="Session not found")
Expand All @@ -46,8 +70,19 @@ async def get_current_mobile_user(
if user.blocked:
raise HTTPException(status_code=403, detail="User is blocked")

# Re-populate cache so next request hits Redis
await SessionService.cache_session_for_auth(
redis=redis,
session_id=session.id,
user_id=session.user_id,
email=user.email or "",
expires_at=session.expires_at,
blocked=user.blocked,
ttl=settings.MOBILE_SESSION_TTL_SECONDS,
)

return MobileUserSchema(
user_id=user.id,
email=user.email,
email=user.email or "",
session_id=session.id,
)
122 changes: 84 additions & 38 deletions app/service/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,31 @@
from db.generated import session as session_queries
import uuid
from db.generated.models import UserSession
from datetime import datetime,timedelta,timezone
from datetime import datetime, timedelta, timezone
from app.infra.redis import RedisClient
from app.core.constant import RedisKey
from db.generated.session import UpsertSessionRow


class SessionRedis(BaseModel):
session_id:uuid.UUID
user_id:uuid.UUID
device_id:uuid.UUID
last_active:datetime
expires_at:datetime
session_id: uuid.UUID
user_id: uuid.UUID
device_id: uuid.UUID
last_active: datetime
expires_at: datetime


class MobileSessionCache(BaseModel):
session_id: uuid.UUID
user_id: uuid.UUID
email: str
expires_at: datetime
blocked: bool

class SessionService :
session_querier : session_queries.AsyncQuerier
redis : RedisClient

class SessionService:
session_querier: session_queries.AsyncQuerier
redis: RedisClient

def init(self, session: session_queries.AsyncQuerier, redis: RedisClient) -> None:
self.session_querier = session
Expand All @@ -26,14 +36,53 @@ def init(self, session: session_queries.AsyncQuerier, redis: RedisClient) -> Non
SessionService.redis = redis

@staticmethod
async def create_session(user_id:uuid.UUID,device_id:uuid.UUID)->UpsertSessionRow:
try :
async def cache_session_for_auth(
redis: RedisClient,
session_id: uuid.UUID,
user_id: uuid.UUID,
email: str,
expires_at: datetime,
blocked: bool,
ttl: int,
) -> None:
key = RedisKey.MobileSessionCache.value.format(session_id=session_id)
payload = MobileSessionCache(
session_id=session_id,
user_id=user_id,
email=email,
expires_at=expires_at,
blocked=blocked,
)
await redis.set(key=key, value=payload.model_dump_json(), expire=ttl)

@staticmethod
async def get_cached_session(
redis: RedisClient,
session_id: uuid.UUID,
) -> MobileSessionCache | None:
key = RedisKey.MobileSessionCache.value.format(session_id=session_id)
raw = await redis.get(key)
if raw is None:
return None
return MobileSessionCache.model_validate_json(raw)

@staticmethod
async def delete_session_cache(
redis: RedisClient,
session_id: uuid.UUID,
) -> None:
key = RedisKey.MobileSessionCache.value.format(session_id=session_id)
await redis.delete(key)

@staticmethod
async def create_session(user_id: uuid.UUID, device_id: uuid.UUID) -> UpsertSessionRow:
try:
session = await SessionService.session_querier.upsert_session(
user_id=user_id,
device_id=device_id,
expires_at=datetime.now(timezone.utc) + timedelta(days=7),
)
if session is None :
if session is None:
raise AppException.internal_error("session creation failed ")

result = await SessionService.redis.set(
Expand All @@ -45,33 +94,30 @@ async def create_session(user_id:uuid.UUID,device_id:uuid.UUID)->UpsertSessionRo
last_active=session.last_active,
expires_at=session.expires_at,
).model_dump_json(),
expire=60*60*5,
nx=True
expire=60 * 60 * 5,
nx=True,
)
if not result:
AppException.forbidden("You already logged in in another device")
return session
except Exception as e :
raise DBExceptionImpl.handle(e)


except Exception as e:
raise DBExceptionImpl.handle(e)

@staticmethod
async def get_session_by_id(session_id:uuid.UUID)->UserSession:
try :
async def get_session_by_id(session_id: uuid.UUID) -> UserSession:
try:
session = await SessionService.session_querier.get_session_by_id(id=session_id)
if session is None :
if session is None:
raise AppException.not_found("session Not found ")
return session
except Exception as e :
except Exception as e:
raise DBExceptionImpl.handle(e)


@staticmethod
async def check_session(
session_id: uuid.UUID,
user_id: uuid.UUID,
device_id: uuid.UUID
device_id: uuid.UUID,
) -> bool:
try:
session_in_redis = await SessionService.redis.get(
Expand Down Expand Up @@ -125,31 +171,31 @@ async def check_session(
except Exception as e:
raise DBExceptionImpl.handle(e)


@staticmethod
async def delete_session(
session_id: uuid.UUID, user_id: uuid.UUID, device_id: uuid.UUID
) -> None:
try :
await SessionService.session_querier.delete_session_by_device(user_id=user_id,device_id=device_id)
except Exception as e :
try:
await SessionService.session_querier.delete_session_by_device(
user_id=user_id, device_id=device_id
)
except Exception as e:
raise DBExceptionImpl.handle(e)


@staticmethod
async def delete_expired_sessions() -> None:
try :
try:
await SessionService.session_querier.delete_expired_sessions()
except Exception as e :
except Exception as e:
raise DBExceptionImpl.handle(e)

@staticmethod
async def count_user_sessions(user_id:uuid.UUID)->int:
try :
count = await SessionService.session_querier.count_user_sessions(user_id=user_id)
if count is None :
async def count_user_sessions(user_id: uuid.UUID) -> int:
try:
count = await SessionService.session_querier.count_user_sessions(user_id=user_id)
if count is None:
raise AppException.internal_error("failed to count ")
else :
else:
return count
except Exception as e :
raise DBExceptionImpl.handle(e)
except Exception as e:
raise DBExceptionImpl.handle(e)
33 changes: 30 additions & 3 deletions app/service/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from app.core.logger import logger
from app.service.face_embedding import FaceImagePayload, FaceEmbeddingService
from app.schema.internal.single_face_match import ClosestUserMatch
from app.service.session import SessionService


class AuthService:
Expand Down Expand Up @@ -142,6 +143,17 @@ async def mobile_register_login(
expiry = Get_expiry_time()
logger.info("created session %s for user %s", session.id, user_id)

# Populate Redis auth cache for fast-path validation
await SessionService.cache_session_for_auth(
redis=redis,
session_id=session.id,
user_id=user_id,
email=user.email or "",
expires_at=session.expires_at,
blocked=user.blocked,
ttl=AuthService.REDIS_SESSION_TTL,
)

return MobileAuthResponse(
access_token=access_token,
refresh_token=refresh_token,
Expand Down Expand Up @@ -325,6 +337,14 @@ async def delete_user(self, *, redis: RedisClient, user_id: uuid.UUID) -> User:
session_key = constant.RedisKey.UserSessionByUser.value.format(
user_id=user_id
)
# Best-effort: also invalidate the per-session MobileSessionCache.
raw_session_id = await redis.get(session_key)
if raw_session_id:
try:
session_id = uuid.UUID(raw_session_id)
await SessionService.delete_session_cache(redis=redis, session_id=session_id)
except (ValueError, Exception):
pass
await redis.delete(session_key)
return existing
except Exception as exc:
Expand All @@ -337,9 +357,16 @@ async def block_user(self, *, redis: RedisClient, user_id: uuid.UUID) -> User:
if not user:
raise AppException.not_found("User not found")

session_key = constant.RedisKey.UserSessionByUser.value.format(
user_id=user_id
)
session_key = constant.RedisKey.UserSessionByUser.value.format(user_id=user_id)
# Best-effort: retrieve the session_id from UserSessionByUser cache to also
# invalidate the per-session MobileSessionCache entry.
raw_session_id = await redis.get(session_key)
if raw_session_id:
try:
session_id = uuid.UUID(raw_session_id)
await SessionService.delete_session_cache(redis=redis, session_id=session_id)
except (ValueError, Exception):
pass # non-blocking: session cache will expire naturally
await redis.delete(session_key)

return user
Expand Down