Skip to content

Commit 4170e5f

Browse files
committed
Provide tx_mode to one-time transactions
1 parent 5c4d6ca commit 4170e5f

File tree

3 files changed

+72
-9
lines changed

3 files changed

+72
-9
lines changed

test_dbapi/test_dbapi.py

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,28 @@
99

1010

1111
class BaseDBApiTestSuit:
12+
def _test_isolation_level_read_only(self, connection: dbapi.Connection, isolation_level: str, read_only: bool):
13+
connection.cursor().execute(
14+
dbapi.YdbQuery("CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))", is_ddl=True)
15+
)
16+
connection.set_isolation_level(isolation_level)
17+
18+
cursor = connection.cursor()
19+
20+
connection.begin()
21+
22+
query = dbapi.YdbQuery("UPSERT INTO foo(id) VALUES (1)")
23+
if read_only:
24+
with pytest.raises(dbapi.DatabaseError):
25+
cursor.execute(query)
26+
else:
27+
cursor.execute(query)
28+
29+
connection.rollback()
30+
31+
connection.cursor().execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True))
32+
connection.cursor().close()
33+
1234
def _test_connection(self, connection: dbapi.Connection):
1335
connection.commit()
1436
connection.rollback()
@@ -100,14 +122,28 @@ def _test_errors(self, connection: dbapi.Connection):
100122

101123

102124
class TestSyncConnection(BaseDBApiTestSuit):
103-
@pytest.fixture(scope="class")
125+
@pytest.fixture
104126
def sync_connection(self) -> dbapi.Connection:
105127
conn = dbapi.YdbDBApi().connect(host="localhost", port="2136", database="/local")
106128
try:
107129
yield conn
108130
finally:
109131
conn.close()
110132

133+
@pytest.mark.parametrize(
134+
"isolation_level, read_only",
135+
[
136+
(dbapi.IsolationLevel.SERIALIZABLE, False),
137+
(dbapi.IsolationLevel.AUTOCOMMIT, False),
138+
(dbapi.IsolationLevel.ONLINE_READONLY, True),
139+
(dbapi.IsolationLevel.ONLINE_READONLY_INCONSISTENT, True),
140+
(dbapi.IsolationLevel.STALE_READONLY, True),
141+
(dbapi.IsolationLevel.SNAPSHOT_READONLY, True),
142+
],
143+
)
144+
def test_isolation_level_read_only(self, isolation_level: str, read_only: bool, sync_connection: dbapi.Connection):
145+
self._test_isolation_level_read_only(sync_connection, isolation_level, read_only)
146+
111147
def test_connection(self, sync_connection: dbapi.Connection):
112148
self._test_connection(sync_connection)
113149

@@ -118,9 +154,8 @@ def test_errors(self, sync_connection: dbapi.Connection):
118154
return self._test_errors(sync_connection)
119155

120156

121-
@pytest.mark.asyncio(scope="class")
122157
class TestAsyncConnection(BaseDBApiTestSuit):
123-
@pytest_asyncio.fixture(scope="class")
158+
@pytest_asyncio.fixture
124159
async def async_connection(self) -> dbapi.AsyncConnection:
125160
def connect():
126161
return dbapi.YdbDBApi().async_connect(host="localhost", port="2136", database="/local")
@@ -131,11 +166,31 @@ def connect():
131166
finally:
132167
await util.greenlet_spawn(conn.close)
133168

169+
@pytest.mark.asyncio
170+
@pytest.mark.parametrize(
171+
"isolation_level, read_only",
172+
[
173+
(dbapi.IsolationLevel.SERIALIZABLE, False),
174+
(dbapi.IsolationLevel.AUTOCOMMIT, False),
175+
(dbapi.IsolationLevel.ONLINE_READONLY, True),
176+
(dbapi.IsolationLevel.ONLINE_READONLY_INCONSISTENT, True),
177+
(dbapi.IsolationLevel.STALE_READONLY, True),
178+
(dbapi.IsolationLevel.SNAPSHOT_READONLY, True),
179+
],
180+
)
181+
async def test_isolation_level_read_only(
182+
self, isolation_level: str, read_only: bool, async_connection: dbapi.AsyncConnection
183+
):
184+
await util.greenlet_spawn(self._test_isolation_level_read_only, async_connection, isolation_level, read_only)
185+
186+
@pytest.mark.asyncio
134187
async def test_connection(self, async_connection: dbapi.AsyncConnection):
135188
await util.greenlet_spawn(self._test_connection, async_connection)
136189

