Skip to content

Commit a85232c

Browse files
author
tretyak-rd
committed
Use small session pool in each connection + shared support
1 parent a0294fc commit a85232c

File tree

4 files changed

+151
-72
lines changed

4 files changed

+151
-72
lines changed

test/test_core.py

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import sqlalchemy as sa
77
import ydb
88
from sqlalchemy import Table, Column, Integer, Unicode
9-
from sqlalchemy.testing.fixtures import TestBase, TablesTest
9+
from sqlalchemy.testing.fixtures import TestBase, TablesTest, config
1010
from ydb._grpc.v4.protos import ydb_common_pb2
1111

1212
from ydb_sqlalchemy import dbapi, IsolationLevel
@@ -220,8 +220,9 @@ def _create_table_and_get_desc(connection, metadata, **kwargs):
220220
)
221221
table.create(connection)
222222

223-
session: ydb.Session = connection.connection.driver_connection.session
223+
session: ydb.Session = connection.connection.driver_connection.session_pool.acquire()
224224
table_description = session.describe_table("/local/" + table.name)
225+
connection.connection.driver_connection.session_pool.release(session)
225226
return table_description
226227

227228
@pytest.mark.parametrize(
@@ -419,11 +420,11 @@ def test_interactive_transaction(
419420

420421
connection_no_trans.execution_options(isolation_level=isolation_level)
421422
with connection_no_trans.begin():
422-
tx_id = dbapi_connection.transaction.tx_id
423+
tx_id = dbapi_connection.tx_context.tx_id
423424
assert tx_id is not None
424425
cursor1 = connection_no_trans.execute(sa.select(table))
425426
cursor2 = connection_no_trans.execute(sa.select(table))
426-
assert dbapi_connection.transaction.tx_id == tx_id
427+
assert dbapi_connection.tx_context.tx_id == tx_id
427428

428429
assert set(cursor1.fetchall()) == {(5,), (6,)}
429430
assert set(cursor2.fetchall()) == {(5,), (6,)}
@@ -448,10 +449,10 @@ def test_not_interactive_transaction(
448449

449450
connection_no_trans.execution_options(isolation_level=isolation_level)
450451
with connection_no_trans.begin():
451-
assert dbapi_connection.transaction is None
452+
assert dbapi_connection.tx_context is None
452453
cursor1 = connection_no_trans.execute(sa.select(table))
453454
cursor2 = connection_no_trans.execute(sa.select(table))
454-
assert dbapi_connection.transaction is None
455+
assert dbapi_connection.tx_context is None
455456

456457
assert set(cursor1.fetchall()) == {(7,), (8,)}
457458
assert set(cursor2.fetchall()) == {(7,), (8,)}
@@ -482,7 +483,59 @@ def test_connection_set(self, connection_no_trans: sa.Connection):
482483
assert dbapi_connection.tx_mode.name == ydb_isolation_settings[0]
483484
assert dbapi_connection.interactive_transaction is ydb_isolation_settings[1]
484485
if dbapi_connection.interactive_transaction:
485-
assert dbapi_connection.transaction is not None
486-
assert dbapi_connection.transaction.tx_id is not None
486+
assert dbapi_connection.tx_context is not None
487+
assert dbapi_connection.tx_context.tx_id is not None
487488
else:
488-
assert dbapi_connection.transaction is None
489+
assert dbapi_connection.tx_context is None
490+
491+
492+
class TestEngine(TestBase):
493+
@pytest.fixture(scope="module")
494+
def ydb_driver(self):
495+
url = config.db_url
496+
driver = ydb.Driver(endpoint=f"grpc://{url.host}:{url.port}", database=url.database)
497+
try:
498+
driver.wait(timeout=5, fail_fast=True)
499+
yield driver
500+
finally:
501+
driver.stop()
502+
503+
driver.stop()
504+
505+
@pytest.fixture(scope="module")
506+
def ydb_pool(self, ydb_driver):
507+
session_pool = ydb.SessionPool(ydb_driver, size=5, workers_threads_count=1)
508+
509+
yield session_pool
510+
511+
session_pool.stop()
512+
513+
def test_sa_queue_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
514+
engine1 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool})
515+
engine2 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool})
516+
517+
with engine1.connect() as conn1, engine2.connect() as conn2:
518+
dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection
519+
dbapi_conn2: dbapi.Connection = conn2.connection.dbapi_connection
520+
521+
assert dbapi_conn1.session_pool is dbapi_conn2.session_pool
522+
assert dbapi_conn1.driver is dbapi_conn2.driver
523+
524+
engine1.dispose()
525+
engine2.dispose()
526+
assert not ydb_driver._stopped
527+
528+
def test_sa_null_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
529+
engine1 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool})
530+
engine2 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool})
531+
532+
with engine1.connect() as conn1, engine2.connect() as conn2:
533+
dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection
534+
dbapi_conn2: dbapi.Connection = conn2.connection.dbapi_connection
535+
536+
assert dbapi_conn1.session_pool is dbapi_conn2.session_pool
537+
assert dbapi_conn1.driver is dbapi_conn2.driver
538+
539+
engine1.dispose()
540+
engine2.dispose()
541+
assert not ydb_driver._stopped

