diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index dc8da639..84275dd2 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -12,9 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -# aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_mysqlconnector_dialect.py +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import mysql.connector +from mysql.connector import CMySQLConnection +from mysql.connector.errors import Error from sqlalchemy.dialects.mysql.mysqlconnector import \ MySQLDialect_mysqlconnector +from sqlalchemy.engine import default + +from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.utils.properties import (Properties, + PropertiesUtils) + +if TYPE_CHECKING: + from sqlalchemy import Connection + + from aws_advanced_python_wrapper.hostinfo import HostInfo class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector): @@ -27,3 +44,154 @@ class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector): name = 'mysql' driver = 'aws_wrapper_mysqlconnector' + + @classmethod + def import_dbapi(cls): + """ + Return the DB-API 2.0 module. + SQLAlchemy calls this to get the driver module. + """ + import aws_advanced_python_wrapper + return aws_advanced_python_wrapper + + def create_connect_args(self, url): + """ + Transform SQLAlchemy URL into connection arguments. + Must include the 'target' parameter for our wrapper driver. + """ + # Extract standard connection parameters + opts = url.translate_connect_args(username='user') + + # Add query string parameters + opts.update(url.query) + + # Add the required 'target' parameter for our wrapper + if 'target' not in opts: + opts['target'] = mysql.connector.Connect + if 'wrapper_plugins' not in opts: + opts['plugins'] = "aurora_connection_tracker,failover" + else: + opts['plugins'] = opts['wrapper_plugins'] + opts.pop('wrapper_plugins', None) + if 'connect_timeout' in opts: + opts['connect_timeout'] = int(opts['connect_timeout']) + + # Return empty args list and kwargs dict + return [], opts + + def _detect_charset(self, connection: Connection) -> str: + if isinstance(connection, CMySQLConnection): + return connection.charset + else: + raise Exception("Could not detect charset because connection was not a CMySQLConnection.") + + def _extract_error_code(self, exception: BaseException) -> int: + if isinstance(exception, AwsWrapperError): + err = exception.driver_error + if err and isinstance(err, Error): + return err.errno + else: + raise Exception("Could not extract error code because driver_error was not a BaseException.") + else: + raise Exception("Could not extract error code because exception was not an AwsWrapperError.") + + def initialize(self, connection): + """ + Override initialization to handle type introspection. + The parent class tries to use TypeInfo.fetch() which requires + a native SQLAlchemy connection, not AwsWrapperConnection. + """ + + # Unwrap SQLAlchemy's connection object + wrapper_conn, wrapper_parent = self._get_wrapper_connection_and_parent(connection) + + # this is driver-based, does not need server version info + # and is fairly critical for even basic SQL operations + self._connection_charset: Optional[str] = self._detect_charset( + wrapper_conn.target_connection + ) + + # call super().initialize() because we need to have + # server_version_info set up. in 1.4 under python 2 only this does the + # "check unicode returns" thing, which is the one area that some + # SQL gets compiled within initialize() currently + default.DefaultDialect.initialize(self, connection) + + self._detect_sql_mode(connection) + self._detect_ansiquotes(connection) # depends on sql mode + self._detect_casing(connection) + if self._server_ansiquotes: + # if ansiquotes == True, build a new IdentifierPreparer + # with the new setting + self.identifier_preparer = self.preparer( + self, server_ansiquotes=self._server_ansiquotes + ) + + self.supports_sequences = ( + self.is_mariadb and self.server_version_info >= (10, 3) + ) + + self.supports_for_update_of = ( + self._is_mysql and self.server_version_info >= (8,) + ) + + self.use_mysql_for_share = ( + self._is_mysql and self.server_version_info >= (8, 0, 1) + ) + + self._needs_correct_for_88718_96365 = ( + not self.is_mariadb and self.server_version_info >= (8,) + ) + + self.delete_returning = ( + self.is_mariadb and self.server_version_info >= (10, 0, 5) + ) + + self.insert_returning = ( + self.is_mariadb and self.server_version_info >= (10, 5) + ) + + self._requires_alias_for_on_duplicate_key = ( + self._is_mysql and self.server_version_info >= (8, 0, 20) + ) + + self._warn_for_known_db_issues() + + def _get_wrapper_connection_and_parent(self, connection): + """ + Traverse the connection chain to find AwsWrapperConnection and its parent connection. + + Args: + connection: SQLAlchemy Connection object + + Returns: + AwsWrapperConnection instance or None, parent connection of AwsWrapperConnection or None + """ + # Start with the DBAPI connection + parent = connection + child = connection.connection + + # Traverse up to 5 levels deep (reasonable limit) + for _ in range(5): + if isinstance(child, AwsWrapperConnection): + return child, parent + + # Try to go deeper if there's a .connection attribute + if hasattr(child, 'connection'): + parent = child + child = child.connection + else: + break + + return None + + def prepare_connect_info(self, host_info: HostInfo, props: Properties) -> Properties: + prop_copy: Properties = Properties(props.copy()) + + prop_copy["host"] = host_info.host + + if host_info.is_port_specified(): + prop_copy["port"] = str(host_info.port) + + PropertiesUtils.remove_wrapper_props(prop_copy) + return prop_copy diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py index 82427f95..d85f105c 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -29,7 +29,6 @@ subqueryload) from sqlalchemy.sql import func -from tests.integration.container.utils.rds_test_utility import RdsTestUtility from ..utils.conditions import (disable_on_features, enable_on_deployments, enable_on_engines) from ..utils.database_engine import DatabaseEngine @@ -114,41 +113,29 @@ class Book(Base): TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, TestEnvironmentFeatures.PERFORMANCE]) class TestSqlAlchemy: - @pytest.fixture(scope='class') - def rds_utils(self): - region: str = TestEnvironment.get_current().get_info().get_region() - return RdsTestUtility(region) - - - @pytest.fixture(scope="class") + @pytest.fixture(scope="function") def engine(self, conn_utils): conn_str = f'mysql+aws_wrapper_mysqlconnector://{conn_utils.user}:{conn_utils.password}@{conn_utils.writer_cluster_host}:{conn_utils.port}/{conn_utils.dbname}' engine = create_engine(conn_str) Base.metadata.create_all(engine) yield engine Base.metadata.drop_all(engine) + engine.dispose() - @pytest.fixture(scope="class") - def Session(self, engine): - Session = sessionmaker(bind=engine) - yield Session - - @pytest.fixture(scope="class") - def session(self, Session): - session = Session() + @pytest.fixture(scope="function") + def session(self, engine): + session = sessionmaker(bind=engine)() yield session session.rollback() session.close() - def test_sqlalchemy_backend_configuration(self, test_environment: TestEnvironment, engine): + def test_sqlalchemy_backend_configuration(self, test_environment: TestEnvironment, session): """Test SQLAlchemy backend configuration with empty plugins""" # Verify that the connection is using the AWS wrapper - with engine.connect() as connection: - assert connection.connection is not None + assert session.connection().connection is not None # Test basic connection functionality - with Session(engine) as session: - assert session.query(TestModel).count() == 0 + assert session.query(TestModel).count() == 0 def test_sqlalchemy_basic_model_operations(self, session, test_environment: TestEnvironment): """Test basic SQLAlchemy ORM operations (CRUD)""" diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py new file mode 100644 index 00000000..0ab30d28 --- /dev/null +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py @@ -0,0 +1,607 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa: N806 + +from __future__ import annotations + +import json +import uuid +from datetime import date, datetime, time, timezone +from decimal import Decimal +from time import perf_counter_ns, sleep +from typing import Any, ClassVar, Dict, List, Optional + +import boto3 +import pytest +from boto3 import client +from botocore.exceptions import ClientError +from sqlalchemy import (JSON, BigInteger, Boolean, Column, Date, DateTime, + Float, ForeignKey, Integer, Numeric, SmallInteger, + String, Text, Time, create_engine, text) +from sqlalchemy.exc import DBAPIError +from sqlalchemy.orm import (DeclarativeBase, Mapped, Session, mapped_column, + relationship, sessionmaker) + +from aws_advanced_python_wrapper.errors import FailoverSuccessError +from tests.integration.container.utils.rds_test_utility import RdsTestUtility +from ..utils.conditions import (disable_on_features, enable_on_deployments, + enable_on_engines, enable_on_features, + enable_on_num_instances) +from ..utils.database_engine import DatabaseEngine +from ..utils.database_engine_deployment import DatabaseEngineDeployment +from ..utils.test_environment import TestEnvironment +from ..utils.test_environment_features import TestEnvironmentFeatures + + +class Base(DeclarativeBase): + pass + +class TestModel(Base): + """Basic test model for SQLAlchemy ORM functionality""" + __tablename__ = 'sqlalchemy_test_model' + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + email: Mapped[str] = mapped_column(String(254), unique=True) + age: Mapped[int] = mapped_column() + is_active: Mapped[Optional[bool]] = mapped_column(Boolean, default=True) + created_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=datetime.now(timezone.utc)) + + +class DataTypeModel(Base): + """Model for testing various data types""" + __tablename__ = 'sqlalchemy_data_type_model' + + id: Mapped[int] = mapped_column(primary_key=True) + + # String fields + string_field: Mapped[Optional[str]] = mapped_column(String(255)) + text_field: Mapped[Optional[str]] = mapped_column(Text) + + # Numeric fields + integer_field: Mapped[Optional[int]] = mapped_column() + small_integer_field: Mapped[Optional[int]] = mapped_column(SmallInteger) + big_integer_field: Mapped[Optional[int]] = mapped_column(BigInteger) + numeric_field: Mapped[Optional[Decimal]] = mapped_column(Numeric(10, 2)) + float_field: Mapped[Optional[float]] = mapped_column(Float) + + # Boolean field + boolean_field: Mapped[Optional[bool]] = mapped_column(Boolean, default=False) + + # Date/Time fields + date_field: Mapped[Optional[date]] = mapped_column(Date) + time_field: Mapped[Optional[time]] = mapped_column(Time) + datetime_field: Mapped[Optional[datetime]] = mapped_column(DateTime) + + # JSON field (MySQL 5.7+) + json_field: Mapped[Optional[Any]] = mapped_column(JSON) + + +class Author(Base): + """Author model for relationship testing""" + __tablename__ = 'sqlalchemy_author' + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + email: Mapped[str] = mapped_column(String(254)) + birth_date: Mapped[Optional[date]] = mapped_column(Date) + + books: Mapped[List[Book]] = relationship(back_populates='author', cascade='all, delete-orphan') + + +class Book(Base): + """Book model for relationship testing""" + __tablename__ = 'sqlalchemy_book' + + id: Mapped[int] = mapped_column(primary_key=True) + title: Mapped[str] = mapped_column(String(200)) + author_id: Mapped[int] = mapped_column(ForeignKey("sqlalchemy_author.id")) + publication_date: Mapped[date] = mapped_column(Date) + pages: Mapped[int] = mapped_column() + price: Mapped[Decimal] = mapped_column(Numeric(8, 2)) + + author: Mapped[Author] = relationship(back_populates='books') + +def _build_url(user, password, host, port, dbname, wrapper_plugins=None, **extra_options): + """Build a SQLAlchemy connection URL using the aws wrapper dialect.""" + query_params = {} + if wrapper_plugins: + query_params['wrapper_plugins'] = wrapper_plugins + query_params['connect_timeout'] = str(extra_options.get('connect_timeout', 10)) + for k, v in extra_options.items(): + if k != 'connect_timeout': + query_params[k] = str(v) + + from sqlalchemy.engine import URL + return URL.create( + drivername="mysql+aws_wrapper_mysqlconnector", + username=user, + password=password, + host=host, + port=port, + database=dbname, + query=query_params, + ) + +@enable_on_engines([DatabaseEngine.MYSQL]) +@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestSqlAlchemyPlugins: + endpoint_id: ClassVar[str] = f"test-sqlalchemy-endpoint-{uuid.uuid4()}" + endpoint_info: ClassVar[Dict[str, Any]] = {} + reuse_existing_endpoint: ClassVar[bool] = False + + @pytest.fixture(scope='class') + def rds_utils(self): + region: str = TestEnvironment.get_current().get_info().get_region() + return RdsTestUtility(region) + + @pytest.fixture(scope='class') + def create_secret(self, conn_utils): + """Create a secret in AWS Secrets Manager with database credentials.""" + region = TestEnvironment.get_current().get_info().get_region() + sm_client = boto3.client('secretsmanager', region_name=region) + env = TestEnvironment.get_current() + + secret_name = f"TestSecret-{uuid.uuid4()}" + engine = "postgres" if env.get_engine() == "pg" else "mysql" + secret_value = { + "engine": engine, + "dbname": env.get_info().get_database_info().get_default_db_name(), + "host": env.get_info().get_database_info().get_cluster_endpoint(), + "username": conn_utils.user, + "password": conn_utils.password, + "description": "Test secret generated by integration tests." + } + + try: + response = sm_client.create_secret( + Name=secret_name, + SecretString=json.dumps(secret_value) + ) + secret_arn = response['ARN'] + yield secret_name, secret_arn + finally: + try: + sm_client.delete_secret(SecretId=secret_name, ForceDeleteWithoutRecovery=True) + except Exception: + pass + + @pytest.fixture(scope='class') + def create_custom_endpoint(self): + """Create a custom endpoint for testing""" + env_info = TestEnvironment.get_current().get_info() + region = env_info.get_region() + rds_client = client('rds', region_name=region) + + if not self.reuse_existing_endpoint: + instances = env_info.get_database_info().get_instances() + self._create_endpoint(rds_client, instances[0:1]) + + self._wait_until_endpoint_available(rds_client) + yield + if not self.reuse_existing_endpoint: + self._delete_endpoint(rds_client) + rds_client.close() + + def _wait_until_endpoint_available(self, rds_client): + end_ns = perf_counter_ns() + 5 * 60 * 1_000_000_000 + available = False + while perf_counter_ns() < end_ns: + response = rds_client.describe_db_cluster_endpoints( + DBClusterEndpointIdentifier=self.endpoint_id, + Filters=[{"Name": "db-cluster-endpoint-type", "Values": ["custom"]}] + ) + endpoints = response["DBClusterEndpoints"] + if len(endpoints) != 1: + sleep(3) + continue + TestSqlAlchemyPlugins.endpoint_info = endpoints[0] + if endpoints[0]["Status"] == "available": + available = True + break + sleep(3) + if not available: + pytest.fail(f"Timed out waiting for custom endpoint: '{self.endpoint_id}'.") + + def _create_endpoint(self, rds_client, instances): + instance_ids = [i.get_instance_id() for i in instances] + rds_client.create_db_cluster_endpoint( + DBClusterEndpointIdentifier=self.endpoint_id, + DBClusterIdentifier=TestEnvironment.get_current().get_cluster_name(), + EndpointType="ANY", + StaticMembers=instance_ids + ) + + def _delete_endpoint(self, rds_client): + try: + rds_client.delete_db_cluster_endpoint(DBClusterEndpointIdentifier=self.endpoint_id) + self._wait_until_endpoint_deleted(rds_client) + except ClientError as e: + if e.response['Error']['Code'] != 'DBClusterEndpointNotFoundFault': + pytest.fail(e) + + def _wait_until_endpoint_deleted(self, rds_client): + end_ns = perf_counter_ns() + 3 * 60 * 1_000_000_000 + deleted = False + while perf_counter_ns() < end_ns: + try: + response = rds_client.describe_db_cluster_endpoints( + DBClusterEndpointIdentifier=self.endpoint_id, + Filters=[{"Name": "db-cluster-endpoint-type", "Values": ["custom"]}] + ) + if len(response["DBClusterEndpoints"]) == 0: + deleted = True + break + sleep(3) + except ClientError as e: + if e.response['Error']['Code'] == 'DBClusterEndpointNotFoundFault': + deleted = True + break + sleep(3) + if deleted: + print(f"Custom endpoint '{self.endpoint_id}' successfully deleted.") + else: + print(f"Warning: Timed out waiting for custom endpoint deletion: '{self.endpoint_id}'.") + + @pytest.fixture(scope='function') + def sa_models(self, sa_setup): + """Create SQLAlchemy tables and provide model classes.""" + engine = sa_setup['engine'] + test_id = str(uuid.uuid4())[:8] + + Base.metadata.create_all(engine, tables=[ + TestModel.__table__, DataTypeModel.__table__, + Author.__table__, Book.__table__ + ]) + + models = { + 'TestModel': TestModel, + 'DataTypeModel': DataTypeModel, + 'Author': Author, + 'Book': Book, + } + + yield models + + Base.metadata.drop_all(engine, tables=[ + Book.__table__, Author.__table__, + DataTypeModel.__table__, TestModel.__table__ + ]) + + + @pytest.fixture(scope='function') + def sa_setup(self, conn_utils, create_secret, request, create_custom_endpoint=None): + """Setup SQLAlchemy engine with configurable plugins.""" + if hasattr(request, 'param') and isinstance(request.param, dict): + config = request.param + plugins_config = config.get('wrapper_plugins', 'aurora_connection_tracker,failover_v2') + extra_options = config.get('options', {}) + use_custom_endpoint = config.get('use_custom_endpoint', False) + custom_endpoint_host = None + if use_custom_endpoint and create_custom_endpoint: + custom_endpoint_host = self.endpoint_info.get('Endpoint') + + if 'iam' in plugins_config: + user = conn_utils.iam_user + extra_options['auth_plugin'] = 'mysql_clear_password' + elif 'aws_secrets_manager' in plugins_config: + user = None + _, secret_arn = create_secret + extra_options['secrets_manager_secret_id'] = secret_arn + else: + user = config.get('user', conn_utils.user) + + if 'iam' in plugins_config or 'aws_secrets_manager' in plugins_config: + password = None + else: + password = config.get('password', conn_utils.password) + + host = custom_endpoint_host or config.get('host', conn_utils.writer_cluster_host) + else: + plugins_config = 'aurora_connection_tracker,failover_v2' + extra_options = {} + user = conn_utils.user + password = conn_utils.password + host = conn_utils.writer_host + + url = _build_url(user, password, host, conn_utils.port, conn_utils.dbname, + wrapper_plugins=plugins_config, **extra_options) + engine = create_engine(url) + SessionLocal = sessionmaker(bind=engine) + + yield {'engine': engine, 'SessionLocal': SessionLocal, + 'plugins': plugins_config, 'options': extra_options} + + engine.dispose() + + def test_sqlalchemy_basic_insert_with_plugins(self, test_environment: TestEnvironment, sa_models, sa_setup): + """Test basic SQLAlchemy insert operation with plugins enabled""" + TestModel = sa_models['TestModel'] + session: Session = sa_setup['SessionLocal']() + + try: + session.query(TestModel).delete() + obj = TestModel(name="Plugin Test User", email="plugin@example.com", age=25, is_active=True) + session.add(obj) + session.commit() + + assert obj.id is not None + assert obj.name == "Plugin Test User" + + retrieved = session.get(TestModel, obj.id) + assert retrieved and retrieved.name == "Plugin Test User" + + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.mark.parametrize('sa_setup', [{'wrapper_plugins': ''}], indirect=True) + def test_sqlalchemy_with_no_plugins(self, test_environment: TestEnvironment, sa_models, sa_setup): + """Test SQLAlchemy with no plugins enabled""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert config['plugins'] == '' + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="No Plugins User", email="noplugins@example.com", age=30) + session.add(obj) + session.commit() + assert obj.id is not None + assert obj.name == "No Plugins User" + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.mark.parametrize('sa_setup', [{'wrapper_plugins': 'failover_v2'}], indirect=True) + def test_sqlalchemy_with_failover_only(self, test_environment: TestEnvironment, sa_models, sa_setup): + """Test SQLAlchemy with only failover plugin""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert config['plugins'] == 'failover_v2' + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="Failover Only User", email="failover@example.com", age=35) + session.add(obj) + session.commit() + assert obj.id is not None + assert obj.name == "Failover Only User" + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.mark.parametrize('sa_setup', [{'wrapper_plugins': 'aurora_connection_tracker,failover_v2'}], indirect=True) + def test_sqlalchemy_with_multiple_plugins(self, test_environment: TestEnvironment, sa_models, sa_setup): + """Test SQLAlchemy with multiple plugins enabled""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert config['plugins'] == 'aurora_connection_tracker,failover_v2' + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="Multi Plugin User", email="multiplugin@example.com", age=40) + session.add(obj) + session.commit() + assert obj.id is not None + assert obj.name == "Multi Plugin User" + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.mark.parametrize('sa_setup', [{ + 'wrapper_plugins': 'aws_secrets_manager', + 'use_secrets_manager': True + }], indirect=True) + def test_sqlalchemy_with_secrets_manager_plugin(self, test_environment: TestEnvironment, sa_setup, sa_models): + """Test SQLAlchemy with AWS Secrets Manager plugin""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert config['plugins'] == 'aws_secrets_manager' + assert 'secrets_manager_secret_id' in config['options'] + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="Secrets Manager User", email="secrets@example.com", age=45) + session.add(obj) + session.commit() + assert obj.id is not None + + retrieved = session.get(TestModel, obj.id) + assert retrieved and retrieved.email == "secrets@example.com" + + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.mark.parametrize('sa_setup', [{ + 'wrapper_plugins': 'iam', + 'password': '', + 'options': {} + }], indirect=True) + def test_sqlalchemy_with_iam_plugin(self, test_environment: TestEnvironment, sa_models, sa_setup, conn_utils): + """Test SQLAlchemy with IAM authentication plugin""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert config['plugins'] == 'iam' + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="IAM User", email="iam@example.com", age=50) + session.add(obj) + session.commit() + assert obj.id is not None + + retrieved = session.get(TestModel, obj.id) + assert retrieved and retrieved.email == "iam@example.com" + + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.mark.parametrize('sa_setup', [{ + 'wrapper_plugins': 'failover_v2', + 'options': { + 'socket_timeout': 10, + 'connect_timeout': 10, + 'monitoring-connect_timeout': 5, + 'monitoring-socket_timeout': 5, + 'topology_refresh_ms': 10 + } + }], indirect=True) + @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) + @enable_on_num_instances(min_instances=2) + def test_sqlalchemy_failover_during_query(self, test_environment: TestEnvironment, sa_setup, sa_models, rds_utils): + """Test SQLAlchemy failover during query operations""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert 'failover_v2' in config['plugins'] + + initial_writer_id = rds_utils.get_cluster_writer_instance_id() + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="Failover Test User", email="failover@example.com", age=30) + session.add(obj) + session.commit() + + result = session.query(TestModel).filter_by(id=obj.id).first() + assert result is not None + assert result.name == "Failover Test User" + + rds_utils.failover_cluster_and_wait_until_writer_changed() + + with pytest.raises(DBAPIError): + session.query(TestModel).filter_by(id=obj.id).first() + + result = session.query(TestModel).filter_by(id=obj.id).first() + assert result is not None + assert result.name == "Failover Test User" + + row = session.execute(text(RdsTestUtility.get_instance_id_query())).fetchone() + if row: + current_writer_id = row._tuple()[0] + else: + raise Exception("Failed to get current_writer_id from row because row was None.") + assert rds_utils.is_db_instance_writer(current_writer_id) is True + assert current_writer_id != initial_writer_id + + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + ''' + @pytest.mark.parametrize('sa_setup', [{ + 'wrapper_plugins': 'custom_endpoint,failover_v2', + 'use_custom_endpoint': True, + 'options': { + 'socket_timeout': 10, + 'connect_timeout': 10, + 'monitoring-connect_timeout': 5, + 'monitoring-socket_timeout': 5, + 'topology_refresh_ms': 10 + } + }], indirect=True) + @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) + @enable_on_num_instances(min_instances=2) + def test_sqlalchemy_custom_endpoint_failover_during_query( + self, test_environment: TestEnvironment, create_custom_endpoint, + sa_setup, sa_models, rds_utils): + """Test SQLAlchemy failover with custom endpoint during query operations""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert 'custom_endpoint' in config['plugins'] + assert 'failover_v2' in config['plugins'] + + initial_writer_id = rds_utils.get_cluster_writer_instance_id() + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="Custom Endpoint Failover Test User", email="custom_failover@example.com", age=35) + session.add(obj) + session.commit() + + result = session.query(TestModel).filter_by(id=obj.id).first() + assert result is not None + assert result.name == "Custom Endpoint Failover Test User" + + rds_utils.failover_cluster_and_wait_until_writer_changed() + + with pytest.raises(DBAPIError): + session.query(TestModel).filter_by(id=obj.id).first() + + result = session.query(TestModel).filter_by(id=obj.id).first() + assert result is not None + assert result.name == "Custom Endpoint Failover Test User" + + row = session.execute(text(RdsTestUtility.get_instance_id_query())).fetchone() + current_writer_id = row[0] + assert rds_utils.is_db_instance_writer(current_writer_id) is True + assert current_writer_id != initial_writer_id + + session.query(TestModel).delete() + session.commit() + finally: + session.close() + ''' + + @pytest.fixture(scope='function') + def sa_rw_split_setup(self, conn_utils): + """Setup SQLAlchemy with read/write splitting configuration""" + writer_url = _build_url( + conn_utils.user, conn_utils.password, conn_utils.writer_cluster_host, + conn_utils.port, conn_utils.dbname, plugins='read_write_splitting') + reader_url = _build_url( + conn_utils.user, conn_utils.password, conn_utils.writer_cluster_host, + conn_utils.port, conn_utils.dbname, plugins='read_write_splitting') + + writer_engine = create_engine(writer_url) + reader_engine = create_engine(reader_url) + WriterSession = sessionmaker(bind=writer_engine) + ReaderSession = sessionmaker(bind=reader_engine) + + test_id = str(uuid.uuid4())[:8] + + class RWSplitTestModel(Base): + __tablename__ = f'sa_rw_split_test_{test_id}' + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(100), nullable=False) + value = Column(Integer, nullable=False) + + Base.metadata.create_all(writer_engine, tables=[RWSplitTestModel.__table__]) + + yield { + 'model': RWSplitTestModel, + 'writer_engine': writer_engine, + 'reader_engine': reader_engine, + 'WriterSession': WriterSession, + 'ReaderSession': ReaderSession, + } + + Base.metadata.drop_all(writer_engine, tables=[RWSplitTestModel.__table__]) + writer_engine.dispose() + reader_engine.dispose() + +