Skip to content

Commit 5b3dc69

Browse files
committed
Implement AsyncCursor
1 parent 9c19d7a commit 5b3dc69

File tree

2 files changed

+119
-52
lines changed

2 files changed

+119
-52
lines changed

ydb_sqlalchemy/dbapi/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .connection import Connection, IsolationLevel # noqa: F401
2-
from .cursor import Cursor, YdbQuery # noqa: F401
2+
from .cursor import Cursor, AsyncCursor, YdbQuery # noqa: F401
33
from .errors import (
44
Warning,
55
Error,
@@ -13,13 +13,14 @@
1313
NotSupportedError,
1414
)
1515

16+
1617
class YdbDBApi:
1718
def __init__(self):
1819
self.paramstyle = "pyformat"
1920
self.threadsafety = 0
2021
self.apilevel = "1.0"
2122
self._init_dbapi_attributes()
22-
23+
2324
def _init_dbapi_attributes(self):
2425
for name, value in {
2526
"Warning": Warning,

ydb_sqlalchemy/dbapi/cursor.py

Lines changed: 116 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import dataclasses
22
import itertools
33
import logging
4-
from typing import Any, Mapping, Optional, Sequence, Union, Dict, Callable
4+
import functools
5+
from typing import Any, Mapping, Optional, Sequence, Union, Dict
6+
import collections.abc
7+
from sqlalchemy import util
58

69
import ydb
10+
import ydb.aio
711

812
from .errors import (
913
InternalError,
@@ -31,10 +35,45 @@ class YdbQuery:
3135
is_ddl: bool = False
3236

3337

34-
class Cursor(object):
38+
def _handle_ydb_errors(func):
39+
@functools.wraps(func)
40+
def wrapper(self, *args, **kwargs):
41+
try:
42+
return func(self, *args, **kwargs)
43+
except (ydb.issues.AlreadyExists, ydb.issues.PreconditionFailed) as e:
44+
raise IntegrityError(e.message, e.issues, e.status) from e
45+
except (ydb.issues.Unsupported, ydb.issues.Unimplemented) as e:
46+
raise NotSupportedError(e.message, e.issues, e.status) from e
47+
except (ydb.issues.BadRequest, ydb.issues.SchemeError) as e:
48+
raise ProgrammingError(e.message, e.issues, e.status) from e
49+
except (
50+
ydb.issues.TruncatedResponseError,
51+
ydb.issues.ConnectionError,
52+
ydb.issues.Aborted,
53+
ydb.issues.Unavailable,
54+
ydb.issues.Overloaded,
55+
ydb.issues.Undetermined,
56+
ydb.issues.Timeout,
57+
ydb.issues.Cancelled,
58+
ydb.issues.SessionBusy,
59+
ydb.issues.SessionExpired,
60+
ydb.issues.SessionPoolEmpty,
61+
) as e:
62+
raise OperationalError(e.message, e.issues, e.status) from e
63+
except ydb.issues.GenericError as e:
64+
raise DataError(e.message, e.issues, e.status) from e
65+
except ydb.issues.InternalError as e:
66+
raise InternalError(e.message, e.issues, e.status) from e
67+
except ydb.Error as e:
68+
raise DatabaseError(e.message, e.issues, e.status) from e
69+
70+
return wrapper
71+
72+
73+
class Cursor:
3574
def __init__(
3675
self,
37-
session_pool: ydb.SessionPool,
76+
session_pool: Union[ydb.SessionPool, ydb.aio.SessionPool],
3877
tx_context: Optional[ydb.BaseTxContext] = None,
3978
):
4079
self.session_pool = session_pool
@@ -54,12 +93,9 @@ def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] =
5493

5594
logger.info("execute sql: %s, params: %s", query, parameters)
5695
if is_ddl:
57-
chunks = self.session_pool.retry_operation_sync(self._execute_ddl, None, query)
96+
chunks = self._execute_ddl(query)
5897
else:
59-
if self.tx_context:
60-
chunks = self._execute_dml(self.tx_context.session, query, parameters, self.tx_context)
61-
else:
62-
chunks = self.session_pool.retry_operation_sync(self._execute_dml, None, query, parameters)
98+
chunks = self._execute_dml(query, parameters)
6399

64100
rows = self._rows_iterable(chunks)
65101
# Prefetch the description:
@@ -74,57 +110,54 @@ def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] =
74110

75111
self.rows = rows
76112

77-
@classmethod
113+
@_handle_ydb_errors
78114
def _execute_dml(
79-
cls,
80-
session: ydb.Session,
81-
query: ydb.DataQuery,
82-
parameters: Optional[Mapping[str, Any]] = None,
83-
tx_context: Optional[ydb.BaseTxContext] = None,
115+
self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None
84116
) -> ydb.convert.ResultSets:
85117
prepared_query = query
86118
if isinstance(query, str) and parameters:
87-
prepared_query = session.prepare(query)
119+
if self.tx_context:
120+
prepared_query = self._run_operation_in_session(self._prepare, query)
121+
else:
122+
prepared_query = self._retry_operation_in_pool(self._prepare, query)
88123

89-
if tx_context:
90-
return cls._handle_ydb_errors(tx_context.execute, prepared_query, parameters)
124+
if self.tx_context:
125+
return self._run_operation_in_tx(self._execute_in_tx, prepared_query, parameters)
91126

