From f82ef8420508c44022e3ee87058c3621ccba050e Mon Sep 17 00:00:00 2001 From: Alessandro Colace Date: Sat, 6 Jun 2026 15:45:55 +0200 Subject: [PATCH] Fix typing on SQLAlchemy where clause --- fastapi_users_db_sqlalchemy/__init__.py | 88 ++++++++------------- fastapi_users_db_sqlalchemy/access_token.py | 29 +++---- 2 files changed, 42 insertions(+), 75 deletions(-) diff --git a/fastapi_users_db_sqlalchemy/__init__.py b/fastapi_users_db_sqlalchemy/__init__.py index 467a2bf..f4419cc 100644 --- a/fastapi_users_db_sqlalchemy/__init__.py +++ b/fastapi_users_db_sqlalchemy/__init__.py @@ -23,33 +23,19 @@ class SQLAlchemyBaseUserTable(Generic[ID]): __tablename__ = "user" if TYPE_CHECKING: # pragma: no cover - id: ID - email: str - hashed_password: str - is_active: bool - is_superuser: bool - is_verified: bool - else: - email: Mapped[str] = mapped_column( - String(length=320), unique=True, index=True, nullable=False - ) - hashed_password: Mapped[str] = mapped_column( - String(length=1024), nullable=False - ) - is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) - is_superuser: Mapped[bool] = mapped_column( - Boolean, default=False, nullable=False - ) - is_verified: Mapped[bool] = mapped_column( - Boolean, default=False, nullable=False - ) + id: Mapped[ID] + + email: Mapped[str] = mapped_column( + String(length=320), unique=True, index=True, nullable=False + ) + hashed_password: Mapped[str] = mapped_column(String(length=1024), nullable=False) + is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + is_superuser: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + is_verified: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) class SQLAlchemyBaseUserTableUUID(SQLAlchemyBaseUserTable[UUID_ID]): - if TYPE_CHECKING: # pragma: no cover - id: UUID_ID - else: - id: Mapped[UUID_ID] = mapped_column(GUID, primary_key=True, default=uuid.uuid4) + id: Mapped[UUID_ID] = mapped_column(GUID, primary_key=True, default=uuid.uuid4) class SQLAlchemyBaseOAuthAccountTable(Generic[ID]): @@ -58,40 +44,30 @@ class SQLAlchemyBaseOAuthAccountTable(Generic[ID]): __tablename__ = "oauth_account" if TYPE_CHECKING: # pragma: no cover - id: ID - oauth_name: str - access_token: str - expires_at: Optional[int] - refresh_token: Optional[str] - account_id: str - account_email: str - else: - oauth_name: Mapped[str] = mapped_column( - String(length=100), index=True, nullable=False - ) - access_token: Mapped[str] = mapped_column(String(length=1024), nullable=False) - expires_at: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) - refresh_token: Mapped[Optional[str]] = mapped_column( - String(length=1024), nullable=True - ) - account_id: Mapped[str] = mapped_column( - String(length=320), index=True, nullable=False - ) - account_email: Mapped[str] = mapped_column(String(length=320), nullable=False) + id: Mapped[ID] + + oauth_name: Mapped[str] = mapped_column( + String(length=100), index=True, nullable=False + ) + access_token: Mapped[str] = mapped_column(String(length=1024), nullable=False) + expires_at: Mapped[Optional[int]] = mapped_column(Integer, nullable=True) + refresh_token: Mapped[Optional[str]] = mapped_column( + String(length=1024), nullable=True + ) + account_id: Mapped[str] = mapped_column( + String(length=320), index=True, nullable=False + ) + account_email: Mapped[str] = mapped_column(String(length=320), nullable=False) class SQLAlchemyBaseOAuthAccountTableUUID(SQLAlchemyBaseOAuthAccountTable[UUID_ID]): - if TYPE_CHECKING: # pragma: no cover - id: UUID_ID - user_id: UUID_ID - else: - id: Mapped[UUID_ID] = mapped_column(GUID, primary_key=True, default=uuid.uuid4) + id: Mapped[UUID_ID] = mapped_column(GUID, primary_key=True, default=uuid.uuid4) - @declared_attr - def user_id(cls) -> Mapped[GUID]: - return mapped_column( - GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False - ) + @declared_attr + def user_id(cls) -> Mapped[UUID_ID]: + return mapped_column( + GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False + ) class SQLAlchemyUserDatabase(Generic[UP, ID], BaseUserDatabase[UP, ID]): @@ -134,8 +110,8 @@ async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UP statement = ( select(self.user_table) .join(self.oauth_account_table) - .where(self.oauth_account_table.oauth_name == oauth) # type: ignore - .where(self.oauth_account_table.account_id == account_id) # type: ignore + .where(self.oauth_account_table.oauth_name == oauth) + .where(self.oauth_account_table.account_id == account_id) ) return await self._get_user(statement) diff --git a/fastapi_users_db_sqlalchemy/access_token.py b/fastapi_users_db_sqlalchemy/access_token.py index 9f68af6..2a6800f 100644 --- a/fastapi_users_db_sqlalchemy/access_token.py +++ b/fastapi_users_db_sqlalchemy/access_token.py @@ -1,6 +1,6 @@ import uuid from datetime import datetime -from typing import TYPE_CHECKING, Any, Generic, Optional +from typing import Any, Generic, Optional from fastapi_users.authentication.strategy.db import AP, AccessTokenDatabase from fastapi_users.models import ID @@ -16,27 +16,18 @@ class SQLAlchemyBaseAccessTokenTable(Generic[ID]): __tablename__ = "accesstoken" - if TYPE_CHECKING: # pragma: no cover - token: str - created_at: datetime - user_id: ID - else: - token: Mapped[str] = mapped_column(String(length=43), primary_key=True) - created_at: Mapped[datetime] = mapped_column( - TIMESTAMPAware(timezone=True), index=True, nullable=False, default=now_utc - ) + token: Mapped[str] = mapped_column(String(length=43), primary_key=True) + created_at: Mapped[datetime] = mapped_column( + TIMESTAMPAware(timezone=True), index=True, nullable=False, default=now_utc + ) class SQLAlchemyBaseAccessTokenTableUUID(SQLAlchemyBaseAccessTokenTable[uuid.UUID]): - if TYPE_CHECKING: # pragma: no cover - user_id: uuid.UUID - else: - - @declared_attr - def user_id(cls) -> Mapped[GUID]: - return mapped_column( - GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False - ) + @declared_attr + def user_id(cls) -> Mapped[uuid.UUID]: + return mapped_column( + GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False + ) class SQLAlchemyAccessTokenDatabase(Generic[AP], AccessTokenDatabase[AP]):