diff --git a/aws_advanced_python_wrapper/django/__init__.py b/aws_advanced_python_wrapper/django/__init__.py new file mode 100644 index 00000000..bd4acb2b --- /dev/null +++ b/aws_advanced_python_wrapper/django/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/aws_advanced_python_wrapper/django/backends/__init__.py b/aws_advanced_python_wrapper/django/backends/__init__.py new file mode 100644 index 00000000..bd4acb2b --- /dev/null +++ b/aws_advanced_python_wrapper/django/backends/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/aws_advanced_python_wrapper/django/backends/mysql_connector/__init__.py b/aws_advanced_python_wrapper/django/backends/mysql_connector/__init__.py new file mode 100644 index 00000000..bd4acb2b --- /dev/null +++ b/aws_advanced_python_wrapper/django/backends/mysql_connector/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/aws_advanced_python_wrapper/django/backends/mysql_connector/base.py b/aws_advanced_python_wrapper/django/backends/mysql_connector/base.py new file mode 100644 index 00000000..ff533d71 --- /dev/null +++ b/aws_advanced_python_wrapper/django/backends/mysql_connector/base.py @@ -0,0 +1,78 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import mysql.connector +import mysql.connector.django.base as base +from django.utils.asyncio import async_unsafe +from django.utils.functional import cached_property +from django.utils.regex_helper import _lazy_re_compile + +from aws_advanced_python_wrapper import AwsWrapperConnection + +# This should match the numerical portion of the version numbers (we can treat +# versions like 5.0.24 and 5.0.24a as the same). +server_version_re = _lazy_re_compile(r"(\d{1,2})\.(\d{1,2})\.(\d{1,2})") + + +class DatabaseWrapper(base.DatabaseWrapper): + """Custom MySQL Connector backend for Django""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._read_only = False + + @async_unsafe + def get_new_connection(self, conn_params): + if "converter_class" not in conn_params: + conn_params["converter_class"] = base.DjangoMySQLConverter + conn = AwsWrapperConnection.connect( + mysql.connector.Connect, + **conn_params + ) + + if not self._read_only: + return conn + else: + conn.read_only = True + return conn + + def get_connection_params(self): + kwargs = super().get_connection_params() + self._read_only = kwargs.pop("read_only", False) + return kwargs + + @cached_property + def mysql_server_info(self): + return self.mysql_server_data["version"] + + @cached_property + def mysql_version(self): + match = server_version_re.match(self.mysql_server_info) + if not match: + raise Exception( + "Unable to determine MySQL version from version string %r" + % self.mysql_server_info + ) + return tuple(int(x) for x in match.groups()) + + @cached_property + def mysql_is_mariadb(self): + return "mariadb" in self.mysql_server_info.lower() + + @cached_property + def sql_mode(self): + sql_mode = self.mysql_server_data["sql_mode"] + return set(sql_mode.split(",") if sql_mode else ()) diff --git a/aws_advanced_python_wrapper/wrapper.py b/aws_advanced_python_wrapper/wrapper.py index 111c87e2..64162cc2 100644 --- a/aws_advanced_python_wrapper/wrapper.py +++ b/aws_advanced_python_wrapper/wrapper.py @@ -267,6 +267,11 @@ def rowcount(self) -> int: def arraysize(self) -> int: return self.target_cursor.arraysize + # Optional for PEP249 + @property + def lastrowid(self) -> int: + return self.target_cursor.lastrowid # type: ignore[attr-defined] + def close(self) -> None: self._plugin_manager.execute(self.target_cursor, DbApiMethod.CURSOR_CLOSE, lambda: self.target_cursor.close()) diff --git a/docs/examples/MySQLDjangoFailover.py b/docs/examples/MySQLDjangoFailover.py new file mode 100644 index 00000000..deab8c13 --- /dev/null +++ b/docs/examples/MySQLDjangoFailover.py @@ -0,0 +1,233 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Django ORM Failover Example with AWS Advanced Python Wrapper + +This example demonstrates how to handle failover events when using Django ORM +with the AWS Advanced Python Wrapper. + +""" + +import django +from django.conf import settings +from django.db import connection, models + +from aws_advanced_python_wrapper import release_resources +from aws_advanced_python_wrapper.errors import ( + FailoverFailedError, FailoverSuccessError, + TransactionResolutionUnknownError) + +# Django settings configuration +DJANGO_SETTINGS = { + 'DATABASES': { + 'default': { + 'ENGINE': 'aws_advanced_python_wrapper.django.backends.mysql_connector', + 'NAME': 'test_db', + 'USER': 'admin', + 'PASSWORD': 'password', + 'HOST': 'database.cluster-xyz.us-east-1.rds.amazonaws.com', + 'PORT': 3306, + 'OPTIONS': { + 'plugins': 'failover', + 'connect_timeout': 10, + 'autocommit': True, + }, + }, + }, +} + +# Configure Django settings +if not settings.configured: + settings.configure(**DJANGO_SETTINGS) +django.setup() + + +class BankAccount(models.Model): + """Example model for demonstrating failover handling.""" + name: str = models.CharField(max_length=100) # type: ignore[assignment] + account_balance: int = models.IntegerField() # type: ignore[assignment] + + class Meta: + app_label = 'myapp' + db_table = 'bank_test' + + def __str__(self) -> str: + return f"{self.name}: ${self.account_balance}" + + +def execute_query_with_failover_handling(query_func): + """ + Execute a Django ORM query with failover error handling. + + Args: + query_func: A callable that executes the desired query + + Returns: + The result of the query function + """ + try: + return query_func() + + except FailoverSuccessError: + # Query execution failed and AWS Advanced Python Wrapper successfully failed over to an available instance. + # https://github.com/aws/aws-advanced-python-wrapper/blob/main/docs/using-the-python-driver/using-plugins/UsingTheFailoverPlugin.md#failoversuccesserror + + # The connection has been re-established. Retry the query. + print("Failover successful! Retrying query...") + + # Retry the query + return query_func() + + except FailoverFailedError as e: + # Failover failed. The application should open a new connection, + # check the results of the failed transaction and re-run it if needed. + # https://github.com/aws/aws-advanced-python-wrapper/blob/main/docs/using-the-python-driver/using-plugins/UsingTheFailoverPlugin.md#failoverfailederror + print(f"Failover failed: {e}") + print("Application should open a new connection and retry the transaction.") + raise e + + except TransactionResolutionUnknownError as e: + # The transaction state is unknown. The application should check the status + # of the failed transaction and restart it if needed. + # https://github.com/aws/aws-advanced-python-wrapper/blob/main/docs/using-the-python-driver/using-plugins/UsingTheFailoverPlugin.md#transactionresolutionunknownerror + print(f"Transaction resolution unknown: {e}") + print("Application should check transaction status and retry if needed.") + raise e + + +def create_table(): + """Create the database table with failover handling.""" + def _create(): + with connection.cursor() as cursor: + cursor.execute(""" + CREATE TABLE IF NOT EXISTS bank_test ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100), + account_balance INT + ) + """) + print("Table created successfully") + + execute_query_with_failover_handling(_create) + + +def drop_table(): + """Drop the database table with failover handling.""" + def _drop(): + with connection.cursor() as cursor: + cursor.execute("DROP TABLE IF EXISTS bank_test") + print("Table dropped successfully") + + execute_query_with_failover_handling(_drop) + + +def insert_records(): + """Insert records with failover handling.""" + print("\n--- Inserting Records ---") + + def _insert1(): + account = BankAccount.objects.create(name="Jane Doe", account_balance=200) + print(f"Inserted: {account}") + return account + + def _insert2(): + account = BankAccount.objects.create(name="John Smith", account_balance=200) + print(f"Inserted: {account}") + return account + + execute_query_with_failover_handling(_insert1) + execute_query_with_failover_handling(_insert2) + + +def query_records(): + """Query records with failover handling.""" + print("\n--- Querying Records ---") + + def _query(): + accounts = list(BankAccount.objects.all()) + for account in accounts: + print(f" {account}") + return accounts + + return execute_query_with_failover_handling(_query) + + +def update_record(): + """Update a record with failover handling.""" + print("\n--- Updating Record ---") + + def _update(): + account = BankAccount.objects.filter(name="Jane Doe").first() + if account: + account.account_balance = 300 + account.save() + print(f"Updated: {account}") + return account + + return execute_query_with_failover_handling(_update) + + +def filter_records(): + """Filter records with failover handling.""" + print("\n--- Filtering Records ---") + + def _filter(): + accounts = list(BankAccount.objects.filter(account_balance__gte=250)) + print(f"Found {len(accounts)} accounts with balance >= $250:") + for account in accounts: + print(f" {account}") + return accounts + + return execute_query_with_failover_handling(_filter) + + +if __name__ == "__main__": + try: + print("Django ORM Failover Example with AWS Advanced Python Wrapper") + print("=" * 60) + + # Create table + create_table() + + # Insert records + insert_records() + + # Query records + query_records() + + # Update a record + update_record() + + # Query again to see the update + query_records() + + # Filter records + filter_records() + + # Cleanup + print("\n--- Cleanup ---") + drop_table() + + print("\n" + "=" * 60) + print("Example completed successfully!") + + except Exception as e: + print(f"Error: {e}") + import traceback + traceback.print_exc() + + finally: + # Clean up AWS Advanced Python Wrapper resources + release_resources() diff --git a/docs/examples/MySQLDjangoReadWriteSplitting.py b/docs/examples/MySQLDjangoReadWriteSplitting.py new file mode 100644 index 00000000..8ae5c23c --- /dev/null +++ b/docs/examples/MySQLDjangoReadWriteSplitting.py @@ -0,0 +1,345 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Django ORM Read/Write Splitting Example with AWS Advanced Python Wrapper + +This example demonstrates how to use the AWS Advanced Python Wrapper with Django ORM +to leverage Aurora features like failover handling and read/write splitting with +internal connection pooling. +""" + +from typing import TYPE_CHECKING, Any, Dict + +import django +from django.conf import settings +from django.db import connection, models + +from aws_advanced_python_wrapper import release_resources +from aws_advanced_python_wrapper.connection_provider import \ + ConnectionProviderManager +from aws_advanced_python_wrapper.errors import ( + FailoverFailedError, FailoverSuccessError, + TransactionResolutionUnknownError) +from aws_advanced_python_wrapper.sql_alchemy_connection_provider import \ + SqlAlchemyPooledConnectionProvider + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.hostinfo import HostInfo + + +def configure_pool(host_info: "HostInfo", props: Dict[str, Any]) -> Dict[str, Any]: + """Configure connection pool settings for each host.""" + return {"pool_size": 5} + + +def get_pool_key(host_info: "HostInfo", props: Dict[str, Any]) -> str: + """ + Generate a unique key for connection pooling. + Include the URL, user, and database in the connection pool key so that a new + connection pool will be opened for each different instance-user-database combination. + """ + url = host_info.url + user = props.get("user", "") + db = props.get("database", "") + return f"{url}{user}{db}" + + +# Database connection configuration +DB_CONFIG = { + 'ENGINE': 'aws_advanced_python_wrapper.django.backends.mysql_connector', + 'NAME': 'test_db', + 'USER': 'admin', + 'PASSWORD': 'password', + 'HOST': 'database.cluster-xyz.us-east-1.rds.amazonaws.com', + 'PORT': 3306, +} + +# Django settings configuration +DJANGO_SETTINGS = { + 'DATABASES': { + 'default': { # Writer connection + **DB_CONFIG, + 'OPTIONS': { + 'plugins': 'read_write_splitting,failover', + 'connect_timeout': 10, + 'autocommit': True, + }, + }, + 'read': { # Reader connection + **DB_CONFIG, + 'OPTIONS': { + 'plugins': 'read_write_splitting,failover', + 'connect_timeout': 10, + 'autocommit': True, + 'read_only': True, # This connection will use reader instances + 'reader_host_selector_strategy': 'least_connections', + }, + }, + }, + 'DATABASE_ROUTERS': ['__main__.ReadWriteRouter'], +} + + +# Database Router for Read/Write Splitting +class ReadWriteRouter: + """ + A router to control database operations for read/write splitting. + """ + + def db_for_read(self, model, **hints): + """ + Direct all read operations to the 'read' database. + """ + return 'read' + + def db_for_write(self, model, **hints): + """ + Direct all write operations to the 'default' database. + """ + return 'default' + + def allow_relation(self, obj1, obj2, **hints): + """ + Allow relations between objects in the same database. + """ + return True + + def allow_migrate(self, db, app_label, model_name=None, **hints): + """ + Allow migrations on all databases. + """ + return True + + +# Configure Django settings +if not settings.configured: + settings.configure(**DJANGO_SETTINGS) +django.setup() + + +class BankAccount(models.Model): + """Example model for demonstrating read/write splitting.""" + name: str = models.CharField(max_length=100) # type: ignore[assignment] + account_balance: int = models.IntegerField() # type: ignore[assignment] + + class Meta: + app_label = 'myapp' + db_table = 'bank_accounts' + + def __str__(self) -> str: + return f"{self.name}: ${self.account_balance}" + + +def execute_query_with_failover_handling(query_func): + """ + Execute a Django ORM query with failover error handling. + + Args: + query_func: A callable that executes the desired query + + Returns: + The result of the query function + """ + try: + return query_func() + + except FailoverSuccessError: + # Query execution failed and AWS Advanced Python Wrapper successfully failed over to an available instance. + # https://github.com/aws/aws-advanced-python-wrapper/blob/main/docs/using-the-python-driver/using-plugins/UsingTheFailoverPlugin.md#failoversuccesserror + + # The connection has been re-established. Retry the query. + print("Failover successful! Retrying query...") + + # Retry the query + return query_func() + + except FailoverFailedError as e: + # Failover failed. The application should open a new connection, + # check the results of the failed transaction and re-run it if needed. + # https://github.com/aws/aws-advanced-python-wrapper/blob/main/docs/using-the-python-driver/using-plugins/UsingTheFailoverPlugin.md#failoverfailederror + print(f"Failover failed: {e}") + print("Application should open a new connection and retry the transaction.") + raise e + + except TransactionResolutionUnknownError as e: + # The transaction state is unknown. The application should check the status + # of the failed transaction and restart it if needed. + # https://github.com/aws/aws-advanced-python-wrapper/blob/main/docs/using-the-python-driver/using-plugins/UsingTheFailoverPlugin.md#transactionresolutionunknownerror + print(f"Transaction resolution unknown: {e}") + print("Application should check transaction status and retry if needed.") + raise e + + +def create_table(): + """Create the database table with failover handling.""" + def _create(): + with connection.cursor() as cursor: + cursor.execute(""" + CREATE TABLE IF NOT EXISTS bank_accounts ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100), + account_balance INT + ) + """) + print("Table created successfully") + + execute_query_with_failover_handling(_create) + + +def drop_table(): + """Drop the database table with failover handling.""" + def _drop(): + with connection.cursor() as cursor: + cursor.execute("DROP TABLE IF EXISTS bank_accounts") + print("Table dropped successfully") + + execute_query_with_failover_handling(_drop) + + +def demonstrate_write_operations(): + """Demonstrate write operations with failover handling (uses 'default' database - writer instance).""" + print("\n--- Write Operations (Writer Instance) ---") + + # Create new records with failover handling + def _create1(): + account = BankAccount.objects.create(name="Jane Doe", account_balance=1000) + print(f"Created: {account}") + return account + + def _create2(): + account = BankAccount.objects.create(name="John Smith", account_balance=1500) + print(f"Created: {account}") + return account + + account1 = execute_query_with_failover_handling(_create1) + execute_query_with_failover_handling(_create2) + + # Update a record with failover handling + def _update(): + account1.account_balance = 1200 + account1.save() + print(f"Updated: {account1}") + return account1 + + execute_query_with_failover_handling(_update) + + +def demonstrate_read_operations(): + """Demonstrate read operations with failover handling (uses 'read' database - reader instance).""" + print("\n--- Read Operations (Reader Instance) ---") + + # Query all records with failover handling + def _query_all(): + accounts = list(BankAccount.objects.all()) + print(f"Total accounts: {len(accounts)}") + for account in accounts: + print(f" {account}") + return accounts + + execute_query_with_failover_handling(_query_all) + + # Filter records with failover handling + def _query_filtered(): + high_balance = list(BankAccount.objects.filter(account_balance__gte=1200)) + print("\nAccounts with balance >= $1200:") + for account in high_balance: + print(f" {account}") + return high_balance + + execute_query_with_failover_handling(_query_filtered) + + +def demonstrate_explicit_database_selection(): + """Demonstrate explicitly selecting which database to use with failover handling.""" + print("\n--- Explicit Database Selection ---") + + # Force read from writer database with failover handling + def _read_from_writer(): + print("Reading from writer (default) database:") + accounts = list(BankAccount.objects.using('default').all()) + for account in accounts: + print(f" {account}") + return accounts + + execute_query_with_failover_handling(_read_from_writer) + + # Force read from reader database with failover handling + def _read_from_reader(): + print("\nReading from reader database:") + accounts = list(BankAccount.objects.using('read').all()) + for account in accounts: + print(f" {account}") + return accounts + + execute_query_with_failover_handling(_read_from_reader) + + +def demonstrate_raw_sql(): + """Demonstrate raw SQL queries with Django and failover handling.""" + print("\n--- Raw SQL Queries ---") + + # Execute raw SQL query with failover handling + def _raw_query(): + accounts = list(BankAccount.objects.raw('SELECT * FROM bank_accounts WHERE account_balance > %s', [1000])) + print("Accounts with balance > $1000:") + for account in accounts: + print(f" {account}") + return accounts + + execute_query_with_failover_handling(_raw_query) + + +if __name__ == "__main__": + # Configure read/write splitting to use internal connection pools. + provider = SqlAlchemyPooledConnectionProvider(configure_pool, get_pool_key) + ConnectionProviderManager.set_connection_provider(provider) + + try: + print("Django ORM Read/Write Splitting Example with AWS Advanced Python Wrapper") + print("=" * 60) + + # Create table + create_table() + + # Demonstrate write operations (uses writer instance) + demonstrate_write_operations() + + # Demonstrate read operations (uses reader instance) + demonstrate_read_operations() + + # Demonstrate explicit database selection + demonstrate_explicit_database_selection() + + # Demonstrate raw SQL + demonstrate_raw_sql() + + # Cleanup + print("\n--- Cleanup ---") + drop_table() + + print("\n" + "=" * 60) + print("Example completed successfully!") + + except Exception as e: + print(f"Error: {e}") + import traceback + traceback.print_exc() + + finally: + # Clean up connection pools + ConnectionProviderManager.release_resources() + + # Clean up AWS Advanced Python Wrapper resources + release_resources() diff --git a/docs/using-the-python-driver/DjangoSupport.md b/docs/using-the-python-driver/DjangoSupport.md new file mode 100644 index 00000000..9ff2d96a --- /dev/null +++ b/docs/using-the-python-driver/DjangoSupport.md @@ -0,0 +1,259 @@ +# Django ORM Support + +> [!IMPORTANT] +> Django ORM support is currently only available for **MySQL databases**. + +The AWS ADvanced Python Wrapper provides a custom Django database backend that enables Django applications to leverage AWS and Aurora functionalities such as failover handling, IAM authentication, and read/write splitting. + +## Prerequisites + +- Django 3.2+ + +## Basic Configuration + +To use the AWS ADvanced Python Wrapper with Django, configure your database settings in `settings.py` to use the custom backend in the `ENGINE` parameter, as well as any wrapper-specific properties in the `OPTIONS` parameter: + +```python +DATABASES = { + 'default': { + 'ENGINE': 'aws_advanced_python_wrapper.django.backends.mysql_connector', + 'NAME': 'your_database_name', + 'USER': 'your_username', + 'PASSWORD': 'your_password', + 'HOST': 'your-cluster-endpoint.cluster-xyz.us-east-1.rds.amazonaws.com', + 'PORT': 3306, + 'OPTIONS': { + 'plugins': 'failover,iam', + 'connect_timeout': 10, + 'autocommit': True, + }, + } +} +``` + +### Supported Engines + +| Underlying Driver | Database Dialect | Engine Value | +|-------------------|------------------|-------------| +| `mysql-connector-python` | MySQL | `'aws_advanced_python_wrapper.django.backends.mysql_connector'` | + + +### OPTIONS Properties + +The `OPTIONS` dictionary supports all standard [AWS ADvanced Python Wrapper parameters](./UsingThePythonDriver.md#aws-advanced-python-driver-parameters) as well as parameters for the underlying driver. + +For a complete list of available plugins and their supported parameters, see the [List of Available Plugins](./UsingThePythonDriver.md#list-of-available-plugins). + +## Using Plugins with Django + +The AWS ADvanced Python Wrapper supports a variety of plugins that enhance your Django application with features like failover handling, IAM authentication, and more. Most plugins can be enabled simply by adding them to the `plugins` parameter in your database `OPTIONS`. + +For a complete list of available plugins, see the [List of Available Plugins](./UsingThePythonDriver.md#list-of-available-plugins) in the main driver documentation. + +Below are two examples of plugins that require additional setup or code changes in your Django application: + +### Failover Plugin + +The Failover Plugin provides automatic failover handling for Aurora clusters. When a database instance becomes unavailable, the plugin automatically connects to a healthy instance in the cluster. + +```python +DATABASES = { + 'default': { + 'ENGINE': 'aws_advanced_python_wrapper.django.backends.mysql_connector', + 'NAME': 'mydb', + 'USER': 'admin', + 'PASSWORD': 'password', + 'HOST': 'my-cluster.cluster-xyz.us-east-1.rds.amazonaws.com', + 'PORT': 3306, + 'OPTIONS': { + 'plugins': 'failover,host_monitoring_v2', + 'connect_timeout': 10, + 'autocommit': True, + }, + } +} +``` + +#### Handling Failover Events + +During a failover event, the driver will throw a `FailoverSuccessError` exception after successfully connecting to a new instance. Your application should catch this exception and retry the failed query: + +```python +from aws_advanced_python_wrapper.errors import FailoverSuccessError + +def execute_query_with_failover_handling(query_func): + try: + return query_func() + except FailoverSuccessError: + # Failover successful, retry the query + return query_func() +``` + +For a complete example, see [MySQLDjangoFailover.py](../examples/MySQLDjangoFailover.py). + +For more information about the Failover Plugin, see the [Failover Plugin documentation](./using-plugins/UsingTheFailoverPlugin.md). + + +## Read/Write Splitting with Django + +The Read/Write Splitting Plugin enables Django applications to automatically route read queries to reader instances and write queries to writer instances in an Aurora cluster. This plugin requires additional configuration to set up multiple database connections and a database router. + +### Configuration + +To use read/write splitting with Django: + +1. Configure multiple database connections (one for writes, one for reads) +2. Set the `read_only` parameter to `True` for the reader connection to ensure all connection objects route to reader instances +3. Create a database router to direct queries to the appropriate connection + +#### Step 1: Configure Database Connections + +```python +DATABASES = { + 'default': { # Writer connection + 'ENGINE': 'aws_advanced_python_wrapper.django.backends.mysql_connector', + 'NAME': 'mydb', + 'USER': 'admin', + 'PASSWORD': 'password', + 'HOST': 'my-cluster.cluster-xyz.us-east-1.rds.amazonaws.com', + 'PORT': 3306, + 'OPTIONS': { + 'plugins': 'read_write_splitting,failover', + 'connect_timeout': 10, + 'autocommit': True, + }, + }, + 'read': { # Reader connection + 'ENGINE': 'aws_advanced_python_wrapper.django.backends.mysql_connector', + 'NAME': 'mydb', + 'USER': 'admin', + 'PASSWORD': 'password', + 'HOST': 'my-cluster.cluster-xyz.us-east-1.rds.amazonaws.com', + 'PORT': 3306, + 'OPTIONS': { + 'plugins': 'read_write_splitting,failover', + 'connect_timeout': 10, + 'autocommit': True, + 'read_only': True, # This connection will use reader instances + }, + }, +} +``` + +#### Step 2: Create a Database Router + +Create a database router to direct read operations to the `read` database and write operations to the `default` database: + +```python +# myapp/routers.py + +class ReadWriteRouter: + """ + A router to control database operations for read/write splitting. + """ + + def db_for_read(self, model, **hints): + """ + Direct all read operations to the 'read' database. + """ + return 'read' + + def db_for_write(self, model, **hints): + """ + Direct all write operations to the 'default' database. + """ + return 'default' + + def allow_relation(self, obj1, obj2, **hints): + """ + Allow relations between objects in the same database. + """ + return True + + def allow_migrate(self, db, app_label, model_name=None, **hints): + """ + Allow migrations on all databases. + """ + return True +``` + +#### Step 3: Register the Router + +Add the router to your Django settings: + +```python +DATABASE_ROUTERS = ['myapp.routers.ReadWriteRouter'] +``` + +### Using Read/Write Splitting + +Once configured, Django will automatically route queries: + +```python +from myapp.models import MyModel + +# This will use the 'read' database (reader instance) +objects = MyModel.objects.all() + +# This will use the 'default' database (writer instance) +MyModel.objects.create(name='New Object') + +# You can also explicitly specify which database to use +MyModel.objects.using('default').all() # Force use of writer +MyModel.objects.using('read').all() # Force use of reader +``` + +### Connection Strategies + +By default, the Read/Write Splitting Plugin randomly selects a reader instance. You can configure different connection strategies using the `reader_host_selector_strategy` parameter: + +```python +'OPTIONS': { + 'plugins': 'read_write_splitting,failover', + 'reader_host_selector_strategy': 'least_connections', + 'read_only': True, +} +``` + +For a complete example including connection pooling, see [MySQLDjangoReadWriteSplitting.py](../examples/MySQLDjangoReadWriteSplitting.py). + +For a list of available connection strategies, see the [Read/Write Splitting Plugin documentation](./using-plugins/UsingTheReadWriteSplittingPlugin.md#connection-strategies). + + +## Resource Management + +The AWS ADvanced Python Wrapper creates background threads for monitoring and connection management. To ensure proper cleanup when your Django application shuts down, add cleanup code to your application's shutdown process: + +```python +# In your Django app's apps.py + +from django.apps import AppConfig +from aws_advanced_python_wrapper import release_resources + +class MyAppConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'myapp' + + def ready(self): + # Register cleanup on application shutdown + import atexit + atexit.register(release_resources) +``` + +Or in your WSGI/ASGI application file: + +```python +# wsgi.py or asgi.py + +import os +from django.core.wsgi import get_wsgi_application +from aws_advanced_python_wrapper import release_resources +import atexit + +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'myproject.settings') + +application = get_wsgi_application() + +# Register cleanup +atexit.register(release_resources) +``` diff --git a/poetry.lock b/poetry.lock index 0ced2492..52d29937 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,23 @@ # This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +[[package]] +name = "asgiref" +version = "3.11.0" +description = "ASGI specs, helper code, and adapters" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "asgiref-3.11.0-py3-none-any.whl", hash = "sha256:1db9021efadb0d9512ce8ffaf72fcef601c7b73a8807a1bb2ef143dc6b14846d"}, + {file = "asgiref-3.11.0.tar.gz", hash = "sha256:13acff32519542a1736223fb79a715acdebe24286d98e8b164a73085f40da2c4"}, +] + +[package.dependencies] +typing_extensions = {version = ">=4", markers = "python_version < \"3.11\""} + +[package.extras] +tests = ["mypy (>=1.14.0)", "pytest", "pytest-asyncio"] + [[package]] name = "aws-xray-sdk" version = "2.15.0" @@ -411,6 +429,67 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "django" +version = "5.2.10" +description = "A high-level Python web framework that encourages rapid development and clean, pragmatic design." +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "django-5.2.10-py3-none-any.whl", hash = "sha256:cf85067a64250c95d5f9067b056c5eaa80591929f7e16fbcd997746e40d6c45c"}, + {file = "django-5.2.10.tar.gz", hash = "sha256:74df100784c288c50a2b5cad59631d71214f40f72051d5af3fdf220c20bdbbbe"}, +] + +[package.dependencies] +asgiref = ">=3.8.1" +sqlparse = ">=0.3.1" +tzdata = {version = "*", markers = "sys_platform == \"win32\""} + +[package.extras] +argon2 = ["argon2-cffi (>=19.1.0)"] +bcrypt = ["bcrypt"] + +[[package]] +name = "django-stubs" +version = "5.2.8" +description = "Mypy stubs for Django" +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "django_stubs-5.2.8-py3-none-any.whl", hash = "sha256:a3c63119fd7062ac63d58869698d07c9e5ec0561295c4e700317c54e8d26716c"}, + {file = "django_stubs-5.2.8.tar.gz", hash = "sha256:9bba597c9a8ed8c025cae4696803d5c8be1cf55bfc7648a084cbf864187e2f8b"}, +] + +[package.dependencies] +django = "*" +django-stubs-ext = ">=5.2.8" +tomli = {version = "*", markers = "python_full_version < \"3.11.0\""} +types-pyyaml = "*" +typing-extensions = ">=4.11.0" + +[package.extras] +compatible-mypy = ["mypy (>=1.13,<1.20)"] +oracle = ["oracledb"] +redis = ["redis", "types-redis"] + +[[package]] +name = "django-stubs-ext" +version = "5.2.8" +description = "Monkey-patching and extensions for django-stubs" +optional = false +python-versions = ">=3.10" +groups = ["dev"] +files = [ + {file = "django_stubs_ext-5.2.8-py3-none-any.whl", hash = "sha256:1dd5470c9675591362c78a157a3cf8aec45d0e7a7f0cf32f227a1363e54e0652"}, + {file = "django_stubs_ext-5.2.8.tar.gz", hash = "sha256:b39938c46d7a547cd84e4a6378dbe51a3dd64d70300459087229e5fee27e5c6b"}, +] + +[package.dependencies] +django = "*" +typing-extensions = "*" + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -1588,6 +1667,22 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] sqlcipher = ["sqlcipher3_binary"] +[[package]] +name = "sqlparse" +version = "0.5.5" +description = "A non-validating SQL parser." +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "sqlparse-0.5.5-py3-none-any.whl", hash = "sha256:12a08b3bf3eec877c519589833aed092e2444e68240a3577e8e26148acc7b1ba"}, + {file = "sqlparse-0.5.5.tar.gz", hash = "sha256:e20d4a9b0b8585fdf63b10d30066c7c94c5d7a7ec47c889a2d83a3caa93ff28e"}, +] + +[package.extras] +dev = ["build"] +doc = ["sphinx"] + [[package]] name = "tabulate" version = "0.9.0" @@ -2097,6 +2192,18 @@ workspaces-thin-client = ["types-boto3-workspaces-thin-client (>=1.40.0,<1.41.0) workspaces-web = ["types-boto3-workspaces-web (>=1.40.0,<1.41.0)"] xray = ["types-boto3-xray (>=1.40.0,<1.41.0)"] +[[package]] +name = "types-pyyaml" +version = "6.0.12.20250915" +description = "Typing stubs for PyYAML" +optional = false +python-versions = ">=3.9" +groups = ["dev"] +files = [ + {file = "types_pyyaml-6.0.12.20250915-py3-none-any.whl", hash = "sha256:e7d4d9e064e89a3b3cae120b4990cd370874d2bf12fa5f46c97018dd5d3c9ab6"}, + {file = "types_pyyaml-6.0.12.20250915.tar.gz", hash = "sha256:0f8b54a528c303f0e6f7165687dd33fafa81c807fcac23f632b63aa624ced1d3"}, +] + [[package]] name = "types-s3transfer" version = "0.13.0" @@ -2255,4 +2362,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = "^3.10.0" -content-hash = "5c9676388fe69de1cd60d813b75285505ccbc9e872168e83f9af0d7130d7cb75" +content-hash = "c6c2fd4bf7806ed4e13b325fed01a2b92eac1c41cb745bdc481423bf98e26dae" diff --git a/pyproject.toml b/pyproject.toml index 84ba6836..cf0164ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,8 @@ SQLAlchemy = "^2.0.30" psycopg = "^3.3.1" psycopg-binary = "^3.3.1" mysql-connector-python = "^9.5.0" +django = "^5.0" +django-stubs = "^5.2.8" [tool.poetry.group.test.dependencies] boto3 = "^1.34.111" diff --git a/tests/integration/container/django/__init__.py b/tests/integration/container/django/__init__.py new file mode 100644 index 00000000..bd4acb2b --- /dev/null +++ b/tests/integration/container/django/__init__.py @@ -0,0 +1,13 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/integration/container/django/test_django_basic.py b/tests/integration/container/django/test_django_basic.py new file mode 100644 index 00000000..05ead4af --- /dev/null +++ b/tests/integration/container/django/test_django_basic.py @@ -0,0 +1,995 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa: N806 +# N806: Django model class references stored in variables use PascalCase intentionally + +from __future__ import annotations + +from datetime import date, datetime, time +from decimal import Decimal +from typing import Any + +import django +import pytest +from django.conf import settings +from django.db import connection, connections, models, transaction +from django.db.models import Avg, CharField, Count, F, Max, Min, Q, Sum, Value +from django.db.models.functions import Concat, Length, Lower, Upper +from django.test.utils import setup_test_environment, teardown_test_environment +from django.utils import timezone + +from tests.integration.container.utils.rds_test_utility import RdsTestUtility +from ..utils.conditions import (disable_on_features, enable_on_deployments, + enable_on_engines) +from ..utils.database_engine import DatabaseEngine +from ..utils.database_engine_deployment import DatabaseEngineDeployment +from ..utils.test_environment import TestEnvironment +from ..utils.test_environment_features import TestEnvironmentFeatures + + +@enable_on_engines([DatabaseEngine.MYSQL]) # MySQL Specific until PG is implemented +@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestDjango: + TestModel: Any + DataTypeModel: Any + Author: Any + Book: Any + + @pytest.fixture(scope='class') + def rds_utils(self): + region: str = TestEnvironment.get_current().get_info().get_region() + return RdsTestUtility(region) + + @pytest.fixture(scope='class') + def django_models(self, django_setup): + """Create Django models after Django is set up""" + + class TestModel(models.Model): + """Basic test model for Django ORM functionality""" + name = models.CharField(max_length=100) + email = models.EmailField(unique=True) + age = models.IntegerField() + is_active = models.BooleanField(default=True) + created_at = models.DateTimeField(auto_now_add=True) + + class Meta: + app_label = 'test_app' + db_table = 'django_test_model' + + class DataTypeModel(models.Model): + """Model for testing various data types""" + # String fields + char_field = models.CharField(max_length=255, null=True, blank=True) + text_field = models.TextField(null=True, blank=True) + + # Numeric fields + integer_field = models.IntegerField(null=True, blank=True) + big_integer_field = models.BigIntegerField(null=True, blank=True) + decimal_field = models.DecimalField(max_digits=10, decimal_places=2, null=True, blank=True) + float_field = models.FloatField(null=True, blank=True) + + # Boolean field + boolean_field = models.BooleanField(default=False) + + # Date/Time fields + date_field = models.DateField(null=True, blank=True) + time_field = models.TimeField(null=True, blank=True) + datetime_field = models.DateTimeField(null=True, blank=True) + + # JSON field (MySQL 5.7+) + json_field = models.JSONField(null=True, blank=True) + + class Meta: + app_label = 'test_app' + db_table = 'django_data_type_model' + + class Author(models.Model): + """Author model for relationship testing""" + name = models.CharField(max_length=100) + email = models.EmailField() + birth_date = models.DateField(null=True, blank=True) + + class Meta: + app_label = 'test_app' + db_table = 'django_author' + + # Store Author first so it's available for Book's ForeignKey + TestDjango.Author = Author + + class Book(models.Model): + """Book model for relationship testing""" + title = models.CharField(max_length=200) + author = models.ForeignKey(TestDjango.Author, on_delete=models.CASCADE, related_name='books') + publication_date = models.DateField() + pages = models.IntegerField() + price = models.DecimalField(max_digits=8, decimal_places=2) + + class Meta: + app_label = 'test_app' + db_table = 'django_book' + + # Store models as class attributes for easy access + TestDjango.TestModel = TestModel + TestDjango.DataTypeModel = DataTypeModel + TestDjango.Book = Book + + # Create tables for our test models + with connection.schema_editor() as schema_editor: + schema_editor.create_model(TestModel) + schema_editor.create_model(DataTypeModel) + schema_editor.create_model(Author) + schema_editor.create_model(Book) + + yield + + # Clean up tables + with connection.schema_editor() as schema_editor: + schema_editor.delete_model(Book) + schema_editor.delete_model(Author) + schema_editor.delete_model(DataTypeModel) + schema_editor.delete_model(TestModel) + + @pytest.fixture(scope='class') + def django_setup(self, conn_utils): + """Setup Django configuration for testing""" + # Configure Django settings + if not settings.configured: + db_config = { + 'ENGINE': 'aws_advanced_python_wrapper.django.backends.mysql_connector', + 'NAME': conn_utils.dbname, + 'USER': conn_utils.user, + 'PASSWORD': conn_utils.password, + 'HOST': conn_utils.writer_cluster_host, + 'PORT': conn_utils.port, + 'OPTIONS': { + 'plugins': 'failover,aurora_connection_tracker', + 'connect_timeout': 10, + 'autocommit': True, + }, + } + + settings.configure( + DEBUG=True, + DATABASES={'default': db_config}, + INSTALLED_APPS=[ + 'django.contrib.contenttypes', + 'django.contrib.auth', + ], + SECRET_KEY='test-secret-key-for-django-tests', + USE_TZ=True, + ) + + django.setup() + setup_test_environment() + + yield + connections.close_all() + + teardown_test_environment() + + def test_django_backend_configuration(self, test_environment: TestEnvironment, django_models): + """Test Django backend configuration with empty plugins""" + # Verify that the connection is using the AWS wrapper + assert hasattr(connection, 'connection') + + # Test basic connection functionality + assert self.TestModel.objects.count() == 0 + + def test_django_basic_model_operations(self, test_environment: TestEnvironment, django_models): + """Test basic Django ORM operations (CRUD)""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create + test_obj = TestModel.objects.create( + name="John Doe", + email="john@example.com", + age=30, + is_active=True + ) + assert test_obj.id is not None + assert test_obj.name == "John Doe" + + # Read + retrieved_obj = TestModel.objects.get(id=test_obj.id) + assert retrieved_obj.name == "John Doe" + assert retrieved_obj.email == "john@example.com" + assert retrieved_obj.age == 30 + assert retrieved_obj.is_active is True + + # Update + retrieved_obj.name = "Jane Doe" + retrieved_obj.age = 25 + retrieved_obj.save() + + updated_obj = TestModel.objects.get(id=test_obj.id) + assert updated_obj.name == "Jane Doe" + assert updated_obj.age == 25 + + # Delete + updated_obj.delete() + assert TestModel.objects.filter(id=test_obj.id).count() == 0 + + def test_django_queryset_operations(self, test_environment: TestEnvironment, django_models): + """Test Django QuerySet operations""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) + + # Test filtering + active_users = TestModel.objects.filter(is_active=True) + assert active_users.count() == 2 + + # Test ordering + ordered_users = TestModel.objects.order_by('age') + ages = [user.age for user in ordered_users] + assert ages == [25, 30, 35] + + # Test complex queries + young_active_users = TestModel.objects.filter(age__lt=30, is_active=True) + assert young_active_users.count() == 1 + assert young_active_users.first().name == "Alice" + + # Test exclude + non_bob_users = TestModel.objects.exclude(name="Bob") + assert non_bob_users.count() == 2 + + # Test exists + assert TestModel.objects.filter(name="Alice").exists() + assert not TestModel.objects.filter(name="David").exists() + + # Clean up + TestModel.objects.all().delete() + + def test_django_data_types(self, test_environment: TestEnvironment, django_models): + """Test Django ORM with various data types""" + DataTypeModel = self.DataTypeModel + + # Ensure clean slate + DataTypeModel.objects.all().delete() + + # Create test data with various data types + test_datetime = datetime(2023, 12, 25, 14, 30, 0) + test_datetime_aware = timezone.make_aware(test_datetime) + + test_data = DataTypeModel.objects.create( + char_field="Test String", + text_field="This is a longer text field content", + integer_field=42, + big_integer_field=9223372036854775807, + decimal_field=Decimal('123.45'), + float_field=3.14159, + boolean_field=True, + date_field=date(2023, 12, 25), + time_field=time(14, 30, 0), + datetime_field=test_datetime_aware, # Use timezone-aware datetime + json_field={"key": "value", "number": 123, "array": [1, 2, 3]} + ) + + # Retrieve and verify data + retrieved = DataTypeModel.objects.get(id=test_data.id) + + assert retrieved.char_field == "Test String" + assert retrieved.text_field == "This is a longer text field content" + assert retrieved.integer_field == 42 + assert retrieved.big_integer_field == 9223372036854775807 + assert retrieved.decimal_field == Decimal('123.45') + assert abs(retrieved.float_field - 3.14159) < 0.001 + assert retrieved.boolean_field is True + assert retrieved.date_field == date(2023, 12, 25) + assert retrieved.time_field == time(14, 30, 0) + # Compare timezone-aware datetimes + assert retrieved.datetime_field == test_datetime_aware + assert retrieved.json_field == {"key": "value", "number": 123, "array": [1, 2, 3]} + + # Clean up + DataTypeModel.objects.all().delete() + + def test_django_null_values(self, test_environment: TestEnvironment, django_models): + """Test Django ORM handling of NULL values""" + DataTypeModel = self.DataTypeModel + + # First, ensure we start with a clean slate + DataTypeModel.objects.all().delete() + + # Create object with NULL values + test_obj = DataTypeModel.objects.create( + char_field=None, + integer_field=None, + date_field=None, + boolean_field=False # This field has default=False, so it won't be NULL + ) + + # Retrieve and verify NULL values + retrieved = DataTypeModel.objects.get(id=test_obj.id) + assert retrieved.char_field is None + assert retrieved.integer_field is None + assert retrieved.date_field is None + assert retrieved.boolean_field is False + + # Test filtering with NULL values + null_char_objects = DataTypeModel.objects.filter(char_field__isnull=True) + assert null_char_objects.count() == 1 + + not_null_char_objects = DataTypeModel.objects.filter(char_field__isnull=False) + assert not_null_char_objects.count() == 0 + + # Create an object with non-NULL values to test the opposite + DataTypeModel.objects.create( + char_field="Not NULL", + integer_field=42, + date_field=date(2023, 1, 1) + ) + + # Now test filtering again + null_char_objects = DataTypeModel.objects.filter(char_field__isnull=True) + assert null_char_objects.count() == 1 # Still one NULL object + + not_null_char_objects = DataTypeModel.objects.filter(char_field__isnull=False) + assert not_null_char_objects.count() == 1 # Now one non-NULL object + + # Clean up + DataTypeModel.objects.all().delete() + + def test_django_relationships(self, test_environment: TestEnvironment, django_models): + """Test Django ORM relationships (ForeignKey)""" + Author = self.Author + Book = self.Book + + # Create author + author = Author.objects.create( + name="J.K. Rowling", + email="jk@example.com", + birth_date=date(1965, 7, 31) + ) + + # Create books + book1 = Book.objects.create( + title="Harry Potter and the Philosopher's Stone", + author=author, + publication_date=date(1997, 6, 26), + pages=223, + price=Decimal('12.99') + ) + + book2 = Book.objects.create( + title="Harry Potter and the Chamber of Secrets", + author=author, + publication_date=date(1998, 7, 2), + pages=251, + price=Decimal('13.99') + ) + + # Test forward relationship + assert book1.author.name == "J.K. Rowling" + assert book2.author.email == "jk@example.com" + + # Test reverse relationship + author_books = author.books.all() + assert author_books.count() == 2 + book_titles = [book.title for book in author_books.order_by('publication_date')] + assert "Harry Potter and the Philosopher's Stone" in book_titles + assert "Harry Potter and the Chamber of Secrets" in book_titles + + # Test related queries + books_by_author = Book.objects.filter(author__name="J.K. Rowling") + assert books_by_author.count() == 2 + + # Test select_related for optimization + book_with_author = Book.objects.select_related('author').get(id=book1.id) + assert book_with_author.author.name == "J.K. Rowling" + + # Clean up + Book.objects.all().delete() + Author.objects.all().delete() + + def test_django_aggregations(self, test_environment: TestEnvironment, django_models): + """Test Django ORM aggregations""" + Author = self.Author + Book = self.Book + + # Create test data + author = Author.objects.create(name="Test Author", email="test@example.com") + + Book.objects.create(title="Book 1", author=author, publication_date=date(2020, 1, 1), pages=100, price=Decimal('10.00')) + Book.objects.create(title="Book 2", author=author, publication_date=date(2021, 1, 1), pages=200, price=Decimal('20.00')) + Book.objects.create(title="Book 3", author=author, publication_date=date(2022, 1, 1), pages=300, price=Decimal('30.00')) + + # Test aggregations + stats = Book.objects.aggregate( + total_books=Count('id'), + total_pages=Sum('pages'), + avg_price=Avg('price'), + max_pages=Max('pages'), + min_price=Min('price') + ) + + assert stats['total_books'] == 3 + assert stats['total_pages'] == 600 + assert abs(float(stats['avg_price']) - 20.0) < 0.01 + assert stats['max_pages'] == 300 + assert stats['min_price'] == Decimal('10.00') + + # Clean up + Book.objects.all().delete() + Author.objects.all().delete() + + def test_django_transactions(self, test_environment: TestEnvironment, django_models): + """Test Django transaction handling""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + initial_count = TestModel.objects.count() + + # Test successful transaction + with transaction.atomic(): + TestModel.objects.create(name="User 1", email="user1@example.com", age=25) + TestModel.objects.create(name="User 2", email="user2@example.com", age=30) + + assert TestModel.objects.count() == initial_count + 2 + + # Test rollback transaction + try: + with transaction.atomic(): + TestModel.objects.create(name="User 3", email="user3@example.com", age=35) + TestModel.objects.create(name="User 4", email="user4@example.com", age=40) + # Force an error to trigger rollback + raise Exception("Force rollback") + except Exception: + pass # Expected exception + + # Should still have only 2 additional records (rollback occurred) + assert TestModel.objects.count() == initial_count + 2 + + # Clean up + TestModel.objects.all().delete() + + def test_django_bulk_operations(self, test_environment: TestEnvironment, django_models): + """Test Django bulk operations""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Test bulk_create + test_objects = [ + TestModel(name=f"User {i}", email=f"user{i}@example.com", age=20 + i) + for i in range(10) + ] + + created_objects = TestModel.objects.bulk_create(test_objects) + assert len(created_objects) == 10 + assert TestModel.objects.count() == 10 + + # Test bulk_update - need to get the objects first and modify them + objects_to_update = list(TestModel.objects.all()) + for obj in objects_to_update: + obj.age += 5 + + TestModel.objects.bulk_update(objects_to_update, ['age']) + + # Verify updates - get fresh objects from database + ages = list(TestModel.objects.values_list('age', flat=True).order_by('name')) + expected_ages = [25 + i for i in range(10)] # 20+i+5 for i in range(10) + assert ages == expected_ages + + # Clean up + TestModel.objects.all().delete() + + def test_django_complex_queries(self, test_environment: TestEnvironment, django_models): + """Test complex Django queries with Q objects and F expressions""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) + TestModel.objects.create(name="David", email="david@example.com", age=28, is_active=True) + + # Test Q objects for complex conditions + complex_query = TestModel.objects.filter( + Q(age__gte=30) | Q(name__startswith='A') + ) + assert complex_query.count() == 3 # Bob (30), Charlie (35), Alice (starts with A) + + # Test F expressions + TestModel.objects.filter(age__lt=30).update(age=F('age') + 5) + + # Verify F expression update + alice = TestModel.objects.get(name="Alice") + david = TestModel.objects.get(name="David") + assert alice.age == 30 # 25 + 5 + assert david.age == 33 # 28 + 5 + + # Clean up, might get a failover error from this connection + TestModel.objects.all().delete() + + def test_django_raw_sql_queries(self, test_environment: TestEnvironment, django_models): + """Test Django raw SQL query execution""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) + + # Test raw() method + raw_results = TestModel.objects.raw( + f'SELECT * FROM {TestModel._meta.db_table} WHERE age >= %s ORDER BY age', + [30] + ) + raw_list = list(raw_results) + assert len(raw_list) == 2 + assert raw_list[0].name == "Bob" + assert raw_list[1].name == "Charlie" + + # Test connection.cursor() for custom SQL + with connection.cursor() as cursor: + cursor.execute( + f'SELECT name, age FROM {TestModel._meta.db_table} WHERE is_active = %s ORDER BY age', + [True] + ) + rows = cursor.fetchall() + assert len(rows) == 2 + assert rows[0][0] == "Alice" # name + assert rows[0][1] == 25 # age + assert rows[1][0] == "Charlie" + assert rows[1][1] == 35 + + # Test raw SQL with connection for aggregate + with connection.cursor() as cursor: + cursor.execute(f'SELECT COUNT(*), AVG(age) FROM {TestModel._meta.db_table}') + count, avg_age = cursor.fetchone() + assert count == 3 + assert abs(float(avg_age) - 30.0) < 0.01 + + # Clean up + TestModel.objects.all().delete() + + def test_django_get_or_create(self, test_environment: TestEnvironment, django_models): + """Test Django get_or_create pattern""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Test create case + obj1, created1 = TestModel.objects.get_or_create( + email="test@example.com", + defaults={'name': 'Test User', 'age': 25, 'is_active': True} + ) + assert created1 is True + assert obj1.name == "Test User" + assert obj1.age == 25 + + # Test get case (object already exists) + obj2, created2 = TestModel.objects.get_or_create( + email="test@example.com", + defaults={'name': 'Different Name', 'age': 30, 'is_active': False} + ) + assert created2 is False + assert obj2.id == obj1.id + assert obj2.name == "Test User" # Should keep original values + assert obj2.age == 25 + + # Verify only one object exists + assert TestModel.objects.filter(email="test@example.com").count() == 1 + + # Clean up + TestModel.objects.all().delete() + + def test_django_update_or_create(self, test_environment: TestEnvironment, django_models): + """Test Django update_or_create pattern""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Test create case + obj1, created1 = TestModel.objects.update_or_create( + email="update@example.com", + defaults={'name': 'Initial Name', 'age': 25, 'is_active': True} + ) + assert created1 is True + assert obj1.name == "Initial Name" + assert obj1.age == 25 + + # Test update case (object already exists) + obj2, created2 = TestModel.objects.update_or_create( + email="update@example.com", + defaults={'name': 'Updated Name', 'age': 30, 'is_active': False} + ) + assert created2 is False + assert obj2.id == obj1.id + assert obj2.name == "Updated Name" # Should be updated + assert obj2.age == 30 + assert obj2.is_active is False + + # Verify only one object exists + assert TestModel.objects.filter(email="update@example.com").count() == 1 + + # Verify the update persisted + retrieved = TestModel.objects.get(email="update@example.com") + assert retrieved.name == "Updated Name" + assert retrieved.age == 30 + + # Clean up + TestModel.objects.all().delete() + + def test_django_prefetch_related(self, test_environment: TestEnvironment, django_models): + """Test Django prefetch_related for optimizing queries""" + Author = self.Author + Book = self.Book + + # Create test data + author1 = Author.objects.create(name="Author 1", email="author1@example.com") + author2 = Author.objects.create(name="Author 2", email="author2@example.com") + + Book.objects.create(title="Book 1A", author=author1, publication_date=date(2020, 1, 1), pages=100, price=Decimal('10.00')) + Book.objects.create(title="Book 1B", author=author1, publication_date=date(2021, 1, 1), pages=200, price=Decimal('20.00')) + Book.objects.create(title="Book 2A", author=author2, publication_date=date(2022, 1, 1), pages=300, price=Decimal('30.00')) + + # Test prefetch_related + authors = Author.objects.prefetch_related('books').all() + + # Access related books (should not trigger additional queries due to prefetch) + for author in authors: + books = list(author.books.all()) + if author.name == "Author 1": + assert len(books) == 2 + book_titles = [book.title for book in books] + assert "Book 1A" in book_titles + assert "Book 1B" in book_titles + elif author.name == "Author 2": + assert len(books) == 1 + assert books[0].title == "Book 2A" + + # Clean up + Book.objects.all().delete() + Author.objects.all().delete() + + def test_django_database_functions(self, test_environment: TestEnvironment, django_models): + """Test Django database functions""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + TestModel.objects.create(name="alice", email="alice@example.com", age=25) + TestModel.objects.create(name="BOB", email="bob@example.com", age=30) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35) + + # Test Upper function + upper_names = TestModel.objects.annotate(upper_name=Upper('name')).values_list('upper_name', flat=True) + upper_list = list(upper_names) + assert "ALICE" in upper_list + assert "BOB" in upper_list + assert "CHARLIE" in upper_list + + # Test Lower function + lower_names = TestModel.objects.annotate(lower_name=Lower('name')).values_list('lower_name', flat=True) + lower_list = list(lower_names) + assert "alice" in lower_list + assert "bob" in lower_list + assert "charlie" in lower_list + + # Test Length function + name_lengths = TestModel.objects.annotate(name_length=Length('name')).filter(name_length__gte=5) + assert name_lengths.count() == 2 # "alice" (5) and "Charlie" (7) + + # Test Concat function + full_info = TestModel.objects.annotate( + full_info=Concat('name', Value(' - '), 'email', output_field=CharField()) + ).first() + assert ' - ' in full_info.full_info + assert '@example.com' in full_info.full_info + + # Clean up + TestModel.objects.all().delete() + + def test_django_annotations(self, test_environment: TestEnvironment, django_models): + """Test Django annotations with expressions""" + TestModel = self.TestModel + Book = self.Book + Author = self.Author + + # Create test data for TestModel + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) + + # Test annotate with F expression for calculations + test_with_age_plus_ten = TestModel.objects.annotate( + age_plus_ten=F('age') + 10 + ).order_by('age') + + # Verify calculation + first_obj = test_with_age_plus_ten.first() + assert first_obj.age_plus_ten == first_obj.age + 10 + assert first_obj.age_plus_ten == 35 # 25 + 10 + + # Create books for F expression testing + author = Author.objects.create(name="Test Author", email="test@example.com") + Book.objects.create(title="Book 1", author=author, publication_date=date(2020, 1, 1), pages=100, price=Decimal('10.00')) + Book.objects.create(title="Book 2", author=author, publication_date=date(2021, 1, 1), pages=200, price=Decimal('20.00')) + Book.objects.create(title="Book 3", author=author, publication_date=date(2022, 1, 1), pages=300, price=Decimal('30.00')) + + # Test annotate with F expression for price per page + books_with_price_per_page = Book.objects.annotate( + price_per_page=F('price') / F('pages') + ).order_by('price_per_page') + + # Verify calculation + first_book = books_with_price_per_page.first() + expected_price_per_page = float(first_book.price) / first_book.pages + assert abs(float(first_book.price_per_page) - expected_price_per_page) < 0.001 + + # Test filtering on annotated field - use a lower threshold to avoid precision issues + cheap_books = Book.objects.annotate( + price_per_page=F('price') / F('pages') + ).filter(price_per_page__lte=0.15) + assert cheap_books.count() == 3 # All books have price_per_page = 0.10 + + # Clean up + TestModel.objects.all().delete() + Book.objects.all().delete() + Author.objects.all().delete() + + def test_django_values_and_values_list(self, test_environment: TestEnvironment, django_models): + """Test Django values() and values_list() methods""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) + + # Test values() - returns list of dictionaries + values_result = TestModel.objects.values('name', 'age').order_by('age') + values_list = list(values_result) + assert len(values_list) == 3 + assert values_list[0] == {'name': 'Alice', 'age': 25} + assert values_list[1] == {'name': 'Bob', 'age': 30} + assert values_list[2] == {'name': 'Charlie', 'age': 35} + + # Test values_list() - returns list of tuples + values_list_result = TestModel.objects.values_list('name', 'age').order_by('age') + tuples_list = list(values_list_result) + assert len(tuples_list) == 3 + assert tuples_list[0] == ('Alice', 25) + assert tuples_list[1] == ('Bob', 30) + assert tuples_list[2] == ('Charlie', 35) + + # Test values_list() with flat=True - returns flat list + names = TestModel.objects.values_list('name', flat=True).order_by('name') + names_list = list(names) + assert names_list == ['Alice', 'Bob', 'Charlie'] + + # Test values() with filtering + active_users = TestModel.objects.filter(is_active=True).values('name', 'email') + active_list = list(active_users) + assert len(active_list) == 2 + active_names = [user['name'] for user in active_list] + assert 'Alice' in active_names + assert 'Charlie' in active_names + assert 'Bob' not in active_names + + # Clean up + TestModel.objects.all().delete() + + def test_django_distinct_queries(self, test_environment: TestEnvironment, django_models): + """Test Django distinct() functionality""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data with duplicate ages + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=25, is_active=True) + TestModel.objects.create(name="David", email="david@example.com", age=30, is_active=True) + + # Test distinct ages + distinct_ages = TestModel.objects.values_list('age', flat=True).distinct().order_by('age') + ages_list = list(distinct_ages) + assert ages_list == [25, 30] + + # Test distinct with multiple fields + distinct_age_status = TestModel.objects.values('age', 'is_active').distinct().order_by('age', 'is_active') + distinct_list = list(distinct_age_status) + assert len(distinct_list) == 3 # (25, True), (30, False), (30, True) + + # Test count with distinct + total_count = TestModel.objects.count() + distinct_age_count = TestModel.objects.values('age').distinct().count() + assert total_count == 4 + assert distinct_age_count == 2 + + # Clean up + TestModel.objects.all().delete() + + def test_django_only_and_defer(self, test_environment: TestEnvironment, django_models): + """Test Django only() and defer() for query optimization""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + obj = TestModel.objects.create( + name="Test User", + email="test@example.com", + age=30, + is_active=True + ) + + # Test only() - load only specific fields + obj_only = TestModel.objects.only('name', 'email').get(id=obj.id) + assert obj_only.name == "Test User" + assert obj_only.email == "test@example.com" + # Accessing deferred fields will trigger additional query, but should still work + assert obj_only.age == 30 + + # Test defer() - exclude specific fields from loading + obj_defer = TestModel.objects.defer('age', 'is_active').get(id=obj.id) + assert obj_defer.name == "Test User" + assert obj_defer.email == "test@example.com" + # Accessing deferred fields will trigger additional query, but should still work + assert obj_defer.age == 30 + + # Clean up + TestModel.objects.all().delete() + + def test_django_in_bulk(self, test_environment: TestEnvironment, django_models): + """Test Django in_bulk() for batch retrieval""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + obj1 = TestModel.objects.create(name="User 1", email="user1@example.com", age=25) + obj2 = TestModel.objects.create(name="User 2", email="user2@example.com", age=30) + obj3 = TestModel.objects.create(name="User 3", email="user3@example.com", age=35) + + # Test in_bulk with IDs (default behavior) + bulk_result = TestModel.objects.in_bulk([obj1.id, obj2.id, obj3.id]) + assert len(bulk_result) == 3 + assert bulk_result[obj1.id].name == "User 1" + assert bulk_result[obj2.id].name == "User 2" + assert bulk_result[obj3.id].name == "User 3" + + # Test in_bulk with all IDs (no list provided) + bulk_all = TestModel.objects.in_bulk() + assert len(bulk_all) == 3 + assert obj1.id in bulk_all + assert obj2.id in bulk_all + assert obj3.id in bulk_all + + # Test in_bulk with email field (unique field) + bulk_by_email = TestModel.objects.in_bulk( + ["user1@example.com", "user3@example.com"], + field_name='email' + ) + assert len(bulk_by_email) == 2 + assert bulk_by_email["user1@example.com"].name == "User 1" + assert bulk_by_email["user3@example.com"].name == "User 3" + + # Clean up + TestModel.objects.all().delete() + + def test_django_conditional_expressions(self, test_environment: TestEnvironment, django_models): + """Test Django Case/When conditional expressions""" + from django.db.models import Case, IntegerField, Value, When + + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + TestModel.objects.create(name="Alice", email="alice@example.com", age=25, is_active=True) + TestModel.objects.create(name="Bob", email="bob@example.com", age=30, is_active=False) + TestModel.objects.create(name="Charlie", email="charlie@example.com", age=35, is_active=True) + + # Test Case/When for conditional logic + results = TestModel.objects.annotate( + age_category=Case( + When(age__lt=30, then=Value('young')), + When(age__gte=30, age__lt=40, then=Value('middle')), + default=Value('senior'), + output_field=CharField() + ) + ).order_by('age') + + results_list = list(results) + assert results_list[0].age_category == 'young' # Alice, 25 + assert results_list[1].age_category == 'middle' # Bob, 30 + assert results_list[2].age_category == 'middle' # Charlie, 35 + + # Test Case/When with integer output + priority_results = TestModel.objects.annotate( + priority=Case( + When(is_active=True, age__lt=30, then=Value(1)), + When(is_active=True, then=Value(2)), + When(is_active=False, then=Value(3)), + default=Value(4), + output_field=IntegerField() + ) + ).order_by('priority', 'name') + + priority_list = list(priority_results) + assert priority_list[0].name == 'Alice' # priority 1: active and young + assert priority_list[1].name == 'Charlie' # priority 2: active but not young + assert priority_list[2].name == 'Bob' # priority 3: not active + + # Clean up + TestModel.objects.all().delete() + + def test_django_iterator(self, test_environment: TestEnvironment, django_models): + """Test Django iterator() for memory-efficient queries""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create test data + for i in range(20): + TestModel.objects.create( + name=f"User {i}", + email=f"user{i}@example.com", + age=20 + i + ) + + # Test iterator() - processes results without caching + count = 0 + for obj in TestModel.objects.iterator(): + assert obj.name.startswith("User") + count += 1 + assert count == 20 + + # Test iterator with chunk_size + count = 0 + for obj in TestModel.objects.iterator(chunk_size=5): + assert obj.email.endswith("@example.com") + count += 1 + assert count == 20 + + # Clean up + TestModel.objects.all().delete() diff --git a/tests/integration/container/django/test_django_plugins.py b/tests/integration/container/django/test_django_plugins.py new file mode 100644 index 00000000..a72e3a19 --- /dev/null +++ b/tests/integration/container/django/test_django_plugins.py @@ -0,0 +1,880 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa: N806 +# N806: Django model class references stored in variables use PascalCase intentionally + +from __future__ import annotations + +import json +import uuid +from time import perf_counter_ns, sleep +from typing import Any, ClassVar, Dict + +import boto3 +import django +import pytest +from boto3 import client +from botocore.exceptions import ClientError +from django.conf import settings +from django.db import connection, connections, models +from django.test.utils import setup_test_environment, teardown_test_environment + +from aws_advanced_python_wrapper.errors import FailoverSuccessError +from tests.integration.container.utils.rds_test_utility import RdsTestUtility +from ..utils.conditions import (disable_on_features, enable_on_deployments, + enable_on_engines, enable_on_features, + enable_on_num_instances) +from ..utils.database_engine import DatabaseEngine +from ..utils.database_engine_deployment import DatabaseEngineDeployment +from ..utils.test_environment import TestEnvironment +from ..utils.test_environment_features import TestEnvironmentFeatures + + +@enable_on_engines([DatabaseEngine.MYSQL]) # Django backends are MySQL-specific +@enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, + TestEnvironmentFeatures.PERFORMANCE]) +class TestDjangoPlugins: + # Type hints for dynamically created Django models (using Any to avoid mypy errors) + TestModel: Any + + # Class variables for custom endpoint + endpoint_id: ClassVar[str] = f"test-django-endpoint-{uuid.uuid4()}" + endpoint_info: ClassVar[Dict[str, Any]] = {} + reuse_existing_endpoint: ClassVar[bool] = False + + @pytest.fixture(scope='class') + def rds_utils(self): + region: str = TestEnvironment.get_current().get_info().get_region() + return RdsTestUtility(region) + + @pytest.fixture(scope='class') + def create_secret(self, conn_utils): + """Create a secret in AWS Secrets Manager with database credentials.""" + region = TestEnvironment.get_current().get_info().get_region() + client = boto3.client('secretsmanager', region_name=region) + env = TestEnvironment.get_current() + + secret_name = f"TestSecret-{uuid.uuid4()}" + + engine = "postgres" if env.get_engine() == "pg" else "mysql" + secret_value = { + "engine": engine, + "dbname": env.get_info().get_database_info().get_default_db_name(), + "host": env.get_info().get_database_info().get_cluster_endpoint(), + "username": conn_utils.user, + "password": conn_utils.password, + "description": "Test secret generated by integration tests." + } + + try: + response = client.create_secret( + Name=secret_name, + SecretString=json.dumps(secret_value) + ) + secret_arn = response['ARN'] + yield secret_name, secret_arn + finally: + try: + client.delete_secret( + SecretId=secret_name, + ForceDeleteWithoutRecovery=True + ) + except Exception: + pass + + @pytest.fixture(scope='class') + def create_custom_endpoint(self): + """Create a custom endpoint for testing""" + env_info = TestEnvironment.get_current().get_info() + region = env_info.get_region() + + rds_client = client('rds', region_name=region) + if not self.reuse_existing_endpoint: + instances = env_info.get_database_info().get_instances() + self._create_endpoint(rds_client, instances[0:1]) + + self._wait_until_endpoint_available(rds_client) + + yield + + if not self.reuse_existing_endpoint: + self._delete_endpoint(rds_client) + + rds_client.close() + + def _wait_until_endpoint_available(self, rds_client): + """Wait until the custom endpoint becomes available""" + end_ns = perf_counter_ns() + 5 * 60 * 1_000_000_000 # 5 minutes + available = False + + while perf_counter_ns() < end_ns: + response = rds_client.describe_db_cluster_endpoints( + DBClusterEndpointIdentifier=self.endpoint_id, + Filters=[ + { + "Name": "db-cluster-endpoint-type", + "Values": ["custom"] + } + ] + ) + + response_endpoints = response["DBClusterEndpoints"] + if len(response_endpoints) != 1: + sleep(3) # Endpoint needs more time to get created. + continue + + response_endpoint = response_endpoints[0] + TestDjangoPlugins.endpoint_info = response_endpoint + available = "available" == response_endpoint["Status"] + if available: + break + + sleep(3) + + if not available: + pytest.fail(f"The test setup step timed out while waiting for the test custom endpoint to become available: " + f"'{TestDjangoPlugins.endpoint_id}'.") + + def _create_endpoint(self, rds_client, instances): + """Create the custom endpoint""" + instance_ids = [instance.get_instance_id() for instance in instances] + rds_client.create_db_cluster_endpoint( + DBClusterEndpointIdentifier=self.endpoint_id, + DBClusterIdentifier=TestEnvironment.get_current().get_cluster_name(), + EndpointType="ANY", + StaticMembers=instance_ids + ) + + def _delete_endpoint(self, rds_client): + """Delete the custom endpoint and wait for deletion""" + try: + rds_client.delete_db_cluster_endpoint(DBClusterEndpointIdentifier=self.endpoint_id) + # Wait for the endpoint to be deleted + self._wait_until_endpoint_deleted(rds_client) + except ClientError as e: + # If the custom endpoint already does not exist, we can continue. Otherwise, fail the test. + if e.response['Error']['Code'] != 'DBClusterEndpointNotFoundFault': + pytest.fail(e) + + def _wait_until_endpoint_deleted(self, rds_client): + """Wait until the custom endpoint is deleted (max 3 minutes)""" + end_ns = perf_counter_ns() + 3 * 60 * 1_000_000_000 # 3 minutes + deleted = False + + while perf_counter_ns() < end_ns: + try: + response = rds_client.describe_db_cluster_endpoints( + DBClusterEndpointIdentifier=self.endpoint_id, + Filters=[ + { + "Name": "db-cluster-endpoint-type", + "Values": ["custom"] + } + ] + ) + + response_endpoints = response["DBClusterEndpoints"] + if len(response_endpoints) == 0: + deleted = True + break + + # Check if endpoint is in deleting state + endpoint_status = response_endpoints[0]["Status"] + if endpoint_status == "deleting": + sleep(3) + continue + + except ClientError as e: + # If we get DBClusterEndpointNotFoundFault, the endpoint is deleted + if e.response['Error']['Code'] == 'DBClusterEndpointNotFoundFault': + deleted = True + break + else: + # Some other error occurred + sleep(3) + continue + + sleep(3) + + if not deleted: + print(f"Warning: Timed out waiting for custom endpoint to be deleted: '{self.endpoint_id}'. " + f"The endpoint may still be in the process of being deleted.") + else: + print(f"Custom endpoint '{self.endpoint_id}' successfully deleted.") + + @pytest.fixture(scope='function') # Changed from 'class' to 'function' + def django_models(self, django_setup): + """Create Django models after Django is set up""" + + # Use unique app labels and table names for each test to avoid conflicts + test_id = str(uuid.uuid4())[:8] # Short unique identifier + + # Store models in class attributes so they're accessible + class TestModel(models.Model): + """Basic test model for Django ORM functionality""" + name = models.CharField(max_length=100) + email = models.EmailField() + age = models.IntegerField() + is_active = models.BooleanField(default=True) + created_at = models.DateTimeField(auto_now_add=True) + + class Meta: + app_label = f'test_app_{test_id}' # Unique app label + db_table = f'django_test_model_plugins_{test_id}' + + class DataTypeModel(models.Model): + """Model for testing various data types""" + # String fields + char_field = models.CharField(max_length=255, null=True, blank=True) + text_field = models.TextField(null=True, blank=True) + + # Numeric fields + integer_field = models.IntegerField(null=True, blank=True) + big_integer_field = models.BigIntegerField(null=True, blank=True) + decimal_field = models.DecimalField(max_digits=10, decimal_places=2, null=True, blank=True) + float_field = models.FloatField(null=True, blank=True) + + # Boolean field + boolean_field = models.BooleanField(default=False) + + # Date/Time fields + date_field = models.DateField(null=True, blank=True) + time_field = models.TimeField(null=True, blank=True) + datetime_field = models.DateTimeField(null=True, blank=True) + + # JSON field (MySQL 5.7+) + json_field = models.JSONField(null=True, blank=True) + + class Meta: + app_label = f'test_app_{test_id}' # Unique app label + db_table = f'django_data_type_model_plugins_{test_id}' + + class Author(models.Model): + """Author model for relationship testing""" + name = models.CharField(max_length=100) + email = models.EmailField() + birth_date = models.DateField(null=True, blank=True) + + class Meta: + app_label = f'test_app_{test_id}' # Unique app label + db_table = f'django_author_plugins_{test_id}' + + class Book(models.Model): + """Book model for relationship testing""" + title = models.CharField(max_length=200) + author = models.ForeignKey(Author, on_delete=models.CASCADE, related_name='books') + publication_date = models.DateField() + pages = models.IntegerField() + price = models.DecimalField(max_digits=8, decimal_places=2) + + class Meta: + app_label = f'test_app_{test_id}' # Unique app label + db_table = f'django_book_plugins_{test_id}' + + # Store models as class attributes for easy access + TestDjangoPlugins.TestModel = TestModel + TestDjangoPlugins.DataTypeModel = DataTypeModel + TestDjangoPlugins.Author = Author + TestDjangoPlugins.Book = Book + + # Create tables for our test models + with connection.schema_editor() as schema_editor: + schema_editor.create_model(TestModel) + schema_editor.create_model(DataTypeModel) + schema_editor.create_model(Author) + schema_editor.create_model(Book) + + yield + + # Clean up tables - this will drop and recreate for each test + with connection.schema_editor() as schema_editor: + schema_editor.delete_model(Book) + schema_editor.delete_model(Author) + schema_editor.delete_model(DataTypeModel) + schema_editor.delete_model(TestModel) + + @pytest.fixture(scope='function') # Changed from 'class' to 'function' + def django_setup(self, conn_utils, create_secret, request, create_custom_endpoint=None): + """Setup Django configuration for testing with configurable plugins""" + # Get configuration from test parameter or use defaults + if hasattr(request, 'param') and isinstance(request.param, dict): + config = request.param + plugins_config = config.get('plugins', 'aurora_connection_tracker,failover') + extra_options = config.get('options', {}) + + # Check if we need to use custom endpoint + use_custom_endpoint = config.get('use_custom_endpoint', False) + custom_endpoint_host = None + + if use_custom_endpoint and create_custom_endpoint: + # create_custom_endpoint is a fixture that sets up the endpoint + # The endpoint info is stored in the class variable + custom_endpoint_host = self.endpoint_info.get('Endpoint') + + # Determine user based on plugin type + if 'iam' in plugins_config: + user = conn_utils.iam_user + elif 'aws_secrets_manager' in plugins_config: + user = None # Secrets manager will provide credentials + _, secret_arn = create_secret + extra_options['secrets_manager_secret_id'] = secret_arn + else: + user = config.get('user', conn_utils.user) + + # Handle password + if 'iam' in plugins_config or 'aws_secrets_manager' in plugins_config: + password = None # Secrets manager will provide credentials + else: + password = config.get('password', conn_utils.password) + + # Use custom endpoint host if provided, otherwise use default + if custom_endpoint_host: + host = custom_endpoint_host + else: + host = config.get('host', conn_utils.writer_cluster_host) + else: + plugins_config = 'aurora_connection_tracker,failover' + extra_options = {} + user = conn_utils.user + password = conn_utils.password + host = conn_utils.writer_host + + # Configure Django settings + if not settings.configured: + db_config = { + 'ENGINE': 'aws_advanced_python_wrapper.django.backends.mysql_connector', + 'NAME': conn_utils.dbname, + "USER": user, + "PASSWORD": password, + 'HOST': host, + 'PORT': conn_utils.port, + 'OPTIONS': { + 'plugins': plugins_config, # Configurable plugins + 'connect_timeout': 10, + 'autocommit': True, + **extra_options # Add any extra options + }, + } + + settings.configure( + DEBUG=True, + DATABASES={'default': db_config}, + INSTALLED_APPS=[ + 'django.contrib.contenttypes', + 'django.contrib.auth', + ], + SECRET_KEY='test-secret-key-for-django-plugins-tests', + USE_TZ=True, + ) + else: + # If settings are already configured, update the database config + settings.DATABASES['default']['USER'] = user + settings.DATABASES['default']['PASSWORD'] = password + settings.DATABASES['default']['HOST'] = host + settings.DATABASES['default']['OPTIONS']['plugins'] = plugins_config + settings.DATABASES['default']['OPTIONS'].update(extra_options) + + django.setup() + setup_test_environment() + + yield {'plugins': plugins_config, 'options': extra_options} # Return the config so tests can access it + + # Close all Django database connections after each test + connections.close_all() + teardown_test_environment() + + def test_django_basic_insert_with_plugins(self, test_environment: TestEnvironment, django_models): + """Test basic Django insert operation with plugins enabled""" + TestModel = self.TestModel + + # Ensure clean slate + TestModel.objects.all().delete() + + # Create a simple test record + test_obj = TestModel.objects.create( + name="Plugin Test User", + email="plugin@example.com", + age=25, + is_active=True + ) + + # Verify the record was created successfully + assert test_obj.id is not None + assert test_obj.name == "Plugin Test User" + assert test_obj.email == "plugin@example.com" + assert test_obj.age == 25 + assert test_obj.is_active is True + + # Verify we can retrieve it + retrieved_obj = TestModel.objects.get(id=test_obj.id) + assert retrieved_obj.name == "Plugin Test User" + + # Clean up + TestModel.objects.all().delete() + + @pytest.mark.parametrize('django_setup', [{'plugins': ''}], indirect=True) + def test_django_with_no_plugins(self, test_environment: TestEnvironment, django_models, django_setup): + """Test Django with no plugins enabled""" + TestModel = self.TestModel + config = django_setup # Get the config from fixture + + # Verify no plugins are configured + assert config['plugins'] == '' + + # Test basic functionality still works + test_obj = TestModel.objects.create( + name="No Plugins User", + email="noplugins@example.com", + age=30 + ) + + assert test_obj.id is not None + assert test_obj.name == "No Plugins User" + + # Clean up + TestModel.objects.all().delete() + + @pytest.mark.parametrize('django_setup', [{'plugins': 'failover'}], indirect=True) + def test_django_with_failover_only(self, test_environment: TestEnvironment, django_models, django_setup): + """Test Django with only failover plugin""" + TestModel = self.TestModel + config = django_setup # Get the config from fixture + + # Verify only failover plugin is configured + assert config['plugins'] == 'failover' + + # Test basic functionality works with failover plugin + test_obj = TestModel.objects.create( + name="Failover Only User", + email="failover@example.com", + age=35 + ) + + assert test_obj.id is not None + assert test_obj.name == "Failover Only User" + + # Clean up + TestModel.objects.all().delete() + + @pytest.mark.parametrize('django_setup', [{'plugins': 'aurora_connection_tracker,failover'}], indirect=True) + def test_django_with_multiple_plugins(self, test_environment: TestEnvironment, django_models, django_setup): + """Test Django with multiple plugins enabled""" + TestModel = self.TestModel + config = django_setup # Get the config from fixture + + # Verify multiple plugins are configured + assert config['plugins'] == 'aurora_connection_tracker,failover' + + # Test basic functionality works with multiple plugins + test_obj = TestModel.objects.create( + name="Multi Plugin User", + email="multiplugin@example.com", + age=40 + ) + + assert test_obj.id is not None + assert test_obj.name == "Multi Plugin User" + + # Clean up + TestModel.objects.all().delete() + + @pytest.mark.parametrize('django_setup', [{ + 'plugins': 'aws_secrets_manager', + 'use_secrets_manager': True + }], indirect=True) + def test_django_with_secrets_manager_plugin(self, test_environment: TestEnvironment, django_setup, django_models): + """Test Django with AWS Secrets Manager plugin""" + TestModel = self.TestModel + config = django_setup + + # Verify secrets manager plugin is configured + assert config['plugins'] == 'aws_secrets_manager' + assert 'secrets_manager_secret_id' in config['options'] + + # Test basic functionality works with secrets manager + test_obj = TestModel.objects.create( + name="Secrets Manager User", + email="secrets@example.com", + age=45 + ) + + assert test_obj.id is not None + assert test_obj.name == "Secrets Manager User" + + # Verify we can retrieve the record + retrieved_obj = TestModel.objects.get(id=test_obj.id) + assert retrieved_obj.email == "secrets@example.com" + + # Clean up + TestModel.objects.all().delete() + + @pytest.mark.parametrize('django_setup', [{ + 'plugins': 'iam', + 'password': '', # IAM doesn't use password + 'options': {} + }], indirect=True) + def test_django_with_iam_plugin(self, test_environment: TestEnvironment, django_models, django_setup, conn_utils): + """Test Django with IAM authentication plugin""" + TestModel = self.TestModel + config = django_setup + + # Verify IAM plugin is configured + assert config['plugins'] == 'iam' + + # Add iam_host to the configuration + if settings.configured: + # Get the writer instance for IAM host + writer_instance = test_environment.get_writer() + settings.DATABASES['default']['OPTIONS']['iam_host'] = writer_instance.get_host() + + # Test basic functionality works with IAM authentication + test_obj = TestModel.objects.create( + name="IAM User", + email="iam@example.com", + age=50 + ) + + assert test_obj.id is not None + assert test_obj.name == "IAM User" + + # Verify we can retrieve the record + retrieved_obj = TestModel.objects.get(id=test_obj.id) + assert retrieved_obj.email == "iam@example.com" + + # Clean up + TestModel.objects.all().delete() + + @pytest.mark.parametrize('django_setup', [{ + 'plugins': 'failover', + 'options': { + 'socket_timeout': 10, + 'connect_timeout': 10, + 'monitoring-connect_timeout': 5, + 'monitoring-socket_timeout': 5, + 'topology_refresh_ms': 10 + } + }], indirect=True) + @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) + @enable_on_num_instances(min_instances=2) + def test_django_failover_during_query(self, test_environment: TestEnvironment, django_setup, django_models, rds_utils): + """Test Django failover during query operations""" + TestModel = self.TestModel + config = django_setup + + # Verify failover plugin is configured + assert 'failover' in config['plugins'] + + # Get initial writer ID + initial_writer_id = rds_utils.get_cluster_writer_instance_id() + + # Create a test record + test_obj = TestModel.objects.create( + name="Failover Test User", + email="failover@example.com", + age=30 + ) + + # Select something from the TestModel - should work fine + result = TestModel.objects.filter(id=test_obj.id).first() + assert result is not None + assert result.name == "Failover Test User" + + # Trigger failover + rds_utils.failover_cluster_and_wait_until_writer_changed() + + # Try selecting again - should throw FailoverSuccessError + with pytest.raises(FailoverSuccessError): + TestModel.objects.filter(id=test_obj.id).first() + + # Select again - should work fine now (connected to new writer) + result = TestModel.objects.filter(id=test_obj.id).first() + assert result is not None + assert result.name == "Failover Test User" + + # Verify we're now connected to a new writer + with connection.cursor() as cursor: + cursor.execute(RdsTestUtility.get_instance_id_query()) + current_writer_id = cursor.fetchone()[0] + assert rds_utils.is_db_instance_writer(current_writer_id) is True + assert current_writer_id != initial_writer_id, "Should be connected to a new writer after failover" + + # Clean up test data + TestModel.objects.all().delete() + + @pytest.mark.parametrize('django_setup', [{ + 'plugins': 'custom_endpoint,failover', + 'use_custom_endpoint': True, + 'options': { + 'socket_timeout': 10, + 'connect_timeout': 10, + 'monitoring-connect_timeout': 5, + 'monitoring-socket_timeout': 5, + 'topology_refresh_ms': 10 + } + }], indirect=True) + @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) + @enable_on_num_instances(min_instances=2) + def test_django_custom_endpoint_failover_during_query( + self, test_environment: TestEnvironment, create_custom_endpoint, + django_setup, django_models, rds_utils): + """Test Django failover with custom endpoint during query operations""" + TestModel = self.TestModel + config = django_setup + + # Verify custom_endpoint and failover plugins are configured + assert 'custom_endpoint' in config['plugins'] + assert 'failover' in config['plugins'] + + # Get initial writer ID + initial_writer_id = rds_utils.get_cluster_writer_instance_id() + + # Create a test record + test_obj = TestModel.objects.create( + name="Custom Endpoint Failover Test User", + email="custom_failover@example.com", + age=35 + ) + + # Step 1: Select something from the TestModel - should work fine + result = TestModel.objects.filter(id=test_obj.id).first() + assert result is not None + assert result.name == "Custom Endpoint Failover Test User" + + # Trigger failover + rds_utils.failover_cluster_and_wait_until_writer_changed() + + # Try selecting again - should throw FailoverSuccessError + with pytest.raises(FailoverSuccessError): + TestModel.objects.filter(id=test_obj.id).first() + + # Select again - should work fine now (connected to new writer) + result = TestModel.objects.filter(id=test_obj.id).first() + assert result is not None + assert result.name == "Custom Endpoint Failover Test User" + + # Verify we're now connected to a new writer + with connection.cursor() as cursor: + cursor.execute(RdsTestUtility.get_instance_id_query()) + current_writer_id = cursor.fetchone()[0] + assert rds_utils.is_db_instance_writer(current_writer_id) is True + assert current_writer_id != initial_writer_id, "Should be connected to a new writer after failover" + + # Clean up test data + TestModel.objects.all().delete() + + @pytest.fixture(scope='function') + def django_rw_split_setup(self, conn_utils): + """Setup Django with read/write splitting configuration""" + # Define a router class for read/write splitting + class ReadWriteRouter: + """Router to direct reads to 'read' database and writes to 'default' database""" + + def db_for_read(self, model, **hints): + """Direct all read operations to the 'read' database""" + return 'read' + + def db_for_write(self, model, **hints): + """Direct all write operations to the 'default' database""" + return 'default' + + def allow_relation(self, obj1, obj2, **hints): + """Allow relations between objects in the same database""" + return True + + def allow_migrate(self, db, app_label, model_name=None, **hints): + """Allow migrations on all databases""" + return True + + # Configure Django with two database connections + if not settings.configured: + db_config_writer = { + 'ENGINE': 'aws_advanced_python_wrapper.django.backends.mysql_connector', + 'NAME': conn_utils.dbname, + 'USER': conn_utils.user, + 'PASSWORD': conn_utils.password, + 'HOST': conn_utils.writer_cluster_host, + 'PORT': conn_utils.port, + 'OPTIONS': { + 'plugins': 'read_write_splitting', + 'connect_timeout': 10, + 'autocommit': True, + }, + } + + db_config_reader = { + 'ENGINE': 'aws_advanced_python_wrapper.django.backends.mysql_connector', + 'NAME': conn_utils.dbname, + 'USER': conn_utils.user, + 'PASSWORD': conn_utils.password, + 'HOST': conn_utils.writer_cluster_host, + 'PORT': conn_utils.port, + 'OPTIONS': { + 'plugins': 'read_write_splitting', + 'connect_timeout': 10, + 'autocommit': True, + 'read_only': True, + }, + } + + settings.configure( + DEBUG=True, + DATABASES={ + 'default': db_config_writer, # Writer connection + 'read': db_config_reader, # Reader connection + }, + DATABASE_ROUTERS=[ReadWriteRouter()], + INSTALLED_APPS=[ + 'django.contrib.contenttypes', + 'django.contrib.auth', + ], + SECRET_KEY='test-secret-key-for-rw-splitting', + USE_TZ=True, + ) + else: + # Update existing settings without overwriting + settings.DATABASES['default'].update({ + 'ENGINE': 'aws_advanced_python_wrapper.django.backends.mysql_connector', + 'NAME': conn_utils.dbname, + 'USER': conn_utils.user, + 'PASSWORD': conn_utils.password, + 'HOST': conn_utils.writer_cluster_host, + 'PORT': conn_utils.port, + }) + settings.DATABASES['default']['OPTIONS'].update({ + 'plugins': 'read_write_splitting', + 'connect_timeout': 10, + 'autocommit': True, + }) + + # Create or update the 'read' database configuration + if 'read' not in settings.DATABASES: + settings.DATABASES['read'] = settings.DATABASES['default'].copy() + settings.DATABASES['read']['OPTIONS'] = settings.DATABASES['default']['OPTIONS'].copy() + settings.DATABASES['read'].update({ + 'OPTIONS': { + 'plugins': 'read_write_splitting', + 'connect_timeout': 10, + 'autocommit': True, + 'read_only': True, + }, + }) + else: + settings.DATABASES['read'].update({ + 'ENGINE': 'aws_advanced_python_wrapper.django.backends.mysql_connector', + 'NAME': conn_utils.dbname, + 'USER': conn_utils.user, + 'PASSWORD': conn_utils.password, + 'HOST': conn_utils.writer_cluster_host, + 'PORT': conn_utils.port, + }) + settings.DATABASES['read']['OPTIONS'].update({ + 'plugins': 'read_write_splitting', + 'connect_timeout': 10, + 'autocommit': True, + 'read_only': True, + }) + + settings.DATABASE_ROUTERS = [ReadWriteRouter()] + + django.setup() + setup_test_environment() + + # Create a test model + test_id = str(uuid.uuid4())[:8] + + class RWSplitTestModel(models.Model): + name = models.CharField(max_length=100) + value = models.IntegerField() + + class Meta: + app_label = f'test_app_{test_id}' + db_table = f'django_rw_split_test_{test_id}' + + # Create table once on the default (writer) connection + # Both connections point to the same database, so only create schema once + with connections['default'].schema_editor() as schema_editor: + schema_editor.create_model(RWSplitTestModel) + + yield RWSplitTestModel + + # Cleanup: Drop table and close connections + with connections['default'].schema_editor() as schema_editor: + schema_editor.delete_model(RWSplitTestModel) + + connections.close_all() + teardown_test_environment() + + @enable_on_num_instances(min_instances=2) + def test_django_read_write_splitting(self, test_environment: TestEnvironment, django_rw_split_setup, rds_utils): + """Test Django with read/write splitting using multiple database connections""" + RWSplitTestModel = django_rw_split_setup + + # Verify writer connection is connected to writer endpoint + with connections['default'].cursor() as cursor: + cursor.execute(RdsTestUtility.get_instance_id_query()) + writer_instance_id = cursor.fetchone()[0] + assert rds_utils.is_db_instance_writer(writer_instance_id), \ + f"Default connection should be connected to writer, but got {writer_instance_id}" + + # Verify reader connection is connected to reader endpoint + with connections['read'].cursor() as cursor: + cursor.execute(RdsTestUtility.get_instance_id_query()) + reader_instance_id = cursor.fetchone()[0] + assert not rds_utils.is_db_instance_writer(reader_instance_id), \ + f"Read connection should be connected to reader, but got {reader_instance_id}" + + # Verify they're different instances + assert writer_instance_id != reader_instance_id, \ + "Writer and reader should be connected to different instances" + + # Perform write operation (should use 'default' database) + test_obj = RWSplitTestModel.objects.using('default').create( + name="Test Write", + value=42 + ) + assert test_obj.id is not None + + # Perform read operation (should use 'read' database via router) + # Note: The router directs reads to 'read' database + retrieved_obj = RWSplitTestModel.objects.get(id=test_obj.id) + assert retrieved_obj.name == "Test Write" + assert retrieved_obj.value == 42 + + # Verify we can do more complex operations + # Create multiple objects + RWSplitTestModel.objects.using('default').create(name="Object 1", value=10) + RWSplitTestModel.objects.using('default').create(name="Object 2", value=20) + RWSplitTestModel.objects.using('default').create(name="Object 3", value=30) + + # Read them back (router will use 'read' database) + all_objects = RWSplitTestModel.objects.all() + assert all_objects.count() == 4 # Including the first test_obj + + # Filter operation (read) - should find objects with value >= 20 + # This includes: test_obj (42), Object 2 (20), Object 3 (30) = 3 objects + filtered = RWSplitTestModel.objects.filter(value__gte=20) + assert filtered.count() == 3 + + # Update operation (write - should use 'default') + RWSplitTestModel.objects.filter(name="Object 1").update(value=15) + + # Verify update (read) + updated_obj = RWSplitTestModel.objects.get(name="Object 1") + assert updated_obj.value == 15 + + # Clean up test data + RWSplitTestModel.objects.using('default').all().delete() diff --git a/tests/integration/container/test_custom_endpoint.py b/tests/integration/container/test_custom_endpoint.py index 131a95a5..88f5859d 100644 --- a/tests/integration/container/test_custom_endpoint.py +++ b/tests/integration/container/test_custom_endpoint.py @@ -140,11 +140,59 @@ def _create_endpoint(self, rds_client, instances): def delete_endpoint(self, rds_client): try: rds_client.delete_db_cluster_endpoint(DBClusterEndpointIdentifier=self.endpoint_id) + # Wait for the endpoint to be deleted + self._wait_until_endpoint_deleted(rds_client) except ClientError as e: # If the custom endpoint already does not exist, we can continue. Otherwise, fail the test. if e.response['Error']['Code'] != 'DBClusterEndpointNotFoundFault': pytest.fail(e) + def _wait_until_endpoint_deleted(self, rds_client): + """Wait until the custom endpoint is deleted (max 3 minutes)""" + end_ns = perf_counter_ns() + 3 * 60 * 1_000_000_000 # 3 minutes + deleted = False + + while perf_counter_ns() < end_ns: + try: + response = rds_client.describe_db_cluster_endpoints( + DBClusterEndpointIdentifier=self.endpoint_id, + Filters=[ + { + "Name": "db-cluster-endpoint-type", + "Values": ["custom"] + } + ] + ) + + response_endpoints = response["DBClusterEndpoints"] + if len(response_endpoints) == 0: + deleted = True + break + + # Check if endpoint is in deleting state + endpoint_status = response_endpoints[0]["Status"] + if endpoint_status == "deleting": + sleep(3) + continue + + except ClientError as e: + # If we get DBClusterEndpointNotFoundFault, the endpoint is deleted + if e.response['Error']['Code'] == 'DBClusterEndpointNotFoundFault': + deleted = True + break + else: + # Some other error occurred + sleep(3) + continue + + sleep(3) + + if not deleted: + self.logger.warning(f"Timed out waiting for custom endpoint to be deleted: '{self.endpoint_id}'. " + f"The endpoint may still be in the process of being deleted.") + else: + self.logger.debug(f"Custom endpoint '{self.endpoint_id}' successfully deleted.") + def wait_until_endpoint_has_members(self, rds_client, expected_members: Set[str]): start_ns = perf_counter_ns() end_ns = perf_counter_ns() + 20 * 60 * 1_000_000_000 # 20 minutes diff --git a/tests/integration/container/utils/rds_test_utility.py b/tests/integration/container/utils/rds_test_utility.py index 501d548b..dbc4cbbe 100644 --- a/tests/integration/container/utils/rds_test_utility.py +++ b/tests/integration/container/utils/rds_test_utility.py @@ -275,14 +275,28 @@ def query_host_role( else: return HostRole.WRITER - def _query_aurora_instance_id(self, conn: Connection, engine: DatabaseEngine) -> str: + @staticmethod + def get_instance_id_query(engine: Optional[DatabaseEngine] = None) -> str: + """Get the SQL query to retrieve the Aurora instance ID based on the database engine. + + Args: + engine: The database engine. If None, uses the current test environment's engine. + + Returns: + The SQL query string to get the instance ID. + """ + if engine is None: + engine = TestEnvironment.get_current().get_engine() + if engine == DatabaseEngine.MYSQL: - sql = "SELECT @@aurora_server_id" + return "SELECT @@aurora_server_id" elif engine == DatabaseEngine.PG: - sql = "SELECT pg_catalog.aurora_db_instance_identifier()" + return "SELECT pg_catalog.aurora_db_instance_identifier()" else: raise UnsupportedOperationError(engine.value) + def _query_aurora_instance_id(self, conn: Connection, engine: DatabaseEngine) -> str: + sql = self.get_instance_id_query(engine) with closing(conn.cursor()) as cursor: cursor.execute(sql) record = cursor.fetchone() diff --git a/tests/unit/test_django_mysql_connector.py b/tests/unit/test_django_mysql_connector.py new file mode 100644 index 00000000..ff14b184 --- /dev/null +++ b/tests/unit/test_django_mysql_connector.py @@ -0,0 +1,144 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +import pytest + + +class TestDatabaseWrapper: + """Unit tests for Django MySQL Connector DatabaseWrapper""" + + @pytest.fixture + def database_wrapper(self): + """Create a DatabaseWrapper instance with mocked dependencies""" + with patch('aws_advanced_python_wrapper.django.backends.mysql_connector.base.base.DatabaseWrapper.__init__'): + from aws_advanced_python_wrapper.django.backends.mysql_connector.base import \ + DatabaseWrapper + wrapper = DatabaseWrapper.__new__(DatabaseWrapper) + wrapper._read_only = False + return wrapper + + def test_get_connection_params_extracts_read_only(self, database_wrapper): + """Test that get_connection_params extracts and removes read_only parameter""" + with patch('aws_advanced_python_wrapper.django.backends.mysql_connector.base.base.DatabaseWrapper.get_connection_params') as mock_super: + mock_super.return_value = { + 'host': 'localhost', + 'read_only': True + } + + result = database_wrapper.get_connection_params() + + assert database_wrapper._read_only is True + assert 'read_only' not in result + + @patch('aws_advanced_python_wrapper.django.backends.mysql_connector.base.AwsWrapperConnection.connect') + @patch('aws_advanced_python_wrapper.django.backends.mysql_connector.base.mysql.connector.Connect') + @patch('aws_advanced_python_wrapper.django.backends.mysql_connector.base.base.DjangoMySQLConverter') + def test_get_new_connection_adds_converter_and_creates_wrapper(self, mock_converter, mock_connector, mock_wrapper_connect, database_wrapper): + """Test that get_new_connection adds converter_class and creates AwsWrapperConnection""" + mock_conn = MagicMock() + mock_wrapper_connect.return_value = mock_conn + database_wrapper._read_only = False + + conn_params = {'host': 'localhost'} + result = database_wrapper.get_new_connection(conn_params) + + assert 'converter_class' in conn_params + assert conn_params['converter_class'] == mock_converter + mock_wrapper_connect.assert_called_once_with(mock_connector, **conn_params) + assert result == mock_conn + + @patch('aws_advanced_python_wrapper.django.backends.mysql_connector.base.AwsWrapperConnection.connect') + @patch('aws_advanced_python_wrapper.django.backends.mysql_connector.base.mysql.connector.Connect') + def test_get_new_connection_sets_read_only_when_true(self, mock_connector, mock_wrapper_connect, database_wrapper): + """Test that get_new_connection sets read_only=True on connection when _read_only is True""" + mock_conn = MagicMock() + mock_wrapper_connect.return_value = mock_conn + database_wrapper._read_only = True + + result = database_wrapper.get_new_connection({'host': 'localhost'}) + + assert result.read_only is True + + @patch('aws_advanced_python_wrapper.django.backends.mysql_connector.base.AwsWrapperConnection.connect') + @patch('aws_advanced_python_wrapper.django.backends.mysql_connector.base.mysql.connector.Connect') + def test_get_new_connection_passes_wrapper_properties(self, mock_connector, mock_wrapper_connect, database_wrapper): + """Test that get_new_connection passes AWS wrapper properties like plugins""" + mock_conn = MagicMock() + mock_wrapper_connect.return_value = mock_conn + database_wrapper._read_only = False + + conn_params = { + 'host': 'localhost', + 'user': 'test_user', + 'password': 'test_password', + 'plugins': 'failover,aurora_connection_tracker', + 'failover_timeout': 60, + 'cluster_id': 'my-cluster' + } + + database_wrapper.get_new_connection(conn_params) + + # Verify all parameters including AWS wrapper properties are passed to connect + mock_wrapper_connect.assert_called_once() + call_args = mock_wrapper_connect.call_args[1] + assert call_args['host'] == 'localhost' + assert call_args['user'] == 'test_user' + assert call_args['password'] == 'test_password' + assert call_args['plugins'] == 'failover,aurora_connection_tracker' + assert call_args['failover_timeout'] == 60 + assert call_args['cluster_id'] == 'my-cluster' + + def test_mysql_version_parses_standard_version(self, database_wrapper): + """Test mysql_version parses standard MySQL version""" + database_wrapper.mysql_server_data = {'version': '8.0.32'} + assert database_wrapper.mysql_version == (8, 0, 32) + + def test_mysql_version_parses_version_with_suffix(self, database_wrapper): + """Test mysql_version parses version with suffix""" + database_wrapper.mysql_server_data = {'version': '5.7.42-log'} + assert database_wrapper.mysql_version == (5, 7, 42) + + def test_mysql_version_raises_exception_for_invalid_format(self, database_wrapper): + """Test mysql_version raises exception for invalid version format""" + database_wrapper.mysql_server_data = {'version': 'invalid-version'} + + with pytest.raises(Exception, match="Unable to determine MySQL version"): + _ = database_wrapper.mysql_version + + def test_mysql_is_mariadb_detects_mariadb(self, database_wrapper): + """Test mysql_is_mariadb detects MariaDB""" + database_wrapper.mysql_server_data = {'version': '10.11.2-MariaDB'} + assert database_wrapper.mysql_is_mariadb is True + + def test_mysql_is_mariadb_detects_mysql(self, database_wrapper): + """Test mysql_is_mariadb detects MySQL""" + database_wrapper.mysql_server_data = {'version': '8.0.32'} + assert database_wrapper.mysql_is_mariadb is False + + def test_sql_mode_parses_comma_separated_modes(self, database_wrapper): + """Test sql_mode parses comma-separated modes""" + database_wrapper.mysql_server_data = {'sql_mode': 'MODE1,MODE2,MODE3'} + assert database_wrapper.sql_mode == {'MODE1', 'MODE2', 'MODE3'} + + def test_sql_mode_handles_empty_values(self, database_wrapper): + """Test sql_mode handles empty string and None""" + database_wrapper.mysql_server_data = {'sql_mode': ''} + assert database_wrapper.sql_mode == set() + + del database_wrapper.__dict__['sql_mode'] + + database_wrapper.mysql_server_data = {'sql_mode': None} + assert database_wrapper.sql_mode == set()