190+
@pytest.mark.asyncio
137191
async def test_cursor_raw_query(self, async_connection: dbapi.AsyncConnection):
138192
await util.greenlet_spawn(self._test_cursor_raw_query, async_connection)
139193

194+
@pytest.mark.asyncio
140195
async def test_errors(self, async_connection: dbapi.AsyncConnection):
141196
await util.greenlet_spawn(self._test_errors, async_connection)

ydb_sqlalchemy/dbapi/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757
self.tx_context: Optional[ydb.TxContext] = None
5858

5959
def cursor(self):
60-
return self._cursor_class(self.session_pool, self.tx_context)
60+
return self._cursor_class(self.session_pool, self.tx_mode, self.tx_context)
6161

6262
def describe(self, table_path: str) -> ydb.TableDescription:
6363
abs_table_path = posixpath.join(self.database, table_path)

ydb_sqlalchemy/dbapi/cursor.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@ class Cursor:
7676
def __init__(
7777
self,
7878
session_pool: Union[ydb.SessionPool, ydb.aio.SessionPool],
79+
tx_mode: ydb.AbstractTransactionModeBuilder,
7980
tx_context: Optional[ydb.BaseTxContext] = None,
8081
):
8182
self.session_pool = session_pool
83+
self.tx_mode = tx_mode
8284
self.tx_context = tx_context
8385
self.description = None
8486
self.arraysize = 1
@@ -142,7 +144,7 @@ def _execute_dml(
142144
if self.tx_context:
143145
return self._run_operation_in_tx(self._execute_in_tx, prepared_query, parameters)
144146

145-
return self._retry_operation_in_pool(self._execute_in_session, prepared_query, parameters)
147+
return self._retry_operation_in_pool(self._execute_in_session, self.tx_mode, prepared_query, parameters)
146148

147149
@_handle_ydb_errors
148150
def _execute_ddl(self, query: str) -> ydb.convert.ResultSets:
@@ -176,9 +178,12 @@ def _execute_in_tx(
176178

177179
@staticmethod
178180
def _execute_in_session(
179-
session: ydb.Session, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]]
181+
session: ydb.Session,
182+
tx_mode: ydb.AbstractTransactionModeBuilder,
183+
prepared_query: ydb.DataQuery,
184+
parameters: Optional[Mapping[str, Any]],
180185
) -> ydb.convert.ResultSets:
181-
return session.transaction().execute(prepared_query, parameters, commit_tx=True)
186+
return session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True)
182187

183188
def _run_operation_in_tx(self, callee: collections.abc.Callable, *args, **kwargs):
184189
return callee(self.tx_context, *args, **kwargs)
@@ -282,9 +287,12 @@ async def _execute_in_tx(
282287

283288
@staticmethod
284289
async def _execute_in_session(
285-
session: ydb.aio.table.Session, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]]
290+
session: ydb.aio.table.Session,
291+
tx_mode: ydb.AbstractTransactionModeBuilder,
292+
prepared_query: ydb.DataQuery,
293+
parameters: Optional[Mapping[str, Any]],
286294
) -> ydb.convert.ResultSets:
287-
return await session.transaction().execute(prepared_query, parameters, commit_tx=True)
295+
return await session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True)
288296

289297
def _run_operation_in_tx(self, callee: collections.abc.Coroutine, *args, **kwargs):
290298
return self._await(callee(self.tx_context, *args, **kwargs))

0 commit comments

Comments
 (0)