test_dbapi/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import pytest
2+
23
import ydb_sqlalchemy.dbapi as dbapi
34

45

56
@pytest.fixture(scope="module")
67
def connection():
7-
conn = dbapi.connect("localhost:2136", database="/local")
8+
conn = dbapi.connect(host="localhost", port="2136", database="/local")
89
yield conn
910
conn.close()

ydb_sqlalchemy/dbapi/connection.py

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import posixpath
2-
from typing import Optional, NamedTuple
2+
from typing import Optional, NamedTuple, Any
33

44
import ydb
55

@@ -17,23 +17,38 @@ class IsolationLevel:
1717

1818

1919
class Connection:
20-
def __init__(self, endpoint=None, host=None, port=None, database=None, **conn_kwargs):
21-
self.endpoint = endpoint or f"grpc://{host}:{port}"
20+
def __init__(
21+
self,
22+
host: str = "",
23+
port: str = "",
24+
database: str = "",
25+
**conn_kwargs: Any,
26+
):
27+
self.endpoint = f"grpc://{host}:{port}"
2228
self.database = database
23-
self.table_client_settings = self._get_table_client_settings()
24-
self.driver = self._create_driver(**conn_kwargs)
25-
self.session = self._create_session()
29+
self.conn_kwargs = conn_kwargs
30+
31+
if "ydb_session_pool" in self.conn_kwargs: # Use session pool managed manually
32+
self._shared_session_pool = True
33+
self.session_pool: ydb.SessionPool = self.conn_kwargs.pop("ydb_session_pool")
34+
self.driver = self.session_pool._pool_impl._driver
35+
self.driver.table_client = ydb.TableClient(self.driver, self._get_table_client_settings())
36+
else:
37+
self._shared_session_pool = False
38+
self.driver = self._create_driver()
39+
self.session_pool = ydb.SessionPool(self.driver, size=5, workers_threads_count=1)
40+
2641
self.interactive_transaction: bool = False # AUTOCOMMIT
2742
self.tx_mode: ydb.AbstractTransactionModeBuilder = ydb.SerializableReadWrite()
28-
self.transaction: Optional[ydb.TxContext] = None
43+
self.tx_context: Optional[ydb.TxContext] = None
2944

3045
def cursor(self):
31-
return Cursor(self, transaction=self.transaction)
46+
return Cursor(self.session_pool, self.tx_context)
3247

