|
1 | | -from decimal import Decimal |
2 | 1 | from datetime import date, datetime |
| 2 | +from decimal import Decimal |
| 3 | +from typing import NamedTuple |
3 | 4 |
|
4 | 5 | import pytest |
| 6 | + |
5 | 7 | import sqlalchemy as sa |
6 | 8 | from sqlalchemy import Table, Column, Integer, Unicode, String |
7 | | -from sqlalchemy.testing.fixtures import TestBase, TablesTest |
| 9 | +from sqlalchemy.testing.fixtures import TestBase, TablesTest, config |
8 | 10 |
|
9 | 11 | import ydb |
10 | 12 | from ydb._grpc.v4.protos import ydb_common_pb2 |
11 | 13 |
|
| 14 | +from ydb_sqlalchemy import dbapi, IsolationLevel |
12 | 15 | from ydb_sqlalchemy.sqlalchemy import types |
13 | | - |
14 | 16 | from ydb_sqlalchemy import sqlalchemy as ydb_sa |
15 | 17 |
|
16 | 18 |
|
@@ -221,9 +223,9 @@ def _create_table_and_get_desc(connection, metadata, **kwargs): |
221 | 223 | ) |
222 | 224 | table.create(connection) |
223 | 225 |
|
224 | | - session: ydb.Session = connection.connection.driver_connection.pool.acquire() |
| 226 | + session: ydb.Session = connection.connection.driver_connection.session_pool.acquire() |
225 | 227 | table_description = session.describe_table("/local/" + table.name) |
226 | | - session.delete() |
| 228 | + connection.connection.driver_connection.session_pool.release(session) |
227 | 229 | return table_description |
228 | 230 |
|
229 | 231 | @pytest.mark.parametrize( |
@@ -371,6 +373,177 @@ def test_several_keys(self, connection, metadata): |
371 | 373 | assert desc.partitioning_settings.max_partitions_count == 5 |
372 | 374 |
|
373 | 375 |
|
| 376 | +class TestTransaction(TablesTest): |
| 377 | + @classmethod |
| 378 | + def define_tables(cls, metadata: sa.MetaData): |
| 379 | + Table( |
| 380 | + "test", |
| 381 | + metadata, |
| 382 | + Column("id", Integer, primary_key=True), |
| 383 | + ) |
| 384 | + |
| 385 | + def test_rollback(self, connection_no_trans: sa.Connection, connection: sa.Connection): |
| 386 | + table = self.tables.test |
| 387 | + |
| 388 | + connection_no_trans.execution_options(isolation_level=IsolationLevel.SERIALIZABLE) |
| 389 | + with connection_no_trans.begin(): |
| 390 | + stm1 = table.insert().values(id=1) |
| 391 | + connection_no_trans.execute(stm1) |
| 392 | + stm2 = table.insert().values(id=2) |
| 393 | + connection_no_trans.execute(stm2) |
| 394 | + connection_no_trans.rollback() |
| 395 | + |
| 396 | + cursor = connection.execute(sa.select(table)) |
| 397 | + result = cursor.fetchall() |
| 398 | + assert result == [] |
| 399 | + |
| 400 | + def test_commit(self, connection_no_trans: sa.Connection, connection: sa.Connection): |
| 401 | + table = self.tables.test |
| 402 | + |
| 403 | + connection_no_trans.execution_options(isolation_level=IsolationLevel.SERIALIZABLE) |
| 404 | + with connection_no_trans.begin(): |
| 405 | + stm1 = table.insert().values(id=3) |
| 406 | + connection_no_trans.execute(stm1) |
| 407 | + stm2 = table.insert().values(id=4) |
| 408 | + connection_no_trans.execute(stm2) |
| 409 | + |
| 410 | + cursor = connection.execute(sa.select(table)) |
| 411 | + result = cursor.fetchall() |
| 412 | + assert set(result) == {(3,), (4,)} |
| 413 | + |
| 414 | + @pytest.mark.parametrize("isolation_level", (IsolationLevel.SERIALIZABLE, IsolationLevel.SNAPSHOT_READONLY)) |
| 415 | + def test_interactive_transaction( |
| 416 | + self, connection_no_trans: sa.Connection, connection: sa.Connection, isolation_level |
| 417 | + ): |
| 418 | + table = self.tables.test |
| 419 | + dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection |
| 420 | + |
| 421 | + stm1 = table.insert().values([{"id": 5}, {"id": 6}]) |
| 422 | + connection.execute(stm1) |
| 423 | + |
| 424 | + connection_no_trans.execution_options(isolation_level=isolation_level) |
| 425 | + with connection_no_trans.begin(): |
| 426 | + tx_id = dbapi_connection.tx_context.tx_id |
| 427 | + assert tx_id is not None |
| 428 | + cursor1 = connection_no_trans.execute(sa.select(table)) |
| 429 | + cursor2 = connection_no_trans.execute(sa.select(table)) |
| 430 | + assert dbapi_connection.tx_context.tx_id == tx_id |
| 431 | + |
| 432 | + assert set(cursor1.fetchall()) == {(5,), (6,)} |
| 433 | + assert set(cursor2.fetchall()) == {(5,), (6,)} |
| 434 | + |
| 435 | + @pytest.mark.parametrize( |
| 436 | + "isolation_level", |
| 437 | + ( |
| 438 | + IsolationLevel.ONLINE_READONLY, |
| 439 | + IsolationLevel.ONLINE_READONLY_INCONSISTENT, |
| 440 | + IsolationLevel.STALE_READONLY, |
| 441 | + IsolationLevel.AUTOCOMMIT, |
| 442 | + ), |
| 443 | + ) |
| 444 | + def test_not_interactive_transaction( |
| 445 | + self, connection_no_trans: sa.Connection, connection: sa.Connection, isolation_level |
| 446 | + ): |
| 447 | + table = self.tables.test |
| 448 | + dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection |
| 449 | + |
| 450 | + stm1 = table.insert().values([{"id": 7}, {"id": 8}]) |
| 451 | + connection.execute(stm1) |
| 452 | + |
| 453 | + connection_no_trans.execution_options(isolation_level=isolation_level) |
| 454 | + with connection_no_trans.begin(): |
| 455 | + assert dbapi_connection.tx_context is None |
| 456 | + cursor1 = connection_no_trans.execute(sa.select(table)) |
| 457 | + cursor2 = connection_no_trans.execute(sa.select(table)) |
| 458 | + assert dbapi_connection.tx_context is None |
| 459 | + |
| 460 | + assert set(cursor1.fetchall()) == {(7,), (8,)} |
| 461 | + assert set(cursor2.fetchall()) == {(7,), (8,)} |
| 462 | + |
| 463 | + |
| 464 | +class TestTransactionIsolationLevel(TestBase): |
| 465 | + class IsolationSettings(NamedTuple): |
| 466 | + ydb_mode: ydb.AbstractTransactionModeBuilder |
| 467 | + interactive: bool |
| 468 | + |
| 469 | + YDB_ISOLATION_SETTINGS_MAP = { |
| 470 | + IsolationLevel.AUTOCOMMIT: IsolationSettings(ydb.SerializableReadWrite().name, False), |
| 471 | + IsolationLevel.SERIALIZABLE: IsolationSettings(ydb.SerializableReadWrite().name, True), |
| 472 | + IsolationLevel.ONLINE_READONLY: IsolationSettings(ydb.OnlineReadOnly().name, False), |
| 473 | + IsolationLevel.ONLINE_READONLY_INCONSISTENT: IsolationSettings( |
| 474 | + ydb.OnlineReadOnly().with_allow_inconsistent_reads().name, False |
| 475 | + ), |
| 476 | + IsolationLevel.STALE_READONLY: IsolationSettings(ydb.StaleReadOnly().name, False), |
| 477 | + IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.SnapshotReadOnly().name, True), |
| 478 | + } |
| 479 | + |
| 480 | + def test_connection_set(self, connection_no_trans: sa.Connection): |
| 481 | + dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection |
| 482 | + |
| 483 | + for sa_isolation_level, ydb_isolation_settings in self.YDB_ISOLATION_SETTINGS_MAP.items(): |
| 484 | + connection_no_trans.execution_options(isolation_level=sa_isolation_level) |
| 485 | + with connection_no_trans.begin(): |
| 486 | + assert dbapi_connection.tx_mode.name == ydb_isolation_settings[0] |
| 487 | + assert dbapi_connection.interactive_transaction is ydb_isolation_settings[1] |
| 488 | + if dbapi_connection.interactive_transaction: |
| 489 | + assert dbapi_connection.tx_context is not None |
| 490 | + assert dbapi_connection.tx_context.tx_id is not None |
| 491 | + else: |
| 492 | + assert dbapi_connection.tx_context is None |
| 493 | + |
| 494 | + |
| 495 | +class TestEngine(TestBase): |
| 496 | + @pytest.fixture(scope="module") |
| 497 | + def ydb_driver(self): |
| 498 | + url = config.db_url |
| 499 | + driver = ydb.Driver(endpoint=f"grpc://{url.host}:{url.port}", database=url.database) |
| 500 | + try: |
| 501 | + driver.wait(timeout=5, fail_fast=True) |
| 502 | + yield driver |
| 503 | + finally: |
| 504 | + driver.stop() |
| 505 | + |
| 506 | + driver.stop() |
| 507 | + |
| 508 | + @pytest.fixture(scope="module") |
| 509 | + def ydb_pool(self, ydb_driver): |
| 510 | + session_pool = ydb.SessionPool(ydb_driver, size=5, workers_threads_count=1) |
| 511 | + |
| 512 | + yield session_pool |
| 513 | + |
| 514 | + session_pool.stop() |
| 515 | + |
| 516 | + def test_sa_queue_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool): |
| 517 | + engine1 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool}) |
| 518 | + engine2 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool}) |
| 519 | + |
| 520 | + with engine1.connect() as conn1, engine2.connect() as conn2: |
| 521 | + dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection |
| 522 | + dbapi_conn2: dbapi.Connection = conn2.connection.dbapi_connection |
| 523 | + |
| 524 | + assert dbapi_conn1.session_pool is dbapi_conn2.session_pool |
| 525 | + assert dbapi_conn1.driver is dbapi_conn2.driver |
| 526 | + |
| 527 | + engine1.dispose() |
| 528 | + engine2.dispose() |
| 529 | + assert not ydb_driver._stopped |
| 530 | + |
| 531 | + def test_sa_null_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool): |
| 532 | + engine1 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool}) |
| 533 | + engine2 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool}) |
| 534 | + |
| 535 | + with engine1.connect() as conn1, engine2.connect() as conn2: |
| 536 | + dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection |
| 537 | + dbapi_conn2: dbapi.Connection = conn2.connection.dbapi_connection |
| 538 | + |
| 539 | + assert dbapi_conn1.session_pool is dbapi_conn2.session_pool |
| 540 | + assert dbapi_conn1.driver is dbapi_conn2.driver |
| 541 | + |
| 542 | + engine1.dispose() |
| 543 | + engine2.dispose() |
| 544 | + assert not ydb_driver._stopped |
| 545 | + |
| 546 | + |
374 | 547 | class TestUpsert(TablesTest): |
375 | 548 | @classmethod |
376 | 549 | def define_tables(cls, metadata): |
|
0 commit comments