Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 32 additions & 56 deletions fastapi_users_db_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,33 +23,19 @@ class SQLAlchemyBaseUserTable(Generic[ID]):
__tablename__ = "user"

if TYPE_CHECKING: # pragma: no cover
id: ID
email: str
hashed_password: str
is_active: bool
is_superuser: bool
is_verified: bool
else:
email: Mapped[str] = mapped_column(
String(length=320), unique=True, index=True, nullable=False
)
hashed_password: Mapped[str] = mapped_column(
String(length=1024), nullable=False
)
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
is_superuser: Mapped[bool] = mapped_column(
Boolean, default=False, nullable=False
)
is_verified: Mapped[bool] = mapped_column(
Boolean, default=False, nullable=False
)
id: Mapped[ID]

email: Mapped[str] = mapped_column(
String(length=320), unique=True, index=True, nullable=False
)
hashed_password: Mapped[str] = mapped_column(String(length=1024), nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
is_superuser: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
is_verified: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)


class SQLAlchemyBaseUserTableUUID(SQLAlchemyBaseUserTable[UUID_ID]):
if TYPE_CHECKING: # pragma: no cover
id: UUID_ID
else:
id: Mapped[UUID_ID] = mapped_column(GUID, primary_key=True, default=uuid.uuid4)
id: Mapped[UUID_ID] = mapped_column(GUID, primary_key=True, default=uuid.uuid4)


class SQLAlchemyBaseOAuthAccountTable(Generic[ID]):
Expand All @@ -58,40 +44,30 @@ class SQLAlchemyBaseOAuthAccountTable(Generic[ID]):
__tablename__ = "oauth_account"

if TYPE_CHECKING: # pragma: no cover
id: ID
oauth_name: str
access_token: str
expires_at: Optional[int]
refresh_token: Optional[str]
account_id: str
account_email: str
else:
oauth_name: Mapped[str] = mapped_column(
String(length=100), index=True, nullable=False
)
access_token: Mapped[str] = mapped_column(String(length=1024), nullable=False)
expires_at: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
refresh_token: Mapped[Optional[str]] = mapped_column(
String(length=1024), nullable=True
)
account_id: Mapped[str] = mapped_column(
String(length=320), index=True, nullable=False
)
account_email: Mapped[str] = mapped_column(String(length=320), nullable=False)
id: Mapped[ID]

oauth_name: Mapped[str] = mapped_column(
String(length=100), index=True, nullable=False
)
access_token: Mapped[str] = mapped_column(String(length=1024), nullable=False)
expires_at: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
refresh_token: Mapped[Optional[str]] = mapped_column(
String(length=1024), nullable=True
)
account_id: Mapped[str] = mapped_column(
String(length=320), index=True, nullable=False
)
account_email: Mapped[str] = mapped_column(String(length=320), nullable=False)


class SQLAlchemyBaseOAuthAccountTableUUID(SQLAlchemyBaseOAuthAccountTable[UUID_ID]):
if TYPE_CHECKING: # pragma: no cover
id: UUID_ID
user_id: UUID_ID
else:
id: Mapped[UUID_ID] = mapped_column(GUID, primary_key=True, default=uuid.uuid4)
id: Mapped[UUID_ID] = mapped_column(GUID, primary_key=True, default=uuid.uuid4)

@declared_attr
def user_id(cls) -> Mapped[GUID]:
return mapped_column(
GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False
)
@declared_attr
def user_id(cls) -> Mapped[UUID_ID]:
return mapped_column(
GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False
)


class SQLAlchemyUserDatabase(Generic[UP, ID], BaseUserDatabase[UP, ID]):
Expand Down Expand Up @@ -134,8 +110,8 @@ async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UP
statement = (
select(self.user_table)
.join(self.oauth_account_table)
.where(self.oauth_account_table.oauth_name == oauth) # type: ignore
.where(self.oauth_account_table.account_id == account_id) # type: ignore
.where(self.oauth_account_table.oauth_name == oauth)
.where(self.oauth_account_table.account_id == account_id)
)
return await self._get_user(statement)

Expand Down
29 changes: 10 additions & 19 deletions fastapi_users_db_sqlalchemy/access_token.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import uuid
from datetime import datetime
from typing import TYPE_CHECKING, Any, Generic, Optional
from typing import Any, Generic, Optional

from fastapi_users.authentication.strategy.db import AP, AccessTokenDatabase
from fastapi_users.models import ID
Expand All @@ -16,27 +16,18 @@ class SQLAlchemyBaseAccessTokenTable(Generic[ID]):

__tablename__ = "accesstoken"

if TYPE_CHECKING: # pragma: no cover
token: str
created_at: datetime
user_id: ID
else:
token: Mapped[str] = mapped_column(String(length=43), primary_key=True)
created_at: Mapped[datetime] = mapped_column(
TIMESTAMPAware(timezone=True), index=True, nullable=False, default=now_utc
)
token: Mapped[str] = mapped_column(String(length=43), primary_key=True)
created_at: Mapped[datetime] = mapped_column(
TIMESTAMPAware(timezone=True), index=True, nullable=False, default=now_utc
)


class SQLAlchemyBaseAccessTokenTableUUID(SQLAlchemyBaseAccessTokenTable[uuid.UUID]):
if TYPE_CHECKING: # pragma: no cover
user_id: uuid.UUID
else:

@declared_attr
def user_id(cls) -> Mapped[GUID]:
return mapped_column(
GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False
)
@declared_attr
def user_id(cls) -> Mapped[uuid.UUID]:
return mapped_column(
GUID, ForeignKey("user.id", ondelete="cascade"), nullable=False
)


class SQLAlchemyAccessTokenDatabase(Generic[AP], AccessTokenDatabase[AP]):
Expand Down