3348
def describe(self, table_path):
3449
full_path = posixpath.join(self.database, table_path)
3550
try:
36-
return ydb.retry_operation_sync(lambda: self.session.describe_table(full_path))
51+
return self.session_pool.retry_operation_sync(lambda session: session.describe_table(full_path))
3752
except ydb.issues.SchemeError as e:
3853
raise ProgrammingError(e.message, e.issues, e.status) from e
3954
except ydb.Error as e:
@@ -64,7 +79,7 @@ class IsolationSettings(NamedTuple):
6479
IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.SnapshotReadOnly(), interactive=True),
6580
}
6681
ydb_isolation_settings = ydb_isolation_settings_map[isolation_level]
67-
if self.transaction and self.transaction.tx_id:
82+
if self.tx_context and self.tx_context.tx_id:
6883
raise InternalError("Failed to set transaction mode: transaction is already began")
6984
self.tx_mode = ydb_isolation_settings.ydb_mode
7085
self.interactive_transaction = ydb_isolation_settings.interactive
@@ -88,27 +103,31 @@ def get_isolation_level(self) -> str:
88103
raise NotSupportedError(f"{self.tx_mode.name} is not supported")
89104

90105
def begin(self):
91-
if not self.session.initialized():
92-
raise InternalError("Failed to begin transaction: session closed")
93-
self.transaction = None
106+
self.tx_context = None
94107
if self.interactive_transaction:
95-
self.transaction = self.session.transaction(self.tx_mode)
96-
self.transaction.begin()
108+
session = self.session_pool.acquire(blocking=True)
109+
self.tx_context = session.transaction(self.tx_mode)
110+
self.tx_context.begin()
97111

98112
def commit(self):
99-
if self.transaction and self.transaction.tx_id:
100-
self.transaction.commit()
113+
if self.tx_context and self.tx_context.tx_id:
114+
self.tx_context.commit()
115+
self.session_pool.release(self.tx_context.session)
116+
self.tx_context = None
101117

102118
def rollback(self):
103-
if self.transaction and self.transaction.tx_id:
104-
self.transaction.rollback()
119+
if self.tx_context and self.tx_context.tx_id:
120+
self.tx_context.rollback()
121+
self.session_pool.release(self.tx_context.session)
122+
self.tx_context = None
105123

106124
def close(self):
107-
self._delete_session()
108-
self._stop_driver()
125+
self.rollback()
126+
if not self._shared_session_pool:
127+
self.session_pool.stop()
128+
self._stop_driver()
109129

110-
@staticmethod
111-
def _get_table_client_settings() -> ydb.TableClientSettings:
130+
def _get_table_client_settings(self) -> ydb.TableClientSettings:
112131
return (
113132
ydb.TableClientSettings()
114133
.with_native_date_in_result_sets(True)
@@ -118,13 +137,11 @@ def _get_table_client_settings() -> ydb.TableClientSettings:
118137
.with_native_json_in_result_sets(True)
119138
)
120139

121-
def _create_driver(self, **conn_kwargs):
122-
# TODO: add cache for initialized drivers/pools?
140+
def _create_driver(self):
123141
driver_config = ydb.DriverConfig(
124142
endpoint=self.endpoint,
125143
database=self.database,
126-
table_client_settings=self.table_client_settings,
127-
**conn_kwargs,
144+
table_client_settings=self._get_table_client_settings(),
128145
)
129146
driver = ydb.Driver(driver_config)
130147
try:
@@ -138,13 +155,3 @@ def _create_driver(self, **conn_kwargs):
138155

139156
def _stop_driver(self):
140157
self.driver.stop()
141-
142-
def _create_session(self) -> ydb.BaseSession:
143-
session = ydb.Session(self.driver, self.table_client_settings)
144-
session.create()
145-
return session
146-
147-
def _delete_session(self):
148-
if self.session.initialized():
149-
self.rollback()
150-
self.session.delete()

ydb_sqlalchemy/dbapi/cursor.py

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import dataclasses
22
import itertools
33
import logging
4-
5-
from typing import Any, Mapping, Optional, Sequence, Union, Dict
4+
from typing import Any, Mapping, Optional, Sequence, Union, Dict, Callable
65

76
import ydb
7+
88
from .errors import (
99
InternalError,
1010
IntegrityError,
@@ -15,7 +15,6 @@
1515
NotSupportedError,
1616
)
1717

18-
1918
logger = logging.getLogger(__name__)
2019

2120

@@ -33,10 +32,13 @@ class YdbQuery:
3332