92-
return cls._handle_ydb_errors(session.transaction().execute, prepared_query, parameters, commit_tx=True)
127+
return self._retry_operation_in_pool(self._execute_in_session, prepared_query, parameters)
93128

94-
@classmethod
95-
def _execute_ddl(cls, session: ydb.Session, query: str) -> ydb.convert.ResultSets:
96-
return cls._handle_ydb_errors(session.execute_scheme, query)
129+
@_handle_ydb_errors
130+
def _execute_ddl(self, query: str) -> ydb.convert.ResultSets:
131+
return self._retry_operation_in_pool(self._execute_scheme, query)
97132

98133
@staticmethod
99-
def _handle_ydb_errors(callee: Callable, *args, **kwargs) -> Any:
100-
try:
101-
return callee(*args, **kwargs)
102-
except (ydb.issues.AlreadyExists, ydb.issues.PreconditionFailed) as e:
103-
raise IntegrityError(e.message, e.issues, e.status) from e
104-
except (ydb.issues.Unsupported, ydb.issues.Unimplemented) as e:
105-
raise NotSupportedError(e.message, e.issues, e.status) from e
106-
except (ydb.issues.BadRequest, ydb.issues.SchemeError) as e:
107-
raise ProgrammingError(e.message, e.issues, e.status) from e
108-
except (
109-
ydb.issues.TruncatedResponseError,
110-
ydb.issues.ConnectionError,
111-
ydb.issues.Aborted,
112-
ydb.issues.Unavailable,
113-
ydb.issues.Overloaded,
114-
ydb.issues.Undetermined,
115-
ydb.issues.Timeout,
116-
ydb.issues.Cancelled,
117-
ydb.issues.SessionBusy,
118-
ydb.issues.SessionExpired,
119-
ydb.issues.SessionPoolEmpty,
120-
) as e:
121-
raise OperationalError(e.message, e.issues, e.status) from e
122-
except ydb.issues.GenericError as e:
123-
raise DataError(e.message, e.issues, e.status) from e
124-
except ydb.issues.InternalError as e:
125-
raise InternalError(e.message, e.issues, e.status) from e
126-
except ydb.Error as e:
127-
raise DatabaseError(e.message, e.issues, e.status) from e
134+
def _execute_scheme(session: ydb.Session, query: str) -> ydb.convert.ResultSets:
135+
return session.execute_scheme(query)
136+
137+
@staticmethod
138+
def _prepare(session: ydb.Session, query: str) -> ydb.DataQuery:
139+
return session.prepare(query)
140+
141+
@staticmethod
142+
def _execute_in_tx(
143+
tx_context: ydb.TxContext, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]]
144+
) -> ydb.convert.ResultSets:
145+
return tx_context.execute(prepared_query, parameters, commit_tx=False)
146+
147+
@staticmethod
148+
def _execute_in_session(
149+
session: ydb.Session, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]]
150+
) -> ydb.convert.ResultSets:
151+
return session.transaction().execute(prepared_query, parameters, commit_tx=True)
152+
153+
def _run_operation_in_tx(self, callee: collections.abc.Callable, *args, **kwargs):
154+
return callee(self.tx_context, *args, **kwargs)
155+
156+
def _run_operation_in_session(self, callee: collections.abc.Callable, *args, **kwargs):
157+
return callee(self.tx_context.session, *args, **kwargs)
158+
159+
def _retry_operation_in_pool(self, callee: collections.abc.Callable, *args, **kwargs):
160+
return self.session_pool.retry_operation_sync(callee, None, *args, **kwargs)
128161

129162
def _rows_iterable(self, chunks_iterable: ydb.convert.ResultSets):
130163
try:
@@ -186,3 +219,36 @@ def close(self):
186219
@property
187220
def rowcount(self):
188221
return len(self._ensure_prefetched())
222+
223+
224+
class AsyncCursor(Cursor):
225+
_await = staticmethod(util.await_only)
226+
227+
@staticmethod
228+
async def _execute_scheme(session: ydb.aio.table.Session, query: str) -> ydb.convert.ResultSets:
229+
return await session.execute_scheme(query)
230+
231+
@staticmethod
232+
async def _prepare(session: ydb.aio.table.Session, query: str) -> ydb.DataQuery:
233+
return await session.prepare(query)
234+
235+
@staticmethod
236+
async def _execute_in_tx(
237+
tx_context: ydb.aio.table.TxContext, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]]
238+
) -> ydb.convert.ResultSets:
239+
return await tx_context.execute(prepared_query, parameters, commit_tx=False)
240+
241+
@staticmethod
242+
async def _execute_in_session(
243+
session: ydb.aio.table.Session, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]]
244+
) -> ydb.convert.ResultSets:
245+
return await session.transaction().execute(prepared_query, parameters, commit_tx=True)
246+
247+
def _run_operation_in_tx(self, callee: collections.abc.Coroutine, *args, **kwargs):
248+
return self._await(callee(self.tx_context, *args, **kwargs))
249+
250+
def _run_operation_in_session(self, callee: collections.abc.Coroutine, *args, **kwargs):
251+
return self._await(callee(self.tx_context.session, *args, **kwargs))
252+
253+
def _retry_operation_in_pool(self, callee: collections.abc.Coroutine, *args, **kwargs):
254+
return self._await(self.session_pool.retry_operation(callee, None, *args, **kwargs))

0 commit comments

Comments
 (0)