Skip to content

Commit cd565f4

Browse files
committed
async support v1
1 parent dfab46d commit cd565f4

File tree

18 files changed

+253
-255
lines changed

18 files changed

+253
-255
lines changed
File renamed without changes.

{{cookiecutter.project_name}}/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ __pycache__/
1212

1313
# C extensions
1414
*.so
15+
.env
1516

1617
# Distribution / packaging
1718
.Python

{{cookiecutter.project_name}}/alembic/env.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from sqlalchemy import engine_from_config
44
from sqlalchemy import pool
5+
from sqlalchemy.ext.asyncio import AsyncEngine
6+
from asyncio import get_event_loop
57

68
from app.core.config import settings
79
from alembic import context
@@ -56,7 +58,14 @@ def run_migrations_offline():
5658
context.run_migrations()
5759

5860

59-
def run_migrations_online():
61+
def do_run_migrations(connection):
62+
context.configure(connection=connection, target_metadata=target_metadata)
63+
64+
with context.begin_transaction():
65+
context.run_migrations()
66+
67+
68+
async def run_migrations_online():
6069
"""Run migrations in 'online' mode.
6170
6271
In this scenario we need to create an Engine
@@ -66,19 +75,19 @@ def run_migrations_online():
6675
configuration = config.get_section(config.config_ini_section)
6776
assert configuration
6877
configuration["sqlalchemy.url"] = get_database_uri()
69-
connectable = engine_from_config(
70-
configuration,
71-
prefix="sqlalchemy.",
72-
poolclass=pool.NullPool,
78+
connectable = AsyncEngine(
79+
engine_from_config(
80+
configuration,
81+
prefix="sqlalchemy.",
82+
poolclass=pool.NullPool,
83+
future=True,
84+
) # type: ignore
7385
)
74-
with connectable.connect() as connection:
75-
context.configure(connection=connection, target_metadata=target_metadata)
76-
77-
with context.begin_transaction():
78-
context.run_migrations()
86+
async with connectable.connect() as connection:
87+
await connection.run_sync(do_run_migrations)
7988

8089

8190
if context.is_offline_mode():
8291
run_migrations_offline()
8392
else:
84-
run_migrations_online()
93+
get_event_loop().run_until_complete(run_migrations_online())
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""init
2+
3+
Revision ID: 3c44f0828d4f
4+
Revises:
5+
Create Date: 2021-10-23 12:14:50.440666
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
11+
12+
# revision identifiers, used by Alembic.
13+
revision = '3c44f0828d4f'
14+
down_revision = None
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade():
20+
# ### commands auto generated by Alembic - please adjust! ###
21+
op.create_table('user',
22+
sa.Column('id', sa.Integer(), nullable=False),
23+
sa.Column('full_name', sa.String(length=254), nullable=True),
24+
sa.Column('email', sa.String(length=254), nullable=False),
25+
sa.Column('hashed_password', sa.String(length=128), nullable=False),
26+
sa.PrimaryKeyConstraint('id')
27+
)
28+
op.create_index(op.f('ix_user_email'), 'user', ['email'], unique=True)
29+
op.create_index(op.f('ix_user_id'), 'user', ['id'], unique=False)
30+
# ### end Alembic commands ###
31+
32+
33+
def downgrade():
34+
# ### commands auto generated by Alembic - please adjust! ###
35+
op.drop_index(op.f('ix_user_id'), table_name='user')
36+
op.drop_index(op.f('ix_user_email'), table_name='user')
37+
op.drop_table('user')
38+
# ### end Alembic commands ###
Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,28 @@
1-
from typing import Generator, Optional
1+
from typing import AsyncGenerator, Optional
22

33
from fastapi import Depends, HTTPException, status
44
from fastapi.security import OAuth2PasswordBearer
55
from jose import jwt
66
from pydantic import ValidationError
77
from sqlalchemy import select
8-
from sqlalchemy.orm import Session
8+
from sqlalchemy.ext.asyncio import AsyncSession
99

1010
from app import schemas
1111
from app.core import security
1212
from app.core.config import settings
1313
from app.models import User
14-
from app.session import SessionLocal
14+
from app.session import async_session
1515

16-
reusable_oauth2 = OAuth2PasswordBearer(
17-
tokenUrl=f"{settings.API_STR}/login/access-token"
18-
)
16+
reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_STR}/auth/access-token")
1917

2018

21-
def get_session() -> Generator:
22-
try:
23-
session: Session = SessionLocal()
19+
async def get_session() -> AsyncGenerator[AsyncSession, None]:
20+
async with async_session() as session:
2421
yield session
25-
except:
26-
raise Exception
27-
session.close()
2822

2923

30-
def get_current_user(
31-
session: Session = Depends(get_session), token: str = Depends(reusable_oauth2)
24+
async def get_current_user(
25+
session: AsyncSession = Depends(get_session), token: str = Depends(reusable_oauth2)
3226
) -> User:
3327

3428
try:
@@ -41,16 +35,16 @@ def get_current_user(
4135
status_code=status.HTTP_403_FORBIDDEN,
4236
detail="Could not validate credentials",
4337
)
44-
user: Optional[User] = (
45-
session.execute(select(User).where(User.id == token_data.sub)).scalars().first()
46-
)
38+
39+
result = await session.execute(select(User).where(User.id == token_data.sub))
40+
user: Optional[User] = result.scalars().first()
4741

4842
if not user:
4943
raise HTTPException(status_code=404, detail="User not found")
5044
return user
5145

5246

53-
def get_current_active_user(
47+
async def get_current_active_user(
5448
current_user: User = Depends(get_current_user),
5549
) -> User:
5650
return current_user

{{cookiecutter.project_name}}/app/api/endpoints/auth.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from jose import jwt
66
from pydantic import ValidationError
77
from sqlalchemy import select
8-
from sqlalchemy.orm import Session
8+
from sqlalchemy.ext.asyncio import AsyncSession
99

1010
from app import schemas
1111
from app.api import deps
@@ -17,20 +17,15 @@
1717

1818

1919
@router.post("/access-token", response_model=schemas.Token)
20-
def login_access_token(
21-
session: Session = Depends(deps.get_session),
20+
async def login_access_token(
21+
session: AsyncSession = Depends(deps.get_session),
2222
form_data: OAuth2PasswordRequestForm = Depends(),
2323
):
2424
"""
2525
OAuth2 compatible token, get an access token for future requests using username and password
2626
"""
27-
28-
user: Optional[User] = (
29-
session.execute(select(User).where(User.email == form_data.username))
30-
.scalars()
31-
.first()
32-
)
33-
27+
result = await session.execute(select(User).where(User.email == form_data.username))
28+
user: Optional[User] = result.scalars().first()
3429
if user is None:
3530
raise HTTPException(status_code=400, detail="Incorrect email or password")
3631

@@ -49,15 +44,17 @@ def login_access_token(
4944

5045

5146
@router.post("/test-token", response_model=schemas.User)
52-
def test_token(current_user: User = Depends(deps.get_current_user)):
47+
async def test_token(current_user: User = Depends(deps.get_current_user)):
5348
"""
5449
Test access token
5550
"""
5651
return current_user
5752

5853

5954
@router.post("/refresh-token", response_model=schemas.Token)
60-
def refresh_token(refresh_token: str, session: Session = Depends(deps.get_session)):
55+
async def refresh_token(
56+
refresh_token: str, session: AsyncSession = Depends(deps.get_session)
57+
):
6158
"""
6259
OAuth2 compatible token, get an access token for future requests using refresh token
6360
"""
@@ -76,13 +73,12 @@ def refresh_token(refresh_token: str, session: Session = Depends(deps.get_sessio
7673
status_code=status.HTTP_403_FORBIDDEN,
7774
detail="Could not validate credentials",
7875
)
79-
80-
user: Optional[User] = (
81-
session.execute(select(User).where(User.id == token_data.sub)).scalars().first()
82-
)
76+
result = await session.execute(select(User).where(User.id == token_data.sub))
77+
user: Optional[User] = result.scalars().first()
8378

8479
if user is None:
8580
raise HTTPException(status_code=404, detail="User not found")
81+
8682
access_token, expire_at = security.create_access_token(user.id)
8783
refresh_token, refresh_expire_at = security.create_refresh_token(user.id)
8884
return {

{{cookiecutter.project_name}}/app/api/endpoints/users.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any
22

33
from fastapi import APIRouter, Depends
4-
from sqlalchemy.orm import Session
4+
from sqlalchemy.ext.asyncio import AsyncSession
55

66
from app import models, schemas
77
from app.api import deps
@@ -11,9 +11,9 @@
1111

1212

1313
@router.put("/me", response_model=schemas.User)
14-
def update_user_me(
14+
async def update_user_me(
1515
user_update: schemas.UserUpdate,
16-
session: Session = Depends(deps.get_session),
16+
session: AsyncSession = Depends(deps.get_session),
1717
current_user: models.User = Depends(deps.get_current_active_user),
1818
) -> Any:
1919
"""
@@ -27,15 +27,15 @@ def update_user_me(
2727
current_user.email = user_update.email
2828

2929
session.add(current_user)
30-
session.commit()
31-
session.refresh(current_user)
30+
await session.commit()
31+
await session.refresh(current_user)
3232

3333
return current_user
3434

3535

3636
@router.get("/me", response_model=schemas.User)
37-
def read_user_me(
38-
session: Session = Depends(deps.get_session),
37+
async def read_user_me(
38+
session: AsyncSession = Depends(deps.get_session),
3939
current_user: models.User = Depends(deps.get_current_active_user),
4040
) -> Any:
4141
"""

{{cookiecutter.project_name}}/app/core/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _assemble_db_connection(cls, v: str, values: Dict[str, Optional[str]]) -> st
5555
postgres_server = values.get("POSTGRES_SERVER")
5656

5757
return AnyUrl.build(
58-
scheme="postgresql+psycopg2",
58+
scheme="postgresql+asyncpg",
5959
user=values.get("POSTGRES_USER"),
6060
password=values.get("POSTGRES_PASSWORD"),
6161
host=postgres_server or "localhost",
Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,40 @@
11
import logging
2+
from typing import Optional
23

34
from sqlalchemy import select
45

56
from app.core import security
67
from app.core.config import settings
78
from app.models import User
8-
from app.session import SessionLocal
9+
from app.session import async_session
10+
from asyncio import get_event_loop
911

1012

11-
def main() -> None:
13+
async def main() -> None:
1214
logging.info("Start initial data")
13-
session = SessionLocal()
14-
user = (
15-
session.execute(
15+
async with async_session() as session:
16+
17+
result = await session.execute(
1618
select(User).where(User.email == settings.FIRST_SUPERUSER_EMAIL)
1719
)
18-
.scalars()
19-
.first()
20-
)
21-
22-
if user is None:
23-
new_superuser = User(
24-
email=settings.FIRST_SUPERUSER_EMAIL,
25-
hashed_password=security.get_password_hash(
26-
settings.FIRST_SUPERUSER_PASSWORD
27-
),
28-
full_name=settings.FIRST_SUPERUSER_EMAIL,
29-
)
30-
session.add(new_superuser)
31-
session.commit()
32-
logging.info("Superuser was created")
33-
else:
34-
logging.warning("Superuser already exists in database")
20+
user: Optional[User] = result.scalars().first()
21+
22+
if user is None:
23+
new_superuser = User(
24+
email=settings.FIRST_SUPERUSER_EMAIL,
25+
hashed_password=security.get_password_hash(
26+
settings.FIRST_SUPERUSER_PASSWORD
27+
),
28+
full_name=settings.FIRST_SUPERUSER_EMAIL,
29+
)
30+
session.add(new_superuser)
31+
await session.commit()
32+
logging.info("Superuser was created")
33+
else:
34+
logging.warning("Superuser already exists in database")
3535

36-
logging.info("Initial data created")
36+
logging.info("Initial data created")
3737

3838

3939
if __name__ == "__main__":
40-
main()
40+
get_event_loop().run_until_complete(main())
Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
from sqlalchemy import create_engine
2-
from sqlalchemy.orm import sessionmaker
1+
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
2+
from sqlalchemy.orm.session import sessionmaker
3+
34

45
from app.core.config import settings
56

6-
engine = create_engine(settings.SQLALCHEMY_DATABASE_URI, pool_pre_ping=True)
7-
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
7+
async_engine = create_async_engine(settings.SQLALCHEMY_DATABASE_URI, pool_pre_ping=True)
8+
9+
AsyncSessionLocal = AsyncSession(async_engine, expire_on_commit=False)
10+
11+
async_session = sessionmaker(async_engine, expire_on_commit=False, class_=AsyncSession)

0 commit comments

Comments
 (0)