3433

3534
class Cursor(object):
36-
def __init__(self, connection, transaction: Optional[ydb.BaseTxContext] = None):
37-
self.connection = connection
38-
self.session: ydb.Session = self.connection.session
39-
self.transaction = transaction
35+
def __init__(
36+
self,
37+
session_pool: ydb.SessionPool,
38+
tx_context: Optional[ydb.BaseTxContext] = None,
39+
):
40+
self.session_pool = session_pool
41+
self.tx_context = tx_context
4042
self.description = None
4143
self.arraysize = 1
4244
self.rows = None
@@ -50,7 +52,15 @@ def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] =
5052
query = ydb.DataQuery(operation.yql_text, operation.parameters_types)
5153
is_ddl = operation.is_ddl
5254

53-
chunks = self._execute(query, parameters, is_ddl)
55+
logger.info("execute sql: %s, params: %s", query, parameters)
56+
if is_ddl:
57+
chunks = self.session_pool.retry_operation_sync(self._execute_ddl, None, query)
58+
else:
59+
if self.tx_context:
60+
chunks = self._execute_dml(self.tx_context.session, query, parameters, self.tx_context)
61+
else:
62+
chunks = self.session_pool.retry_operation_sync(self._execute_dml, None, query, parameters)
63+
5464
rows = self._rows_iterable(chunks)
5565
# Prefetch the description:
5666
try:
@@ -64,23 +74,31 @@ def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] =
6474

6575
self.rows = rows
6676

67-
def _execute(self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]], is_ddl: bool):
68-
self.description = None
69-
logger.info("execute sql: %s, params: %s", query, parameters)
77+
@classmethod
78+
def _execute_dml(
79+
cls,
80+
session: ydb.Session,
81+
query: ydb.DataQuery,
82+
parameters: Optional[Mapping[str, Any]] = None,
83+
tx_context: Optional[ydb.BaseTxContext] = None,
84+
) -> ydb.convert.ResultSets:
85+
prepared_query = query
86+
if isinstance(query, str) and parameters:
87+
prepared_query = session.prepare(query)
88+
89+
if tx_context:
90+
return cls._handle_ydb_errors(tx_context.execute, prepared_query, parameters)
91+
92+
return cls._handle_ydb_errors(session.transaction().execute, prepared_query, parameters, commit_tx=True)
93+
94+
@classmethod
95+
def _execute_ddl(cls, session: ydb.Session, query: str) -> ydb.convert.ResultSets:
96+
return cls._handle_ydb_errors(session.execute_scheme, query)
97+
98+
@staticmethod
99+
def _handle_ydb_errors(callee: Callable, *args, **kwargs) -> Any:
70100
try:
71-
if is_ddl:
72-
return ydb.retry_operation_sync(lambda: self.session.execute_scheme(query))
73-
74-
prepared_query = query
75-
if isinstance(query, str) and parameters:
76-
prepared_query = self.session.prepare(query)
77-
78-
if not self.transaction:
79-
return ydb.retry_operation_sync(
80-
lambda: self.session.transaction().execute(prepared_query, parameters, commit_tx=True)
81-
)
82-
else:
83-
return self.transaction.execute(prepared_query, parameters)
101+
return callee(*args, **kwargs)
84102
except (ydb.issues.AlreadyExists, ydb.issues.PreconditionFailed) as e:
85103
raise IntegrityError(e.message, e.issues, e.status) from e
86104
except (ydb.issues.Unsupported, ydb.issues.Unimplemented) as e:
@@ -108,7 +126,7 @@ def _execute(self, query: Union[ydb.DataQuery, str], parameters: Optional[Mappin
108126
except ydb.Error as e:
109127
raise DatabaseError(e.message, e.issues, e.status) from e
110128

111-
def _rows_iterable(self, chunks_iterable):
129+
def _rows_iterable(self, chunks_iterable: ydb.convert.ResultSets):
112130
try:
113131
for chunk in chunks_iterable:
114132
self.description = [

0 commit comments

Comments
 (0)