From 72457dc119179d4f515340d2728c7d74234153ea Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 6 Feb 2026 14:41:36 -0800 Subject: [PATCH 01/39] Simple PG workflow working --- aws_advanced_python_wrapper/__init__.py | 24 ++- .../sqlalchemy/orm_dialect.py | 139 ++++++++++++++++++ pyproject.toml | 2 + tests/unit/test_sqlalchemy_orm.py | 66 +++++++++ 4 files changed, 230 insertions(+), 1 deletion(-) create mode 100644 aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py create mode 100644 tests/unit/test_sqlalchemy_orm.py diff --git a/aws_advanced_python_wrapper/__init__.py b/aws_advanced_python_wrapper/__init__.py index fbac66233..7d4dbf38a 100644 --- a/aws_advanced_python_wrapper/__init__.py +++ b/aws_advanced_python_wrapper/__init__.py @@ -15,8 +15,20 @@ from logging import DEBUG, getLogger from .cleanup import release_resources +from .driver_info import DriverInfo from .utils.utils import LogUtils from .wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper.pep249 import ( + Error, + InterfaceError, + DatabaseError, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError +) # PEP249 compliance connect = AwsWrapperConnection.connect @@ -32,9 +44,19 @@ 'set_logger', 'apilevel', 'threadsafety', - 'paramstyle' + 'paramstyle', + 'Error', + 'InterfaceError', + 'DatabaseError', + 'DataError', + 'OperationalError', + 'IntegrityError', + 'InternalError', + 'ProgrammingError', + 'NotSupportedError' ] +__version__ = DriverInfo.DRIVER_VERSION def set_logger(name='aws_advanced_python_wrapper', level=DEBUG, format_string=None): LogUtils.setup_logger(getLogger(name), level, format_string) diff --git a/aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py new file mode 100644 index 000000000..498e7adce --- /dev/null +++ b/aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py @@ -0,0 +1,139 @@ +# aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_psycopg_dialect.py +from psycopg import Connection +from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg +import re + +class SqlAlchemyOrmPgDialect(PGDialect_psycopg): + """ + SQLAlchemy dialect for AWS Advanced Python Wrapper. + Extends PostgreSQL psycopg dialect with Aurora-aware connection handling. + """ + + name = 'postgresql' + driver = 'aws_wrapper' + + def __init__(self, **kwargs): + # Skip parent's version check since we're a wrapper, not psycopg itself + super(PGDialect_psycopg, self).__init__(**kwargs) + + # Dynamically detect the actual psycopg version we're wrapping to ensure + # SQLAlchemy uses the correct feature set and SQL generation + try: + import psycopg + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", psycopg.__version__) + if m: + self.psycopg_version = tuple( + int(x) for x in m.group(1, 2, 3) if x is not None + ) + else: + self.psycopg_version = (3, 0, 2) # Minimum supported + except (ImportError, AttributeError): + self.psycopg_version = (3, 0, 2) + + @classmethod + def import_dbapi(cls): + """ + Return the DB-API 2.0 module. + SQLAlchemy calls this to get the driver module. + """ + import aws_advanced_python_wrapper + return aws_advanced_python_wrapper + + def create_connect_args(self, url): + """ + Transform SQLAlchemy URL into connection arguments. + Must include 'target' parameter for the wrapper. + """ + # Extract standard connection parameters + opts = url.translate_connect_args(username='user') + + # Add query string parameters + opts.update(url.query) + + # Add the required 'target' parameter for your wrapper + if 'target' not in opts: + opts['target'] = Connection.connect + + # Return empty args list and kwargs dict + return ([], opts) + + def on_connect(self): + """ + Return a callable that will be executed on new connections. This can be used if we need to set any session-level + parameters. + """ + + def set_session_params(conn): + # Set any Aurora-specific session parameters + cursor = conn.cursor() + try: + # Example: Set statement timeout + cursor.execute("SET statement_timeout = '60s'") + finally: + cursor.close() + + return set_session_params + + def get_isolation_level(self, dbapi_connection): + """Get the current isolation level""" + cursor = dbapi_connection.cursor() + try: + cursor.execute("SHOW transaction_isolation") + val = cursor.fetchone() + if val: + # Extract first element from tuple and format + return val.upper().replace(' ', '_') + return 'READ_COMMITTED' # PostgreSQL's default + finally: + cursor.close() + + def initialize(self, connection): + """ + Override initialization to handle type introspection. + The parent class tries to use TypeInfo.fetch() which requires + a native psycopg connection, not our wrapper. + """ + # Find the AwsWrapperConnection at whatever nesting level + wrapper_conn = self._get_wrapper_connection(connection) + + if wrapper_conn and hasattr(wrapper_conn, 'connection'): + # Get the underlying psycopg connection + underlying_conn = wrapper_conn.connection + + # Temporarily swap the entire connection chain + original_dbapi_conn = connection.connection + connection.connection = underlying_conn + + try: + # Call parent initialization with native psycopg connection + super().initialize(connection) + finally: + # Restore original connection chain + connection.connection = original_dbapi_conn + else: + # If we can't find wrapper or it doesn't expose underlying connection, + # skip type introspection (custom types won't be auto-configured) + pass + + def _get_wrapper_connection(self, connection): + """ + Traverse the connection chain to find AwsWrapperConnection. + Handles variable nesting depths depending on pool configuration. + """ + from aws_advanced_python_wrapper import AwsWrapperConnection + + # Start with the DBAPI connection + current = connection.connection + + # Traverse up to 5 levels deep (reasonable limit) + for _ in range(5): + if isinstance(current, AwsWrapperConnection): + return current + + # Try to go deeper if there's a .connection attribute + if hasattr(current, 'connection'): + current = current.connection + else: + break + + return None diff --git a/pyproject.toml b/pyproject.toml index ffd73d2f4..80d787048 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,3 +84,5 @@ filterwarnings = [ 'ignore:Exception during reset or similar:pytest.PytestUnhandledThreadExceptionWarning' ] +[tool.poetry.plugins."sqlalchemy.dialects"] +"postgresql.aws_wrapper" = "aws_advanced_python_wrapper.sqlalchemy.orm_dialect:SqlAlchemyOrmPgDialect" diff --git a/tests/unit/test_sqlalchemy_orm.py b/tests/unit/test_sqlalchemy_orm.py new file mode 100644 index 000000000..7ec7bec4b --- /dev/null +++ b/tests/unit/test_sqlalchemy_orm.py @@ -0,0 +1,66 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy import create_engine, Column, Integer, String +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +class TestSqlAlchemyORM: + def test_basic_workflow(self): + # Step 1: Create engine (connection to database) + engine = create_engine('postgresql+aws_wrapper://pguser:pgpassword@mydb.cluster-XYZ.us-west-1.rds.amazonaws.com:5432/somedb') + + # Step 2: Define base class for declarative models + Base = declarative_base() + + # Step 3: Define model class (separate from database operations) + class User(Base): + __tablename__ = 'users' + + id = Column(Integer, primary_key=True) + name = Column(String(50)) + email = Column(String(100)) + + # Step 4: Create tables + Base.metadata.create_all(engine) + + # Step 5: Create session factory + Session = sessionmaker(bind=engine) + + # Step 6: Use session for database operations + session = Session() + + # INSERT - Create new object and add to session + new_user = User(name='John Doe', email='john@example.com') + session.add(new_user) + session.commit() # Explicit commit required + + # SELECT - Query using session + users = session.query(User).filter(User.name == 'John Doe').all() + for user in users: + print(f"{user.name}: {user.email}") + + + # UPDATE - Modify object and commit + user = session.query(User).filter(User.name == "John Doe").first() + user.email = 'newemail@example.com' + session.commit() # Changes tracked by session + + # DELETE - Remove object from session + user_to_delete = session.query(User).filter(User.name == "John Doe").first() + session.delete(user_to_delete) + session.commit() + + # Always close session when done + session.close() From 6cd61c7e8caa1455492a42523614c33c0f525a1a Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Fri, 6 Feb 2026 15:23:33 -0800 Subject: [PATCH 02/39] Cleanup --- .../sqlalchemy/orm_dialect.py | 84 +++++++++++-------- tests/unit/test_sqlalchemy_orm.py | 46 +++++----- 2 files changed, 68 insertions(+), 62 deletions(-) diff --git a/aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py index 498e7adce..c71e6f22a 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py @@ -3,21 +3,30 @@ from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg import re +from aws_advanced_python_wrapper import AwsWrapperConnection + + class SqlAlchemyOrmPgDialect(PGDialect_psycopg): """ - SQLAlchemy dialect for AWS Advanced Python Wrapper. - Extends PostgreSQL psycopg dialect with Aurora-aware connection handling. + SQLAlchemy dialect for AWS Advanced Python Wrapper with psycopg. Extends the SQLAlchemy PostgreSQL psycopg dialect. + This dialect is not related to the DriverDialect or DatabaseDialect classes used by our driver. Instead, it is used + directly by SQLAlchemy. This dialect is registered in pyproject.toml and is selected by prefixing the connection + string passed to create_engine with "postgresql+aws_wrapper://" ("[name]+[driver]"). """ name = 'postgresql' driver = 'aws_wrapper' def __init__(self, **kwargs): - # Skip parent's version check since we're a wrapper, not psycopg itself + # PGDialect_psycopg's __init__ function checks the driver version and raises an exception if it is lower than + # 3.0.2. If we call it, the exception is raised because it mistakenly interprets our driver version as its own. + # As a workaround we call the grandparent __init__ instead of the parent's __init__. + # TODO: since we are calling the grandparent's __init__ instead of the parent's __init__, we should investigate + # whether any important code in the parent's __init__ needs to be executed. super(PGDialect_psycopg, self).__init__(**kwargs) - # Dynamically detect the actual psycopg version we're wrapping to ensure - # SQLAlchemy uses the correct feature set and SQL generation + # Dynamically detect the actual psycopg version installed and set it as self.psycopg_version. Note that setting + # this field before calling super().__init__ does not avoid the issue noted above. try: import psycopg m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", psycopg.__version__) @@ -26,8 +35,10 @@ def __init__(self, **kwargs): int(x) for x in m.group(1, 2, 3) if x is not None ) else: - self.psycopg_version = (3, 0, 2) # Minimum supported + # Fallback to 3.0.2 if version parsing fails, which is the minimum required psycopg version. + self.psycopg_version = (3, 0, 2) except (ImportError, AttributeError): + # Fallback to 3.0.2 if version parsing fails, which is the minimum required psycopg version. self.psycopg_version = (3, 0, 2) @classmethod @@ -42,7 +53,7 @@ def import_dbapi(cls): def create_connect_args(self, url): """ Transform SQLAlchemy URL into connection arguments. - Must include 'target' parameter for the wrapper. + Must include the 'target' parameter for our wrapper driver. """ # Extract standard connection parameters opts = url.translate_connect_args(username='user') @@ -50,7 +61,7 @@ def create_connect_args(self, url): # Add query string parameters opts.update(url.query) - # Add the required 'target' parameter for your wrapper + # Add the required 'target' parameter for our wrapper if 'target' not in opts: opts['target'] = Connection.connect @@ -62,7 +73,6 @@ def on_connect(self): Return a callable that will be executed on new connections. This can be used if we need to set any session-level parameters. """ - def set_session_params(conn): # Set any Aurora-specific session parameters cursor = conn.cursor() @@ -75,15 +85,13 @@ def set_session_params(conn): return set_session_params def get_isolation_level(self, dbapi_connection): - """Get the current isolation level""" cursor = dbapi_connection.cursor() try: cursor.execute("SHOW transaction_isolation") val = cursor.fetchone() if val: - # Extract first element from tuple and format return val.upper().replace(' ', '_') - return 'READ_COMMITTED' # PostgreSQL's default + return 'READ_COMMITTED' # return Postgres' default isolation level. finally: cursor.close() @@ -91,48 +99,50 @@ def initialize(self, connection): """ Override initialization to handle type introspection. The parent class tries to use TypeInfo.fetch() which requires - a native psycopg connection, not our wrapper. + a native psycopg connection, not AwsWrapperConnection. """ - # Find the AwsWrapperConnection at whatever nesting level - wrapper_conn = self._get_wrapper_connection(connection) - - if wrapper_conn and hasattr(wrapper_conn, 'connection'): - # Get the underlying psycopg connection - underlying_conn = wrapper_conn.connection + # Unwrap SQLAlchemy's connection object + wrapper_conn, wrapper_parent = self._get_wrapper_connection_and_parent(connection) - # Temporarily swap the entire connection chain - original_dbapi_conn = connection.connection - connection.connection = underlying_conn + # Check if wrapper_conn and wrapper_parent expose their underlying connections + if wrapper_conn and hasattr(wrapper_conn, 'connection') and wrapper_parent and hasattr(wrapper_parent.connection, 'connection'): + # Temporarily remove the AwsWrapperConnection from the connection chain + psycopg_conn = wrapper_conn.connection + wrapper_parent.connection = psycopg_conn try: - # Call parent initialization with native psycopg connection super().initialize(connection) finally: - # Restore original connection chain - connection.connection = original_dbapi_conn + # Restore wrapper connection in the connection chain. + wrapper_parent.connection = wrapper_conn else: - # If we can't find wrapper or it doesn't expose underlying connection, - # skip type introspection (custom types won't be auto-configured) + # If unable to swap underlying pscyopg connection, skip type introspection. + # This means custom types (hstore, json, etc.) won't be auto-configured. pass - def _get_wrapper_connection(self, connection): + def _get_wrapper_connection_and_parent(self, connection): """ - Traverse the connection chain to find AwsWrapperConnection. - Handles variable nesting depths depending on pool configuration. - """ - from aws_advanced_python_wrapper import AwsWrapperConnection + Traverse the connection chain to find AwsWrapperConnection and its parent connection. + + Args: + connection: SQLAlchemy Connection object + Returns: + AwsWrapperConnection instance or None, parent connection of AwsWrapperConnection or None + """ # Start with the DBAPI connection - current = connection.connection + parent = connection + child = connection.connection # Traverse up to 5 levels deep (reasonable limit) for _ in range(5): - if isinstance(current, AwsWrapperConnection): - return current + if isinstance(child, AwsWrapperConnection): + return child, parent # Try to go deeper if there's a .connection attribute - if hasattr(current, 'connection'): - current = current.connection + if hasattr(child, 'connection'): + parent = child + child = child.connection else: break diff --git a/tests/unit/test_sqlalchemy_orm.py b/tests/unit/test_sqlalchemy_orm.py index 7ec7bec4b..ffbbcf35a 100644 --- a/tests/unit/test_sqlalchemy_orm.py +++ b/tests/unit/test_sqlalchemy_orm.py @@ -39,28 +39,24 @@ class User(Base): Session = sessionmaker(bind=engine) # Step 6: Use session for database operations - session = Session() - - # INSERT - Create new object and add to session - new_user = User(name='John Doe', email='john@example.com') - session.add(new_user) - session.commit() # Explicit commit required - - # SELECT - Query using session - users = session.query(User).filter(User.name == 'John Doe').all() - for user in users: - print(f"{user.name}: {user.email}") - - - # UPDATE - Modify object and commit - user = session.query(User).filter(User.name == "John Doe").first() - user.email = 'newemail@example.com' - session.commit() # Changes tracked by session - - # DELETE - Remove object from session - user_to_delete = session.query(User).filter(User.name == "John Doe").first() - session.delete(user_to_delete) - session.commit() - - # Always close session when done - session.close() + with Session() as session: + # INSERT - Create new object and add to session + new_user = User(name='John Doe', email='john@example.com') + session.add(new_user) + session.commit() # Explicit commit required + + # SELECT - Query using session + users = session.query(User).filter(User.name == 'John Doe').all() + for user in users: + print(f"{user.name}: {user.email}") + + + # UPDATE - Modify object and commit + user = session.query(User).filter(User.name == "John Doe").first() + user.email = 'newemail@example.com' + session.commit() + + # DELETE - Remove object from session + user_to_delete = session.query(User).filter(User.name == "John Doe").first() + session.delete(user_to_delete) + session.commit() From bc607f6139cb9fd05dbe2b329f90cad5af218380 Mon Sep 17 00:00:00 2001 From: aaron-congo Date: Mon, 9 Feb 2026 10:20:28 -0800 Subject: [PATCH 03/39] Fix failover2 wrong writer host --- .../cluster_topology_monitor.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/aws_advanced_python_wrapper/cluster_topology_monitor.py b/aws_advanced_python_wrapper/cluster_topology_monitor.py index 173cb0014..660b48bae 100644 --- a/aws_advanced_python_wrapper/cluster_topology_monitor.py +++ b/aws_advanced_python_wrapper/cluster_topology_monitor.py @@ -21,7 +21,6 @@ from time import perf_counter_ns from typing import TYPE_CHECKING, Dict, Optional -from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.host_availability import HostAvailability from aws_advanced_python_wrapper.hostinfo import HostInfo, Topology from aws_advanced_python_wrapper.utils import services_container @@ -345,8 +344,8 @@ def _open_any_connection_and_update_topology(self) -> Topology: self._cluster_id, self._initial_host_info.host) try: - writer_id = self._topology_utils.get_writer_id_if_connected( - conn, self._plugin_service.driver_dialect) + writer_id = self._topology_utils.get_writer_host_if_connected( + conn, self._plugin_service.driver_dialect) if writer_id: self._is_verified_writer_connection = True writer_verified_by_this_thread = True @@ -355,10 +354,9 @@ def _open_any_connection_and_update_topology(self) -> Topology: writer_host_info = self._initial_host_info self._writer_host_info.set(writer_host_info) else: - instance_template = self._get_instance_template(writer_id, conn) - writer_host = instance_template.host.replace("?", writer_id) - port = instance_template.port \ - if instance_template.is_port_specified() \ + writer_host = self._instance_template.host.replace("?", writer_id) + port = self._instance_template.port \ + if self._instance_template.is_port_specified() \ else self._initial_host_info.port writer_host_info = HostInfo( writer_host, From 5daa9701eabbb01e8cddfcec1bc3a7887e2afb6e Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Tue, 24 Mar 2026 17:24:04 -0700 Subject: [PATCH 04/39] Add mysql-connector SQLAlchemy ORM --- .../sqlalchemy/mysql_orm_dialect.py | 19 + .../{orm_dialect.py => pg_orm_dialect.py} | 4 +- .../sqlalchemy/test_sqlalchemy_basic.py | 993 ++++++++++++++++++ tests/unit/test_sqlalchemy_orm.py | 2 +- 4 files changed, 1015 insertions(+), 3 deletions(-) create mode 100644 aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py rename aws_advanced_python_wrapper/sqlalchemy/{orm_dialect.py => pg_orm_dialect.py} (97%) create mode 100644 tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py new file mode 100644 index 000000000..6d7ff34db --- /dev/null +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -0,0 +1,19 @@ +# aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_mysqlconnector_dialect.py +from psycopg import Connection +from sqlalchemy.dialects.mysql.mysqlconnector import MySQLDialect_mysqlconnector +import re + +from aws_advanced_python_wrapper import AwsWrapperConnection + + +class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector): + """ + SQLAlchemy dialect for AWS Advanced Python Wrapper with mysqlconnector. Extends the SQLAlchemy MySQL mysqlconnector dialect. + This dialect is not related to the DriverDialect or DatabaseDialect classes used by our driver. Instead, it is used + directly by SQLAlchemy. This dialect is registered in pyproject.toml and is selected by prefixing the connection + string passed to create_engine with "mysql+aws_wrapper_mysqlconnector://" ("[name]+[driver]"). + """ + + name = 'mysql' + driver = 'aws_wrapper_mysqlconnector' + diff --git a/aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py similarity index 97% rename from aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py rename to aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py index c71e6f22a..c2780b861 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py @@ -11,11 +11,11 @@ class SqlAlchemyOrmPgDialect(PGDialect_psycopg): SQLAlchemy dialect for AWS Advanced Python Wrapper with psycopg. Extends the SQLAlchemy PostgreSQL psycopg dialect. This dialect is not related to the DriverDialect or DatabaseDialect classes used by our driver. Instead, it is used directly by SQLAlchemy. This dialect is registered in pyproject.toml and is selected by prefixing the connection - string passed to create_engine with "postgresql+aws_wrapper://" ("[name]+[driver]"). + string passed to create_engine with "postgresql+aws_wrapper_psycopg://" ("[name]+[driver]"). """ name = 'postgresql' - driver = 'aws_wrapper' + driver = 'aws_wrapper_psycopg' def __init__(self, **kwargs): # PGDialect_psycopg's __init__ function checks the driver version and raises an exception if it is lower than diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py new file mode 100644 index 000000000..bea23f156 --- /dev/null +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -0,0 +1,993 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa: N806 + +from __future__ import annotations + +from datetime import date, datetime, time +from decimal import Decimal +from typing import Any + +import pytest +from sqlalchemy.orm import declarative_base, Mapped, mapped_column +from sqlalchemy import Column, Integer, String + +from tests.integration.container.utils.rds_test_utility import RdsTestUtility +from ..utils.conditions import (disable_on_features, enable_on_deployments, + enable_on_engines) +from ..utils.database_engine import DatabaseEngine +from ..utils.database_engine_deployment import DatabaseEngineDeployment +from ..utils.test_environment import TestEnvironment +from ..utils.test_environment_features import TestEnvironmentFeatures + + +@enable_on_engines([DatabaseEngine.MYSQL]) # MySQL Specific until PG is implemented +@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestSqlAlchemy: + TestModel: Any + DataTypeModel: Any + Author: Any + Book: Any + + @pytest.fixture(scope='class') + def rds_utils(self): + region: str = TestEnvironment.get_current().get_info().get_region() + return RdsTestUtility(region) + + @pytest.fixture(scope='class') + def sqlalchemy_models(self, sqlalchemy_setup): + """Create SQLAlchemy models after SQLAlchemy is set up""" + + Base = declarative_base() + + class TestModel(Base): + """Basic test model for SQLAlchemy ORM functionality""" + __tablename__ = 'sqlalchemy_test_model' + + name: Mapped[str] = mapped_column(String(100)) + email: Mapped[str] = mapped_column(String, primary_key=True) + age: Mapped[int] = mapped_column(Integer) + is_active: Mapped[bool] = mapped_column(Bool, server_default=True) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now()) + +''' + class DataTypeModel(models.Model): + """Model for testing various data types""" + # String fields + char_field = models.CharField(max_length=255, null=True, blank=True) + text_field = models.TextField(null=True, blank=True) + + # Numeric fields + integer_field = models.IntegerField(null=True, blank=True) + big_integer_field = models.BigIntegerField(null=True, blank=True) + decimal_field = models.DecimalField(max_digits=10, decimal_places=2, null=True, blank=True) + float_field = models.FloatField(null=True, blank=True) + + # Boolean field + boolean_field = models.BooleanField(default=False) + + # Date/Time fields + date_field = models.DateField(null=True, blank=True) + time_field = models.TimeField(null=True, blank=True) + datetime_field = models.DateTimeField(null=True, blank=True) + + # JSON field (MySQL 5.7+) + json_field = models.JSONField(null=True, blank=True) + + class Meta: + app_label = 'test_app' + db_table = 'django_data_type_model' + + class Author(models.Model): + """Author model for relationship testing""" + name = models.CharField(max_length=100) + email = models.EmailField() + birth_date = models.DateField(null=True, blank=True) + + class Meta: + app_label = 'test_app' + db_table = 'django_author' + + # Store Author first so it's available for Book's ForeignKey + TestDjango.Author = Author + + class Book(models.Model): + """Book model for relationship testing""" + title = models.CharField(max_length=200) + author = models.ForeignKey(TestDjango.Author, on_delete=models.CASCADE, related_name='books') + publication_date = models.DateField() + pages = models.IntegerField() + price = models.DecimalField(max_digits=8, decimal_places=2) + + class Meta: + app_label = 'test_app' + db_table = 'django_book' + + # Store models as class attributes for easy access + TestDjango.TestModel = TestModel + TestDjango.DataTypeModel = DataTypeModel + TestDjango.Book = Book + + # Create tables for our test models + with connection.schema_editor() as schema_editor: + schema_editor.create_model(TestModel) + schema_editor.create_model(DataTypeModel) + schema_editor.create_model(Author) + schema_editor.create_model(Book) + + yield + + # Clean up tables + with connection.schema_editor() as schema_editor: + schema_editor.delete_model(Book) + schema_editor.delete_model(Author) + schema_editor.delete_model(DataTypeModel) + schema_editor.delete_model(TestModel) + + @pytest.fixture(scope='class') + def django_setup(self, conn_utils): + """Setup Django configuration for testing""" + # Configure Django settings + if not settings.configured: + db_config = { + 'ENGINE': 'aws_advanced_python_wrapper.django.backends.mysql_connector', + 'NAME': conn_utils.dbname, + 'USER': conn_utils.user, + 'PASSWORD': conn_utils.password, + 'HOST': conn_utils.writer_cluster_host, + 'PORT': conn_utils.port, + 'OPTIONS': { + 'plugins': 'failover_v2,aurora_connection_tracker', + 'connect_timeout': 10, + 'autocommit': True, + }, + } + + settings.configure( + DEBUG=True, + DATABASES={'default': db_config}, + INSTALLED_APPS=[ + 'django.contrib.contenttypes', + 'django.contrib.auth', + ], + SECRET_KEY='test-secret-key-for-django-tests', + USE_TZ=True, + ) + + django.setup() + setup_test_environment() + + yield + connections.close_all() + + teardown_test_environment() + + def test_django_backend_configuration(self, test_environment: TestEnvironment, django_models): + """Test Django backend configuration with empty plugins""" + # Verify that the connection is using the AWS wrapper + assert hasattr(connection, 'connection') + + # Test basic connection functionality + assert self.TestModel.objects.count() == 0 + + def test_django_basic_model_operations(self, test_environment: TestEnvironment, django_models): + """Test basic Django ORM operations (CRUD)""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create + test_obj = TestModel.objects.create( + name="John Doe", + email="john@example.com", + age=30, + is_active=True + ) + assert test_obj.id is not None + assert test_obj.name == "John Doe" + + # Read + retrieved_obj = TestModel.objects.get(id=test_obj.id) + assert retrieved_obj.name == "John Doe" + assert retrieved_obj.email == "john@example.com" + assert retrieved_obj.age == 30 + assert retrieved_obj.is_active is True + + # Update + retrieved_obj.name = "Jane Doe" + retrieved_obj.age = 25 + retrieved_obj.save() + + updated_obj = TestModel.objects.get(id=test_obj.id) + assert updated_obj.name == "Jane Doe" + assert updated_obj.age == 25 + + # Delete + updated_obj.delete() + assert TestModel.objects.filter(id=test_obj.id).count() == 0 + + def test_django_queryset_operations(self, test_environment: TestEnvironment, django_models): + """Test Django QuerySet operations""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) + + # Test filtering + active_users = TestModel.objects.filter(is_active=True) + assert active_users.count() == 2 + + # Test ordering + ordered_users = TestModel.objects.order_by('age') + ages = [user.age for user in ordered_users] + assert ages == [25, 30, 35] + + # Test complex queries + young_active_users = TestModel.objects.filter(age__lt=30, is_active=True) + assert young_active_users.count() == 1 + assert young_active_users.first().name == "Alice" + + # Test exclude + non_bob_users = TestModel.objects.exclude(name="Bob") + assert non_bob_users.count() == 2 + + # Test exists + assert TestModel.objects.filter(name="Alice").exists() + assert not TestModel.objects.filter(name="David").exists() + + # Clean up + TestModel.objects.all().delete() + + def test_django_data_types(self, test_environment: TestEnvironment, django_models): + """Test Django ORM with various data types""" + DataTypeModel = self.DataTypeModel + + # Ensure clean slate + DataTypeModel.objects.all().delete() + + # Create test data with various data types + test_datetime = datetime(2023, 12, 25, 14, 30, 0) + test_datetime_aware = timezone.make_aware(test_datetime) + + test_data = DataTypeModel.objects.create( + char_field="Test String", + text_field="This is a longer text field content", + integer_field=42, + big_integer_field=9223372036854775807, + decimal_field=Decimal('123.45'), + float_field=3.14159, + boolean_field=True, + date_field=date(2023, 12, 25), + time_field=time(14, 30, 0), + datetime_field=test_datetime_aware, # Use timezone-aware datetime + json_field={"key": "value", "number": 123, "array": [1, 2, 3]} + ) + + # Retrieve and verify data + retrieved = DataTypeModel.objects.get(id=test_data.id) + + assert retrieved.char_field == "Test String" + assert retrieved.text_field == "This is a longer text field content" + assert retrieved.integer_field == 42 + assert retrieved.big_integer_field == 9223372036854775807 + assert retrieved.decimal_field == Decimal('123.45') + assert abs(retrieved.float_field - 3.14159) < 0.001 + assert retrieved.boolean_field is True + assert retrieved.date_field == date(2023, 12, 25) + assert retrieved.time_field == time(14, 30, 0) + # Compare timezone-aware datetimes + assert retrieved.datetime_field == test_datetime_aware + assert retrieved.json_field == {"key": "value", "number": 123, "array": [1, 2, 3]} + + # Clean up + DataTypeModel.objects.all().delete() + + def test_django_null_values(self, test_environment: TestEnvironment, django_models): + """Test Django ORM handling of NULL values""" + DataTypeModel = self.DataTypeModel + + # First, ensure we start with a clean slate + DataTypeModel.objects.all().delete() + + # Create object with NULL values + test_obj = DataTypeModel.objects.create( + char_field=None, + integer_field=None, + date_field=None, + boolean_field=False # This field has default=False, so it won't be NULL + ) + + # Retrieve and verify NULL values + retrieved = DataTypeModel.objects.get(id=test_obj.id) + assert retrieved.char_field is None + assert retrieved.integer_field is None + assert retrieved.date_field is None + assert retrieved.boolean_field is False + + # Test filtering with NULL values + null_char_objects = DataTypeModel.objects.filter(char_field__isnull=True) + assert null_char_objects.count() == 1 + + not_null_char_objects = DataTypeModel.objects.filter(char_field__isnull=False) + assert not_null_char_objects.count() == 0 + + # Create an object with non-NULL values to test the opposite + DataTypeModel.objects.create( + char_field="Not NULL", + integer_field=42, + date_field=date(2023, 1, 1) + ) + + # Now test filtering again + null_char_objects = DataTypeModel.objects.filter(char_field__isnull=True) + assert null_char_objects.count() == 1 # Still one NULL object + + not_null_char_objects = DataTypeModel.objects.filter(char_field__isnull=False) + assert not_null_char_objects.count() == 1 # Now one non-NULL object + + # Clean up + DataTypeModel.objects.all().delete() + + def test_django_relationships(self, test_environment: TestEnvironment, django_models): + """Test Django ORM relationships (ForeignKey)""" + Author = self.Author + Book = self.Book + + # Create author + author = Author.objects.create( + name="J.K. Rowling", + email="jk@example.com", + birth_date=date(1965, 7, 31) + ) + + # Create books + book1 = Book.objects.create( + title="Harry Potter and the Philosopher's Stone", + author=author, + publication_date=date(1997, 6, 26), + pages=223, + price=Decimal('12.99') + ) + + book2 = Book.objects.create( + title="Harry Potter and the Chamber of Secrets", + author=author, + publication_date=date(1998, 7, 2), + pages=251, + price=Decimal('13.99') + ) + + # Test forward relationship + assert book1.author.name == "J.K. Rowling" + assert book2.author.email == "jk@example.com" + + # Test reverse relationship + author_books = author.books.all() + assert author_books.count() == 2 + book_titles = [book.title for book in author_books.order_by('publication_date')] + assert "Harry Potter and the Philosopher's Stone" in book_titles + assert "Harry Potter and the Chamber of Secrets" in book_titles + + # Test related queries + books_by_author = Book.objects.filter(author__name="J.K. Rowling") + assert books_by_author.count() == 2 + + # Test select_related for optimization + book_with_author = Book.objects.select_related('author').get(id=book1.id) + assert book_with_author.author.name == "J.K. Rowling" + + # Clean up + Book.objects.all().delete() + Author.objects.all().delete() + + def test_django_aggregations(self, test_environment: TestEnvironment, django_models): + """Test Django ORM aggregations""" + Author = self.Author + Book = self.Book + + # Create test data + author = Author.objects.create(name="Test Author", email="test@example.com") + + Book.objects.create(title="Book 1", author=author, publication_date=date(2020, 1, 1), pages=100, price=Decimal('10.00')) + Book.objects.create(title="Book 2", author=author, publication_date=date(2021, 1, 1), pages=200, price=Decimal('20.00')) + Book.objects.create(title="Book 3", author=author, publication_date=date(2022, 1, 1), pages=300, price=Decimal('30.00')) + + # Test aggregations + stats = Book.objects.aggregate( + total_books=Count('id'), + total_pages=Sum('pages'), + avg_price=Avg('price'), + max_pages=Max('pages'), + min_price=Min('price') + ) + + assert stats['total_books'] == 3 + assert stats['total_pages'] == 600 + assert abs(float(stats['avg_price']) - 20.0) < 0.01 + assert stats['max_pages'] == 300 + assert stats['min_price'] == Decimal('10.00') + + # Clean up + Book.objects.all().delete() + Author.objects.all().delete() + + def test_django_transactions(self, test_environment: TestEnvironment, django_models): + """Test Django transaction handling""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + initial_count = TestModel.objects.count() + + # Test successful transaction + with transaction.atomic(): + TestModel.objects.create(name="User 1", email="user1@example.com", age=25) + TestModel.objects.create(name="User 2", email="user2@example.com", age=30) + + assert TestModel.objects.count() == initial_count + 2 + + # Test rollback transaction + try: + with transaction.atomic(): + TestModel.objects.create(name="User 3", email="user3@example.com", age=35) + TestModel.objects.create(name="User 4", email="user4@example.com", age=40) + # Force an error to trigger rollback + raise Exception("Force rollback") + except Exception: + pass # Expected exception + + # Should still have only 2 additional records (rollback occurred) + assert TestModel.objects.count() == initial_count + 2 + + # Clean up + TestModel.objects.all().delete() + + def test_django_bulk_operations(self, test_environment: TestEnvironment, django_models): + """Test Django bulk operations""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Test bulk_create + test_objects = [ + TestModel(name=f"User {i}", email=f"user{i}@example.com", age=20 + i) + for i in range(10) + ] + + created_objects = TestModel.objects.bulk_create(test_objects) + assert len(created_objects) == 10 + assert TestModel.objects.count() == 10 + + # Test bulk_update - need to get the objects first and modify them + objects_to_update = list(TestModel.objects.all()) + for obj in objects_to_update: + obj.age += 5 + + TestModel.objects.bulk_update(objects_to_update, ['age']) + + # Verify updates - get fresh objects from database + ages = list(TestModel.objects.values_list('age', flat=True).order_by('name')) + expected_ages = [25 + i for i in range(10)] # 20+i+5 for i in range(10) + assert ages == expected_ages + + # Clean up + TestModel.objects.all().delete() + + def test_django_complex_queries(self, test_environment: TestEnvironment, django_models): + """Test complex Django queries with Q objects and F expressions""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) + TestModel.objects.create(name="David", email="david@example.com", age=28, is_active=True) + + # Test Q objects for complex conditions + complex_query = TestModel.objects.filter( + Q(age__gte=30) | Q(name__startswith='A') + ) + assert complex_query.count() == 3 # Bob (30), Charlie (35), Alice (starts with A) + + # Test F expressions + TestModel.objects.filter(age__lt=30).update(age=F('age') + 5) + + # Verify F expression update + alice = TestModel.objects.get(name="Alice") + david = TestModel.objects.get(name="David") + assert alice.age == 30 # 25 + 5 + assert david.age == 33 # 28 + 5 + + # Clean up, might get a failover error from this connection + TestModel.objects.all().delete() + + def test_django_raw_sql_queries(self, test_environment: TestEnvironment, django_models): + """Test Django raw SQL query execution""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) + + # Test raw() method + raw_results = TestModel.objects.raw( + f'SELECT * FROM {TestModel._meta.db_table} WHERE age >= %s ORDER BY age', + [30] + ) + raw_list = list(raw_results) + assert len(raw_list) == 2 + assert raw_list[0].name == "Bob" + assert raw_list[1].name == "Charlie" + + # Test connection.cursor() for custom SQL + with connection.cursor() as cursor: + cursor.execute( + f'SELECT name, age FROM {TestModel._meta.db_table} WHERE is_active = %s ORDER BY age', + [True] + ) + rows = cursor.fetchall() + assert len(rows) == 2 + assert rows[0][0] == "Alice" # name + assert rows[0][1] == 25 # age + assert rows[1][0] == "Charlie" + assert rows[1][1] == 35 + + # Test raw SQL with connection for aggregate + with connection.cursor() as cursor: + cursor.execute(f'SELECT COUNT(*), AVG(age) FROM {TestModel._meta.db_table}') + count, avg_age = cursor.fetchone() + assert count == 3 + assert abs(float(avg_age) - 30.0) < 0.01 + + # Clean up + TestModel.objects.all().delete() + + def test_django_get_or_create(self, test_environment: TestEnvironment, django_models): + """Test Django get_or_create pattern""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Test create case + obj1, created1 = TestModel.objects.get_or_create( + email="test@example.com", + defaults={'name': 'Test User', 'age': 25, 'is_active': True} + ) + assert created1 is True + assert obj1.name == "Test User" + assert obj1.age == 25 + + # Test get case (object already exists) + obj2, created2 = TestModel.objects.get_or_create( + email="test@example.com", + defaults={'name': 'Different Name', 'age': 30, 'is_active': False} + ) + assert created2 is False + assert obj2.id == obj1.id + assert obj2.name == "Test User" # Should keep original values + assert obj2.age == 25 + + # Verify only one object exists + assert TestModel.objects.filter(email="test@example.com").count() == 1 + + # Clean up + TestModel.objects.all().delete() + + def test_django_update_or_create(self, test_environment: TestEnvironment, django_models): + """Test Django update_or_create pattern""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Test create case + obj1, created1 = TestModel.objects.update_or_create( + email="update@example.com", + defaults={'name': 'Initial Name', 'age': 25, 'is_active': True} + ) + assert created1 is True + assert obj1.name == "Initial Name" + assert obj1.age == 25 + + # Test update case (object already exists) + obj2, created2 = TestModel.objects.update_or_create( + email="update@example.com", + defaults={'name': 'Updated Name', 'age': 30, 'is_active': False} + ) + assert created2 is False + assert obj2.id == obj1.id + assert obj2.name == "Updated Name" # Should be updated + assert obj2.age == 30 + assert obj2.is_active is False + + # Verify only one object exists + assert TestModel.objects.filter(email="update@example.com").count() == 1 + + # Verify the update persisted + retrieved = TestModel.objects.get(email="update@example.com") + assert retrieved.name == "Updated Name" + assert retrieved.age == 30 + + # Clean up + TestModel.objects.all().delete() + + def test_django_prefetch_related(self, test_environment: TestEnvironment, django_models): + """Test Django prefetch_related for optimizing queries""" + Author = self.Author + Book = self.Book + + # Create test data + author1 = Author.objects.create(name="Author 1", email="author1@example.com") + author2 = Author.objects.create(name="Author 2", email="author2@example.com") + + Book.objects.create(title="Book 1A", author=author1, publication_date=date(2020, 1, 1), pages=100, price=Decimal('10.00')) + Book.objects.create(title="Book 1B", author=author1, publication_date=date(2021, 1, 1), pages=200, price=Decimal('20.00')) + Book.objects.create(title="Book 2A", author=author2, publication_date=date(2022, 1, 1), pages=300, price=Decimal('30.00')) + + # Test prefetch_related + authors = Author.objects.prefetch_related('books').all() + + # Access related books (should not trigger additional queries due to prefetch) + for author in authors: + books = list(author.books.all()) + if author.name == "Author 1": + assert len(books) == 2 + book_titles = [book.title for book in books] + assert "Book 1A" in book_titles + assert "Book 1B" in book_titles + elif author.name == "Author 2": + assert len(books) == 1 + assert books[0].title == "Book 2A" + + # Clean up + Book.objects.all().delete() + Author.objects.all().delete() + + def test_django_database_functions(self, test_environment: TestEnvironment, django_models): + """Test Django database functions""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + TestModel.objects.create(name="alice", email="alice@example.com", age=25) + TestModel.objects.create(name="BOB", email="bob@example.com", age=30) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35) + + # Test Upper function + upper_names = TestModel.objects.annotate(upper_name=Upper('name')).values_list('upper_name', flat=True) + upper_list = list(upper_names) + assert "ALICE" in upper_list + assert "BOB" in upper_list + assert "CHARLIE" in upper_list + + # Test Lower function + lower_names = TestModel.objects.annotate(lower_name=Lower('name')).values_list('lower_name', flat=True) + lower_list = list(lower_names) + assert "alice" in lower_list + assert "bob" in lower_list + assert "charlie" in lower_list + + # Test Length function + name_lengths = TestModel.objects.annotate(name_length=Length('name')).filter(name_length__gte=5) + assert name_lengths.count() == 2 # "alice" (5) and "Charlie" (7) + + # Test Concat function + full_info = TestModel.objects.annotate( + full_info=Concat('name', Value(' - '), 'email', output_field=CharField()) + ).first() + assert ' - ' in full_info.full_info + assert '@example.com' in full_info.full_info + + # Clean up + TestModel.objects.all().delete() + + def test_django_annotations(self, test_environment: TestEnvironment, django_models): + """Test Django annotations with expressions""" + TestModel = self.TestModel + Book = self.Book + Author = self.Author + + # Create test data for TestModel + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) + + # Test annotate with F expression for calculations + test_with_age_plus_ten = TestModel.objects.annotate( + age_plus_ten=F('age') + 10 + ).order_by('age') + + # Verify calculation + first_obj = test_with_age_plus_ten.first() + assert first_obj.age_plus_ten == first_obj.age + 10 + assert first_obj.age_plus_ten == 35 # 25 + 10 + + # Create books for F expression testing + author = Author.objects.create(name="Test Author", email="test@example.com") + Book.objects.create(title="Book 1", author=author, publication_date=date(2020, 1, 1), pages=100, price=Decimal('10.00')) + Book.objects.create(title="Book 2", author=author, publication_date=date(2021, 1, 1), pages=200, price=Decimal('20.00')) + Book.objects.create(title="Book 3", author=author, publication_date=date(2022, 1, 1), pages=300, price=Decimal('30.00')) + + # Test annotate with F expression for price per page + books_with_price_per_page = Book.objects.annotate( + price_per_page=F('price') / F('pages') + ).order_by('price_per_page') + + # Verify calculation + first_book = books_with_price_per_page.first() + expected_price_per_page = float(first_book.price) / first_book.pages + assert abs(float(first_book.price_per_page) - expected_price_per_page) < 0.001 + + # Test filtering on annotated field - use a lower threshold to avoid precision issues + cheap_books = Book.objects.annotate( + price_per_page=F('price') / F('pages') + ).filter(price_per_page__lte=0.15) + assert cheap_books.count() == 3 # All books have price_per_page = 0.10 + + # Clean up + TestModel.objects.all().delete() + Book.objects.all().delete() + Author.objects.all().delete() + + def test_django_values_and_values_list(self, test_environment: TestEnvironment, django_models): + """Test Django values() and values_list() methods""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) + + # Test values() - returns list of dictionaries + values_result = TestModel.objects.values('name', 'age').order_by('age') + values_list = list(values_result) + assert len(values_list) == 3 + assert values_list[0] == {'name': 'Alice', 'age': 25} + assert values_list[1] == {'name': 'Bob', 'age': 30} + assert values_list[2] == {'name': 'Charlie', 'age': 35} + + # Test values_list() - returns list of tuples + values_list_result = TestModel.objects.values_list('name', 'age').order_by('age') + tuples_list = list(values_list_result) + assert len(tuples_list) == 3 + assert tuples_list[0] == ('Alice', 25) + assert tuples_list[1] == ('Bob', 30) + assert tuples_list[2] == ('Charlie', 35) + + # Test values_list() with flat=True - returns flat list + names = TestModel.objects.values_list('name', flat=True).order_by('name') + names_list = list(names) + assert names_list == ['Alice', 'Bob', 'Charlie'] + + # Test values() with filtering + active_users = TestModel.objects.filter(is_active=True).values('name', 'email') + active_list = list(active_users) + assert len(active_list) == 2 + active_names = [user['name'] for user in active_list] + assert 'Alice' in active_names + assert 'Charlie' in active_names + assert 'Bob' not in active_names + + # Clean up + TestModel.objects.all().delete() + + def test_django_distinct_queries(self, test_environment: TestEnvironment, django_models): + """Test Django distinct() functionality""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data with duplicate ages + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=25, is_active=True) + TestModel.objects.create(name="David", email="david@example.com", age=30, is_active=True) + + # Test distinct ages + distinct_ages = TestModel.objects.values_list('age', flat=True).distinct().order_by('age') + ages_list = list(distinct_ages) + assert ages_list == [25, 30] + + # Test distinct with multiple fields + distinct_age_status = TestModel.objects.values('age', 'is_active').distinct().order_by('age', 'is_active') + distinct_list = list(distinct_age_status) + assert len(distinct_list) == 3 # (25, True), (30, False), (30, True) + + # Test count with distinct + total_count = TestModel.objects.count() + distinct_age_count = TestModel.objects.values('age').distinct().count() + assert total_count == 4 + assert distinct_age_count == 2 + + # Clean up + TestModel.objects.all().delete() + + def test_django_only_and_defer(self, test_environment: TestEnvironment, django_models): + """Test Django only() and defer() for query optimization""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + obj = TestModel.objects.create( + name="Test User", + email="test@example.com", + age=30, + is_active=True + ) + + # Test only() - load only specific fields + obj_only = TestModel.objects.only('name', 'email').get(id=obj.id) + assert obj_only.name == "Test User" + assert obj_only.email == "test@example.com" + # Accessing deferred fields will trigger additional query, but should still work + assert obj_only.age == 30 + + # Test defer() - exclude specific fields from loading + obj_defer = TestModel.objects.defer('age', 'is_active').get(id=obj.id) + assert obj_defer.name == "Test User" + assert obj_defer.email == "test@example.com" + # Accessing deferred fields will trigger additional query, but should still work + assert obj_defer.age == 30 + + # Clean up + TestModel.objects.all().delete() + + def test_django_in_bulk(self, test_environment: TestEnvironment, django_models): + """Test Django in_bulk() for batch retrieval""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + obj1 = TestModel.objects.create(name="User 1", email="user1@example.com", age=25) + obj2 = TestModel.objects.create(name="User 2", email="user2@example.com", age=30) + obj3 = TestModel.objects.create(name="User 3", email="user3@example.com", age=35) + + # Test in_bulk with IDs (default behavior) + bulk_result = TestModel.objects.in_bulk([obj1.id, obj2.id, obj3.id]) + assert len(bulk_result) == 3 + assert bulk_result[obj1.id].name == "User 1" + assert bulk_result[obj2.id].name == "User 2" + assert bulk_result[obj3.id].name == "User 3" + + # Test in_bulk with all IDs (no list provided) + bulk_all = TestModel.objects.in_bulk() + assert len(bulk_all) == 3 + assert obj1.id in bulk_all + assert obj2.id in bulk_all + assert obj3.id in bulk_all + + # Test in_bulk with email field (unique field) + bulk_by_email = TestModel.objects.in_bulk( + ["user1@example.com", "user3@example.com"], + field_name='email' + ) + assert len(bulk_by_email) == 2 + assert bulk_by_email["user1@example.com"].name == "User 1" + assert bulk_by_email["user3@example.com"].name == "User 3" + + # Clean up + TestModel.objects.all().delete() + + def test_django_conditional_expressions(self, test_environment: TestEnvironment, django_models): + """Test Django Case/When conditional expressions""" + from django.db.models import Case, IntegerField, Value, When + + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) + + # Test Case/When for conditional logic + results = TestModel.objects.annotate( + age_category=Case( + When(age__lt=30, then=Value('young')), + When(age__gte=30, age__lt=40, then=Value('middle')), + default=Value('senior'), + output_field=CharField() + ) + ).order_by('age') + + results_list = list(results) + assert results_list[0].age_category == 'young' # Alice, 25 + assert results_list[1].age_category == 'middle' # Bob, 30 + assert results_list[2].age_category == 'middle' # Charlie, 35 + + # Test Case/When with integer output + priority_results = TestModel.objects.annotate( + priority=Case( + When(is_active=True, age__lt=30, then=Value(1)), + When(is_active=True, then=Value(2)), + When(is_active=False, then=Value(3)), + default=Value(4), + output_field=IntegerField() + ) + ).order_by('priority', 'name') + + priority_list = list(priority_results) + assert priority_list[0].name == 'Alice' # priority 1: active and young + assert priority_list[1].name == 'Charlie' # priority 2: active but not young + assert priority_list[2].name == 'Bob' # priority 3: not active + + # Clean up + TestModel.objects.all().delete() + + def test_django_iterator(self, test_environment: TestEnvironment, django_models): + """Test Django iterator() for memory-efficient queries""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + for i in range(20): + TestModel.objects.create( + name=f"User {i}", + email=f"user{i}@example.com", + age=20 + i + ) + + # Test iterator() - processes results without caching + count = 0 + for obj in TestModel.objects.iterator(): + assert obj.name.startswith("User") + count += 1 + assert count == 20 + + # Test iterator with chunk_size + count = 0 + for obj in TestModel.objects.iterator(chunk_size=5): + assert obj.email.endswith("@example.com") + count += 1 + assert count == 20 + + # Clean up + TestModel.objects.all().delete() + +''' + diff --git a/tests/unit/test_sqlalchemy_orm.py b/tests/unit/test_sqlalchemy_orm.py index ffbbcf35a..bf8468acd 100644 --- a/tests/unit/test_sqlalchemy_orm.py +++ b/tests/unit/test_sqlalchemy_orm.py @@ -19,7 +19,7 @@ class TestSqlAlchemyORM: def test_basic_workflow(self): # Step 1: Create engine (connection to database) - engine = create_engine('postgresql+aws_wrapper://pguser:pgpassword@mydb.cluster-XYZ.us-west-1.rds.amazonaws.com:5432/somedb') + engine = create_engine('mysql+aws_wrapper_mysqlconnector://mysqlmaster:mysqlpassword@database-mysql-ulojonat.cluster-cx422ywmsto6.us-east-2.rds.amazonaws.com:3306/mysqldb') # Step 2: Define base class for declarative models Base = declarative_base() From 5ee014abc1d62f881f12903fbbc8497ed5a1e27f Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Wed, 25 Mar 2026 11:44:28 -0700 Subject: [PATCH 05/39] Revert connection string in sqlalchemy orm unit test --- tests/unit/test_sqlalchemy_orm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/unit/test_sqlalchemy_orm.py b/tests/unit/test_sqlalchemy_orm.py index bf8468acd..70acb6f58 100644 --- a/tests/unit/test_sqlalchemy_orm.py +++ b/tests/unit/test_sqlalchemy_orm.py @@ -19,8 +19,7 @@ class TestSqlAlchemyORM: def test_basic_workflow(self): # Step 1: Create engine (connection to database) - engine = create_engine('mysql+aws_wrapper_mysqlconnector://mysqlmaster:mysqlpassword@database-mysql-ulojonat.cluster-cx422ywmsto6.us-east-2.rds.amazonaws.com:3306/mysqldb') - + engine = create_engine('postgresql+aws_wrapper://pguser:pgpassword@mydb.cluster-XYZ.us-west-1.rds.amazonaws.com:5432/somedb') # Step 2: Define base class for declarative models Base = declarative_base() From 4871ac706b58d2ad8014ac9f1d40060e54169916 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Thu, 26 Mar 2026 15:27:08 -0700 Subject: [PATCH 06/39] Add __init__.py for sqlalchemy integration tests --- .../container/sqlalchemy/__init__.py | 13 ++++++++++ .../sqlalchemy/test_sqlalchemy_basic.py | 24 ++++++++++--------- 2 files changed, 26 insertions(+), 11 deletions(-) create mode 100644 tests/integration/container/sqlalchemy/__init__.py diff --git a/tests/integration/container/sqlalchemy/__init__.py b/tests/integration/container/sqlalchemy/__init__.py new file mode 100644 index 000000000..bd4acb2bf --- /dev/null +++ b/tests/integration/container/sqlalchemy/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py index bea23f156..8766dbaeb 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -21,8 +21,8 @@ from typing import Any import pytest -from sqlalchemy.orm import declarative_base, Mapped, mapped_column -from sqlalchemy import Column, Integer, String +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from sqlalchemy import Column, Integer, String, Boolean, DateTime from tests.integration.container.utils.rds_test_utility import RdsTestUtility from ..utils.conditions import (disable_on_features, enable_on_deployments, @@ -52,18 +52,20 @@ def rds_utils(self): @pytest.fixture(scope='class') def sqlalchemy_models(self, sqlalchemy_setup): """Create SQLAlchemy models after SQLAlchemy is set up""" + #Base = declarative_base() - Base = declarative_base() + class Base(DeclarativeBase): + pass - class TestModel(Base): - """Basic test model for SQLAlchemy ORM functionality""" - __tablename__ = 'sqlalchemy_test_model' + class TestModel(Base): + """Basic test model for SQLAlchemy ORM functionality""" + __tablename__ = 'sqlalchemy_test_model' - name: Mapped[str] = mapped_column(String(100)) - email: Mapped[str] = mapped_column(String, primary_key=True) - age: Mapped[int] = mapped_column(Integer) - is_active: Mapped[bool] = mapped_column(Bool, server_default=True) - created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), onupdate=func.now()) + name: Mapped[str] = mapped_column(String(100)) + email: Mapped[str] = mapped_column(String, primary_key=True) + age: Mapped[int] = mapped_column(Integer) + is_active: Mapped[bool] = mapped_column(Boolean) + created_at: Mapped[datetime] = mapped_column(DateTime) ''' class DataTypeModel(models.Model): From 068e2e9bb50bb4537cd1f54450e06f00fe613820 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Thu, 26 Mar 2026 15:59:06 -0700 Subject: [PATCH 07/39] Fix RdsUtils not being found --- .../cluster_topology_monitor.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/aws_advanced_python_wrapper/cluster_topology_monitor.py b/aws_advanced_python_wrapper/cluster_topology_monitor.py index 660b48bae..173cb0014 100644 --- a/aws_advanced_python_wrapper/cluster_topology_monitor.py +++ b/aws_advanced_python_wrapper/cluster_topology_monitor.py @@ -21,6 +21,7 @@ from time import perf_counter_ns from typing import TYPE_CHECKING, Dict, Optional +from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.host_availability import HostAvailability from aws_advanced_python_wrapper.hostinfo import HostInfo, Topology from aws_advanced_python_wrapper.utils import services_container @@ -344,8 +345,8 @@ def _open_any_connection_and_update_topology(self) -> Topology: self._cluster_id, self._initial_host_info.host) try: - writer_id = self._topology_utils.get_writer_host_if_connected( - conn, self._plugin_service.driver_dialect) + writer_id = self._topology_utils.get_writer_id_if_connected( + conn, self._plugin_service.driver_dialect) if writer_id: self._is_verified_writer_connection = True writer_verified_by_this_thread = True @@ -354,9 +355,10 @@ def _open_any_connection_and_update_topology(self) -> Topology: writer_host_info = self._initial_host_info self._writer_host_info.set(writer_host_info) else: - writer_host = self._instance_template.host.replace("?", writer_id) - port = self._instance_template.port \ - if self._instance_template.is_port_specified() \ + instance_template = self._get_instance_template(writer_id, conn) + writer_host = instance_template.host.replace("?", writer_id) + port = instance_template.port \ + if instance_template.is_port_specified() \ else self._initial_host_info.port writer_host_info = HostInfo( writer_host, From f81b1db37aded6861d3b4fd3a83829b7df771cb4 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Wed, 1 Apr 2026 18:21:14 -0700 Subject: [PATCH 08/39] Translate basic django test to sqlalchemy --- pyproject.toml | 3 +- .../sqlalchemy/test_sqlalchemy_basic.py | 235 +++++++----------- .../container/utils/test_database_info.py | 2 +- .../utils/test_environment_request.py | 2 +- 4 files changed, 96 insertions(+), 146 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 80d787048..7cf7f9c61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,4 +85,5 @@ filterwarnings = [ ] [tool.poetry.plugins."sqlalchemy.dialects"] -"postgresql.aws_wrapper" = "aws_advanced_python_wrapper.sqlalchemy.orm_dialect:SqlAlchemyOrmPgDialect" +"postgresql.aws_wrapper_psycopg" = "aws_advanced_python_wrapper.sqlalchemy.pg_orm_dialect:SqlAlchemyOrmPgDialect" +"mysql.aws_wrapper_mysqlconnector" = "aws_advanced_python_wrapper.sqlalchemy.mysql_orm_dialect:SqlAlchemyOrmMysqlDialect" diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py index 8766dbaeb..ad080230a 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -21,8 +21,8 @@ from typing import Any import pytest -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column -from sqlalchemy import Column, Integer, String, Boolean, DateTime +from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy import create_engine, Column, ForeignKey, Integer, BigInteger, SmallInteger, Float, Numeric, String, Boolean, Date, Time, DateTime, Text, JSON from tests.integration.container.utils.rds_test_utility import RdsTestUtility from ..utils.conditions import (disable_on_features, enable_on_deployments, @@ -39,146 +39,94 @@ TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, TestEnvironmentFeatures.PERFORMANCE]) class TestSqlAlchemy: - TestModel: Any - DataTypeModel: Any - Author: Any - Book: Any - @pytest.fixture(scope='class') def rds_utils(self): region: str = TestEnvironment.get_current().get_info().get_region() return RdsTestUtility(region) - @pytest.fixture(scope='class') - def sqlalchemy_models(self, sqlalchemy_setup): - """Create SQLAlchemy models after SQLAlchemy is set up""" - #Base = declarative_base() - - class Base(DeclarativeBase): - pass + Base = declarative_base() class TestModel(Base): """Basic test model for SQLAlchemy ORM functionality""" __tablename__ = 'sqlalchemy_test_model' - name: Mapped[str] = mapped_column(String(100)) - email: Mapped[str] = mapped_column(String, primary_key=True) - age: Mapped[int] = mapped_column(Integer) - is_active: Mapped[bool] = mapped_column(Boolean) - created_at: Mapped[datetime] = mapped_column(DateTime) - -''' - class DataTypeModel(models.Model): - """Model for testing various data types""" - # String fields - char_field = models.CharField(max_length=255, null=True, blank=True) - text_field = models.TextField(null=True, blank=True) - - # Numeric fields - integer_field = models.IntegerField(null=True, blank=True) - big_integer_field = models.BigIntegerField(null=True, blank=True) - decimal_field = models.DecimalField(max_digits=10, decimal_places=2, null=True, blank=True) - float_field = models.FloatField(null=True, blank=True) - - # Boolean field - boolean_field = models.BooleanField(default=False) - - # Date/Time fields - date_field = models.DateField(null=True, blank=True) - time_field = models.TimeField(null=True, blank=True) - datetime_field = models.DateTimeField(null=True, blank=True) - - # JSON field (MySQL 5.7+) - json_field = models.JSONField(null=True, blank=True) - - class Meta: - app_label = 'test_app' - db_table = 'django_data_type_model' - - class Author(models.Model): - """Author model for relationship testing""" - name = models.CharField(max_length=100) - email = models.EmailField() - birth_date = models.DateField(null=True, blank=True) - - class Meta: - app_label = 'test_app' - db_table = 'django_author' - - # Store Author first so it's available for Book's ForeignKey - TestDjango.Author = Author - - class Book(models.Model): - """Book model for relationship testing""" - title = models.CharField(max_length=200) - author = models.ForeignKey(TestDjango.Author, on_delete=models.CASCADE, related_name='books') - publication_date = models.DateField() - pages = models.IntegerField() - price = models.DecimalField(max_digits=8, decimal_places=2) - - class Meta: - app_label = 'test_app' - db_table = 'django_book' - - # Store models as class attributes for easy access - TestDjango.TestModel = TestModel - TestDjango.DataTypeModel = DataTypeModel - TestDjango.Book = Book - - # Create tables for our test models - with connection.schema_editor() as schema_editor: - schema_editor.create_model(TestModel) - schema_editor.create_model(DataTypeModel) - schema_editor.create_model(Author) - schema_editor.create_model(Book) - - yield - - # Clean up tables - with connection.schema_editor() as schema_editor: - schema_editor.delete_model(Book) - schema_editor.delete_model(Author) - schema_editor.delete_model(DataTypeModel) - schema_editor.delete_model(TestModel) - - @pytest.fixture(scope='class') - def django_setup(self, conn_utils): - """Setup Django configuration for testing""" - # Configure Django settings - if not settings.configured: - db_config = { - 'ENGINE': 'aws_advanced_python_wrapper.django.backends.mysql_connector', - 'NAME': conn_utils.dbname, - 'USER': conn_utils.user, - 'PASSWORD': conn_utils.password, - 'HOST': conn_utils.writer_cluster_host, - 'PORT': conn_utils.port, - 'OPTIONS': { - 'plugins': 'failover_v2,aurora_connection_tracker', - 'connect_timeout': 10, - 'autocommit': True, - }, - } - - settings.configure( - DEBUG=True, - DATABASES={'default': db_config}, - INSTALLED_APPS=[ - 'django.contrib.contenttypes', - 'django.contrib.auth', - ], - SECRET_KEY='test-secret-key-for-django-tests', - USE_TZ=True, - ) - - django.setup() - setup_test_environment() - - yield - connections.close_all() - - teardown_test_environment() - + id = Column(Integer, primary_key=True) + + name = Column(String(100)) + email = Column(String, primary_key=True) + age = Column(Integer) + is_active = Column(Boolean) + created_at = Column(DateTime) + + class DataTypeModel(Base): + """Model for testing various data types""" + __tablename__ = 'sqlalchemy_data_type_model' + + id = Column(Integer, primary_key=True) + + # String fields + string_field = Column(String(255)) + text_field = Column(Text) + + # Numeric fields + integer_field = Column(Integer) + small_integer_field = Column(SmallInteger) + big_integer_field = Column(BigInteger) + numeric_field = Column(Numeric) + float_field = Column(Float) + + # Boolean field + boolean_field = Column(Boolean) + + # Date/Time fields + date_field = Column(Date) + time_field = Column(Time) + datetime_field = Column(DateTime) + + # JSON field (MySQL 5.7+) + json_field = Column(JSON) + + class Author(Base): + """Author model for relationship testing""" + __tablename__ = 'sqlalchemy_author' + + id = Column(Integer, primary_key=True) + name = Column(String(100)) + email = Column(String) + birth_date = Column(Date) + + class Book(Base): + """Book model for relationship testing""" + __tablename__ = 'sqlalchemy_book' + + id = Column(Integer, primary_key=True) + title = Column(String(200)) + author = Column(String, ForeignKey("Author.id")) + publication_date = Column(Date) + pages = Column(Integer) + price = Column(Numeric) + + @pytest.fixture(scope="class") + def engine(self, conn_utils): + conn_str = f'mysql+aws_wrapper_mysqlconnector://{conn_utils.user}:{conn_utils.password}@{conn_utils.writer_cluster_host}:{conn_utils.port}/{conn_utils.dbname}' + engine = create_engine(conn_str) + Base.metadata.create_all(engine) + yield engine + Base.metadata.drop_all(engine) + + @pytest.fixture(scope="class") + def Session(self, engine): + Session = sessionmaker(bind=engine) + yield Session + + @pytest.fixture(scope="class") + def session(self, Session): + session = Session() + yield session + session.rollback() + session.close() + + ''' def test_django_backend_configuration(self, test_environment: TestEnvironment, django_models): """Test Django backend configuration with empty plugins""" # Verify that the connection is using the AWS wrapper @@ -186,26 +134,25 @@ def test_django_backend_configuration(self, test_environment: TestEnvironment, d # Test basic connection functionality assert self.TestModel.objects.count() == 0 + ''' - def test_django_basic_model_operations(self, test_environment: TestEnvironment, django_models): + def test_sqlalchemy_basic_model_operations(self, session, test_environment: TestEnvironment): """Test basic Django ORM operations (CRUD)""" - TestModel = self.TestModel - - # Ensure clean slate - TestModel.objects.all().delete() # Create - test_obj = TestModel.objects.create( + test_obj = TestModel( name="John Doe", email="john@example.com", age=30, is_active=True ) + session.add(test_obj) + session.commit() assert test_obj.id is not None assert test_obj.name == "John Doe" # Read - retrieved_obj = TestModel.objects.get(id=test_obj.id) + retrieved_obj = session.query(TestModel).filter(TestModel.id == test_obj.id).first() assert retrieved_obj.name == "John Doe" assert retrieved_obj.email == "john@example.com" assert retrieved_obj.age == 30 @@ -214,16 +161,18 @@ def test_django_basic_model_operations(self, test_environment: TestEnvironment, # Update retrieved_obj.name = "Jane Doe" retrieved_obj.age = 25 - retrieved_obj.save() + session.commit() - updated_obj = TestModel.objects.get(id=test_obj.id) + updated_obj = session.query(TestModel).filter(TestModel.id == test_obj.id).first() assert updated_obj.name == "Jane Doe" assert updated_obj.age == 25 # Delete - updated_obj.delete() - assert TestModel.objects.filter(id=test_obj.id).count() == 0 + session.delete(updated_obj) + session.commit() + assert session.query(TestModel).filter(TestModel.id == test_obj.id).count() == 0 +''' def test_django_queryset_operations(self, test_environment: TestEnvironment, django_models): """Test Django QuerySet operations""" TestModel = self.TestModel diff --git a/tests/integration/container/utils/test_database_info.py b/tests/integration/container/utils/test_database_info.py index a1b3a0944..edc49faf0 100644 --- a/tests/integration/container/utils/test_database_info.py +++ b/tests/integration/container/utils/test_database_info.py @@ -42,7 +42,7 @@ def __init__(self, database_info: Dict[str, Any]) -> None: self._username = typing.cast('str', database_info.get("username")) self._password = typing.cast('str', database_info.get("password")) - self._default_db_name = typing.cast('str', database_info.get("defaultDbName")) + self._default_db_name = "mysqldb" self._cluster_endpoint = typing.cast('str', database_info.get("clusterEndpoint")) self._cluster_endpoint_port = typing.cast('int', database_info.get("clusterEndpointPort")) self._cluster_read_only_endpoint = typing.cast('str', database_info.get("clusterReadOnlyEndpoint")) diff --git a/tests/integration/container/utils/test_environment_request.py b/tests/integration/container/utils/test_environment_request.py index db1bbeef2..def700293 100644 --- a/tests/integration/container/utils/test_environment_request.py +++ b/tests/integration/container/utils/test_environment_request.py @@ -63,7 +63,7 @@ def get_features(self) -> Set[TestEnvironmentFeatures]: return self._features def get_num_of_instances(self) -> int: - return self._num_of_instances + return 3 def get_display_name(self) -> str: return "Test environment [{0}, {1}, {2}, {3}, {4}, {5}]".format( From a1ebf4c7666c1f43b2e848cd3b538d26c513ef7d Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Sat, 4 Apr 2026 16:15:12 -0700 Subject: [PATCH 09/39] Add basic CRUD test for sqlalchemy ORM mysql tests --- .../sqlalchemy/test_sqlalchemy_basic.py | 136 +++++++++--------- 1 file changed, 71 insertions(+), 65 deletions(-) diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py index ad080230a..49ebd56cd 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -16,13 +16,16 @@ from __future__ import annotations -from datetime import date, datetime, time +from datetime import date, datetime, time, timezone from decimal import Decimal from typing import Any import pytest -from sqlalchemy.orm import declarative_base, sessionmaker -from sqlalchemy import create_engine, Column, ForeignKey, Integer, BigInteger, SmallInteger, Float, Numeric, String, Boolean, Date, Time, DateTime, Text, JSON +from sqlalchemy.orm import declarative_base, sessionmaker, relationship +from sqlalchemy import ( + create_engine, Column, ForeignKey, Integer, BigInteger, SmallInteger, + Float, Numeric, String, Boolean, Date, Time, DateTime, Text, JSON +) from tests.integration.container.utils.rds_test_utility import RdsTestUtility from ..utils.conditions import (disable_on_features, enable_on_deployments, @@ -32,79 +35,83 @@ from ..utils.test_environment import TestEnvironment from ..utils.test_environment_features import TestEnvironmentFeatures +Base = declarative_base() -@enable_on_engines([DatabaseEngine.MYSQL]) # MySQL Specific until PG is implemented -@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) -@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, - TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, - TestEnvironmentFeatures.PERFORMANCE]) -class TestSqlAlchemy: - @pytest.fixture(scope='class') - def rds_utils(self): - region: str = TestEnvironment.get_current().get_info().get_region() - return RdsTestUtility(region) +class TestModel(Base): + """Basic test model for SQLAlchemy ORM functionality""" + __tablename__ = 'sqlalchemy_test_model' - Base = declarative_base() + id = Column(Integer, primary_key=True) - class TestModel(Base): - """Basic test model for SQLAlchemy ORM functionality""" - __tablename__ = 'sqlalchemy_test_model' + name = Column(String(100), nullable=False) + email = Column(String(254), nullable=False, unique=True) + age = Column(Integer, nullable=False) + is_active = Column(Boolean, default=True) + created_at = Column(DateTime, default=datetime.now(timezone.utc)) - id = Column(Integer, primary_key=True) +class DataTypeModel(Base): + """Model for testing various data types""" + __tablename__ = 'sqlalchemy_data_type_model' - name = Column(String(100)) - email = Column(String, primary_key=True) - age = Column(Integer) - is_active = Column(Boolean) - created_at = Column(DateTime) + id = Column(Integer, primary_key=True) - class DataTypeModel(Base): - """Model for testing various data types""" - __tablename__ = 'sqlalchemy_data_type_model' + # String fields + string_field = Column(String(255)) + text_field = Column(Text) - id = Column(Integer, primary_key=True) + # Numeric fields + integer_field = Column(Integer) + small_integer_field = Column(SmallInteger) + big_integer_field = Column(BigInteger) + numeric_field = Column(Numeric(10, 2)) + float_field = Column(Float) - # String fields - string_field = Column(String(255)) - text_field = Column(Text) + # Boolean field + boolean_field = Column(Boolean, default=False) - # Numeric fields - integer_field = Column(Integer) - small_integer_field = Column(SmallInteger) - big_integer_field = Column(BigInteger) - numeric_field = Column(Numeric) - float_field = Column(Float) + # Date/Time fields + date_field = Column(Date) + time_field = Column(Time) + datetime_field = Column(DateTime) - # Boolean field - boolean_field = Column(Boolean) + # JSON field (MySQL 5.7+) + json_field = Column(JSON) - # Date/Time fields - date_field = Column(Date) - time_field = Column(Time) - datetime_field = Column(DateTime) +class Author(Base): + """Author model for relationship testing""" + __tablename__ = 'sqlalchemy_author' - # JSON field (MySQL 5.7+) - json_field = Column(JSON) + id = Column(Integer, primary_key=True) + name = Column(String(100), nullable=False) + email = Column(String(254), nullable=False) + birth_date = Column(Date) - class Author(Base): - """Author model for relationship testing""" - __tablename__ = 'sqlalchemy_author' + books = relationship('Book', back_populates='author', cascade='all, delete-orphan') - id = Column(Integer, primary_key=True) - name = Column(String(100)) - email = Column(String) - birth_date = Column(Date) +class Book(Base): + """Book model for relationship testing""" + __tablename__ = 'sqlalchemy_book' - class Book(Base): - """Book model for relationship testing""" - __tablename__ = 'sqlalchemy_book' + id = Column(Integer, primary_key=True) + title = Column(String(200), nullable=False) + author_id = Column(Integer, ForeignKey("sqlalchemy_author.id"), nullable=False) + publication_date = Column(Date, nullable=False) + pages = Column(Integer, nullable=False) + price = Column(Numeric(8, 2), nullable=False) + + author = relationship('Author', back_populates='books') + +@enable_on_engines([DatabaseEngine.MYSQL]) # MySQL Specific until PG is implemented +@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestSqlAlchemy: + @pytest.fixture(scope='class') + def rds_utils(self): + region: str = TestEnvironment.get_current().get_info().get_region() + return RdsTestUtility(region) - id = Column(Integer, primary_key=True) - title = Column(String(200)) - author = Column(String, ForeignKey("Author.id")) - publication_date = Column(Date) - pages = Column(Integer) - price = Column(Numeric) @pytest.fixture(scope="class") def engine(self, conn_utils): @@ -152,7 +159,7 @@ def test_sqlalchemy_basic_model_operations(self, session, test_environment: Test assert test_obj.name == "John Doe" # Read - retrieved_obj = session.query(TestModel).filter(TestModel.id == test_obj.id).first() + retrieved_obj = session.query(TestModel).filter_by(id = test_obj.id).one() assert retrieved_obj.name == "John Doe" assert retrieved_obj.email == "john@example.com" assert retrieved_obj.age == 30 @@ -163,7 +170,7 @@ def test_sqlalchemy_basic_model_operations(self, session, test_environment: Test retrieved_obj.age = 25 session.commit() - updated_obj = session.query(TestModel).filter(TestModel.id == test_obj.id).first() + updated_obj = session.query(TestModel).filter_by(id = test_obj.id).one() assert updated_obj.name == "Jane Doe" assert updated_obj.age == 25 @@ -172,7 +179,7 @@ def test_sqlalchemy_basic_model_operations(self, session, test_environment: Test session.commit() assert session.query(TestModel).filter(TestModel.id == test_obj.id).count() == 0 -''' + ''' def test_django_queryset_operations(self, test_environment: TestEnvironment, django_models): """Test Django QuerySet operations""" TestModel = self.TestModel @@ -939,6 +946,5 @@ def test_django_iterator(self, test_environment: TestEnvironment, django_models) # Clean up TestModel.objects.all().delete() - -''' + ''' From f081a632da1590e77b8aa45731ca15269bb6097f Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Tue, 7 Apr 2026 16:41:43 -0700 Subject: [PATCH 10/39] Add remaining basic MySQL SQLAlchemy ORM tests --- .../sqlalchemy/test_sqlalchemy_basic.py | 1081 +++++++---------- 1 file changed, 470 insertions(+), 611 deletions(-) diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py index 49ebd56cd..822b92ac8 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -21,10 +21,15 @@ from typing import Any import pytest -from sqlalchemy.orm import declarative_base, sessionmaker, relationship +from sqlalchemy.sql import func +from sqlalchemy.orm import ( + declarative_base, sessionmaker, relationship, Session, joinedload, + subqueryload +) from sqlalchemy import ( create_engine, Column, ForeignKey, Integer, BigInteger, SmallInteger, - Float, Numeric, String, Boolean, Date, Time, DateTime, Text, JSON + Float, Numeric, String, Boolean, Date, Time, DateTime, Text, JSON, or_, + and_, text ) from tests.integration.container.utils.rds_test_utility import RdsTestUtility @@ -133,18 +138,18 @@ def session(self, Session): session.rollback() session.close() - ''' - def test_django_backend_configuration(self, test_environment: TestEnvironment, django_models): - """Test Django backend configuration with empty plugins""" + def test_sqlalchemy_backend_configuration(self, test_environment: TestEnvironment, engine): + """Test SQLAlchemy backend configuration with empty plugins""" # Verify that the connection is using the AWS wrapper - assert hasattr(connection, 'connection') + with engine.connect() as connection: + assert connection.connection is not None # Test basic connection functionality - assert self.TestModel.objects.count() == 0 - ''' + with Session(engine) as session: + assert session.query(TestModel).count() == 0 def test_sqlalchemy_basic_model_operations(self, session, test_environment: TestEnvironment): - """Test basic Django ORM operations (CRUD)""" + """Test basic SQLAlchemy ORM operations (CRUD)""" # Create test_obj = TestModel( @@ -179,772 +184,626 @@ def test_sqlalchemy_basic_model_operations(self, session, test_environment: Test session.commit() assert session.query(TestModel).filter(TestModel.id == test_obj.id).count() == 0 - ''' - def test_django_queryset_operations(self, test_environment: TestEnvironment, django_models): - """Test Django QuerySet operations""" - TestModel = self.TestModel - + def test_sqlalchemy_query_operations(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy query operations""" # Ensure clean slate - TestModel.objects.all().delete() - + session.query(TestModel).delete() + session.commit() # Create test data - TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) - TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) - TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) - + session.add_all([ + TestModel(name="Alice", email="alice@example.com", age=25, is_active=True), + TestModel(name="Bob", email="bob@example.com", age=30, is_active=False), + TestModel(name="Charlie", email="charlie@example.com", age=35, is_active=True), + ]) + session.commit() # Test filtering - active_users = TestModel.objects.filter(is_active=True) - assert active_users.count() == 2 - + active_users = session.query(TestModel).filter(TestModel.is_active == True).all() + assert len(active_users) == 2 # Test ordering - ordered_users = TestModel.objects.order_by('age') + ordered_users = session.query(TestModel).order_by(TestModel.age).all() ages = [user.age for user in ordered_users] assert ages == [25, 30, 35] - # Test complex queries - young_active_users = TestModel.objects.filter(age__lt=30, is_active=True) - assert young_active_users.count() == 1 - assert young_active_users.first().name == "Alice" - - # Test exclude - non_bob_users = TestModel.objects.exclude(name="Bob") - assert non_bob_users.count() == 2 - + young_active_users = session.query(TestModel).filter( + TestModel.age < 30, TestModel.is_active == True + ).all() + assert len(young_active_users) == 1 + assert young_active_users[0].name == "Alice" + # Test exclude (using NOT) + non_bob_users = session.query(TestModel).filter(TestModel.name != "Bob").all() + assert len(non_bob_users) == 2 # Test exists - assert TestModel.objects.filter(name="Alice").exists() - assert not TestModel.objects.filter(name="David").exists() - + assert session.query(TestModel).filter(TestModel.name == "Alice").first() is not None + assert session.query(TestModel).filter(TestModel.name == "David").first() is None # Clean up - TestModel.objects.all().delete() - - def test_django_data_types(self, test_environment: TestEnvironment, django_models): - """Test Django ORM with various data types""" - DataTypeModel = self.DataTypeModel + session.query(TestModel).delete() + session.commit() + def test_sqlalchemy_data_types(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy with various data types""" # Ensure clean slate - DataTypeModel.objects.all().delete() - + session.query(DataTypeModel).delete() + session.commit() # Create test data with various data types test_datetime = datetime(2023, 12, 25, 14, 30, 0) - test_datetime_aware = timezone.make_aware(test_datetime) - - test_data = DataTypeModel.objects.create( - char_field="Test String", + test_data = DataTypeModel( + string_field="Test String", text_field="This is a longer text field content", integer_field=42, + small_integer_field=5, big_integer_field=9223372036854775807, - decimal_field=Decimal('123.45'), + numeric_field=Decimal('123.45'), float_field=3.14159, boolean_field=True, date_field=date(2023, 12, 25), time_field=time(14, 30, 0), - datetime_field=test_datetime_aware, # Use timezone-aware datetime - json_field={"key": "value", "number": 123, "array": [1, 2, 3]} + datetime_field=test_datetime, + json_field={"key": "value", "number": 123, "array": [1, 2, 3]}, ) - + session.add(test_data) + session.commit() # Retrieve and verify data - retrieved = DataTypeModel.objects.get(id=test_data.id) - - assert retrieved.char_field == "Test String" + retrieved = session.query(DataTypeModel).get(test_data.id) + assert retrieved.string_field == "Test String" assert retrieved.text_field == "This is a longer text field content" assert retrieved.integer_field == 42 + assert retrieved.small_integer_field == 5 assert retrieved.big_integer_field == 9223372036854775807 - assert retrieved.decimal_field == Decimal('123.45') + assert retrieved.numeric_field == Decimal('123.45') assert abs(retrieved.float_field - 3.14159) < 0.001 assert retrieved.boolean_field is True assert retrieved.date_field == date(2023, 12, 25) assert retrieved.time_field == time(14, 30, 0) - # Compare timezone-aware datetimes - assert retrieved.datetime_field == test_datetime_aware + assert retrieved.datetime_field == test_datetime assert retrieved.json_field == {"key": "value", "number": 123, "array": [1, 2, 3]} - # Clean up - DataTypeModel.objects.all().delete() - - def test_django_null_values(self, test_environment: TestEnvironment, django_models): - """Test Django ORM handling of NULL values""" - DataTypeModel = self.DataTypeModel - - # First, ensure we start with a clean slate - DataTypeModel.objects.all().delete() + session.query(DataTypeModel).delete() + session.commit() + def test_sqlalchemy_null_values(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy handling of NULL values""" + # Ensure clean slate + session.query(DataTypeModel).delete() + session.commit() # Create object with NULL values - test_obj = DataTypeModel.objects.create( - char_field=None, + test_obj = DataTypeModel( + string_field=None, integer_field=None, date_field=None, - boolean_field=False # This field has default=False, so it won't be NULL + boolean_field=False, ) - + session.add(test_obj) + session.commit() # Retrieve and verify NULL values - retrieved = DataTypeModel.objects.get(id=test_obj.id) - assert retrieved.char_field is None + retrieved = session.query(DataTypeModel).get(test_obj.id) + assert retrieved.string_field is None assert retrieved.integer_field is None assert retrieved.date_field is None assert retrieved.boolean_field is False - # Test filtering with NULL values - null_char_objects = DataTypeModel.objects.filter(char_field__isnull=True) - assert null_char_objects.count() == 1 - - not_null_char_objects = DataTypeModel.objects.filter(char_field__isnull=False) - assert not_null_char_objects.count() == 0 - + null_char_objects = session.query(DataTypeModel).filter(DataTypeModel.string_field.is_(None)).all() + assert len(null_char_objects) == 1 + not_null_char_objects = session.query(DataTypeModel).filter(DataTypeModel.string_field.isnot(None)).all() + assert len(not_null_char_objects) == 0 # Create an object with non-NULL values to test the opposite - DataTypeModel.objects.create( - char_field="Not NULL", + session.add(DataTypeModel( + string_field="Not NULL", integer_field=42, - date_field=date(2023, 1, 1) - ) - + date_field=date(2023, 1, 1), + )) + session.commit() # Now test filtering again - null_char_objects = DataTypeModel.objects.filter(char_field__isnull=True) - assert null_char_objects.count() == 1 # Still one NULL object - - not_null_char_objects = DataTypeModel.objects.filter(char_field__isnull=False) - assert not_null_char_objects.count() == 1 # Now one non-NULL object - + null_string_objects = session.query(DataTypeModel).filter(DataTypeModel.string_field.is_(None)).all() + # Still one NULL object + assert len(null_string_objects) == 1 + not_null_string_objects = session.query(DataTypeModel).filter(DataTypeModel.string_field.isnot(None)).all() + # Now one non-NULL object + assert len(not_null_string_objects) == 1 # Clean up - DataTypeModel.objects.all().delete() - - def test_django_relationships(self, test_environment: TestEnvironment, django_models): - """Test Django ORM relationships (ForeignKey)""" - Author = self.Author - Book = self.Book + session.query(DataTypeModel).delete() + session.commit() + def test_sqlalchemy_relationships(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy relationships (ForeignKey)""" # Create author - author = Author.objects.create( + author = Author( name="J.K. Rowling", email="jk@example.com", - birth_date=date(1965, 7, 31) + birth_date=date(1965, 7, 31), ) - + session.add(author) + session.commit() # Create books - book1 = Book.objects.create( + book1 = Book( title="Harry Potter and the Philosopher's Stone", - author=author, + author_id=author.id, publication_date=date(1997, 6, 26), pages=223, - price=Decimal('12.99') + price=Decimal('12.99'), ) - - book2 = Book.objects.create( + book2 = Book( title="Harry Potter and the Chamber of Secrets", - author=author, + author_id=author.id, publication_date=date(1998, 7, 2), pages=251, - price=Decimal('13.99') + price=Decimal('13.99'), ) - + session.add_all([book1, book2]) + session.commit() # Test forward relationship assert book1.author.name == "J.K. Rowling" assert book2.author.email == "jk@example.com" - # Test reverse relationship - author_books = author.books.all() - assert author_books.count() == 2 - book_titles = [book.title for book in author_books.order_by('publication_date')] + assert len(author.books) == 2 + book_titles = [book.title for book in sorted(author.books, key=lambda b: b.publication_date)] assert "Harry Potter and the Philosopher's Stone" in book_titles assert "Harry Potter and the Chamber of Secrets" in book_titles - # Test related queries - books_by_author = Book.objects.filter(author__name="J.K. Rowling") - assert books_by_author.count() == 2 - - # Test select_related for optimization - book_with_author = Book.objects.select_related('author').get(id=book1.id) + books_by_author = session.query(Book).join(Author).filter(Author.name == "J.K. Rowling").all() + assert len(books_by_author) == 2 + # Test joinedload for optimization + book_with_author = session.query(Book).options( + joinedload(Book.author) + ).filter(Book.id == book1.id).one() assert book_with_author.author.name == "J.K. Rowling" - - # Clean up - Book.objects.all().delete() - Author.objects.all().delete() - - def test_django_aggregations(self, test_environment: TestEnvironment, django_models): - """Test Django ORM aggregations""" - Author = self.Author - Book = self.Book - - # Create test data - author = Author.objects.create(name="Test Author", email="test@example.com") - - Book.objects.create(title="Book 1", author=author, publication_date=date(2020, 1, 1), pages=100, price=Decimal('10.00')) - Book.objects.create(title="Book 2", author=author, publication_date=date(2021, 1, 1), pages=200, price=Decimal('20.00')) - Book.objects.create(title="Book 3", author=author, publication_date=date(2022, 1, 1), pages=300, price=Decimal('30.00')) - - # Test aggregations - stats = Book.objects.aggregate( - total_books=Count('id'), - total_pages=Sum('pages'), - avg_price=Avg('price'), - max_pages=Max('pages'), - min_price=Min('price') - ) - - assert stats['total_books'] == 3 - assert stats['total_pages'] == 600 - assert abs(float(stats['avg_price']) - 20.0) < 0.01 - assert stats['max_pages'] == 300 - assert stats['min_price'] == Decimal('10.00') - # Clean up - Book.objects.all().delete() - Author.objects.all().delete() - - def test_django_transactions(self, test_environment: TestEnvironment, django_models): - """Test Django transaction handling""" - TestModel = self.TestModel - - # Ensure clean slate - TestModel.objects.all().delete() + session.query(Book).delete() + session.query(Author).delete() + session.commit() - initial_count = TestModel.objects.count() + def test_sqlalchemy_aggregations(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy aggregations""" + author = Author(name="Test Author", email="test@example.com") + session.add(author) + session.flush() + books = [ + Book(title="Book 1", author_id=author.id, publication_date=date(2020, 1, 1), pages=100, price=Decimal('10.00')), + Book(title="Book 2", author_id=author.id, publication_date=date(2021, 1, 1), pages=200, price=Decimal('20.00')), + Book(title="Book 3", author_id=author.id, publication_date=date(2022, 1, 1), pages=300, price=Decimal('30.00')), + ] + session.add_all(books) + session.flush() + stats = session.query( + func.count(Book.id).label('total_books'), + func.sum(Book.pages).label('total_pages'), + func.avg(Book.price).label('avg_price'), + func.max(Book.pages).label('max_pages'), + func.min(Book.price).label('min_price'), + ).one() + assert stats.total_books == 3 + assert stats.total_pages == 600 + assert abs(float(stats.avg_price) - 20.0) < 0.01 + assert stats.max_pages == 300 + assert stats.min_price == Decimal('10.00') + session.rollback() + def test_sqlalchemy_transactions(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy transaction handling""" + session.query(TestModel).delete() + session.commit() + initial_count = session.query(TestModel).count() # Test successful transaction - with transaction.atomic(): - TestModel.objects.create(name="User 1", email="user1@example.com", age=25) - TestModel.objects.create(name="User 2", email="user2@example.com", age=30) - - assert TestModel.objects.count() == initial_count + 2 - + session.add(TestModel(name="User 1", email="user1@example.com", age=25)) + session.add(TestModel(name="User 2", email="user2@example.com", age=30)) + session.commit() + assert session.query(TestModel).count() == initial_count + 2 # Test rollback transaction try: - with transaction.atomic(): - TestModel.objects.create(name="User 3", email="user3@example.com", age=35) - TestModel.objects.create(name="User 4", email="user4@example.com", age=40) - # Force an error to trigger rollback - raise Exception("Force rollback") + session.add(TestModel(name="User 3", email="user3@example.com", age=35)) + session.add(TestModel(name="User 4", email="user4@example.com", age=40)) + session.flush() + raise Exception("Force rollback") except Exception: - pass # Expected exception - - # Should still have only 2 additional records (rollback occurred) - assert TestModel.objects.count() == initial_count + 2 - - # Clean up - TestModel.objects.all().delete() - - def test_django_bulk_operations(self, test_environment: TestEnvironment, django_models): - """Test Django bulk operations""" - TestModel = self.TestModel - - # Ensure clean slate - TestModel.objects.all().delete() + session.rollback() + assert session.query(TestModel).count() == initial_count + 2 + session.query(TestModel).delete() + session.commit() - # Test bulk_create - test_objects = [ + def test_sqlalchemy_bulk_operations(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy bulk operations""" + session.query(TestModel).delete() + session.commit() + # Test bulk insert + session.bulk_save_objects([ TestModel(name=f"User {i}", email=f"user{i}@example.com", age=20 + i) for i in range(10) - ] - - created_objects = TestModel.objects.bulk_create(test_objects) - assert len(created_objects) == 10 - assert TestModel.objects.count() == 10 - - # Test bulk_update - need to get the objects first and modify them - objects_to_update = list(TestModel.objects.all()) - for obj in objects_to_update: - obj.age += 5 - - TestModel.objects.bulk_update(objects_to_update, ['age']) - - # Verify updates - get fresh objects from database - ages = list(TestModel.objects.values_list('age', flat=True).order_by('name')) - expected_ages = [25 + i for i in range(10)] # 20+i+5 for i in range(10) + ]) + session.commit() + assert session.query(TestModel).count() == 10 + # Test bulk update + session.query(TestModel).update({TestModel.age: TestModel.age + 5}) + session.commit() + ages = [r.age for r in session.query(TestModel).order_by(TestModel.name).all()] + expected_ages = [25 + i for i in range(10)] assert ages == expected_ages + session.query(TestModel).delete() + session.commit() - # Clean up - TestModel.objects.all().delete() - - def test_django_complex_queries(self, test_environment: TestEnvironment, django_models): - """Test complex Django queries with Q objects and F expressions""" - TestModel = self.TestModel - - # Ensure clean slate - TestModel.objects.all().delete() - - # Create test data - TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) - TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) - TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) - TestModel.objects.create(name="David", email="david@example.com", age=28, is_active=True) - - # Test Q objects for complex conditions - complex_query = TestModel.objects.filter( - Q(age__gte=30) | Q(name__startswith='A') - ) - assert complex_query.count() == 3 # Bob (30), Charlie (35), Alice (starts with A) - - # Test F expressions - TestModel.objects.filter(age__lt=30).update(age=F('age') + 5) - - # Verify F expression update - alice = TestModel.objects.get(name="Alice") - david = TestModel.objects.get(name="David") - assert alice.age == 30 # 25 + 5 - assert david.age == 33 # 28 + 5 - - # Clean up, might get a failover error from this connection - TestModel.objects.all().delete() - - def test_django_raw_sql_queries(self, test_environment: TestEnvironment, django_models): - """Test Django raw SQL query execution""" - TestModel = self.TestModel - - # Ensure clean slate - TestModel.objects.all().delete() - - # Create test data - TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) - TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) - TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) - - # Test raw() method - raw_results = TestModel.objects.raw( - f'SELECT * FROM {TestModel._meta.db_table} WHERE age >= %s ORDER BY age', - [30] + def test_sqlalchemy_complex_queries(self, test_environment: TestEnvironment, session): + """Test complex SQLAlchemy queries with or_/and_ and column expressions""" + session.query(TestModel).delete() + session.commit() + session.add_all([ + TestModel(name="Alice", email="alice@example.com", age=25, is_active=True), + TestModel(name="Bob", email="bob@example.com", age=30, is_active=False), + TestModel(name="Charlie", email="charlie@example.com", age=35, is_active=True), + TestModel(name="David", email="david@example.com", age=28, is_active=True), + ]) + session.commit() + # Test or_ for complex conditions + results = session.query(TestModel).filter( + or_(TestModel.age >= 30, TestModel.name.like('A%')) + ).all() + assert len(results) == 3 + # Test column expression update (equivalent to Django's F expressions) + session.query(TestModel).filter(TestModel.age < 30).update( + {TestModel.age: TestModel.age + 5}, synchronize_session='fetch' ) - raw_list = list(raw_results) - assert len(raw_list) == 2 - assert raw_list[0].name == "Bob" - assert raw_list[1].name == "Charlie" - - # Test connection.cursor() for custom SQL - with connection.cursor() as cursor: - cursor.execute( - f'SELECT name, age FROM {TestModel._meta.db_table} WHERE is_active = %s ORDER BY age', - [True] - ) - rows = cursor.fetchall() - assert len(rows) == 2 - assert rows[0][0] == "Alice" # name - assert rows[0][1] == 25 # age - assert rows[1][0] == "Charlie" - assert rows[1][1] == 35 - - # Test raw SQL with connection for aggregate - with connection.cursor() as cursor: - cursor.execute(f'SELECT COUNT(*), AVG(age) FROM {TestModel._meta.db_table}') - count, avg_age = cursor.fetchone() - assert count == 3 - assert abs(float(avg_age) - 30.0) < 0.01 - - # Clean up - TestModel.objects.all().delete() - - def test_django_get_or_create(self, test_environment: TestEnvironment, django_models): - """Test Django get_or_create pattern""" - TestModel = self.TestModel + session.commit() + alice = session.query(TestModel).filter_by(name="Alice").one() + david = session.query(TestModel).filter_by(name="David").one() + assert alice.age == 30 + assert david.age == 33 + session.query(TestModel).delete() + session.commit() - # Ensure clean slate - TestModel.objects.all().delete() + def test_sqlalchemy_raw_sql_queries(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy raw SQL query execution""" + session.query(TestModel).delete() + session.commit() + session.add_all([ + TestModel(name="Alice", email="alice@example.com", age=25, is_active=True), + TestModel(name="Bob", email="bob@example.com", age=30, is_active=False), + TestModel(name="Charlie", email="charlie@example.com", age=35, is_active=True), + ]) + session.commit() + table = TestModel.__tablename__ + # Test raw SQL with text() + rows = session.execute( + text(f'SELECT * FROM {table} WHERE age >= :age ORDER BY age'), + {'age': 30} + ).fetchall() + assert len(rows) == 2 + # Test raw SQL for specific columns + rows = session.execute( + text(f'SELECT name, age FROM {table} WHERE is_active = :active ORDER BY age'), + {'active': True} + ).fetchall() + assert len(rows) == 2 + assert rows[0][0] == "Alice" + assert rows[0][1] == 25 + assert rows[1][0] == "Charlie" + assert rows[1][1] == 35 + # Test raw SQL aggregate + result = session.execute( + text(f'SELECT COUNT(*), AVG(age) FROM {table}') + ).fetchone() + assert result[0] == 3 + assert abs(float(result[1]) - 30.0) < 0.01 + session.query(TestModel).delete() + session.commit() + def test_sqlalchemy_get_or_create(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy get-or-create pattern""" + session.query(TestModel).delete() + session.commit() # Test create case - obj1, created1 = TestModel.objects.get_or_create( - email="test@example.com", - defaults={'name': 'Test User', 'age': 25, 'is_active': True} - ) + obj1 = session.query(TestModel).filter_by(email="test@example.com").first() + created1 = obj1 is None + if created1: + obj1 = TestModel(name="Test User", email="test@example.com", age=25, is_active=True) + session.add(obj1) + session.commit() assert created1 is True assert obj1.name == "Test User" assert obj1.age == 25 - - # Test get case (object already exists) - obj2, created2 = TestModel.objects.get_or_create( - email="test@example.com", - defaults={'name': 'Different Name', 'age': 30, 'is_active': False} - ) + # Test get case + obj2 = session.query(TestModel).filter_by(email="test@example.com").first() + created2 = obj2 is None + if created2: + obj2 = TestModel(name="Different Name", email="test@example.com", age=30, is_active=False) + session.add(obj2) + session.commit() assert created2 is False assert obj2.id == obj1.id - assert obj2.name == "Test User" # Should keep original values + assert obj2.name == "Test User" assert obj2.age == 25 + assert session.query(TestModel).filter_by(email="test@example.com").count() == 1 + session.query(TestModel).delete() + session.commit() - # Verify only one object exists - assert TestModel.objects.filter(email="test@example.com").count() == 1 - - # Clean up - TestModel.objects.all().delete() - - def test_django_update_or_create(self, test_environment: TestEnvironment, django_models): - """Test Django update_or_create pattern""" - TestModel = self.TestModel - - # Ensure clean slate - TestModel.objects.all().delete() - + def test_sqlalchemy_update_or_create(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy update-or-create pattern""" + session.query(TestModel).delete() + session.commit() # Test create case - obj1, created1 = TestModel.objects.update_or_create( - email="update@example.com", - defaults={'name': 'Initial Name', 'age': 25, 'is_active': True} - ) + obj1 = session.query(TestModel).filter_by(email="update@example.com").first() + created1 = obj1 is None + if created1: + obj1 = TestModel(name="Initial Name", email="update@example.com", age=25, is_active=True) + session.add(obj1) + session.commit() assert created1 is True assert obj1.name == "Initial Name" assert obj1.age == 25 - - # Test update case (object already exists) - obj2, created2 = TestModel.objects.update_or_create( - email="update@example.com", - defaults={'name': 'Updated Name', 'age': 30, 'is_active': False} - ) + # Test update case + obj2 = session.query(TestModel).filter_by(email="update@example.com").first() + created2 = obj2 is None + if created2: + obj2 = TestModel(name="Updated Name", email="update@example.com", age=30, is_active=False) + session.add(obj2) + else: + obj2.name = "Updated Name" + obj2.age = 30 + obj2.is_active = False + session.commit() assert created2 is False assert obj2.id == obj1.id - assert obj2.name == "Updated Name" # Should be updated + assert obj2.name == "Updated Name" assert obj2.age == 30 assert obj2.is_active is False - - # Verify only one object exists - assert TestModel.objects.filter(email="update@example.com").count() == 1 - - # Verify the update persisted - retrieved = TestModel.objects.get(email="update@example.com") + assert session.query(TestModel).filter_by(email="update@example.com").count() == 1 + retrieved = session.query(TestModel).filter_by(email="update@example.com").one() assert retrieved.name == "Updated Name" assert retrieved.age == 30 + session.query(TestModel).delete() + session.commit() - # Clean up - TestModel.objects.all().delete() - - def test_django_prefetch_related(self, test_environment: TestEnvironment, django_models): - """Test Django prefetch_related for optimizing queries""" - Author = self.Author - Book = self.Book - - # Create test data - author1 = Author.objects.create(name="Author 1", email="author1@example.com") - author2 = Author.objects.create(name="Author 2", email="author2@example.com") - - Book.objects.create(title="Book 1A", author=author1, publication_date=date(2020, 1, 1), pages=100, price=Decimal('10.00')) - Book.objects.create(title="Book 1B", author=author1, publication_date=date(2021, 1, 1), pages=200, price=Decimal('20.00')) - Book.objects.create(title="Book 2A", author=author2, publication_date=date(2022, 1, 1), pages=300, price=Decimal('30.00')) - - # Test prefetch_related - authors = Author.objects.prefetch_related('books').all() - - # Access related books (should not trigger additional queries due to prefetch) + def test_sqlalchemy_eager_loading(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy eager loading for optimizing queries""" + author1 = Author(name="Author 1", email="author1@example.com") + author2 = Author(name="Author 2", email="author2@example.com") + session.add_all([author1, author2]) + session.flush() + session.add_all([ + Book(title="Book 1A", author_id=author1.id, publication_date=date(2020, 1, 1), pages=100, price=Decimal('10.00')), + Book(title="Book 1B", author_id=author1.id, publication_date=date(2021, 1, 1), pages=200, price=Decimal('20.00')), + Book(title="Book 2A", author_id=author2.id, publication_date=date(2022, 1, 1), pages=300, price=Decimal('30.00')), + ]) + session.commit() + # Test subqueryload (equivalent to Django's prefetch_related) + authors = session.query(Author).options(subqueryload(Author.books)).all() for author in authors: - books = list(author.books.all()) if author.name == "Author 1": - assert len(books) == 2 - book_titles = [book.title for book in books] - assert "Book 1A" in book_titles - assert "Book 1B" in book_titles + assert len(author.books) == 2 + titles = [b.title for b in author.books] + assert "Book 1A" in titles + assert "Book 1B" in titles elif author.name == "Author 2": - assert len(books) == 1 - assert books[0].title == "Book 2A" - - # Clean up - Book.objects.all().delete() - Author.objects.all().delete() - - def test_django_database_functions(self, test_environment: TestEnvironment, django_models): - """Test Django database functions""" - TestModel = self.TestModel - - # Ensure clean slate - TestModel.objects.all().delete() + assert len(author.books) == 1 + assert author.books[0].title == "Book 2A" + session.rollback() - # Create test data - TestModel.objects.create(name="alice", email="alice@example.com", age=25) - TestModel.objects.create(name="BOB", email="bob@example.com", age=30) - TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35) - - # Test Upper function - upper_names = TestModel.objects.annotate(upper_name=Upper('name')).values_list('upper_name', flat=True) - upper_list = list(upper_names) - assert "ALICE" in upper_list - assert "BOB" in upper_list - assert "CHARLIE" in upper_list - - # Test Lower function - lower_names = TestModel.objects.annotate(lower_name=Lower('name')).values_list('lower_name', flat=True) - lower_list = list(lower_names) - assert "alice" in lower_list - assert "bob" in lower_list - assert "charlie" in lower_list - - # Test Length function - name_lengths = TestModel.objects.annotate(name_length=Length('name')).filter(name_length__gte=5) - assert name_lengths.count() == 2 # "alice" (5) and "Charlie" (7) - - # Test Concat function - full_info = TestModel.objects.annotate( - full_info=Concat('name', Value(' - '), 'email', output_field=CharField()) + def test_sqlalchemy_database_functions(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy database functions""" + session.query(TestModel).delete() + session.commit() + session.add_all([ + TestModel(name="alice", email="alice@example.com", age=25), + TestModel(name="BOB", email="bob@example.com", age=30), + TestModel(name="Charlie", email="charlie@example.com", age=35), + ]) + session.commit() + # Test upper + upper_names = [r[0] for r in session.query(func.upper(TestModel.name)).all()] + assert "ALICE" in upper_names + assert "BOB" in upper_names + assert "CHARLIE" in upper_names + # Test lower + lower_names = [r[0] for r in session.query(func.lower(TestModel.name)).all()] + assert "alice" in lower_names + assert "bob" in lower_names + assert "charlie" in lower_names + # Test length + results = session.query(TestModel).filter(func.length(TestModel.name) >= 5).all() + assert len(results) == 2 # "alice" (5) and "Charlie" (7) + # Test concat + result = session.query( + func.concat(TestModel.name, ' - ', TestModel.email) ).first() - assert ' - ' in full_info.full_info - assert '@example.com' in full_info.full_info - - # Clean up - TestModel.objects.all().delete() - - def test_django_annotations(self, test_environment: TestEnvironment, django_models): - """Test Django annotations with expressions""" - TestModel = self.TestModel - Book = self.Book - Author = self.Author - - # Create test data for TestModel - TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) - TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) - TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) - - # Test annotate with F expression for calculations - test_with_age_plus_ten = TestModel.objects.annotate( - age_plus_ten=F('age') + 10 - ).order_by('age') - - # Verify calculation - first_obj = test_with_age_plus_ten.first() - assert first_obj.age_plus_ten == first_obj.age + 10 - assert first_obj.age_plus_ten == 35 # 25 + 10 - - # Create books for F expression testing - author = Author.objects.create(name="Test Author", email="test@example.com") - Book.objects.create(title="Book 1", author=author, publication_date=date(2020, 1, 1), pages=100, price=Decimal('10.00')) - Book.objects.create(title="Book 2", author=author, publication_date=date(2021, 1, 1), pages=200, price=Decimal('20.00')) - Book.objects.create(title="Book 3", author=author, publication_date=date(2022, 1, 1), pages=300, price=Decimal('30.00')) - - # Test annotate with F expression for price per page - books_with_price_per_page = Book.objects.annotate( - price_per_page=F('price') / F('pages') - ).order_by('price_per_page') - - # Verify calculation - first_book = books_with_price_per_page.first() - expected_price_per_page = float(first_book.price) / first_book.pages - assert abs(float(first_book.price_per_page) - expected_price_per_page) < 0.001 - - # Test filtering on annotated field - use a lower threshold to avoid precision issues - cheap_books = Book.objects.annotate( - price_per_page=F('price') / F('pages') - ).filter(price_per_page__lte=0.15) - assert cheap_books.count() == 3 # All books have price_per_page = 0.10 - - # Clean up - TestModel.objects.all().delete() - Book.objects.all().delete() - Author.objects.all().delete() - - def test_django_values_and_values_list(self, test_environment: TestEnvironment, django_models): - """Test Django values() and values_list() methods""" - TestModel = self.TestModel + assert ' - ' in result[0] + assert '@example.com' in result[0] + session.query(TestModel).delete() + session.commit() + def test_sqlalchemy_values_and_values_list(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy equivalents of Django's values() and values_list() functions""" # Ensure clean slate - TestModel.objects.all().delete() - + session.query(TestModel).delete() + session.commit() # Create test data - TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) - TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) - TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) - - # Test values() - returns list of dictionaries - values_result = TestModel.objects.values('name', 'age').order_by('age') - values_list = list(values_result) - assert len(values_list) == 3 - assert values_list[0] == {'name': 'Alice', 'age': 25} - assert values_list[1] == {'name': 'Bob', 'age': 30} - assert values_list[2] == {'name': 'Charlie', 'age': 35} - - # Test values_list() - returns list of tuples - values_list_result = TestModel.objects.values_list('name', 'age').order_by('age') - tuples_list = list(values_list_result) - assert len(tuples_list) == 3 - assert tuples_list[0] == ('Alice', 25) - assert tuples_list[1] == ('Bob', 30) - assert tuples_list[2] == ('Charlie', 35) - - # Test values_list() with flat=True - returns flat list - names = TestModel.objects.values_list('name', flat=True).order_by('name') - names_list = list(names) - assert names_list == ['Alice', 'Bob', 'Charlie'] - - # Test values() with filtering - active_users = TestModel.objects.filter(is_active=True).values('name', 'email') - active_list = list(active_users) - assert len(active_list) == 2 - active_names = [user['name'] for user in active_list] + session.add_all([ + TestModel(name="Alice", email="alice@example.com", age=25, is_active=True), + TestModel(name="Bob", email="bob@example.com", age=30, is_active=False), + TestModel(name="Charlie", email="charlie@example.com", age=35, is_active=True), + ]) + session.commit() + # Convert values to dicts (equivalent to Django's values()) + values_result = session.query(TestModel.name, TestModel.age).order_by(TestModel.age).all() + assert len(values_result) == 3 + assert values_result[0] == ('Alice', 25) + assert values_result[1] == ('Bob', 30) + assert values_result[2] == ('Charlie', 35) + values_dicts = [{'name': r.name, 'age': r.age} for r in values_result] + assert values_dicts[0] == {'name': 'Alice', 'age': 25} + assert values_dicts[1] == {'name': 'Bob', 'age': 30} + assert values_dicts[2] == {'name': 'Charlie', 'age': 35} + # Test flat list (equivalent to Django's values_list with flat=True) + names = [r[0] for r in session.query(TestModel.name).order_by(TestModel.name).all()] + assert names == ['Alice', 'Bob', 'Charlie'] + # Test with filtering + active_users = session.query(TestModel.name, TestModel.email).filter( + TestModel.is_active == True + ).all() + assert len(active_users) == 2 + active_names = [r.name for r in active_users] assert 'Alice' in active_names assert 'Charlie' in active_names assert 'Bob' not in active_names - # Clean up - TestModel.objects.all().delete() - - def test_django_distinct_queries(self, test_environment: TestEnvironment, django_models): - """Test Django distinct() functionality""" - TestModel = self.TestModel + session.query(TestModel).delete() + session.commit() + def test_sqlalchemy_distinct_queries(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy distinct() functionality""" # Ensure clean slate - TestModel.objects.all().delete() - + session.query(TestModel).delete() + session.commit() # Create test data with duplicate ages - TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) - TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) - TestModel.objects.create(name="Charlie", email="charlie@example.com", age=25, is_active=True) - TestModel.objects.create(name="David", email="david@example.com", age=30, is_active=True) - + session.add_all([ + TestModel(name="Alice", email="alice@example.com", age=25, is_active=True), + TestModel(name="Bob", email="bob@example.com", age=30, is_active=False), + TestModel(name="Charlie", email="charlie@example.com", age=25, is_active=True), + TestModel(name="David", email="david@example.com", age=30, is_active=True), + ]) + session.commit() # Test distinct ages - distinct_ages = TestModel.objects.values_list('age', flat=True).distinct().order_by('age') - ages_list = list(distinct_ages) + ages_list = [r[0] for r in session.query(TestModel.age).distinct().order_by(TestModel.age).all()] assert ages_list == [25, 30] - # Test distinct with multiple fields - distinct_age_status = TestModel.objects.values('age', 'is_active').distinct().order_by('age', 'is_active') - distinct_list = list(distinct_age_status) + distinct_list = session.query(TestModel.age, TestModel.is_active).distinct().order_by( + TestModel.age, TestModel.is_active + ).all() assert len(distinct_list) == 3 # (25, True), (30, False), (30, True) - # Test count with distinct - total_count = TestModel.objects.count() - distinct_age_count = TestModel.objects.values('age').distinct().count() + total_count = session.query(TestModel).count() + distinct_age_count = session.query(TestModel.age).distinct().count() assert total_count == 4 assert distinct_age_count == 2 - # Clean up - TestModel.objects.all().delete() - - def test_django_only_and_defer(self, test_environment: TestEnvironment, django_models): - """Test Django only() and defer() for query optimization""" - TestModel = self.TestModel + session.query(TestModel).delete() + session.commit() + def test_sqlalchemy_load_only_and_defer(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy load_only() and defer() for query optimization""" + from sqlalchemy.orm import defer, load_only # Ensure clean slate - TestModel.objects.all().delete() - + session.query(TestModel).delete() + session.commit() # Create test data - obj = TestModel.objects.create( - name="Test User", - email="test@example.com", - age=30, - is_active=True - ) - - # Test only() - load only specific fields - obj_only = TestModel.objects.only('name', 'email').get(id=obj.id) + obj = TestModel(name="Test User", email="test@example.com", age=30, is_active=True) + session.add(obj) + session.commit() + obj_id = obj.id + session.expire_all() + # Test load_only() - load only specific fields + obj_only = session.query(TestModel).options( + load_only(TestModel.name, TestModel.email) + ).get(obj_id) assert obj_only.name == "Test User" assert obj_only.email == "test@example.com" - # Accessing deferred fields will trigger additional query, but should still work assert obj_only.age == 30 - + session.expire_all() # Test defer() - exclude specific fields from loading - obj_defer = TestModel.objects.defer('age', 'is_active').get(id=obj.id) + obj_defer = session.query(TestModel).options( + defer(TestModel.age), defer(TestModel.is_active) + ).get(obj_id) assert obj_defer.name == "Test User" assert obj_defer.email == "test@example.com" - # Accessing deferred fields will trigger additional query, but should still work assert obj_defer.age == 30 - # Clean up - TestModel.objects.all().delete() - - def test_django_in_bulk(self, test_environment: TestEnvironment, django_models): - """Test Django in_bulk() for batch retrieval""" - TestModel = self.TestModel + session.query(TestModel).delete() + session.commit() + def test_sqlalchemy_batch_retrieval(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy batch retrieval (equivalent to Django's in_bulk)""" # Ensure clean slate - TestModel.objects.all().delete() - + session.query(TestModel).delete() + session.commit() # Create test data - obj1 = TestModel.objects.create(name="User 1", email="user1@example.com", age=25) - obj2 = TestModel.objects.create(name="User 2", email="user2@example.com", age=30) - obj3 = TestModel.objects.create(name="User 3", email="user3@example.com", age=35) - - # Test in_bulk with IDs (default behavior) - bulk_result = TestModel.objects.in_bulk([obj1.id, obj2.id, obj3.id]) + obj1 = TestModel(name="User 1", email="user1@example.com", age=25) + obj2 = TestModel(name="User 2", email="user2@example.com", age=30) + obj3 = TestModel(name="User 3", email="user3@example.com", age=35) + session.add_all([obj1, obj2, obj3]) + session.commit() + # Test bulk retrieval by IDs + ids = [obj1.id, obj2.id, obj3.id] + bulk_result = {o.id: o for o in session.query(TestModel).filter(TestModel.id.in_(ids)).all()} assert len(bulk_result) == 3 assert bulk_result[obj1.id].name == "User 1" assert bulk_result[obj2.id].name == "User 2" assert bulk_result[obj3.id].name == "User 3" - - # Test in_bulk with all IDs (no list provided) - bulk_all = TestModel.objects.in_bulk() + # Test bulk retrieval of all + bulk_all = {o.id: o for o in session.query(TestModel).all()} assert len(bulk_all) == 3 assert obj1.id in bulk_all assert obj2.id in bulk_all assert obj3.id in bulk_all - - # Test in_bulk with email field (unique field) - bulk_by_email = TestModel.objects.in_bulk( - ["user1@example.com", "user3@example.com"], - field_name='email' - ) + # Test bulk retrieval by email field + emails = ["user1@example.com", "user3@example.com"] + bulk_by_email = { + o.email: o for o in session.query(TestModel).filter(TestModel.email.in_(emails)).all() + } assert len(bulk_by_email) == 2 assert bulk_by_email["user1@example.com"].name == "User 1" assert bulk_by_email["user3@example.com"].name == "User 3" - # Clean up - TestModel.objects.all().delete() - - def test_django_conditional_expressions(self, test_environment: TestEnvironment, django_models): - """Test Django Case/When conditional expressions""" - from django.db.models import Case, IntegerField, Value, When - - TestModel = self.TestModel + session.query(TestModel).delete() + session.commit() + def test_sqlalchemy_conditional_expressions(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy case() conditional expressions""" + from sqlalchemy import String, case # Ensure clean slate - TestModel.objects.all().delete() - + session.query(TestModel).delete() + session.commit() # Create test data - TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) - TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) - TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) - - # Test Case/When for conditional logic - results = TestModel.objects.annotate( - age_category=Case( - When(age__lt=30, then=Value('young')), - When(age__gte=30, age__lt=40, then=Value('middle')), - default=Value('senior'), - output_field=CharField() - ) - ).order_by('age') - - results_list = list(results) - assert results_list[0].age_category == 'young' # Alice, 25 - assert results_list[1].age_category == 'middle' # Bob, 30 - assert results_list[2].age_category == 'middle' # Charlie, 35 - - # Test Case/When with integer output - priority_results = TestModel.objects.annotate( - priority=Case( - When(is_active=True, age__lt=30, then=Value(1)), - When(is_active=True, then=Value(2)), - When(is_active=False, then=Value(3)), - default=Value(4), - output_field=IntegerField() - ) - ).order_by('priority', 'name') - - priority_list = list(priority_results) - assert priority_list[0].name == 'Alice' # priority 1: active and young - assert priority_list[1].name == 'Charlie' # priority 2: active but not young - assert priority_list[2].name == 'Bob' # priority 3: not active - + session.add_all([ + TestModel(name="Alice", email="alice@example.com", age=25, is_active=True), + TestModel(name="Bob", email="bob@example.com", age=30, is_active=False), + TestModel(name="Charlie", email="charlie@example.com", age=35, is_active=True), + ]) + session.commit() + # Test case() for conditional logic + age_category = case( + (TestModel.age < 30, 'young'), + (TestModel.age.between(30, 39), 'middle'), + else_='senior' + ).label('age_category') + results = session.query(TestModel, age_category).order_by(TestModel.age).all() + assert results[0].age_category == 'young' # Alice, 25 + assert results[1].age_category == 'middle' # Bob, 30 + assert results[2].age_category == 'middle' # Charlie, 35 + # Test case() with integer output + from sqlalchemy import Integer + priority = case( + (and_(TestModel.is_active == True, TestModel.age < 30), 1), + (TestModel.is_active == True, 2), + (TestModel.is_active == False, 3), + else_=4 + ).label('priority') + results = session.query(TestModel, priority).order_by('priority', TestModel.name).all() + assert results[0].TestModel.name == 'Alice' # priority 1 + assert results[1].TestModel.name == 'Charlie' # priority 2 + assert results[2].TestModel.name == 'Bob' # priority 3 # Clean up - TestModel.objects.all().delete() - - def test_django_iterator(self, test_environment: TestEnvironment, django_models): - """Test Django iterator() for memory-efficient queries""" - TestModel = self.TestModel + session.query(TestModel).delete() + session.commit() + def test_sqlalchemy_yield_per(self, test_environment: TestEnvironment, session): + """Test SQLAlchemy yield_per() for memory-efficient queries""" # Ensure clean slate - TestModel.objects.all().delete() - + session.query(TestModel).delete() + session.commit() # Create test data - for i in range(20): - TestModel.objects.create( - name=f"User {i}", - email=f"user{i}@example.com", - age=20 + i - ) - - # Test iterator() - processes results without caching + session.add_all([ + TestModel(name=f"User {i}", email=f"user{i}@example.com", age=20 + i) + for i in range(20) + ]) + session.commit() + # Test yield_per() - processes results without caching all at once count = 0 - for obj in TestModel.objects.iterator(): + for obj in session.query(TestModel).yield_per(100): assert obj.name.startswith("User") count += 1 assert count == 20 - - # Test iterator with chunk_size + # Test yield_per with smaller chunk size count = 0 - for obj in TestModel.objects.iterator(chunk_size=5): + for obj in session.query(TestModel).yield_per(5): assert obj.email.endswith("@example.com") count += 1 assert count == 20 - # Clean up - TestModel.objects.all().delete() - ''' + session.query(TestModel).delete() + session.commit() From 1b8d401c0c3ecd00c02fb6a5fce25c096b88a2b0 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Tue, 7 Apr 2026 16:58:33 -0700 Subject: [PATCH 11/39] Remove temporary changes to get tests to run locally --- .../container/utils/test_database_info.py | 2 +- .../utils/test_environment_request.py | 2 +- tests/unit/test_sqlalchemy_orm.py | 61 ------------------- 3 files changed, 2 insertions(+), 63 deletions(-) delete mode 100644 tests/unit/test_sqlalchemy_orm.py diff --git a/tests/integration/container/utils/test_database_info.py b/tests/integration/container/utils/test_database_info.py index edc49faf0..a1b3a0944 100644 --- a/tests/integration/container/utils/test_database_info.py +++ b/tests/integration/container/utils/test_database_info.py @@ -42,7 +42,7 @@ def __init__(self, database_info: Dict[str, Any]) -> None: self._username = typing.cast('str', database_info.get("username")) self._password = typing.cast('str', database_info.get("password")) - self._default_db_name = "mysqldb" + self._default_db_name = typing.cast('str', database_info.get("defaultDbName")) self._cluster_endpoint = typing.cast('str', database_info.get("clusterEndpoint")) self._cluster_endpoint_port = typing.cast('int', database_info.get("clusterEndpointPort")) self._cluster_read_only_endpoint = typing.cast('str', database_info.get("clusterReadOnlyEndpoint")) diff --git a/tests/integration/container/utils/test_environment_request.py b/tests/integration/container/utils/test_environment_request.py index def700293..db1bbeef2 100644 --- a/tests/integration/container/utils/test_environment_request.py +++ b/tests/integration/container/utils/test_environment_request.py @@ -63,7 +63,7 @@ def get_features(self) -> Set[TestEnvironmentFeatures]: return self._features def get_num_of_instances(self) -> int: - return 3 + return self._num_of_instances def get_display_name(self) -> str: return "Test environment [{0}, {1}, {2}, {3}, {4}, {5}]".format( diff --git a/tests/unit/test_sqlalchemy_orm.py b/tests/unit/test_sqlalchemy_orm.py deleted file mode 100644 index 70acb6f58..000000000 --- a/tests/unit/test_sqlalchemy_orm.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from sqlalchemy import create_engine, Column, Integer, String -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker - -class TestSqlAlchemyORM: - def test_basic_workflow(self): - # Step 1: Create engine (connection to database) - engine = create_engine('postgresql+aws_wrapper://pguser:pgpassword@mydb.cluster-XYZ.us-west-1.rds.amazonaws.com:5432/somedb') - # Step 2: Define base class for declarative models - Base = declarative_base() - - # Step 3: Define model class (separate from database operations) - class User(Base): - __tablename__ = 'users' - - id = Column(Integer, primary_key=True) - name = Column(String(50)) - email = Column(String(100)) - - # Step 4: Create tables - Base.metadata.create_all(engine) - - # Step 5: Create session factory - Session = sessionmaker(bind=engine) - - # Step 6: Use session for database operations - with Session() as session: - # INSERT - Create new object and add to session - new_user = User(name='John Doe', email='john@example.com') - session.add(new_user) - session.commit() # Explicit commit required - - # SELECT - Query using session - users = session.query(User).filter(User.name == 'John Doe').all() - for user in users: - print(f"{user.name}: {user.email}") - - - # UPDATE - Modify object and commit - user = session.query(User).filter(User.name == "John Doe").first() - user.email = 'newemail@example.com' - session.commit() - - # DELETE - Remove object from session - user_to_delete = session.query(User).filter(User.name == "John Doe").first() - session.delete(user_to_delete) - session.commit() From 4429afdd3c3d11663af45bfdd7e6eb31f7d7847b Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Tue, 7 Apr 2026 17:01:28 -0700 Subject: [PATCH 12/39] Add license headers and remove unused import --- .../sqlalchemy/mysql_orm_dialect.py | 15 ++++++++++++++- .../sqlalchemy/pg_orm_dialect.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index 6d7ff34db..fde407e5c 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -1,5 +1,18 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_mysqlconnector_dialect.py -from psycopg import Connection from sqlalchemy.dialects.mysql.mysqlconnector import MySQLDialect_mysqlconnector import re diff --git a/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py index c2780b861..d792ce501 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py @@ -1,3 +1,17 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_psycopg_dialect.py from psycopg import Connection from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg From 4f6500cb72450904b76214dcd47137a78deff2e3 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Thu, 9 Apr 2026 11:14:13 -0700 Subject: [PATCH 13/39] Try fixing mypy errors in tests --- .../integration/container/sqlalchemy/test_sqlalchemy_basic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py index 822b92ac8..9f099febd 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -688,7 +688,7 @@ def test_sqlalchemy_load_only_and_defer(self, test_environment: TestEnvironment, session.expire_all() # Test load_only() - load only specific fields obj_only = session.query(TestModel).options( - load_only(TestModel.name, TestModel.email) + load_only('TestModel.name', 'TestModel.email') ).get(obj_id) assert obj_only.name == "Test User" assert obj_only.email == "test@example.com" @@ -696,7 +696,7 @@ def test_sqlalchemy_load_only_and_defer(self, test_environment: TestEnvironment, session.expire_all() # Test defer() - exclude specific fields from loading obj_defer = session.query(TestModel).options( - defer(TestModel.age), defer(TestModel.is_active) + defer('TestModel.age'), defer('TestModel.is_active') ).get(obj_id) assert obj_defer.name == "Test User" assert obj_defer.email == "test@example.com" From 9bd824f3553183ea90398eae00dc06857e821d99 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Mon, 13 Apr 2026 13:22:35 -0700 Subject: [PATCH 14/39] Try to fix mypy Base class error --- .../container/sqlalchemy/test_sqlalchemy_basic.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py index 9f099febd..46b2e8639 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -40,7 +40,10 @@ from ..utils.test_environment import TestEnvironment from ..utils.test_environment_features import TestEnvironmentFeatures -Base = declarative_base() +class Base: + __allow_unmapped__ = True + +Base = declarative_base(cls=Base) class TestModel(Base): """Basic test model for SQLAlchemy ORM functionality""" @@ -688,7 +691,7 @@ def test_sqlalchemy_load_only_and_defer(self, test_environment: TestEnvironment, session.expire_all() # Test load_only() - load only specific fields obj_only = session.query(TestModel).options( - load_only('TestModel.name', 'TestModel.email') + load_only(TestModel.name, TestModel.email) ).get(obj_id) assert obj_only.name == "Test User" assert obj_only.email == "test@example.com" @@ -696,7 +699,7 @@ def test_sqlalchemy_load_only_and_defer(self, test_environment: TestEnvironment, session.expire_all() # Test defer() - exclude specific fields from loading obj_defer = session.query(TestModel).options( - defer('TestModel.age'), defer('TestModel.is_active') + defer(TestModel.age), defer(TestModel.is_active) ).get(obj_id) assert obj_defer.name == "Test User" assert obj_defer.email == "test@example.com" From ee53f8c5261381f49962c1bc050d21a3c7cab22b Mon Sep 17 00:00:00 2001 From: Karen <64801825+karenc-bq@users.noreply.github.com> Date: Mon, 13 Apr 2026 15:02:35 -0700 Subject: [PATCH 15/39] fix: test failures due to using legacy sqlalchemy api (#1226) --- aws_advanced_python_wrapper/__init__.py | 18 ++-- .../sqlalchemy/mysql_orm_dialect.py | 7 +- .../sqlalchemy/pg_orm_dialect.py | 3 +- .../sqlalchemy/test_sqlalchemy_basic.py | 89 ++++++++++--------- 4 files changed, 56 insertions(+), 61 deletions(-) diff --git a/aws_advanced_python_wrapper/__init__.py b/aws_advanced_python_wrapper/__init__.py index 7d4dbf38a..459a67fab 100644 --- a/aws_advanced_python_wrapper/__init__.py +++ b/aws_advanced_python_wrapper/__init__.py @@ -14,21 +14,16 @@ from logging import DEBUG, getLogger +from aws_advanced_python_wrapper.pep249 import (DatabaseError, DataError, + Error, IntegrityError, + InterfaceError, InternalError, + NotSupportedError, + OperationalError, + ProgrammingError) from .cleanup import release_resources from .driver_info import DriverInfo from .utils.utils import LogUtils from .wrapper import AwsWrapperConnection -from aws_advanced_python_wrapper.pep249 import ( - Error, - InterfaceError, - DatabaseError, - DataError, - OperationalError, - IntegrityError, - InternalError, - ProgrammingError, - NotSupportedError -) # PEP249 compliance connect = AwsWrapperConnection.connect @@ -58,5 +53,6 @@ __version__ = DriverInfo.DRIVER_VERSION + def set_logger(name='aws_advanced_python_wrapper', level=DEBUG, format_string=None): LogUtils.setup_logger(getLogger(name), level, format_string) diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index fde407e5c..dc8da6394 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -13,10 +13,8 @@ # limitations under the License. # aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_mysqlconnector_dialect.py -from sqlalchemy.dialects.mysql.mysqlconnector import MySQLDialect_mysqlconnector -import re - -from aws_advanced_python_wrapper import AwsWrapperConnection +from sqlalchemy.dialects.mysql.mysqlconnector import \ + MySQLDialect_mysqlconnector class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector): @@ -29,4 +27,3 @@ class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector): name = 'mysql' driver = 'aws_wrapper_mysqlconnector' - diff --git a/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py index d792ce501..c066520bc 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/pg_orm_dialect.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re + # aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_psycopg_dialect.py from psycopg import Connection from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg -import re from aws_advanced_python_wrapper import AwsWrapperConnection diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py index 46b2e8639..82427f95b 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -18,19 +18,16 @@ from datetime import date, datetime, time, timezone from decimal import Decimal -from typing import Any +from typing import Any, List, Optional import pytest +from sqlalchemy import (JSON, BigInteger, Boolean, Date, DateTime, Float, + ForeignKey, Numeric, SmallInteger, String, Text, Time, + and_, create_engine, or_, text) +from sqlalchemy.orm import (DeclarativeBase, Mapped, Session, joinedload, + mapped_column, relationship, sessionmaker, + subqueryload) from sqlalchemy.sql import func -from sqlalchemy.orm import ( - declarative_base, sessionmaker, relationship, Session, joinedload, - subqueryload -) -from sqlalchemy import ( - create_engine, Column, ForeignKey, Integer, BigInteger, SmallInteger, - Float, Numeric, String, Boolean, Date, Time, DateTime, Text, JSON, or_, - and_, text -) from tests.integration.container.utils.rds_test_utility import RdsTestUtility from ..utils.conditions import (disable_on_features, enable_on_deployments, @@ -40,74 +37,76 @@ from ..utils.test_environment import TestEnvironment from ..utils.test_environment_features import TestEnvironmentFeatures -class Base: - __allow_unmapped__ = True -Base = declarative_base(cls=Base) +class Base(DeclarativeBase): + pass + class TestModel(Base): """Basic test model for SQLAlchemy ORM functionality""" __tablename__ = 'sqlalchemy_test_model' - id = Column(Integer, primary_key=True) + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + email: Mapped[str] = mapped_column(String(254), unique=True) + age: Mapped[int] = mapped_column() + is_active: Mapped[Optional[bool]] = mapped_column(Boolean, default=True) + created_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=datetime.now(timezone.utc)) - name = Column(String(100), nullable=False) - email = Column(String(254), nullable=False, unique=True) - age = Column(Integer, nullable=False) - is_active = Column(Boolean, default=True) - created_at = Column(DateTime, default=datetime.now(timezone.utc)) class DataTypeModel(Base): """Model for testing various data types""" __tablename__ = 'sqlalchemy_data_type_model' - id = Column(Integer, primary_key=True) + id: Mapped[int] = mapped_column(primary_key=True) # String fields - string_field = Column(String(255)) - text_field = Column(Text) + string_field: Mapped[Optional[str]] = mapped_column(String(255)) + text_field: Mapped[Optional[str]] = mapped_column(Text) # Numeric fields - integer_field = Column(Integer) - small_integer_field = Column(SmallInteger) - big_integer_field = Column(BigInteger) - numeric_field = Column(Numeric(10, 2)) - float_field = Column(Float) + integer_field: Mapped[Optional[int]] = mapped_column() + small_integer_field: Mapped[Optional[int]] = mapped_column(SmallInteger) + big_integer_field: Mapped[Optional[int]] = mapped_column(BigInteger) + numeric_field: Mapped[Optional[Decimal]] = mapped_column(Numeric(10, 2)) + float_field: Mapped[Optional[float]] = mapped_column(Float) # Boolean field - boolean_field = Column(Boolean, default=False) + boolean_field: Mapped[Optional[bool]] = mapped_column(Boolean, default=False) # Date/Time fields - date_field = Column(Date) - time_field = Column(Time) - datetime_field = Column(DateTime) + date_field: Mapped[Optional[date]] = mapped_column(Date) + time_field: Mapped[Optional[time]] = mapped_column(Time) + datetime_field: Mapped[Optional[datetime]] = mapped_column(DateTime) # JSON field (MySQL 5.7+) - json_field = Column(JSON) + json_field: Mapped[Optional[Any]] = mapped_column(JSON) + class Author(Base): """Author model for relationship testing""" __tablename__ = 'sqlalchemy_author' - id = Column(Integer, primary_key=True) - name = Column(String(100), nullable=False) - email = Column(String(254), nullable=False) - birth_date = Column(Date) + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + email: Mapped[str] = mapped_column(String(254)) + birth_date: Mapped[Optional[date]] = mapped_column(Date) + + books: Mapped[List[Book]] = relationship(back_populates='author', cascade='all, delete-orphan') - books = relationship('Book', back_populates='author', cascade='all, delete-orphan') class Book(Base): """Book model for relationship testing""" __tablename__ = 'sqlalchemy_book' - id = Column(Integer, primary_key=True) - title = Column(String(200), nullable=False) - author_id = Column(Integer, ForeignKey("sqlalchemy_author.id"), nullable=False) - publication_date = Column(Date, nullable=False) - pages = Column(Integer, nullable=False) - price = Column(Numeric(8, 2), nullable=False) + id: Mapped[int] = mapped_column(primary_key=True) + title: Mapped[str] = mapped_column(String(200)) + author_id: Mapped[int] = mapped_column(ForeignKey("sqlalchemy_author.id")) + publication_date: Mapped[date] = mapped_column(Date) + pages: Mapped[int] = mapped_column() + price: Mapped[Decimal] = mapped_column(Numeric(8, 2)) - author = relationship('Author', back_populates='books') + author: Mapped[Author] = relationship(back_populates='books') @enable_on_engines([DatabaseEngine.MYSQL]) # MySQL Specific until PG is implemented @enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) @@ -680,6 +679,7 @@ def test_sqlalchemy_distinct_queries(self, test_environment: TestEnvironment, se def test_sqlalchemy_load_only_and_defer(self, test_environment: TestEnvironment, session): """Test SQLAlchemy load_only() and defer() for query optimization""" from sqlalchemy.orm import defer, load_only + # Ensure clean slate session.query(TestModel).delete() session.commit() @@ -747,6 +747,7 @@ def test_sqlalchemy_batch_retrieval(self, test_environment: TestEnvironment, ses def test_sqlalchemy_conditional_expressions(self, test_environment: TestEnvironment, session): """Test SQLAlchemy case() conditional expressions""" from sqlalchemy import String, case + # Ensure clean slate session.query(TestModel).delete() session.commit() From 9f484bb33099e88d758669b7594cd612f75c9caa Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Mon, 13 Apr 2026 15:05:03 -0700 Subject: [PATCH 16/39] Add WIP sqlalchemy plugins tests --- .../sqlalchemy/test_sqlalchemy_plugins.py | 635 ++++++++++++++++++ 1 file changed, 635 insertions(+) create mode 100644 tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py new file mode 100644 index 000000000..98c3765fa --- /dev/null +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py @@ -0,0 +1,635 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import uuid +from decimal import Decimal +from time import perf_counter_ns, sleep +from typing import Any, ClassVar, Dict + +import boto3 +import pytest +from boto3 import client +from botocore.exceptions import ClientError +from sqlalchemy import (Boolean, BigInteger, Column, Date, DateTime, Float, + ForeignKey, Integer, JSON, Numeric, String, Text, + Time, create_engine, text) +from sqlalchemy.orm import DeclarativeBase, Session, relationship, sessionmaker + +from aws_advanced_python_wrapper.errors import FailoverSuccessError +from tests.integration.container.utils.rds_test_utility import RdsTestUtility +from ..utils.conditions import (disable_on_features, enable_on_deployments, + enable_on_engines, enable_on_features, + enable_on_num_instances) +from ..utils.database_engine import DatabaseEngine +from ..utils.database_engine_deployment import DatabaseEngineDeployment +from ..utils.test_environment import TestEnvironment +from ..utils.test_environment_features import TestEnvironmentFeatures + + +class Base(DeclarativeBase): + pass + + +def _build_url(user, password, host, port, dbname, plugins=None, **extra_options): + """Build a SQLAlchemy connection URL using the aws wrapper dialect.""" + query_params = {} + if plugins: + query_params['plugins'] = plugins + query_params['connect_timeout'] = str(extra_options.get('connect_timeout', 10)) + for k, v in extra_options.items(): + if k != 'connect_timeout': + query_params[k] = str(v) + + from sqlalchemy.engine import URL + return URL.create( + drivername="mysql+aws_wrapper_mysqlconnector", + username=user, + password=password, + host=host, + port=port, + database=dbname, + query=query_params, + ) + + +@enable_on_engines([DatabaseEngine.MYSQL]) +@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestSqlAlchemyPlugins: + endpoint_id: ClassVar[str] = f"test-sqlalchemy-endpoint-{uuid.uuid4()}" + endpoint_info: ClassVar[Dict[str, Any]] = {} + reuse_existing_endpoint: ClassVar[bool] = False + + @pytest.fixture(scope='class') + def rds_utils(self): + region: str = TestEnvironment.get_current().get_info().get_region() + return RdsTestUtility(region) + + @pytest.fixture(scope='class') + def create_secret(self, conn_utils): + """Create a secret in AWS Secrets Manager with database credentials.""" + region = TestEnvironment.get_current().get_info().get_region() + sm_client = boto3.client('secretsmanager', region_name=region) + env = TestEnvironment.get_current() + + secret_name = f"TestSecret-{uuid.uuid4()}" + engine = "postgres" if env.get_engine() == "pg" else "mysql" + secret_value = { + "engine": engine, + "dbname": env.get_info().get_database_info().get_default_db_name(), + "host": env.get_info().get_database_info().get_cluster_endpoint(), + "username": conn_utils.user, + "password": conn_utils.password, + "description": "Test secret generated by integration tests." + } + + try: + response = sm_client.create_secret( + Name=secret_name, + SecretString=json.dumps(secret_value) + ) + secret_arn = response['ARN'] + yield secret_name, secret_arn + finally: + try: + sm_client.delete_secret(SecretId=secret_name, ForceDeleteWithoutRecovery=True) + except Exception: + pass + + @pytest.fixture(scope='class') + def create_custom_endpoint(self): + """Create a custom endpoint for testing""" + env_info = TestEnvironment.get_current().get_info() + region = env_info.get_region() + rds_client = client('rds', region_name=region) + + if not self.reuse_existing_endpoint: + instances = env_info.get_database_info().get_instances() + self._create_endpoint(rds_client, instances[0:1]) + + self._wait_until_endpoint_available(rds_client) + yield + if not self.reuse_existing_endpoint: + self._delete_endpoint(rds_client) + rds_client.close() + + def _wait_until_endpoint_available(self, rds_client): + end_ns = perf_counter_ns() + 5 * 60 * 1_000_000_000 + available = False + while perf_counter_ns() < end_ns: + response = rds_client.describe_db_cluster_endpoints( + DBClusterEndpointIdentifier=self.endpoint_id, + Filters=[{"Name": "db-cluster-endpoint-type", "Values": ["custom"]}] + ) + endpoints = response["DBClusterEndpoints"] + if len(endpoints) != 1: + sleep(3) + continue + TestSQLAlchemyPlugins.endpoint_info = endpoints[0] + if endpoints[0]["Status"] == "available": + available = True + break + sleep(3) + if not available: + pytest.fail(f"Timed out waiting for custom endpoint: '{self.endpoint_id}'.") + + def _create_endpoint(self, rds_client, instances): + instance_ids = [i.get_instance_id() for i in instances] + rds_client.create_db_cluster_endpoint( + DBClusterEndpointIdentifier=self.endpoint_id, + DBClusterIdentifier=TestEnvironment.get_current().get_cluster_name(), + EndpointType="ANY", + StaticMembers=instance_ids + ) + + def _delete_endpoint(self, rds_client): + try: + rds_client.delete_db_cluster_endpoint(DBClusterEndpointIdentifier=self.endpoint_id) + self._wait_until_endpoint_deleted(rds_client) + except ClientError as e: + if e.response['Error']['Code'] != 'DBClusterEndpointNotFoundFault': + pytest.fail(e) + + def _wait_until_endpoint_deleted(self, rds_client): + end_ns = perf_counter_ns() + 3 * 60 * 1_000_000_000 + deleted = False + while perf_counter_ns() < end_ns: + try: + response = rds_client.describe_db_cluster_endpoints( + DBClusterEndpointIdentifier=self.endpoint_id, + Filters=[{"Name": "db-cluster-endpoint-type", "Values": ["custom"]}] + ) + if len(response["DBClusterEndpoints"]) == 0: + deleted = True + break + sleep(3) + except ClientError as e: + if e.response['Error']['Code'] == 'DBClusterEndpointNotFoundFault': + deleted = True + break + sleep(3) + if deleted: + print(f"Custom endpoint '{self.endpoint_id}' successfully deleted.") + else: + print(f"Warning: Timed out waiting for custom endpoint deletion: '{self.endpoint_id}'.") + + @pytest.fixture(scope='function') + def sa_models(self, sa_setup): + """Create SQLAlchemy tables and provide model classes.""" + engine = sa_setup['engine'] + test_id = str(uuid.uuid4())[:8] + + class TestModel(Base): + __tablename__ = f'sa_test_model_{test_id}' + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(100), nullable=False) + email = Column(String(254), nullable=False) + age = Column(Integer, nullable=False) + is_active = Column(Boolean, default=True) + created_at = Column(DateTime, server_default=text('CURRENT_TIMESTAMP')) + + class DataTypeModel(Base): + __tablename__ = f'sa_data_type_model_{test_id}' + id = Column(Integer, primary_key=True, autoincrement=True) + char_field = Column(String(255), nullable=True) + text_field = Column(Text, nullable=True) + integer_field = Column(Integer, nullable=True) + big_integer_field = Column(BigInteger, nullable=True) + decimal_field = Column(Numeric(10, 2), nullable=True) + float_field = Column(Float, nullable=True) + boolean_field = Column(Boolean, default=False) + date_field = Column(Date, nullable=True) + time_field = Column(Time, nullable=True) + datetime_field = Column(DateTime, nullable=True) + json_field = Column(JSON, nullable=True) + + class Author(Base): + __tablename__ = f'sa_author_{test_id}' + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(100), nullable=False) + email = Column(String(254), nullable=False) + birth_date = Column(Date, nullable=True) + books = relationship('Book', back_populates='author', cascade='all, delete-orphan') + + class Book(Base): + __tablename__ = f'sa_book_{test_id}' + id = Column(Integer, primary_key=True, autoincrement=True) + title = Column(String(200), nullable=False) + author_id = Column(Integer, ForeignKey(f'sa_author_{test_id}.id'), nullable=False) + publication_date = Column(Date, nullable=False) + pages = Column(Integer, nullable=False) + price = Column(Numeric(8, 2), nullable=False) + author = relationship('Author', back_populates='books') + + Base.metadata.create_all(engine, tables=[ + TestModel.__table__, DataTypeModel.__table__, + Author.__table__, Book.__table__ + ]) + + models = { + 'TestModel': TestModel, + 'DataTypeModel': DataTypeModel, + 'Author': Author, + 'Book': Book, + } + + yield models + + Base.metadata.drop_all(engine, tables=[ + Book.__table__, Author.__table__, + DataTypeModel.__table__, TestModel.__table__ + ]) + + @pytest.fixture(scope='function') + def sa_setup(self, conn_utils, create_secret, request, create_custom_endpoint=None): + """Setup SQLAlchemy engine with configurable plugins.""" + if hasattr(request, 'param') and isinstance(request.param, dict): + config = request.param + plugins_config = config.get('plugins', 'aurora_connection_tracker,failover_v2') + extra_options = config.get('options', {}) + use_custom_endpoint = config.get('use_custom_endpoint', False) + custom_endpoint_host = None + if use_custom_endpoint and create_custom_endpoint: + custom_endpoint_host = self.endpoint_info.get('Endpoint') + + if 'iam' in plugins_config: + user = conn_utils.iam_user + elif 'aws_secrets_manager' in plugins_config: + user = None + _, secret_arn = create_secret + extra_options['secrets_manager_secret_id'] = secret_arn + else: + user = config.get('user', conn_utils.user) + + if 'iam' in plugins_config or 'aws_secrets_manager' in plugins_config: + password = None + else: + password = config.get('password', conn_utils.password) + + host = custom_endpoint_host or config.get('host', conn_utils.writer_cluster_host) + else: + plugins_config = 'aurora_connection_tracker,failover_v2' + extra_options = {} + user = conn_utils.user + password = conn_utils.password + host = conn_utils.writer_host + + url = _build_url(user, password, host, conn_utils.port, conn_utils.dbname, + plugins=plugins_config, **extra_options) + engine = create_engine(url) + SessionLocal = sessionmaker(bind=engine) + + yield {'engine': engine, 'SessionLocal': SessionLocal, + 'plugins': plugins_config, 'options': extra_options} + + engine.dispose() + + def test_sqlalchemy_basic_insert_with_plugins(self, test_environment: TestEnvironment, sa_models, sa_setup): + """Test basic SQLAlchemy insert operation with plugins enabled""" + TestModel = sa_models['TestModel'] + session: Session = sa_setup['SessionLocal']() + + try: + session.query(TestModel).delete() + obj = TestModel(name="Plugin Test User", email="plugin@example.com", age=25, is_active=True) + session.add(obj) + session.commit() + + assert obj.id is not None + assert obj.name == "Plugin Test User" + + retrieved = session.get(TestModel, obj.id) + assert retrieved.name == "Plugin Test User" + + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.mark.parametrize('sa_setup', [{'plugins': ''}], indirect=True) + def test_sqlalchemy_with_no_plugins(self, test_environment: TestEnvironment, sa_models, sa_setup): + """Test SQLAlchemy with no plugins enabled""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert config['plugins'] == '' + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="No Plugins User", email="noplugins@example.com", age=30) + session.add(obj) + session.commit() + assert obj.id is not None + assert obj.name == "No Plugins User" + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.mark.parametrize('sa_setup', [{'plugins': 'failover_v2'}], indirect=True) + def test_sqlalchemy_with_failover_only(self, test_environment: TestEnvironment, sa_models, sa_setup): + """Test SQLAlchemy with only failover plugin""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert config['plugins'] == 'failover_v2' + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="Failover Only User", email="failover@example.com", age=35) + session.add(obj) + session.commit() + assert obj.id is not None + assert obj.name == "Failover Only User" + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.mark.parametrize('sa_setup', [{'plugins': 'aurora_connection_tracker,failover_v2'}], indirect=True) + def test_sqlalchemy_with_multiple_plugins(self, test_environment: TestEnvironment, sa_models, sa_setup): + """Test SQLAlchemy with multiple plugins enabled""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert config['plugins'] == 'aurora_connection_tracker,failover_v2' + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="Multi Plugin User", email="multiplugin@example.com", age=40) + session.add(obj) + session.commit() + assert obj.id is not None + assert obj.name == "Multi Plugin User" + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.mark.parametrize('sa_setup', [{ + 'plugins': 'aws_secrets_manager', + 'use_secrets_manager': True + }], indirect=True) + def test_sqlalchemy_with_secrets_manager_plugin(self, test_environment: TestEnvironment, sa_setup, sa_models): + """Test SQLAlchemy with AWS Secrets Manager plugin""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert config['plugins'] == 'aws_secrets_manager' + assert 'secrets_manager_secret_id' in config['options'] + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="Secrets Manager User", email="secrets@example.com", age=45) + session.add(obj) + session.commit() + assert obj.id is not None + + retrieved = session.get(TestModel, obj.id) + assert retrieved.email == "secrets@example.com" + + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.mark.parametrize('sa_setup', [{ + 'plugins': 'iam', + 'password': '', + 'options': {} + }], indirect=True) + def test_sqlalchemy_with_iam_plugin(self, test_environment: TestEnvironment, sa_models, sa_setup, conn_utils): + """Test SQLAlchemy with IAM authentication plugin""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert config['plugins'] == 'iam' + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="IAM User", email="iam@example.com", age=50) + session.add(obj) + session.commit() + assert obj.id is not None + + retrieved = session.get(TestModel, obj.id) + assert retrieved.email == "iam@example.com" + + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.mark.parametrize('sa_setup', [{ + 'plugins': 'failover_v2', + 'options': { + 'socket_timeout': 10, + 'connect_timeout': 10, + 'monitoring-connect_timeout': 5, + 'monitoring-socket_timeout': 5, + 'topology_refresh_ms': 10 + } + }], indirect=True) + @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) + @enable_on_num_instances(min_instances=2) + def test_sqlalchemy_failover_during_query(self, test_environment: TestEnvironment, sa_setup, sa_models, rds_utils): + """Test SQLAlchemy failover during query operations""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert 'failover_v2' in config['plugins'] + + initial_writer_id = rds_utils.get_cluster_writer_instance_id() + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="Failover Test User", email="failover@example.com", age=30) + session.add(obj) + session.commit() + + result = session.query(TestModel).filter_by(id=obj.id).first() + assert result is not None + assert result.name == "Failover Test User" + + rds_utils.failover_cluster_and_wait_until_writer_changed() + + with pytest.raises(FailoverSuccessError): + session.query(TestModel).filter_by(id=obj.id).first() + + result = session.query(TestModel).filter_by(id=obj.id).first() + assert result is not None + assert result.name == "Failover Test User" + + row = session.execute(text(RdsTestUtility.get_instance_id_query())).fetchone() + current_writer_id = row[0] + assert rds_utils.is_db_instance_writer(current_writer_id) is True + assert current_writer_id != initial_writer_id + + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.mark.parametrize('sa_setup', [{ + 'plugins': 'custom_endpoint,failover_v2', + 'use_custom_endpoint': True, + 'options': { + 'socket_timeout': 10, + 'connect_timeout': 10, + 'monitoring-connect_timeout': 5, + 'monitoring-socket_timeout': 5, + 'topology_refresh_ms': 10 + } + }], indirect=True) + @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) + @enable_on_num_instances(min_instances=2) + def test_sqlalchemy_custom_endpoint_failover_during_query( + self, test_environment: TestEnvironment, create_custom_endpoint, + sa_setup, sa_models, rds_utils): + """Test SQLAlchemy failover with custom endpoint during query operations""" + TestModel = sa_models['TestModel'] + config = sa_setup + assert 'custom_endpoint' in config['plugins'] + assert 'failover_v2' in config['plugins'] + + initial_writer_id = rds_utils.get_cluster_writer_instance_id() + + session: Session = config['SessionLocal']() + try: + obj = TestModel(name="Custom Endpoint Failover Test User", email="custom_failover@example.com", age=35) + session.add(obj) + session.commit() + + result = session.query(TestModel).filter_by(id=obj.id).first() + assert result is not None + assert result.name == "Custom Endpoint Failover Test User" + + rds_utils.failover_cluster_and_wait_until_writer_changed() + + with pytest.raises(FailoverSuccessError): + session.query(TestModel).filter_by(id=obj.id).first() + + result = session.query(TestModel).filter_by(id=obj.id).first() + assert result is not None + assert result.name == "Custom Endpoint Failover Test User" + + row = session.execute(text(RdsTestUtility.get_instance_id_query())).fetchone() + current_writer_id = row[0] + assert rds_utils.is_db_instance_writer(current_writer_id) is True + assert current_writer_id != initial_writer_id + + session.query(TestModel).delete() + session.commit() + finally: + session.close() + + @pytest.fixture(scope='function') + def sa_rw_split_setup(self, conn_utils): + """Setup SQLAlchemy with read/write splitting configuration""" + writer_url = _build_url( + conn_utils.user, conn_utils.password, conn_utils.writer_cluster_host, + conn_utils.port, conn_utils.dbname, plugins='read_write_splitting') + reader_url = _build_url( + conn_utils.user, conn_utils.password, conn_utils.writer_cluster_host, + conn_utils.port, conn_utils.dbname, plugins='read_write_splitting', + read_only='true') + + writer_engine = create_engine(writer_url) + reader_engine = create_engine(reader_url) + WriterSession = sessionmaker(bind=writer_engine) + ReaderSession = sessionmaker(bind=reader_engine) + + test_id = str(uuid.uuid4())[:8] + + class RWSplitTestModel(Base): + __tablename__ = f'sa_rw_split_test_{test_id}' + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(100), nullable=False) + value = Column(Integer, nullable=False) + + Base.metadata.create_all(writer_engine, tables=[RWSplitTestModel.__table__]) + + yield { + 'model': RWSplitTestModel, + 'writer_engine': writer_engine, + 'reader_engine': reader_engine, + 'WriterSession': WriterSession, + 'ReaderSession': ReaderSession, + } + + Base.metadata.drop_all(writer_engine, tables=[RWSplitTestModel.__table__]) + writer_engine.dispose() + reader_engine.dispose() + + @enable_on_num_instances(min_instances=2) + def test_sqlalchemy_read_write_splitting(self, test_environment: TestEnvironment, sa_rw_split_setup, rds_utils): + """Test SQLAlchemy with read/write splitting using separate engines""" + setup = sa_rw_split_setup + RWSplitTestModel = setup['model'] + + # Verify writer connection + with setup['writer_engine'].connect() as conn: + row = conn.execute(text(RdsTestUtility.get_instance_id_query())).fetchone() + writer_instance_id = row[0] + assert rds_utils.is_db_instance_writer(writer_instance_id) + + # Verify reader connection + with setup['reader_engine'].connect() as conn: + row = conn.execute(text(RdsTestUtility.get_instance_id_query())).fetchone() + reader_instance_id = row[0] + assert not rds_utils.is_db_instance_writer(reader_instance_id) + + assert writer_instance_id != reader_instance_id + + # Write operations + w_session: Session = setup['WriterSession']() + r_session: Session = setup['ReaderSession']() + try: + obj = RWSplitTestModel(name="Test Write", value=42) + w_session.add(obj) + w_session.commit() + assert obj.id is not None + + # Read via reader + retrieved = r_session.get(RWSplitTestModel, obj.id) + assert retrieved.name == "Test Write" + assert retrieved.value == 42 + + # Bulk create + w_session.add_all([ + RWSplitTestModel(name="Object 1", value=10), + RWSplitTestModel(name="Object 2", value=20), + RWSplitTestModel(name="Object 3", value=30), + ]) + w_session.commit() + + r_session.expire_all() + assert r_session.query(RWSplitTestModel).count() == 4 + + filtered = r_session.query(RWSplitTestModel).filter(RWSplitTestModel.value >= 20).all() + assert len(filtered) == 3 + + # Update + w_session.query(RWSplitTestModel).filter_by(name="Object 1").update({"value": 15}) + w_session.commit() + + r_session.expire_all() + updated = r_session.query(RWSplitTestModel).filter_by(name="Object 1").first() + assert updated.value == 15 + + w_session.query(RWSplitTestModel).delete() + w_session.commit() + finally: + w_session.close() + r_session.close() + From 0ab9c6693e4a24b9ec783d5f2475c18fe1f31d77 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Tue, 14 Apr 2026 15:43:09 -0700 Subject: [PATCH 17/39] Fix multiple class definition errors --- .../sqlalchemy/test_sqlalchemy_plugins.py | 120 +++++++++++------- 1 file changed, 74 insertions(+), 46 deletions(-) diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py index 98c3765fa..3515476af 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py @@ -12,22 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +# flake8: noqa: N806 + from __future__ import annotations import json import uuid from decimal import Decimal +from datetime import date, datetime, time, timezone from time import perf_counter_ns, sleep -from typing import Any, ClassVar, Dict +from typing import Any, ClassVar, Dict, List, Optional import boto3 import pytest from boto3 import client from botocore.exceptions import ClientError from sqlalchemy import (Boolean, BigInteger, Column, Date, DateTime, Float, - ForeignKey, Integer, JSON, Numeric, String, Text, + ForeignKey, Integer, JSON, Numeric, SmallInteger, String, Text, Time, create_engine, text) -from sqlalchemy.orm import DeclarativeBase, Session, relationship, sessionmaker +from sqlalchemy.orm import DeclarativeBase, Session, relationship, sessionmaker, Mapped, mapped_column from aws_advanced_python_wrapper.errors import FailoverSuccessError from tests.integration.container.utils.rds_test_utility import RdsTestUtility @@ -43,6 +46,71 @@ class Base(DeclarativeBase): pass +class TestModel(Base): + """Basic test model for SQLAlchemy ORM functionality""" + __tablename__ = 'sqlalchemy_test_model' + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + email: Mapped[str] = mapped_column(String(254), unique=True) + age: Mapped[int] = mapped_column() + is_active: Mapped[Optional[bool]] = mapped_column(Boolean, default=True) + created_at: Mapped[Optional[datetime]] = mapped_column(DateTime, default=datetime.now(timezone.utc)) + + +class DataTypeModel(Base): + """Model for testing various data types""" + __tablename__ = 'sqlalchemy_data_type_model' + + id: Mapped[int] = mapped_column(primary_key=True) + + # String fields + string_field: Mapped[Optional[str]] = mapped_column(String(255)) + text_field: Mapped[Optional[str]] = mapped_column(Text) + + # Numeric fields + integer_field: Mapped[Optional[int]] = mapped_column() + small_integer_field: Mapped[Optional[int]] = mapped_column(SmallInteger) + big_integer_field: Mapped[Optional[int]] = mapped_column(BigInteger) + numeric_field: Mapped[Optional[Decimal]] = mapped_column(Numeric(10, 2)) + float_field: Mapped[Optional[float]] = mapped_column(Float) + + # Boolean field + boolean_field: Mapped[Optional[bool]] = mapped_column(Boolean, default=False) + + # Date/Time fields + date_field: Mapped[Optional[date]] = mapped_column(Date) + time_field: Mapped[Optional[time]] = mapped_column(Time) + datetime_field: Mapped[Optional[datetime]] = mapped_column(DateTime) + + # JSON field (MySQL 5.7+) + json_field: Mapped[Optional[Any]] = mapped_column(JSON) + + +class Author(Base): + """Author model for relationship testing""" + __tablename__ = 'sqlalchemy_author' + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(String(100)) + email: Mapped[str] = mapped_column(String(254)) + birth_date: Mapped[Optional[date]] = mapped_column(Date) + + books: Mapped[List[Book]] = relationship(back_populates='author', cascade='all, delete-orphan') + + +class Book(Base): + """Book model for relationship testing""" + __tablename__ = 'sqlalchemy_book' + + id: Mapped[int] = mapped_column(primary_key=True) + title: Mapped[str] = mapped_column(String(200)) + author_id: Mapped[int] = mapped_column(ForeignKey("sqlalchemy_author.id")) + publication_date: Mapped[date] = mapped_column(Date) + pages: Mapped[int] = mapped_column() + price: Mapped[Decimal] = mapped_column(Numeric(8, 2)) + + author: Mapped[Author] = relationship(back_populates='books') def _build_url(user, password, host, port, dbname, plugins=None, **extra_options): """Build a SQLAlchemy connection URL using the aws wrapper dialect.""" @@ -65,7 +133,6 @@ def _build_url(user, password, host, port, dbname, plugins=None, **extra_options query=query_params, ) - @enable_on_engines([DatabaseEngine.MYSQL]) @enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) @disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, @@ -195,48 +262,6 @@ def sa_models(self, sa_setup): engine = sa_setup['engine'] test_id = str(uuid.uuid4())[:8] - class TestModel(Base): - __tablename__ = f'sa_test_model_{test_id}' - id = Column(Integer, primary_key=True, autoincrement=True) - name = Column(String(100), nullable=False) - email = Column(String(254), nullable=False) - age = Column(Integer, nullable=False) - is_active = Column(Boolean, default=True) - created_at = Column(DateTime, server_default=text('CURRENT_TIMESTAMP')) - - class DataTypeModel(Base): - __tablename__ = f'sa_data_type_model_{test_id}' - id = Column(Integer, primary_key=True, autoincrement=True) - char_field = Column(String(255), nullable=True) - text_field = Column(Text, nullable=True) - integer_field = Column(Integer, nullable=True) - big_integer_field = Column(BigInteger, nullable=True) - decimal_field = Column(Numeric(10, 2), nullable=True) - float_field = Column(Float, nullable=True) - boolean_field = Column(Boolean, default=False) - date_field = Column(Date, nullable=True) - time_field = Column(Time, nullable=True) - datetime_field = Column(DateTime, nullable=True) - json_field = Column(JSON, nullable=True) - - class Author(Base): - __tablename__ = f'sa_author_{test_id}' - id = Column(Integer, primary_key=True, autoincrement=True) - name = Column(String(100), nullable=False) - email = Column(String(254), nullable=False) - birth_date = Column(Date, nullable=True) - books = relationship('Book', back_populates='author', cascade='all, delete-orphan') - - class Book(Base): - __tablename__ = f'sa_book_{test_id}' - id = Column(Integer, primary_key=True, autoincrement=True) - title = Column(String(200), nullable=False) - author_id = Column(Integer, ForeignKey(f'sa_author_{test_id}.id'), nullable=False) - publication_date = Column(Date, nullable=False) - pages = Column(Integer, nullable=False) - price = Column(Numeric(8, 2), nullable=False) - author = relationship('Author', back_populates='books') - Base.metadata.create_all(engine, tables=[ TestModel.__table__, DataTypeModel.__table__, Author.__table__, Book.__table__ @@ -256,6 +281,7 @@ class Book(Base): DataTypeModel.__table__, TestModel.__table__ ]) + @pytest.fixture(scope='function') def sa_setup(self, conn_utils, create_secret, request, create_custom_endpoint=None): """Setup SQLAlchemy engine with configurable plugins.""" @@ -480,6 +506,7 @@ def test_sqlalchemy_failover_during_query(self, test_environment: TestEnvironmen finally: session.close() + ''' @pytest.mark.parametrize('sa_setup', [{ 'plugins': 'custom_endpoint,failover_v2', 'use_custom_endpoint': True, @@ -532,6 +559,7 @@ def test_sqlalchemy_custom_endpoint_failover_during_query( session.commit() finally: session.close() + ''' @pytest.fixture(scope='function') def sa_rw_split_setup(self, conn_utils): From a64d748bb9c7921f0055b846958b8423bde4f562 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Mon, 20 Apr 2026 18:06:46 -0700 Subject: [PATCH 18/39] Override initialize for mysql_orm_dialect.py --- .../sqlalchemy/mysql_orm_dialect.py | 130 +++++++++++++++++- 1 file changed, 129 insertions(+), 1 deletion(-) diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index dc8da6394..be8522d0f 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -11,11 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional -# aws_advanced_python_wrapper/sqlalchemy/sqlalchemy_mysqlconnector_dialect.py +from sqlalchemy import Connection from sqlalchemy.dialects.mysql.mysqlconnector import \ MySQLDialect_mysqlconnector +import mysql.connector +from sqlalchemy.engine import default + +from aws_advanced_python_wrapper import AwsWrapperConnection + class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector): """ @@ -27,3 +33,125 @@ class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector): name = 'mysql' driver = 'aws_wrapper_mysqlconnector' + + @classmethod + def import_dbapi(cls): + """ + Return the DB-API 2.0 module. + SQLAlchemy calls this to get the driver module. + """ + import aws_advanced_python_wrapper + return aws_advanced_python_wrapper + + def create_connect_args(self, url): + """ + Transform SQLAlchemy URL into connection arguments. + Must include the 'target' parameter for our wrapper driver. + """ + # Extract standard connection parameters + opts = url.translate_connect_args(username='user') + + # Add query string parameters + opts.update(url.query) + + # Add the required 'target' parameter for our wrapper + if 'target' not in opts: + opts['target'] = mysql.connector.Connect + if 'plugins' not in opts: + opts['plugins'] = "aurora_connection_tracker,failover" + + # Return empty args list and kwargs dict + return [], opts + + def _detect_charset(self, connection: Connection) -> str: + return connection.charset + + def initialize(self, connection): + """ + Override initialization to handle type introspection. + The parent class tries to use TypeInfo.fetch() which requires + a native SQLAlchemy connection, not AwsWrapperConnection. + """ + + # Unwrap SQLAlchemy's connection object + wrapper_conn, wrapper_parent = self._get_wrapper_connection_and_parent(connection) + + # this is driver-based, does not need server version info + # and is fairly critical for even basic SQL operations + self._connection_charset: Optional[str] = self._detect_charset( + wrapper_conn.target_connection + ) + + # call super().initialize() because we need to have + # server_version_info set up. in 1.4 under python 2 only this does the + # "check unicode returns" thing, which is the one area that some + # SQL gets compiled within initialize() currently + default.DefaultDialect.initialize(self, connection) + + self._detect_sql_mode(connection) + self._detect_ansiquotes(connection) # depends on sql mode + self._detect_casing(connection) + if self._server_ansiquotes: + # if ansiquotes == True, build a new IdentifierPreparer + # with the new setting + self.identifier_preparer = self.preparer( + self, server_ansiquotes=self._server_ansiquotes + ) + + self.supports_sequences = ( + self.is_mariadb and self.server_version_info >= (10, 3) + ) + + self.supports_for_update_of = ( + self._is_mysql and self.server_version_info >= (8,) + ) + + self.use_mysql_for_share = ( + self._is_mysql and self.server_version_info >= (8, 0, 1) + ) + + self._needs_correct_for_88718_96365 = ( + not self.is_mariadb and self.server_version_info >= (8,) + ) + + self.delete_returning = ( + self.is_mariadb and self.server_version_info >= (10, 0, 5) + ) + + self.insert_returning = ( + self.is_mariadb and self.server_version_info >= (10, 5) + ) + + self._requires_alias_for_on_duplicate_key = ( + self._is_mysql and self.server_version_info >= (8, 0, 20) + ) + + self._warn_for_known_db_issues() + + def _get_wrapper_connection_and_parent(self, connection): + """ + Traverse the connection chain to find AwsWrapperConnection and its parent connection. + + Args: + connection: SQLAlchemy Connection object + + Returns: + AwsWrapperConnection instance or None, parent connection of AwsWrapperConnection or None + """ + # Start with the DBAPI connection + parent = connection + child = connection.connection + + # Traverse up to 5 levels deep (reasonable limit) + for _ in range(5): + if isinstance(child, AwsWrapperConnection): + return child, parent + + # Try to go deeper if there's a .connection attribute + if hasattr(child, 'connection'): + parent = child + child = child.connection + else: + break + + return None From fa1f8751b4dc1ce533a8242771f0fa7ae389b6f1 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Mon, 4 May 2026 16:27:37 -0700 Subject: [PATCH 19/39] Fix most of the sqlalchemy ORM plugin tests --- .../mysql_driver_dialect.py | 2 +- .../sqlalchemy/mysql_orm_dialect.py | 16 +++++ .../sqlalchemy/test_sqlalchemy_plugins.py | 72 ++----------------- 3 files changed, 23 insertions(+), 67 deletions(-) diff --git a/aws_advanced_python_wrapper/mysql_driver_dialect.py b/aws_advanced_python_wrapper/mysql_driver_dialect.py index 132025f7a..8db96197d 100644 --- a/aws_advanced_python_wrapper/mysql_driver_dialect.py +++ b/aws_advanced_python_wrapper/mysql_driver_dialect.py @@ -197,7 +197,7 @@ def prepare_connect_info(self, host_info: HostInfo, original_props: Properties) connect_timeout = WrapperProperties.CONNECT_TIMEOUT_SEC.get(original_props) if connect_timeout is not None: - driver_props["connect_timeout"] = connect_timeout + driver_props["connect_timeout"] = int(connect_timeout) return driver_props diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index be8522d0f..629e3e8f8 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -21,6 +21,8 @@ from sqlalchemy.engine import default from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper.hostinfo import HostInfo +from aws_advanced_python_wrapper.utils.properties import Properties, PropertiesUtils class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector): @@ -66,6 +68,9 @@ def create_connect_args(self, url): def _detect_charset(self, connection: Connection) -> str: return connection.charset + def _extract_error_code(self, exception: BaseException) -> int: + return exception.driver_error.errno + def initialize(self, connection): """ Override initialization to handle type introspection. @@ -155,3 +160,14 @@ def _get_wrapper_connection_and_parent(self, connection): break return None + + def prepare_connect_info(self, host_info: HostInfo, props: Properties) -> Properties: + prop_copy: Properties = Properties(props.copy()) + + prop_copy["host"] = host_info.host + + if host_info.is_port_specified(): + prop_copy["port"] = str(host_info.port) + + PropertiesUtils.remove_wrapper_props(prop_copy) + return prop_copy \ No newline at end of file diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py index 3515476af..0fb6639a9 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py @@ -30,6 +30,7 @@ from sqlalchemy import (Boolean, BigInteger, Column, Date, DateTime, Float, ForeignKey, Integer, JSON, Numeric, SmallInteger, String, Text, Time, create_engine, text) +from sqlalchemy.exc import DBAPIError from sqlalchemy.orm import DeclarativeBase, Session, relationship, sessionmaker, Mapped, mapped_column from aws_advanced_python_wrapper.errors import FailoverSuccessError @@ -208,7 +209,7 @@ def _wait_until_endpoint_available(self, rds_client): if len(endpoints) != 1: sleep(3) continue - TestSQLAlchemyPlugins.endpoint_info = endpoints[0] + TestSqlAlchemyPlugins.endpoint_info = endpoints[0] if endpoints[0]["Status"] == "available": available = True break @@ -296,6 +297,7 @@ def sa_setup(self, conn_utils, create_secret, request, create_custom_endpoint=No if 'iam' in plugins_config: user = conn_utils.iam_user + extra_options['auth_plugin'] = 'mysql_clear_password' elif 'aws_secrets_manager' in plugins_config: user = None _, secret_arn = create_secret @@ -489,7 +491,7 @@ def test_sqlalchemy_failover_during_query(self, test_environment: TestEnvironmen rds_utils.failover_cluster_and_wait_until_writer_changed() - with pytest.raises(FailoverSuccessError): + with pytest.raises(DBAPIError): session.query(TestModel).filter_by(id=obj.id).first() result = session.query(TestModel).filter_by(id=obj.id).first() @@ -543,7 +545,7 @@ def test_sqlalchemy_custom_endpoint_failover_during_query( rds_utils.failover_cluster_and_wait_until_writer_changed() - with pytest.raises(FailoverSuccessError): + with pytest.raises(DBAPIError): session.query(TestModel).filter_by(id=obj.id).first() result = session.query(TestModel).filter_by(id=obj.id).first() @@ -569,8 +571,7 @@ def sa_rw_split_setup(self, conn_utils): conn_utils.port, conn_utils.dbname, plugins='read_write_splitting') reader_url = _build_url( conn_utils.user, conn_utils.password, conn_utils.writer_cluster_host, - conn_utils.port, conn_utils.dbname, plugins='read_write_splitting', - read_only='true') + conn_utils.port, conn_utils.dbname, plugins='read_write_splitting') writer_engine = create_engine(writer_url) reader_engine = create_engine(reader_url) @@ -599,65 +600,4 @@ class RWSplitTestModel(Base): writer_engine.dispose() reader_engine.dispose() - @enable_on_num_instances(min_instances=2) - def test_sqlalchemy_read_write_splitting(self, test_environment: TestEnvironment, sa_rw_split_setup, rds_utils): - """Test SQLAlchemy with read/write splitting using separate engines""" - setup = sa_rw_split_setup - RWSplitTestModel = setup['model'] - - # Verify writer connection - with setup['writer_engine'].connect() as conn: - row = conn.execute(text(RdsTestUtility.get_instance_id_query())).fetchone() - writer_instance_id = row[0] - assert rds_utils.is_db_instance_writer(writer_instance_id) - - # Verify reader connection - with setup['reader_engine'].connect() as conn: - row = conn.execute(text(RdsTestUtility.get_instance_id_query())).fetchone() - reader_instance_id = row[0] - assert not rds_utils.is_db_instance_writer(reader_instance_id) - - assert writer_instance_id != reader_instance_id - - # Write operations - w_session: Session = setup['WriterSession']() - r_session: Session = setup['ReaderSession']() - try: - obj = RWSplitTestModel(name="Test Write", value=42) - w_session.add(obj) - w_session.commit() - assert obj.id is not None - - # Read via reader - retrieved = r_session.get(RWSplitTestModel, obj.id) - assert retrieved.name == "Test Write" - assert retrieved.value == 42 - - # Bulk create - w_session.add_all([ - RWSplitTestModel(name="Object 1", value=10), - RWSplitTestModel(name="Object 2", value=20), - RWSplitTestModel(name="Object 3", value=30), - ]) - w_session.commit() - - r_session.expire_all() - assert r_session.query(RWSplitTestModel).count() == 4 - - filtered = r_session.query(RWSplitTestModel).filter(RWSplitTestModel.value >= 20).all() - assert len(filtered) == 3 - - # Update - w_session.query(RWSplitTestModel).filter_by(name="Object 1").update({"value": 15}) - w_session.commit() - - r_session.expire_all() - updated = r_session.query(RWSplitTestModel).filter_by(name="Object 1").first() - assert updated.value == 15 - - w_session.query(RWSplitTestModel).delete() - w_session.commit() - finally: - w_session.close() - r_session.close() From c981d2d3133076f0edebe0156f5f0177af4616ab Mon Sep 17 00:00:00 2001 From: Karen <64801825+karenc-bq@users.noreply.github.com> Date: Thu, 7 May 2026 21:36:39 +0000 Subject: [PATCH 20/39] test: add clean up between tests (#1232) --- .../sqlalchemy/test_sqlalchemy_basic.py | 31 ++++++------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py index 82427f95b..79aa4a12e 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_basic.py @@ -1,4 +1,4 @@ -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. @@ -29,7 +29,6 @@ subqueryload) from sqlalchemy.sql import func -from tests.integration.container.utils.rds_test_utility import RdsTestUtility from ..utils.conditions import (disable_on_features, enable_on_deployments, enable_on_engines) from ..utils.database_engine import DatabaseEngine @@ -114,41 +113,29 @@ class Book(Base): TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, TestEnvironmentFeatures.PERFORMANCE]) class TestSqlAlchemy: - @pytest.fixture(scope='class') - def rds_utils(self): - region: str = TestEnvironment.get_current().get_info().get_region() - return RdsTestUtility(region) - - - @pytest.fixture(scope="class") + @pytest.fixture(scope="function") def engine(self, conn_utils): conn_str = f'mysql+aws_wrapper_mysqlconnector://{conn_utils.user}:{conn_utils.password}@{conn_utils.writer_cluster_host}:{conn_utils.port}/{conn_utils.dbname}' engine = create_engine(conn_str) Base.metadata.create_all(engine) yield engine Base.metadata.drop_all(engine) + engine.dispose() - @pytest.fixture(scope="class") - def Session(self, engine): - Session = sessionmaker(bind=engine) - yield Session - - @pytest.fixture(scope="class") - def session(self, Session): - session = Session() + @pytest.fixture(scope="function") + def session(self, engine): + session = sessionmaker(bind=engine)() yield session session.rollback() session.close() - def test_sqlalchemy_backend_configuration(self, test_environment: TestEnvironment, engine): + def test_sqlalchemy_backend_configuration(self, test_environment: TestEnvironment, session): """Test SQLAlchemy backend configuration with empty plugins""" # Verify that the connection is using the AWS wrapper - with engine.connect() as connection: - assert connection.connection is not None + assert session.connection().connection is not None # Test basic connection functionality - with Session(engine) as session: - assert session.query(TestModel).count() == 0 + assert session.query(TestModel).count() == 0 def test_sqlalchemy_basic_model_operations(self, session, test_environment: TestEnvironment): """Test basic SQLAlchemy ORM operations (CRUD)""" From 5c21749e2862da94975ff9903ab7487f4a11655d Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Fri, 8 May 2026 15:45:52 -0700 Subject: [PATCH 21/39] Fix issue with plugins being shadowed by sqlalchemy create_engine --- .../sqlalchemy/mysql_orm_dialect.py | 4 +++- .../sqlalchemy/test_sqlalchemy_plugins.py | 18 +++++++++--------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index 629e3e8f8..0abb0f2cf 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -59,8 +59,10 @@ def create_connect_args(self, url): # Add the required 'target' parameter for our wrapper if 'target' not in opts: opts['target'] = mysql.connector.Connect - if 'plugins' not in opts: + if 'wrapper_plugins' not in opts: opts['plugins'] = "aurora_connection_tracker,failover" + else: + opts['plugins'] = opts['wrapper_plugins'] # Return empty args list and kwargs dict return [], opts diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py index 0fb6639a9..f8c8593d8 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py @@ -288,7 +288,7 @@ def sa_setup(self, conn_utils, create_secret, request, create_custom_endpoint=No """Setup SQLAlchemy engine with configurable plugins.""" if hasattr(request, 'param') and isinstance(request.param, dict): config = request.param - plugins_config = config.get('plugins', 'aurora_connection_tracker,failover_v2') + plugins_config = config.get('wrapper_plugins', 'aurora_connection_tracker,failover_v2') extra_options = config.get('options', {}) use_custom_endpoint = config.get('use_custom_endpoint', False) custom_endpoint_host = None @@ -319,7 +319,7 @@ def sa_setup(self, conn_utils, create_secret, request, create_custom_endpoint=No host = conn_utils.writer_host url = _build_url(user, password, host, conn_utils.port, conn_utils.dbname, - plugins=plugins_config, **extra_options) + wrapper_plugins=plugins_config, **extra_options) engine = create_engine(url) SessionLocal = sessionmaker(bind=engine) @@ -350,7 +350,7 @@ def test_sqlalchemy_basic_insert_with_plugins(self, test_environment: TestEnviro finally: session.close() - @pytest.mark.parametrize('sa_setup', [{'plugins': ''}], indirect=True) + @pytest.mark.parametrize('sa_setup', [{'wrapper_plugins': ''}], indirect=True) def test_sqlalchemy_with_no_plugins(self, test_environment: TestEnvironment, sa_models, sa_setup): """Test SQLAlchemy with no plugins enabled""" TestModel = sa_models['TestModel'] @@ -369,7 +369,7 @@ def test_sqlalchemy_with_no_plugins(self, test_environment: TestEnvironment, sa_ finally: session.close() - @pytest.mark.parametrize('sa_setup', [{'plugins': 'failover_v2'}], indirect=True) + @pytest.mark.parametrize('sa_setup', [{'wrapper_plugins': 'failover_v2'}], indirect=True) def test_sqlalchemy_with_failover_only(self, test_environment: TestEnvironment, sa_models, sa_setup): """Test SQLAlchemy with only failover plugin""" TestModel = sa_models['TestModel'] @@ -388,7 +388,7 @@ def test_sqlalchemy_with_failover_only(self, test_environment: TestEnvironment, finally: session.close() - @pytest.mark.parametrize('sa_setup', [{'plugins': 'aurora_connection_tracker,failover_v2'}], indirect=True) + @pytest.mark.parametrize('sa_setup', [{'wrapper_plugins': 'aurora_connection_tracker,failover_v2'}], indirect=True) def test_sqlalchemy_with_multiple_plugins(self, test_environment: TestEnvironment, sa_models, sa_setup): """Test SQLAlchemy with multiple plugins enabled""" TestModel = sa_models['TestModel'] @@ -408,7 +408,7 @@ def test_sqlalchemy_with_multiple_plugins(self, test_environment: TestEnvironmen session.close() @pytest.mark.parametrize('sa_setup', [{ - 'plugins': 'aws_secrets_manager', + 'wrapper_plugins': 'aws_secrets_manager', 'use_secrets_manager': True }], indirect=True) def test_sqlalchemy_with_secrets_manager_plugin(self, test_environment: TestEnvironment, sa_setup, sa_models): @@ -434,7 +434,7 @@ def test_sqlalchemy_with_secrets_manager_plugin(self, test_environment: TestEnvi session.close() @pytest.mark.parametrize('sa_setup', [{ - 'plugins': 'iam', + 'wrapper_plugins': 'iam', 'password': '', 'options': {} }], indirect=True) @@ -460,7 +460,7 @@ def test_sqlalchemy_with_iam_plugin(self, test_environment: TestEnvironment, sa_ session.close() @pytest.mark.parametrize('sa_setup', [{ - 'plugins': 'failover_v2', + 'wrapper_plugins': 'failover_v2', 'options': { 'socket_timeout': 10, 'connect_timeout': 10, @@ -510,7 +510,7 @@ def test_sqlalchemy_failover_during_query(self, test_environment: TestEnvironmen ''' @pytest.mark.parametrize('sa_setup', [{ - 'plugins': 'custom_endpoint,failover_v2', + 'wrapper_plugins': 'custom_endpoint,failover_v2', 'use_custom_endpoint': True, 'options': { 'socket_timeout': 10, From d5b2b82e79eff449bbdd3a75e8cdf93dc441b872 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Mon, 11 May 2026 14:14:25 -0700 Subject: [PATCH 22/39] Remove wrapper_plugins from opts after processing it --- aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index 0abb0f2cf..f59ad0a40 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -63,6 +63,7 @@ def create_connect_args(self, url): opts['plugins'] = "aurora_connection_tracker,failover" else: opts['plugins'] = opts['wrapper_plugins'] + opts.pop('wrapper_plugins', None) # Return empty args list and kwargs dict return [], opts From f955a9cdba4a718e85bb2b14b8e0bd25c73af00d Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Wed, 13 May 2026 15:33:04 -0700 Subject: [PATCH 23/39] Try to fix mypy issues --- .../sqlalchemy/test_sqlalchemy_plugins.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py index f8c8593d8..b0dd9ef1b 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py @@ -113,11 +113,11 @@ class Book(Base): author: Mapped[Author] = relationship(back_populates='books') -def _build_url(user, password, host, port, dbname, plugins=None, **extra_options): +def _build_url(user, password, host, port, dbname, wrapper_plugins=None, **extra_options): """Build a SQLAlchemy connection URL using the aws wrapper dialect.""" query_params = {} - if plugins: - query_params['plugins'] = plugins + if wrapper_plugins: + query_params['wrapper_plugins'] = wrapper_plugins query_params['connect_timeout'] = str(extra_options.get('connect_timeout', 10)) for k, v in extra_options.items(): if k != 'connect_timeout': @@ -263,6 +263,7 @@ def sa_models(self, sa_setup): engine = sa_setup['engine'] test_id = str(uuid.uuid4())[:8] + breakpoint() Base.metadata.create_all(engine, tables=[ TestModel.__table__, DataTypeModel.__table__, Author.__table__, Book.__table__ @@ -342,7 +343,7 @@ def test_sqlalchemy_basic_insert_with_plugins(self, test_environment: TestEnviro assert obj.id is not None assert obj.name == "Plugin Test User" - retrieved = session.get(TestModel, obj.id) + retrieved: TestModel = session.get(TestModel, obj.id) assert retrieved.name == "Plugin Test User" session.query(TestModel).delete() @@ -425,7 +426,7 @@ def test_sqlalchemy_with_secrets_manager_plugin(self, test_environment: TestEnvi session.commit() assert obj.id is not None - retrieved = session.get(TestModel, obj.id) + retrieved: TestModel = session.get(TestModel, obj.id) assert retrieved.email == "secrets@example.com" session.query(TestModel).delete() @@ -451,7 +452,7 @@ def test_sqlalchemy_with_iam_plugin(self, test_environment: TestEnvironment, sa_ session.commit() assert obj.id is not None - retrieved = session.get(TestModel, obj.id) + retrieved: TestModel = session.get(TestModel, obj.id) assert retrieved.email == "iam@example.com" session.query(TestModel).delete() @@ -499,7 +500,7 @@ def test_sqlalchemy_failover_during_query(self, test_environment: TestEnvironmen assert result.name == "Failover Test User" row = session.execute(text(RdsTestUtility.get_instance_id_query())).fetchone() - current_writer_id = row[0] + current_writer_id = row._tuple()[0] assert rds_utils.is_db_instance_writer(current_writer_id) is True assert current_writer_id != initial_writer_id From c6958bcd3f1aaa8ca68d9bf370fc369b7a238643 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Wed, 13 May 2026 16:44:23 -0700 Subject: [PATCH 24/39] Try fixing one mypy error --- .../container/sqlalchemy/test_sqlalchemy_plugins.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py index b0dd9ef1b..00422091f 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py @@ -343,8 +343,8 @@ def test_sqlalchemy_basic_insert_with_plugins(self, test_environment: TestEnviro assert obj.id is not None assert obj.name == "Plugin Test User" - retrieved: TestModel = session.get(TestModel, obj.id) - assert retrieved.name == "Plugin Test User" + retrieved: TestModel? = session.get(TestModel, obj.id) + assert retrieved && retrieved.name == "Plugin Test User" session.query(TestModel).delete() session.commit() From 1480c48e45d1ab31dfc8a6ec987ad9c0ee524b63 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Wed, 13 May 2026 17:06:04 -0700 Subject: [PATCH 25/39] Fix syntax error --- .../container/sqlalchemy/test_sqlalchemy_plugins.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py index 00422091f..0fb5cd48f 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py @@ -343,8 +343,8 @@ def test_sqlalchemy_basic_insert_with_plugins(self, test_environment: TestEnviro assert obj.id is not None assert obj.name == "Plugin Test User" - retrieved: TestModel? = session.get(TestModel, obj.id) - assert retrieved && retrieved.name == "Plugin Test User" + retrieved = session.get(TestModel, obj.id) + assert retrieved and retrieved.name == "Plugin Test User" session.query(TestModel).delete() session.commit() From 0f9859d01a64520e8dc500b02fc4076935cf4c0f Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Wed, 13 May 2026 17:16:54 -0700 Subject: [PATCH 26/39] Fix retrieved variable types --- .../container/sqlalchemy/test_sqlalchemy_plugins.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py index 0fb5cd48f..46120d86f 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py @@ -426,8 +426,8 @@ def test_sqlalchemy_with_secrets_manager_plugin(self, test_environment: TestEnvi session.commit() assert obj.id is not None - retrieved: TestModel = session.get(TestModel, obj.id) - assert retrieved.email == "secrets@example.com" + retrieved = session.get(TestModel, obj.id) + assert retrieved and retrieved.email == "secrets@example.com" session.query(TestModel).delete() session.commit() @@ -452,8 +452,8 @@ def test_sqlalchemy_with_iam_plugin(self, test_environment: TestEnvironment, sa_ session.commit() assert obj.id is not None - retrieved: TestModel = session.get(TestModel, obj.id) - assert retrieved.email == "iam@example.com" + retrieved = session.get(TestModel, obj.id) + assert retrieved and retrieved.email == "iam@example.com" session.query(TestModel).delete() session.commit() From c75e9f775ab217f75d903b36a81a04fd832de6bb Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Wed, 13 May 2026 17:32:26 -0700 Subject: [PATCH 27/39] Fix mypy error about row --- .../container/sqlalchemy/test_sqlalchemy_plugins.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py index 46120d86f..d0e88200b 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py @@ -500,7 +500,10 @@ def test_sqlalchemy_failover_during_query(self, test_environment: TestEnvironmen assert result.name == "Failover Test User" row = session.execute(text(RdsTestUtility.get_instance_id_query())).fetchone() - current_writer_id = row._tuple()[0] + if row: + current_writer_id = row._tuple()[0] + else: + raise Exception("Failed to get current_writer_id from row because row was None.") assert rds_utils.is_db_instance_writer(current_writer_id) is True assert current_writer_id != initial_writer_id From 771b4b055930bd2ec8e877a19f3ab38f93c935de Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Wed, 13 May 2026 18:02:08 -0700 Subject: [PATCH 28/39] Try to fix mypy errors in mysql_orm_dialect.py --- aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index 245cb76d9..b32e9e2b8 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -19,9 +19,11 @@ MySQLDialect_mysqlconnector import mysql.connector +from mysql.connector import CMySQLConnection from sqlalchemy.engine import default from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.utils.properties import Properties, PropertiesUtils @@ -68,10 +70,10 @@ def create_connect_args(self, url): # Return empty args list and kwargs dict return [], opts - def _detect_charset(self, connection: Connection) -> str: + def _detect_charset(self, connection: CMySQLConnection) -> str: return connection.charset - def _extract_error_code(self, exception: BaseException) -> int: + def _extract_error_code(self, exception: AwsWrapperError) -> int: return exception.driver_error.errno def initialize(self, connection): From 7941fe0a80722be7fd8aac5df010e6ae2b043e10 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Wed, 13 May 2026 18:13:54 -0700 Subject: [PATCH 29/39] Try to fix LSP violation errors --- .../sqlalchemy/mysql_orm_dialect.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index b32e9e2b8..999cfb433 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -70,11 +70,17 @@ def create_connect_args(self, url): # Return empty args list and kwargs dict return [], opts - def _detect_charset(self, connection: CMySQLConnection) -> str: - return connection.charset + def _detect_charset(self, connection: Connection) -> str: + if isinstance(connection, CMySQLConnection): + return connection.charset + else: + raise Exception("Could not detect charset because connection was not a CMySQLConnection.") - def _extract_error_code(self, exception: AwsWrapperError) -> int: - return exception.driver_error.errno + def _extract_error_code(self, exception: BaseException) -> int: + if isinstance(exception, AwsWrapperError): + return exception.driver_error.errno + else: + raise Exception("Could not extract error code because exception was not an AwsWrapperError.") def initialize(self, connection): """ From 159b6c3329e36b3f2d7edb26b3980f061738b0c0 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Wed, 13 May 2026 18:21:09 -0700 Subject: [PATCH 30/39] Try to fix mypy error for missing errno field --- aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index 999cfb433..02d48f2ac 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -77,7 +77,7 @@ def _detect_charset(self, connection: Connection) -> str: raise Exception("Could not detect charset because connection was not a CMySQLConnection.") def _extract_error_code(self, exception: BaseException) -> int: - if isinstance(exception, AwsWrapperError): + if isinstance(exception, AwsWrapperError) and isinstance(exception.driver_error, BaseException): return exception.driver_error.errno else: raise Exception("Could not extract error code because exception was not an AwsWrapperError.") From 7e2bb1c848ba2aebb49351044e6a940008bc4558 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Thu, 14 May 2026 10:48:51 -0700 Subject: [PATCH 31/39] Fix last mypy error --- .../sqlalchemy/mysql_orm_dialect.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index 02d48f2ac..f03064a7b 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -77,8 +77,12 @@ def _detect_charset(self, connection: Connection) -> str: raise Exception("Could not detect charset because connection was not a CMySQLConnection.") def _extract_error_code(self, exception: BaseException) -> int: - if isinstance(exception, AwsWrapperError) and isinstance(exception.driver_error, BaseException): - return exception.driver_error.errno + if isinstance(exception, AwsWrapperError): + err = exception.driver_error + if err and isinstance(err, BaseException): + return exception.driver_error.errno + else: + raise Exception("Could not extract error code because driver_error was not a BaseException.") else: raise Exception("Could not extract error code because exception was not an AwsWrapperError.") From edd4d02b6baa4205324e4efb6bab275581e237d6 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Thu, 14 May 2026 10:55:03 -0700 Subject: [PATCH 32/39] Use err variable's errno property --- aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index f03064a7b..857e18f5c 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -80,7 +80,7 @@ def _extract_error_code(self, exception: BaseException) -> int: if isinstance(exception, AwsWrapperError): err = exception.driver_error if err and isinstance(err, BaseException): - return exception.driver_error.errno + return err.errno else: raise Exception("Could not extract error code because driver_error was not a BaseException.") else: From 9f80c4eabe4f7fefd4b51362e95c541702958625 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Thu, 14 May 2026 11:19:31 -0700 Subject: [PATCH 33/39] Check errno property on correct type --- aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index 857e18f5c..8789eb113 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -20,6 +20,7 @@ import mysql.connector from mysql.connector import CMySQLConnection +from mysql.connector.errors import Error from sqlalchemy.engine import default from aws_advanced_python_wrapper import AwsWrapperConnection @@ -79,7 +80,7 @@ def _detect_charset(self, connection: Connection) -> str: def _extract_error_code(self, exception: BaseException) -> int: if isinstance(exception, AwsWrapperError): err = exception.driver_error - if err and isinstance(err, BaseException): + if err and isinstance(err, Error): return err.errno else: raise Exception("Could not extract error code because driver_error was not a BaseException.") From f0d7d9180b27cea037fee28d2a9b4559e00c3da2 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Thu, 14 May 2026 13:42:15 -0700 Subject: [PATCH 34/39] Address flake8 errors --- .../sqlalchemy/mysql_orm_dialect.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index 8789eb113..0835a5a63 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -13,8 +13,8 @@ # limitations under the License. from typing import Optional +from typing import TYPE_CHECKING -from sqlalchemy import Connection from sqlalchemy.dialects.mysql.mysqlconnector import \ MySQLDialect_mysqlconnector @@ -25,9 +25,12 @@ from aws_advanced_python_wrapper import AwsWrapperConnection from aws_advanced_python_wrapper.errors import AwsWrapperError -from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.utils.properties import Properties, PropertiesUtils +if TYPE_CHECKING: + from sqlalchemy import Connection + from aws_advanced_python_wrapper.hostinfo import HostInfo + class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector): """ SQLAlchemy dialect for AWS Advanced Python Wrapper with mysqlconnector. Extends the SQLAlchemy MySQL mysqlconnector dialect. From cfebe5b70f65e07dab54b2462a4dd178787cfc9f Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Thu, 14 May 2026 14:04:37 -0700 Subject: [PATCH 35/39] Add annotations import --- aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index 0835a5a63..199afa8b1 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import Optional from typing import TYPE_CHECKING @@ -31,6 +33,7 @@ from sqlalchemy import Connection from aws_advanced_python_wrapper.hostinfo import HostInfo + class SqlAlchemyOrmMysqlDialect(MySQLDialect_mysqlconnector): """ SQLAlchemy dialect for AWS Advanced Python Wrapper with mysqlconnector. Extends the SQLAlchemy MySQL mysqlconnector dialect. From b45cd2799c924dae9afc8f3444850b40680c0ff8 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Thu, 14 May 2026 14:08:06 -0700 Subject: [PATCH 36/39] Run isort --- .../sqlalchemy/mysql_orm_dialect.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index 199afa8b1..5d3eeb264 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -14,23 +14,23 @@ from __future__ import annotations -from typing import Optional -from typing import TYPE_CHECKING - -from sqlalchemy.dialects.mysql.mysqlconnector import \ - MySQLDialect_mysqlconnector +from typing import TYPE_CHECKING, Optional import mysql.connector from mysql.connector import CMySQLConnection from mysql.connector.errors import Error +from sqlalchemy.dialects.mysql.mysqlconnector import \ + MySQLDialect_mysqlconnector from sqlalchemy.engine import default from aws_advanced_python_wrapper import AwsWrapperConnection from aws_advanced_python_wrapper.errors import AwsWrapperError -from aws_advanced_python_wrapper.utils.properties import Properties, PropertiesUtils +from aws_advanced_python_wrapper.utils.properties import (Properties, + PropertiesUtils) if TYPE_CHECKING: from sqlalchemy import Connection + from aws_advanced_python_wrapper.hostinfo import HostInfo From 88dfb8ea6a091d01ea0f2ffb46f9e459ef2b0809 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Thu, 14 May 2026 14:11:51 -0700 Subject: [PATCH 37/39] Run isort on tests --- .../container/sqlalchemy/test_sqlalchemy_plugins.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py index d0e88200b..366b622f8 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py @@ -18,8 +18,8 @@ import json import uuid -from decimal import Decimal from datetime import date, datetime, time, timezone +from decimal import Decimal from time import perf_counter_ns, sleep from typing import Any, ClassVar, Dict, List, Optional @@ -27,11 +27,12 @@ import pytest from boto3 import client from botocore.exceptions import ClientError -from sqlalchemy import (Boolean, BigInteger, Column, Date, DateTime, Float, - ForeignKey, Integer, JSON, Numeric, SmallInteger, String, Text, - Time, create_engine, text) +from sqlalchemy import (JSON, BigInteger, Boolean, Column, Date, DateTime, + Float, ForeignKey, Integer, Numeric, SmallInteger, + String, Text, Time, create_engine, text) from sqlalchemy.exc import DBAPIError -from sqlalchemy.orm import DeclarativeBase, Session, relationship, sessionmaker, Mapped, mapped_column +from sqlalchemy.orm import (DeclarativeBase, Mapped, Session, mapped_column, + relationship, sessionmaker) from aws_advanced_python_wrapper.errors import FailoverSuccessError from tests.integration.container.utils.rds_test_utility import RdsTestUtility From c99960226f73f509ebd806d8fa661c79c22a2867 Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Thu, 14 May 2026 17:30:13 -0700 Subject: [PATCH 38/39] Move int cast for connect_timeout to fix unit tests --- aws_advanced_python_wrapper/mysql_driver_dialect.py | 2 +- aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/aws_advanced_python_wrapper/mysql_driver_dialect.py b/aws_advanced_python_wrapper/mysql_driver_dialect.py index 8db96197d..132025f7a 100644 --- a/aws_advanced_python_wrapper/mysql_driver_dialect.py +++ b/aws_advanced_python_wrapper/mysql_driver_dialect.py @@ -197,7 +197,7 @@ def prepare_connect_info(self, host_info: HostInfo, original_props: Properties) connect_timeout = WrapperProperties.CONNECT_TIMEOUT_SEC.get(original_props) if connect_timeout is not None: - driver_props["connect_timeout"] = int(connect_timeout) + driver_props["connect_timeout"] = connect_timeout return driver_props diff --git a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py index 5d3eeb264..84275dd24 100644 --- a/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py +++ b/aws_advanced_python_wrapper/sqlalchemy/mysql_orm_dialect.py @@ -73,6 +73,8 @@ def create_connect_args(self, url): else: opts['plugins'] = opts['wrapper_plugins'] opts.pop('wrapper_plugins', None) + if 'connect_timeout' in opts: + opts['connect_timeout'] = int(opts['connect_timeout']) # Return empty args list and kwargs dict return [], opts From 9d2f642dc99e1e4ed970324c4b02970b7a1fa8fb Mon Sep 17 00:00:00 2001 From: Jonathan Louie Date: Fri, 15 May 2026 11:02:31 -0700 Subject: [PATCH 39/39] Remove breakpoint call --- .../integration/container/sqlalchemy/test_sqlalchemy_plugins.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py index 366b622f8..0ab30d287 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py @@ -264,7 +264,6 @@ def sa_models(self, sa_setup): engine = sa_setup['engine'] test_id = str(uuid.uuid4())[:8] - breakpoint() Base.metadata.create_all(engine, tables=[ TestModel.__table__, DataTypeModel.__table__, Author.__table__, Book.__table__