diff --git a/aws_advanced_python_wrapper/cluster_topology_monitor.py b/aws_advanced_python_wrapper/cluster_topology_monitor.py index 9ca70ae34..47b178f4b 100644 --- a/aws_advanced_python_wrapper/cluster_topology_monitor.py +++ b/aws_advanced_python_wrapper/cluster_topology_monitor.py @@ -504,8 +504,8 @@ def __call__(self) -> None: if is_writer: try: - if self._monitor._topology_utils.get_host_role( - connection, self._monitor._plugin_service.driver_dialect) != HostRole.WRITER: + if self._monitor._plugin_service.get_host_role( + connection) != HostRole.WRITER: is_writer = False except Exception as ex: logger.debug("HostMonitor.InvalidWriterQuery", ex) diff --git a/aws_advanced_python_wrapper/database_dialect.py b/aws_advanced_python_wrapper/database_dialect.py index 59b3f5d8f..2ef01a0d6 100644 --- a/aws_advanced_python_wrapper/database_dialect.py +++ b/aws_advanced_python_wrapper/database_dialect.py @@ -36,8 +36,9 @@ from enum import Enum, auto from aws_advanced_python_wrapper.errors import (AwsWrapperError, - QueryTimeoutError) -from aws_advanced_python_wrapper.hostinfo import HostInfo + QueryTimeoutError, + UnsupportedOperationError) +from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole from aws_advanced_python_wrapper.thread_pool_container import \ ThreadPoolContainer from aws_advanced_python_wrapper.utils.decorators import \ @@ -90,7 +91,6 @@ class TargetDriverType(Enum): class TopologyAwareDatabaseDialect(Protocol): _TOPOLOGY_QUERY: str _HOST_ID_QUERY: str - _IS_READER_QUERY: str _WRITER_HOST_QUERY: str @property @@ -101,10 +101,6 @@ def topology_query(self) -> str: def host_id_query(self) -> str: return self._HOST_ID_QUERY - @property - def is_reader_query(self) -> str: - return self._IS_READER_QUERY - @property def writer_id_query(self) -> str: return self._WRITER_HOST_QUERY @@ -148,6 +144,16 @@ def host_alias_query(self) -> str: def server_version_query(self) -> str: ... + @property + @abstractmethod + def host_id_query(self) -> str: + ... + + @property + @abstractmethod + def is_reader_query(self) -> str: + ... + @property @abstractmethod def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]: @@ -189,6 +195,9 @@ class MysqlDatabaseDialect(DatabaseDialect): _DIALECT_UPDATE_CANDIDATES: Tuple[DialectCode, ...] = ( DialectCode.AURORA_MYSQL, DialectCode.GLOBAL_AURORA_MYSQL, DialectCode.MULTI_AZ_CLUSTER_MYSQL, DialectCode.RDS_MYSQL) _exception_handler: Optional[ExceptionHandler] = None + _HOST_ID_EXPRESSION = "CONCAT(@@hostname, ':', @@port)" + _HOST_ID_QUERY = f"SELECT @@hostname AS host, {_HOST_ID_EXPRESSION} AS host_id" + _IS_READER_QUERY = "SELECT @@read_only" @property def default_port(self) -> int: @@ -196,12 +205,20 @@ def default_port(self) -> int: @property def host_alias_query(self) -> str: - return "SELECT CONCAT(@@hostname, ':', @@port)" + return f"SELECT {self._HOST_ID_EXPRESSION}" @property def server_version_query(self) -> str: return "SHOW VARIABLES LIKE 'version_comment'" + @property + def host_id_query(self) -> str: + return self._HOST_ID_QUERY + + @property + def is_reader_query(self) -> str: + return self._IS_READER_QUERY + @property def exception_handler(self) -> Optional[ExceptionHandler]: if MysqlDatabaseDialect._exception_handler is None: @@ -240,6 +257,9 @@ class PgDatabaseDialect(DatabaseDialect): _DIALECT_UPDATE_CANDIDATES: Tuple[DialectCode, ...] = ( DialectCode.AURORA_PG, DialectCode.GLOBAL_AURORA_PG, DialectCode.MULTI_AZ_CLUSTER_PG, DialectCode.RDS_PG) _exception_handler: Optional[ExceptionHandler] = None + _HOST_ID_EXPRESSION = "pg_catalog.CONCAT(pg_catalog.inet_server_addr(), ':', pg_catalog.inet_server_port())" + _HOST_ID_QUERY = f"SELECT pg_catalog.inet_server_addr() AS host, {_HOST_ID_EXPRESSION} AS host_id" + _IS_READER_QUERY = "SELECT pg_catalog.pg_is_in_recovery()" @property def default_port(self) -> int: @@ -247,12 +267,20 @@ def default_port(self) -> int: @property def host_alias_query(self) -> str: - return "SELECT pg_catalog.CONCAT(pg_catalog.inet_server_addr(), ':', pg_catalog.inet_server_port())" + return f"SELECT {self._HOST_ID_EXPRESSION}" @property def server_version_query(self) -> str: return "SELECT 'version', pg_catalog.VERSION()" + @property + def host_id_query(self) -> str: + return self._HOST_ID_QUERY + + @property + def is_reader_query(self) -> str: + return self._IS_READER_QUERY + @property def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]: return PgDatabaseDialect._DIALECT_UPDATE_CANDIDATES @@ -298,10 +326,17 @@ def is_blue_green_status_available(self, conn: Connection) -> bool: class RdsMysqlDialect(MysqlDatabaseDialect, BlueGreenDialect): _DIALECT_UPDATE_CANDIDATES = (DialectCode.AURORA_MYSQL, DialectCode.GLOBAL_AURORA_MYSQL, DialectCode.MULTI_AZ_CLUSTER_MYSQL) + _HOST_ID_QUERY = ("SELECT id, SUBSTRING_INDEX(endpoint, '.', 1) " + "FROM mysql.rds_topology " + "WHERE id = @@server_id") _BG_STATUS_QUERY = "SELECT version, endpoint, port, role, status FROM mysql.rds_topology" _BG_STATUS_EXISTS_QUERY = \ "SELECT 1 AS tmp FROM information_schema.tables WHERE table_schema = 'mysql' AND table_name = 'rds_topology'" + @property + def host_id_query(self) -> str: + return self._HOST_ID_QUERY + def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool: initial_transaction_status: bool = driver_dialect.is_in_transaction(conn) try: @@ -352,10 +387,17 @@ class RdsPgDialect(PgDatabaseDialect, BlueGreenDialect): "WHERE name OPERATOR(pg_catalog.=) 'rds.extensions'") _DIALECT_UPDATE_CANDIDATES = (DialectCode.AURORA_PG, DialectCode.GLOBAL_AURORA_PG, DialectCode.MULTI_AZ_CLUSTER_PG) + _HOST_ID_QUERY = ("SELECT id, SUBSTRING(endpoint FROM 0 FOR POSITION('.' IN endpoint)) " + "FROM rds_tools.show_topology() " + "WHERE id OPERATOR(pg_catalog.=) rds_tools.dbi_resource_id()") _BG_STATUS_QUERY = (f"SELECT version, endpoint, port, role, status " f"FROM rds_tools.show_topology('aws_advanced_python_wrapper-{DriverInfo.DRIVER_VERSION}')") _BG_STATUS_EXISTS_QUERY = "SELECT 'rds_tools.show_topology'::regproc" + @property + def host_id_query(self) -> str: + return self._HOST_ID_QUERY + def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool: initial_transaction_status: bool = driver_dialect.is_in_transaction(conn) if not super().is_dialect(conn, driver_dialect): @@ -401,7 +443,7 @@ class AuroraMysqlDialect(MysqlDatabaseDialect, TopologyAwareDatabaseDialect, Blu "FROM information_schema.replica_host_status " "WHERE time_to_sec(timediff(now(), LAST_UPDATE_TIMESTAMP)) <= 300 " "OR SESSION_ID = 'MASTER_SESSION_ID' ") - _HOST_ID_QUERY = "SELECT @@aurora_server_id" + _HOST_ID_QUERY = "SELECT @@aurora_server_id, @@aurora_server_id" _IS_READER_QUERY = "SELECT @@innodb_read_only" _WRITER_HOST_QUERY = \ ("SELECT SERVER_ID FROM information_schema.replica_host_status " @@ -411,6 +453,10 @@ class AuroraMysqlDialect(MysqlDatabaseDialect, TopologyAwareDatabaseDialect, Blu _BG_STATUS_EXISTS_QUERY = \ "SELECT 1 AS tmp FROM information_schema.tables WHERE table_schema = 'mysql' AND table_name = 'rds_topology'" + @property + def is_reader_query(self) -> str: + return self._IS_READER_QUERY + @property def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]: return AuroraMysqlDialect._DIALECT_UPDATE_CANDIDATES @@ -465,8 +511,7 @@ class AuroraPgDialect(PgDatabaseDialect, TopologyAwareDatabaseDialect, AuroraLim "OR SESSION_ID OPERATOR(pg_catalog.=) 'MASTER_SESSION_ID' " "OR LAST_UPDATE_TIMESTAMP IS NULL") - _HOST_ID_QUERY = "SELECT pg_catalog.aurora_db_instance_identifier()" - _IS_READER_QUERY = "SELECT pg_catalog.pg_is_in_recovery()" + _HOST_ID_QUERY = "SELECT pg_catalog.aurora_db_instance_identifier(), pg_catalog.aurora_db_instance_identifier()" _LIMITLESS_ROUTER_ENDPOINT_QUERY = "SELECT router_endpoint, load FROM pg_catalog.aurora_limitless_router_endpoints()" _BG_STATUS_QUERY = (f"SELECT version, endpoint, port, role, status " @@ -644,8 +689,9 @@ class MultiAzClusterMysqlDialect(MysqlDatabaseDialect, TopologyAwareDatabaseDial _TOPOLOGY_QUERY = "SELECT id, endpoint, port FROM mysql.rds_topology" _WRITER_HOST_QUERY = "SHOW REPLICA STATUS" _WRITER_HOST_COLUMN_INDEX = 39 - _HOST_ID_QUERY = "SELECT @@server_id" - _IS_READER_QUERY = "SELECT @@read_only" + _HOST_ID_QUERY = ("SELECT id, SUBSTRING_INDEX(endpoint, '.', 1) " + "FROM mysql.rds_topology " + "WHERE id = @@server_id") @property def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]: @@ -701,8 +747,9 @@ class MultiAzClusterPgDialect(PgDatabaseDialect, TopologyAwareDatabaseDialect): f"SELECT id, endpoint, port FROM rds_tools.show_topology('aws_python_driver-{DriverInfo.DRIVER_VERSION}')" _WRITER_HOST_QUERY = \ "SELECT multi_az_db_cluster_source_dbi_resource_id FROM rds_tools.multi_az_db_cluster_source_dbi_resource_id()" - _HOST_ID_QUERY = "SELECT dbi_resource_id FROM rds_tools.dbi_resource_id()" - _IS_READER_QUERY = "SELECT pg_catalog.pg_is_in_recovery()" + _HOST_ID_QUERY = ("SELECT id, SUBSTRING(endpoint FROM 0 FOR POSITION('.' IN endpoint)) " + "FROM rds_tools.show_topology() " + "WHERE id OPERATOR(pg_catalog.=) rds_tools.dbi_resource_id()") _exception_handler: Optional[ExceptionHandler] = None @property @@ -761,6 +808,16 @@ def host_alias_query(self) -> str: def server_version_query(self) -> str: return "" + @property + def host_id_query(self) -> str: + raise UnsupportedOperationError( + Messages.get_formatted("UnknownDialect.UnsupportedMethod", "host_id_query")) + + @property + def is_reader_query(self) -> str: + raise UnsupportedOperationError( + Messages.get_formatted("UnknownDialect.UnsupportedMethod", "is_reader_query")) + @property def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]: return UnknownDatabaseDialect._DIALECT_UPDATE_CANDIDATES @@ -984,4 +1041,43 @@ def check_existence_queries(conn: Connection, existence_queries: Tuple[str, ...] return False return True - # not do we need to add the transaction try catch here or is it better to surround the calling method + + @staticmethod + def get_host_role(conn: Connection, driver_dialect: DriverDialect, is_reader_query: str, + thread_pool, timeout_sec: float) -> HostRole: + try: + cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout( + thread_pool, timeout_sec, driver_dialect, conn)(DialectUtils._execute_is_reader_query) + result = cursor_execute_func_with_timeout(conn, is_reader_query) + if result is not None: + is_reader = bool(result[0]) + return HostRole.READER if is_reader else HostRole.WRITER + except TimeoutError as e: + raise QueryTimeoutError(Messages.get("DialectUtils.GetHostRoleTimeout")) from e + except Exception as e: + raise AwsWrapperError(Messages.get("DialectUtils.ErrorGettingHostRole")) from e + + raise AwsWrapperError(Messages.get("DialectUtils.ErrorGettingHostRole")) + + @staticmethod + def _execute_is_reader_query(conn: Connection, is_reader_query: str): + with closing(conn.cursor()) as cursor: + cursor.execute(is_reader_query) + return cursor.fetchone() + + @staticmethod + def get_instance_id(conn: Connection, driver_dialect: DriverDialect, instance_id_query: str, + thread_pool, timeout_sec: float) -> Optional[Tuple[str, str]]: + cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout( + thread_pool, timeout_sec, driver_dialect, conn)(DialectUtils._execute_instance_id_query) + result = cursor_execute_func_with_timeout(conn, instance_id_query) + if result is not None and len(result) >= 2: + return (str(result[0]), str(result[1])) + + return None + + @staticmethod + def _execute_instance_id_query(conn: Connection, instance_id_query: str): + with closing(conn.cursor()) as cursor: + cursor.execute(instance_id_query) + return cursor.fetchone() diff --git a/aws_advanced_python_wrapper/host_list_provider.py b/aws_advanced_python_wrapper/host_list_provider.py index 253813ab0..f0aabd747 100644 --- a/aws_advanced_python_wrapper/host_list_provider.py +++ b/aws_advanced_python_wrapper/host_list_provider.py @@ -81,18 +81,6 @@ def get_current_topology(self, connection: Connection, initial_host_info: HostIn def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Topology: ... - def get_host_role(self, connection: Connection) -> HostRole: - """ - Evaluates the host role of the given connection - either a writer or a reader. - - :param connection: a connection to the database instance whose role should be determined. - :return: the role of the given connection - either a writer or a reader. - """ - ... - - def identify_connection(self, connection: Optional[Connection]) -> Optional[HostInfo]: - ... - def get_cluster_id(self) -> str: ... @@ -336,49 +324,6 @@ def force_refresh(self, connection: Optional[Connection] = None) -> Topology: self._hosts = topology.hosts return tuple(self._hosts) - def get_host_role(self, connection: Connection) -> HostRole: - driver_dialect = self._host_list_provider_service.driver_dialect - - return self._topology_utils.get_host_role(connection, driver_dialect) - - def identify_connection(self, connection: Optional[Connection]) -> Optional[HostInfo]: - """ - Identify which host the given connection points to. - :param connection: an opened connection. - :return: a :py:class:`HostInfo` object containing host information for the given connection. - """ - if connection is None: - raise AwsWrapperError(Messages.get("RdsHostListProvider.ErrorIdentifyConnection")) - - driver_dialect = self._host_list_provider_service.driver_dialect - try: - host_id = self._topology_utils.get_host_id(connection, driver_dialect) - if host_id is not None: - hosts = self.refresh(connection) - is_force_refresh = False - if not hosts: - hosts = self.force_refresh(connection) - is_force_refresh = True - - if not hosts: - return None - - found_host: Optional[HostInfo] = next((host_info for host_info in hosts if host_info.host_id == host_id), None) - if not found_host and not is_force_refresh: - hosts = self.force_refresh(connection) - if not hosts: - return None - - found_host = next( - (host_info for host_info in hosts if host_info.host_id == host_id), - None) - - return found_host - except TimeoutError as e: - raise QueryTimeoutError(Messages.get("RdsHostListProvider.IdentifyConnectionTimeout")) from e - - raise AwsWrapperError(Messages.get("RdsHostListProvider.ErrorIdentifyConnection")) - def get_cluster_id(self): self._initialize() return self._cluster_id @@ -428,14 +373,6 @@ def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) raise AwsWrapperError( Messages.get_formatted("HostListProvider.ForceMonitoringRefreshUnsupported", "ConnectionStringHostListProvider")) - def get_host_role(self, connection: Connection) -> HostRole: - raise UnsupportedOperationError( - Messages.get_formatted("ConnectionStringHostListProvider.UnsupportedMethod", "get_host_role")) - - def identify_connection(self, connection: Optional[Connection]) -> Optional[HostInfo]: - raise UnsupportedOperationError( - Messages.get_formatted("ConnectionStringHostListProvider.UnsupportedMethod", "identify_connection")) - def get_cluster_id(self): return "" @@ -622,44 +559,6 @@ def create_host( host_info.add_alias(host_id) return host_info - def get_host_role(self, connection: Connection, driver_dialect: DriverDialect) -> HostRole: - try: - cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout( - self._thread_pool, self._max_timeout_sec, driver_dialect, connection)(self._get_host_role) - result = cursor_execute_func_with_timeout(connection) - if result is not None: - is_reader = result[0] - return HostRole.READER if is_reader else HostRole.WRITER - except TimeoutError as e: - raise QueryTimeoutError(Messages.get("RdsHostListProvider.GetHostRoleTimeout")) from e - - raise AwsWrapperError(Messages.get("RdsHostListProvider.ErrorGettingHostRole")) - - def _get_host_role(self, conn: Connection): - with closing(conn.cursor()) as cursor: - cursor.execute(self._dialect.is_reader_query) - return cursor.fetchone() - - def get_host_id(self, connection: Connection, driver_dialect: DriverDialect) -> Optional[str]: - """ - Identify which host the given connection points to. - :param connection: an opened connection. - :return: a str of the current host's id - """ - - cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout( - self._thread_pool, self._max_timeout_sec, driver_dialect, connection)(self._get_host_id) - result = cursor_execute_func_with_timeout(connection) - if result: - host_id: str = result[0] - return host_id - return None - - def _get_host_id(self, conn: Connection): - with closing(conn.cursor()) as cursor: - cursor.execute(self._dialect.host_id_query) - return cursor.fetchone() - def get_writer_id_if_connected(self, connection: Connection, driver_dialect: DriverDialect) -> Optional[str]: try: cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout( diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index 0c082ae84..9ab0fb93c 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -57,7 +57,7 @@ from aws_advanced_python_wrapper.connection_provider import ( ConnectionProvider, ConnectionProviderManager) from aws_advanced_python_wrapper.database_dialect import ( - DatabaseDialect, DatabaseDialectManager, TopologyAwareDatabaseDialect, + DatabaseDialect, DatabaseDialectManager, DialectUtils, UnknownDatabaseDialect) from aws_advanced_python_wrapper.default_plugin import DefaultPlugin from aws_advanced_python_wrapper.developer_plugin import DeveloperPluginFactory @@ -566,7 +566,13 @@ def get_host_role(self, connection: Optional[Connection] = None) -> HostRole: if connection is None: raise AwsWrapperError(Messages.get("PluginServiceImpl.GetHostRoleConnectionNone")) - return self._host_list_provider.get_host_role(connection) + timeout_sec = WrapperProperties.AUXILIARY_QUERY_TIMEOUT_SEC.get_float(self._props) + return DialectUtils.get_host_role( + connection, + self._driver_dialect, + self.database_dialect.is_reader_query, + self._thread_pool, + timeout_sec) def refresh_host_list(self, connection: Optional[Connection] = None): connection = self.current_connection if connection is None else connection @@ -610,11 +616,51 @@ def set_availability(self, host_aliases: FrozenSet[str], availability: HostAvail def identify_connection(self, connection: Optional[Connection] = None) -> Optional[HostInfo]: connection = self.current_connection if connection is None else connection + if connection is None: + raise AwsWrapperError(Messages.get("PluginServiceImpl.ErrorIdentifyConnection")) - if not isinstance(self.database_dialect, TopologyAwareDatabaseDialect): - return None + try: + timeout_sec = WrapperProperties.AUXILIARY_QUERY_TIMEOUT_SEC.get_float(self._props) + instance_ids = DialectUtils.get_instance_id( + connection, + self._driver_dialect, + self.database_dialect.host_id_query, + self._thread_pool, + timeout_sec) + + if instance_ids is None: + raise AwsWrapperError(Messages.get("PluginServiceImpl.ErrorIdentifyConnection")) + + topology = self.host_list_provider.refresh(connection) + is_force_refresh = False + if topology is None: + topology = self.host_list_provider.force_refresh(connection) + is_force_refresh = True + + if topology is None: + return None + + instance_name = instance_ids[1] + found_host: Optional[HostInfo] = next( + (host for host in topology if host.host_id == instance_name), + None) - return self.host_list_provider.identify_connection(connection) + if found_host is None and not is_force_refresh: + topology = self.host_list_provider.force_refresh(connection) + if topology is None: + return None + + found_host = next( + (host for host in topology if host.host_id == instance_name), + None) + + return found_host + except TimeoutError as e: + raise QueryTimeoutError(Messages.get("PluginServiceImpl.IdentifyConnectionTimeout")) from e + except UnsupportedOperationError as e: + raise e + except Exception as e: + raise AwsWrapperError(Messages.get("PluginServiceImpl.ErrorIdentifyConnection")) from e def fill_aliases(self, connection: Optional[Connection] = None, host_info: Optional[HostInfo] = None): connection = self.current_connection if connection is None else connection diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index 1b6c76317..95c95cdcd 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -96,8 +96,6 @@ ConnectTimePlugin.ConnectTime=[ConnectTimePlugin] Connected in {} nanos. ConnectionProvider.UnsupportedHostSelectorStrategy=[ConnectionProvider] Unsupported host selection strategy '{}' specified for this connection provider '{}'. Please visit the documentation for all supported strategies. -ConnectionStringHostListProvider.UnsupportedMethod = [ConnectionStringHostListProvider] ConnectionStringHostListProvider does not support {}. - CustomEndpointMonitor.DetectedChangeInCustomEndpointInfo=[CustomEndpointMonitor] Detected change in custom endpoint info for '{}':\n{} CustomEndpointMonitor.Exception=[CustomEndpointMonitor] Encountered an exception while monitoring custom endpoint '{}': {}. CustomEndpointMonitor.Interrupted=[CustomEndpointMonitor] Custom endpoint monitor for '{}' was interrupted. @@ -122,6 +120,9 @@ DefaultTelemetryFactory.NoTracingBackendProvided=[DefaultTelemetryFactory] No te DialectCode.InvalidStringValue=[DialectCode] '{}' is not a valid DialectCode value. If you are using the 'wrapper_dialect' connection property, please ensure you set it to one of the following: pg, rds-pg, aurora-pg, mysql, rds-mysql, aurora-mysql, or custom. +DialectUtils.GetHostRoleTimeout=[DialectUtils] The timeout limit was reached while querying for the current host's role. +DialectUtils.ErrorGettingHostRole=[DialectUtils] An error occurred while obtaining the connected host's role. This could occur if the connection is broken or if you are not connected to an unknown database. + DatabaseDialectManager.CurrentDialectCanUpdate=[DatabaseDialectManager] Current dialect: {}, {}, can_update: {} DatabaseDialectManager.QueryForDialectTimeout=[DatabaseDialectManager] The timeout limit was reached while querying for the current database dialect. DatabaseDialectManager.UnknownDialect=[DatabaseDialectManager] The database dialect could not be identified. Please use the 'wrapper_dialect' configuration parameter to configure it. @@ -345,6 +346,8 @@ PluginServiceImpl.SetCurrentHostInfo=[PluginServiceImpl] Set current host info t PluginServiceImpl.UnableToUpdateTransactionStatus=[PluginServiceImpl] Unable to update transaction status, current connection is None. PluginServiceImpl.UpdateDialectConnectionNone=[PluginServiceImpl] The plugin service attempted to update the current dialect but could not identify a connection to use. PluginServiceImpl.UnsupportedStrategy=[PluginServiceImpl] The driver does not support the requested host selection strategy: {} +PluginServiceImpl.ErrorIdentifyConnection=[PluginServiceImpl] An error occurred while obtaining the connection's host ID. +PluginServiceImpl.IdentifyConnectionTimeout=[PluginServiceImpl] The timeout limit was reached while querying for the current host's ID. PropertiesUtils.ErrorParsingConnectionString=[PropertiesUtils] An error occurred while parsing the connection string: '{}'. Please ensure the format of your connection string is valid. PropertiesUtils.InvalidPgSchemeUrl=[PropertiesUtils] PropertiesUtils.parse_pg_scheme_url was called, but the passed in string did not begin with 'postgresql://' or 'postgres://'. Detected connection string: '{}'. @@ -353,10 +356,6 @@ PropertiesUtils.NoHostDefined=[PropertiesUtils] PropertiesUtils.get_url was call RdsHostListProvider.ClusterInstanceHostPatternNotSupportedForRDSCustom=[RdsHostListProvider] An RDS Custom url can't be used as the 'cluster_instance_host_pattern' configuration setting. RdsHostListProvider.ClusterInstanceHostPatternNotSupportedForRDSProxy=[RdsHostListProvider] An RDS Proxy url can't be used as the 'cluster_instance_host_pattern' configuration setting. -RdsHostListProvider.ErrorGettingHostRole=[RdsHostListProvider] An error occurred while obtaining the connected host's role. This could occur if the connection is broken or if you are not connected to an Aurora database. -RdsHostListProvider.ErrorIdentifyConnection=[RdsHostListProvider] An error occurred while obtaining the connection's host ID. -RdsHostListProvider.GetHostRoleTimeout=[RdsHostListProvider] The timeout limit was reached while querying for the current host's role. -RdsHostListProvider.IdentifyConnectionTimeout=[RdsHostListProvider] The timeout limit was reached while querying for the current host's ID. RdsHostListProvider.InvalidPattern=[RdsHostListProvider] Invalid value for the 'cluster_instance_host_pattern' configuration setting - the host pattern must contain a '?' character as a placeholder for the DB instance identifiers of the instances in the cluster. RdsHostListProvider.InvalidQuery=[RdsHostListProvider] Error obtaining host list. Provided database might not be an Aurora Db cluster RdsHostListProvider.InvalidTopology=[RdsHostListProvider] The topology query returned an invalid topology - no writer instance detected. @@ -464,6 +463,7 @@ Testing._get_multi_az_instance_ids=[Testing] Get topology: {}. Testing._get_multi_az_instance_ids_connecting=[Testing] Connecting to {}. UnknownDialect.AbortConnection=[UnknownDialect] abort_connection was called, but the database dialect is unknown. A valid database dialect must be detected in order to abort a connection. +UnknownDialect.UnsupportedMethod = [UnknownDialect] UnknownDialect does not support {}. Wrapper.ConnectMethod=[Wrapper] Target driver should be a target driver's connect() method/function. Wrapper.RequiredTargetDriver=[Wrapper] Target driver is required. diff --git a/docs/using-the-python-wrapper/using-plugins/UsingTheSimpleReadWriteSplittingPlugin.md b/docs/using-the-python-wrapper/using-plugins/UsingTheSimpleReadWriteSplittingPlugin.md index a2b264dea..b110477e9 100644 --- a/docs/using-the-python-wrapper/using-plugins/UsingTheSimpleReadWriteSplittingPlugin.md +++ b/docs/using-the-python-wrapper/using-plugins/UsingTheSimpleReadWriteSplittingPlugin.md @@ -51,10 +51,15 @@ Additionally, to consistently ensure the role of connections made with the plugi If it is unable to return a verified initial connection, it will log a message and continue with the normal workflow of the other plugins. When connecting with custom endpoints and other non-standard URLs, role verification on the initial connection can also be triggered by providing the expected role through the `srw_verify_initial_connection_type` parameter. Set this to `writer` or `reader` accordingly. -## Limitations When Verifying Connections +The AWS Advanced Python Wrapper supports verifying the role of connections to PostgreSQL, MySQL, and MariaDB databases through using the following queries: -#### Non-RDS clusters -The verification step determines the role of the connection by executing a query against it. The AWS Advanced Python Wrapper does not support gathering such information for databases that are not Aurora or RDS clusters. Thus, when connecting to non-RDS clusters `verifyNewSrwConnections` must be set to `false`. +| DB Type | Query | +|----------------|-----------------------------------------| +| PostgreSQL | `SELECT pg_catalog.pg_is_in_recovery()` | +| Aurora MySQL | `SELECT @@innodb_read_only` | +| MySQL, MariaDB | `SELECT @@read_only` | + +Role-verification can be disabled by setting the `verifyNewSrwConnections` parameter to `false`. The Simple Read/Write Splitting Plugin will continue to function, relying purely on the endpoints from the `srwWriteEndpoint` and `srwReadEndpoint` parameters. #### Autocommit The verification logic results in errors such as `Cannot change transaction read-only property in the middle of a transaction` from the underlying driver when: diff --git a/tests/unit/test_cluster_topology_monitor.py b/tests/unit/test_cluster_topology_monitor.py index e63e0a4ec..a740c831d 100644 --- a/tests/unit/test_cluster_topology_monitor.py +++ b/tests/unit/test_cluster_topology_monitor.py @@ -248,8 +248,8 @@ def test_call_connection_success_writer_detected(self, monitor_impl_mock, topolo monitor = HostMonitor(monitor_impl_mock, host_info, None) connection_mock = MagicMock() monitor_impl_mock._plugin_service.force_connect.return_value = connection_mock + monitor_impl_mock._plugin_service.get_host_role.return_value = HostRole.WRITER topology_utils_mock.get_writer_id_if_connected.return_value = "writer.com" - topology_utils_mock.get_host_role.return_value = HostRole.WRITER call_count = [0] diff --git a/tests/unit/test_connection_string_host_list_provider.py b/tests/unit/test_connection_string_host_list_provider.py index dbe049a51..fb0998de3 100644 --- a/tests/unit/test_connection_string_host_list_provider.py +++ b/tests/unit/test_connection_string_host_list_provider.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License.s -import pytest +import pytest # type: ignore from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.host_list_provider import \ @@ -36,20 +36,6 @@ def props(): return Properties({"host": "instance-1.xyz.us-east-2.rds.amazonaws.com"}) -def test_get_host_role(mock_provider_service, mock_cursor, props): - provider = ConnectionStringHostListProvider(mock_provider_service, props) - - with pytest.raises(AwsWrapperError): - provider.get_host_role("ConnectionStringHostListProvider.ErrorDoesNotSupportHostRole") - - -def test_identify_connection_no_dialect(mock_provider_service, props): - provider = ConnectionStringHostListProvider(mock_provider_service, props) - - with pytest.raises(AwsWrapperError): - provider.identify_connection("ConnectionStringHostListProvider.ErrorDoesNotSupportIdentifyConnection") - - def test_refresh(mock_provider_service, props): provider = ConnectionStringHostListProvider(mock_provider_service, props) expected_host = HostInfo(props.get("host")) diff --git a/tests/unit/test_multi_az_rds_host_list_provider.py b/tests/unit/test_multi_az_rds_host_list_provider.py index 7b95cf5bc..fc2d6a3c6 100644 --- a/tests/unit/test_multi_az_rds_host_list_provider.py +++ b/tests/unit/test_multi_az_rds_host_list_provider.py @@ -23,7 +23,7 @@ QueryTimeoutError) from aws_advanced_python_wrapper.host_list_provider import ( MultiAzTopologyUtils, RdsHostListProvider) -from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole +from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import ProgrammingError from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) @@ -176,46 +176,6 @@ def test_get_topology_no_connection(mocker, mock_provider_service, initial_hosts mock_monitor.force_refresh_with_connection.assert_not_called() -def test_identify_connection_errors(mock_provider_service, mock_conn, mock_cursor, props): - mock_cursor.fetchone.return_value = None - provider = create_provider(mock_provider_service, props) - - with pytest.raises(AwsWrapperError): - provider.identify_connection(mock_conn) - - mock_cursor.execute.side_effect = TimeoutError() - with pytest.raises(QueryTimeoutError): - provider.identify_connection(mock_conn) - - -def test_identify_connection_no_match_in_topology(mocker, mock_provider_service, mock_conn, mock_cursor, props): - mock_cursor.fetchone.return_value = ("non-matching-host",) - provider = create_provider(mock_provider_service, props) - mock_monitor = mocker.MagicMock() - mock_monitor.force_refresh_with_connection.return_value = () - mocker.patch.object(provider, '_get_or_create_monitor', return_value=mock_monitor) - - assert provider.identify_connection(mock_conn) is None - - -def test_identify_connection_empty_topology(mocker, mock_provider_service, mock_conn, mock_cursor, props): - provider = create_provider(mock_provider_service, props) - mock_cursor.fetchone.return_value = ("instance-1",) - - provider.refresh = mocker.MagicMock(return_value=[]) - assert provider.identify_connection(mock_conn) is None - - -def test_identify_connection_host_in_topology(mock_provider_service, mock_conn, mock_cursor, props): - provider = create_provider(mock_provider_service, props) - mock_cursor.fetchone.return_value = ("instance-1",) - mock_topology_query(mock_conn, mock_cursor, [("instance-1", "instance-1.xyz.us-east-2.rds.amazonaws.com", 5432)]) - - host_info = provider.identify_connection(mock_conn) - assert "instance-1.xyz.us-east-2.rds.amazonaws.com" == host_info.host - assert "instance-1" == host_info.host_id - - def test_host_pattern_setting(mock_provider_service, props): props = Properties({"host": "127:0:0:1", "port": 5432, WrapperProperties.CLUSTER_INSTANCE_HOST_PATTERN.name: "?.custom-domain.com"}) @@ -241,21 +201,6 @@ def test_host_pattern_setting(mock_provider_service, props): provider._initialize() -def test_get_host_role(mock_provider_service, mock_conn, mock_cursor, props): - mock_cursor.fetchone.return_value = (True,) - provider = create_provider(mock_provider_service, props) - - assert HostRole.READER == provider.get_host_role(mock_conn) - - mock_cursor.fetchone.return_value = None - with pytest.raises(AwsWrapperError): - provider.get_host_role(mock_conn) - - mock_cursor.execute.side_effect = TimeoutError() - with pytest.raises(QueryTimeoutError): - provider.get_host_role(mock_conn) - - def test_cluster_id_setting(mock_provider_service): props = Properties({"host": "my-cluster.cluster-xyz.us-east-2.rds.amazonaws.com", "port": 5432, WrapperProperties.CLUSTER_ID.name: "my-cluster-id"}) diff --git a/tests/unit/test_plugin_service.py b/tests/unit/test_plugin_service.py new file mode 100644 index 000000000..974ac10aa --- /dev/null +++ b/tests/unit/test_plugin_service.py @@ -0,0 +1,414 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.s + +from concurrent.futures import TimeoutError +from unittest.mock import MagicMock + +import pytest # type: ignore + +from aws_advanced_python_wrapper.database_dialect import ( + AuroraPgDialect, MultiAzClusterPgDialect, UnknownDatabaseDialect) +from aws_advanced_python_wrapper.errors import (AwsWrapperError, + QueryTimeoutError, + UnsupportedOperationError) +from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole +from aws_advanced_python_wrapper.plugin_service import PluginServiceImpl +from aws_advanced_python_wrapper.utils.properties import Properties + + +def test_get_host_role_unknown_dialect(mocker): + mock_conn = MagicMock() + mock_container = mocker.MagicMock() + mock_container.plugin_manager = mocker.MagicMock() + + plugin_service = PluginServiceImpl( + mock_container, + Properties({"host": "test.com"}), + lambda: None, + mocker.MagicMock(), + mocker.MagicMock() + ) + plugin_service._database_dialect = UnknownDatabaseDialect() + + with pytest.raises(UnsupportedOperationError): + plugin_service.get_host_role(mock_conn) + + +def test_identify_connection_unknown_dialect(mocker): + mock_conn = MagicMock() + mock_container = mocker.MagicMock() + mock_container.plugin_manager = mocker.MagicMock() + + plugin_service = PluginServiceImpl( + mock_container, + Properties({"host": "test.com"}), + lambda: None, + mocker.MagicMock(), + mocker.MagicMock() + ) + plugin_service._database_dialect = UnknownDatabaseDialect() + plugin_service._current_connection = mock_conn + + with pytest.raises(UnsupportedOperationError): + plugin_service.identify_connection(mock_conn) + + +def test_get_host_role_reader(mocker): + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.fetchone.return_value = (True,) + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_conn.cursor.return_value = mock_cursor + + mock_container = mocker.MagicMock() + mock_container.plugin_manager = mocker.MagicMock() + + plugin_service = PluginServiceImpl( + mock_container, + Properties({"host": "test.com"}), + lambda: None, + mocker.MagicMock(), + mocker.MagicMock() + ) + plugin_service._database_dialect = AuroraPgDialect() + plugin_service._current_connection = mock_conn + + # Mock preserve_transaction_status_with_timeout to execute directly + def mock_preserve(thread_pool, timeout, driver_dialect, conn): + def decorator(func): + return func + return decorator + + mocker.patch('aws_advanced_python_wrapper.database_dialect.preserve_transaction_status_with_timeout', mock_preserve) + + assert HostRole.READER == plugin_service.get_host_role(mock_conn) + + +def test_get_host_role_writer(mocker): + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.fetchone.return_value = (False,) + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_conn.cursor.return_value = mock_cursor + + mock_container = mocker.MagicMock() + mock_container.plugin_manager = mocker.MagicMock() + + plugin_service = PluginServiceImpl( + mock_container, + Properties({"host": "test.com"}), + lambda: None, + mocker.MagicMock(), + mocker.MagicMock() + ) + plugin_service._database_dialect = AuroraPgDialect() + plugin_service._current_connection = mock_conn + + # Mock preserve_transaction_status_with_timeout to execute directly + def mock_preserve(thread_pool, timeout, driver_dialect, conn): + def decorator(func): + return func + return decorator + + mocker.patch('aws_advanced_python_wrapper.database_dialect.preserve_transaction_status_with_timeout', mock_preserve) + + assert HostRole.WRITER == plugin_service.get_host_role(mock_conn) + + +def test_get_host_role_error(mocker): + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.execute.side_effect = ValueError() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_conn.cursor.return_value = mock_cursor + + mock_container = mocker.MagicMock() + mock_container.plugin_manager = mocker.MagicMock() + + plugin_service = PluginServiceImpl( + mock_container, + Properties({"host": "test.com"}), + lambda: None, + mocker.MagicMock(), + mocker.MagicMock() + ) + plugin_service._database_dialect = AuroraPgDialect() + plugin_service._current_connection = mock_conn + + # Mock preserve_transaction_status_with_timeout to execute directly + def mock_preserve(thread_pool, timeout, driver_dialect, conn): + def decorator(func): + return func + return decorator + + mocker.patch('aws_advanced_python_wrapper.database_dialect.preserve_transaction_status_with_timeout', mock_preserve) + + with pytest.raises(AwsWrapperError): + plugin_service.get_host_role(mock_conn) + + +def test_get_host_role_timeout(mocker): + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.execute.side_effect = TimeoutError() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_conn.cursor.return_value = mock_cursor + + mock_container = mocker.MagicMock() + mock_container.plugin_manager = mocker.MagicMock() + + plugin_service = PluginServiceImpl( + mock_container, + Properties({"host": "test.com"}), + lambda: None, + mocker.MagicMock(), + mocker.MagicMock() + ) + plugin_service._database_dialect = AuroraPgDialect() + plugin_service._current_connection = mock_conn + + # Mock preserve_transaction_status_with_timeout to execute directly + def mock_preserve(thread_pool, timeout, driver_dialect, conn): + def decorator(func): + return func + return decorator + + mocker.patch('aws_advanced_python_wrapper.database_dialect.preserve_transaction_status_with_timeout', mock_preserve) + + with pytest.raises(QueryTimeoutError): + plugin_service.get_host_role(mock_conn) + + +def test_identify_connection_error_no_result(mocker): + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.fetchone.return_value = None + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_conn.cursor.return_value = mock_cursor + + mock_container = mocker.MagicMock() + mock_container.plugin_manager = mocker.MagicMock() + mock_host_list_provider = mocker.MagicMock() + + plugin_service = PluginServiceImpl( + mock_container, + Properties({"host": "test.com"}), + lambda: None, + mocker.MagicMock(), + mocker.MagicMock() + ) + plugin_service._database_dialect = AuroraPgDialect() + plugin_service._current_connection = mock_conn + plugin_service._host_list_provider = mock_host_list_provider + + # Mock preserve_transaction_status_with_timeout to execute directly + def mock_preserve(thread_pool, timeout, driver_dialect, conn): + def decorator(func): + return func + return decorator + + mocker.patch('aws_advanced_python_wrapper.database_dialect.preserve_transaction_status_with_timeout', mock_preserve) + + with pytest.raises(AwsWrapperError): + plugin_service.identify_connection(mock_conn) + + +def test_identify_connection_timeout(mocker): + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.execute.side_effect = TimeoutError() + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_conn.cursor.return_value = mock_cursor + + mock_container = mocker.MagicMock() + mock_container.plugin_manager = mocker.MagicMock() + + plugin_service = PluginServiceImpl( + mock_container, + Properties({"host": "test.com"}), + lambda: None, + mocker.MagicMock(), + mocker.MagicMock() + ) + plugin_service._database_dialect = AuroraPgDialect() + plugin_service._current_connection = mock_conn + + # Mock preserve_transaction_status_with_timeout to execute directly + def mock_preserve(thread_pool, timeout, driver_dialect, conn): + def decorator(func): + return func + return decorator + + mocker.patch('aws_advanced_python_wrapper.database_dialect.preserve_transaction_status_with_timeout', mock_preserve) + + with pytest.raises(QueryTimeoutError): + plugin_service.identify_connection(mock_conn) + + +def test_identify_connection_no_match_in_topology(mocker): + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.fetchone.return_value = ("host-value", "non-matching-host") + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_conn.cursor.return_value = mock_cursor + + mock_container = mocker.MagicMock() + mock_container.plugin_manager = mocker.MagicMock() + mock_host_list_provider = mocker.MagicMock() + mock_host_list_provider.refresh.return_value = () + mock_host_list_provider.force_refresh.return_value = () + + plugin_service = PluginServiceImpl( + mock_container, + Properties({"host": "test.com"}), + lambda: None, + mocker.MagicMock(), + mocker.MagicMock() + ) + plugin_service._database_dialect = AuroraPgDialect() + plugin_service._current_connection = mock_conn + plugin_service._host_list_provider = mock_host_list_provider + + # Mock preserve_transaction_status_with_timeout to execute directly + def mock_preserve(thread_pool, timeout, driver_dialect, conn): + def decorator(func): + return func + return decorator + + mocker.patch('aws_advanced_python_wrapper.database_dialect.preserve_transaction_status_with_timeout', mock_preserve) + + assert plugin_service.identify_connection(mock_conn) is None + + +def test_identify_connection_empty_topology(mocker): + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.fetchone.return_value = ("host-value", "instance-1") + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_conn.cursor.return_value = mock_cursor + + mock_container = mocker.MagicMock() + mock_container.plugin_manager = mocker.MagicMock() + mock_host_list_provider = mocker.MagicMock() + mock_host_list_provider.refresh.return_value = [] + + plugin_service = PluginServiceImpl( + mock_container, + Properties({"host": "test.com"}), + lambda: None, + mocker.MagicMock(), + mocker.MagicMock() + ) + plugin_service._database_dialect = AuroraPgDialect() + plugin_service._current_connection = mock_conn + plugin_service._host_list_provider = mock_host_list_provider + + # Mock preserve_transaction_status_with_timeout to execute directly + def mock_preserve(thread_pool, timeout, driver_dialect, conn): + def decorator(func): + return func + return decorator + + mocker.patch('aws_advanced_python_wrapper.database_dialect.preserve_transaction_status_with_timeout', mock_preserve) + + assert plugin_service.identify_connection(mock_conn) is None + + +def test_identify_connection_host_in_topology_aurora(mocker): + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.fetchone.return_value = ("instance-1", "instance-1") + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_conn.cursor.return_value = mock_cursor + + expected_host = HostInfo("instance-1.xyz.us-east-2.rds.amazonaws.com", 5432, HostRole.WRITER, host_id="instance-1") + + mock_container = mocker.MagicMock() + mock_container.plugin_manager = mocker.MagicMock() + mock_host_list_provider = mocker.MagicMock() + mock_host_list_provider.refresh.return_value = (expected_host,) + + plugin_service = PluginServiceImpl( + mock_container, + Properties({"host": "test.com"}), + lambda: None, + mocker.MagicMock(), + mocker.MagicMock() + ) + plugin_service._database_dialect = AuroraPgDialect() + plugin_service._current_connection = mock_conn + plugin_service._host_list_provider = mock_host_list_provider + + # Mock preserve_transaction_status_with_timeout to execute directly + def mock_preserve(thread_pool, timeout, driver_dialect, conn): + def decorator(func): + return func + return decorator + + mocker.patch('aws_advanced_python_wrapper.database_dialect.preserve_transaction_status_with_timeout', mock_preserve) + + host_info = plugin_service.identify_connection(mock_conn) + assert host_info is not None + assert "instance-1.xyz.us-east-2.rds.amazonaws.com" == host_info.host + assert "instance-1" == host_info.host_id + + +def test_identify_connection_host_in_topology_multiaz(mocker): + mock_conn = MagicMock() + mock_cursor = MagicMock() + # Multi-AZ returns different values: (instanceId, instanceName) + mock_cursor.fetchone.return_value = ("db-WQFQKBTL2LQUPIEFIFBGENS4ZQ", "instance-1") + mock_cursor.__enter__ = MagicMock(return_value=mock_cursor) + mock_cursor.__exit__ = MagicMock(return_value=False) + mock_conn.cursor.return_value = mock_cursor + + expected_host = HostInfo("instance-1.xyz.us-east-2.rds.amazonaws.com", 5432, HostRole.WRITER, host_id="instance-1") + + mock_container = mocker.MagicMock() + mock_container.plugin_manager = mocker.MagicMock() + mock_host_list_provider = mocker.MagicMock() + mock_host_list_provider.refresh.return_value = (expected_host,) + + plugin_service = PluginServiceImpl( + mock_container, + Properties({"host": "test.com", "port": 5432}), + lambda: None, + mocker.MagicMock(), + mocker.MagicMock() + ) + plugin_service._database_dialect = MultiAzClusterPgDialect() + plugin_service._current_connection = mock_conn + plugin_service._host_list_provider = mock_host_list_provider + + # Mock preserve_transaction_status_with_timeout to execute directly + def mock_preserve(thread_pool, timeout, driver_dialect, conn): + def decorator(func): + return func + return decorator + + mocker.patch('aws_advanced_python_wrapper.database_dialect.preserve_transaction_status_with_timeout', mock_preserve) + + host_info = plugin_service.identify_connection(mock_conn) + assert host_info is not None + assert "instance-1.xyz.us-east-2.rds.amazonaws.com" == host_info.host + assert "instance-1" == host_info.host_id diff --git a/tests/unit/test_rds_host_list_provider.py b/tests/unit/test_rds_host_list_provider.py index c4f74bf51..78f5afdf1 100644 --- a/tests/unit/test_rds_host_list_provider.py +++ b/tests/unit/test_rds_host_list_provider.py @@ -189,50 +189,6 @@ def test_get_topology_no_connection(mocker, mock_provider_service, initial_hosts mock_monitor.force_refresh_with_connection.assert_not_called() -def test_identify_connection_errors(mock_provider_service, mock_conn, mock_cursor, props): - mock_cursor.fetchone.return_value = None - topology_utils = AuroraTopologyUtils(AuroraPgDialect(), props) - provider = RdsHostListProvider(mock_provider_service, mock_provider_service, props, topology_utils) - - with pytest.raises(AwsWrapperError): - provider.identify_connection(mock_conn) - - mock_cursor.execute.side_effect = TimeoutError() - with pytest.raises(QueryTimeoutError): - provider.identify_connection(mock_conn) - - -def test_identify_connection_no_match_in_topology(mocker, mock_provider_service, mock_conn, mock_cursor, props): - mock_cursor.fetchone.return_value = ("non-matching-host",) - topology_utils = AuroraTopologyUtils(AuroraPgDialect(), props) - provider = RdsHostListProvider(mock_provider_service, mock_provider_service, props, topology_utils) - mock_monitor = mocker.MagicMock() - mock_monitor.force_refresh_with_connection.return_value = () - mocker.patch.object(provider, '_get_or_create_monitor', return_value=mock_monitor) - - assert provider.identify_connection(mock_conn) is None - - -def test_identify_connection_empty_topology(mocker, mock_provider_service, mock_conn, mock_cursor, props): - topology_utils = AuroraTopologyUtils(AuroraPgDialect(), props) - provider = RdsHostListProvider(mock_provider_service, mock_provider_service, props, topology_utils) - mock_cursor.fetchone.return_value = ("instance-1",) - - provider.refresh = mocker.MagicMock(return_value=[]) - assert provider.identify_connection(mock_conn) is None - - -def test_identify_connection_host_in_topology(mock_provider_service, mock_conn, mock_cursor, props): - topology_utils = AuroraTopologyUtils(AuroraPgDialect(), props) - provider = RdsHostListProvider(mock_provider_service, mock_provider_service, props, topology_utils) - mock_cursor.fetchone.return_value = ("instance-1",) - mock_topology_query(mock_conn, mock_cursor, [("instance-1", True)]) - - host_info = provider.identify_connection(mock_conn) - assert "instance-1.xyz.us-east-2.rds.amazonaws.com" == host_info.host - assert "instance-1" == host_info.host_id - - def test_host_pattern_setting(mock_provider_service, props): props = Properties({"host": "127:0:0:1", WrapperProperties.CLUSTER_INSTANCE_HOST_PATTERN.name: "?.custom-domain.com"}) @@ -254,22 +210,6 @@ def test_host_pattern_setting(mock_provider_service, props): provider = RdsHostListProvider(mock_provider_service, mock_provider_service, props, AuroraTopologyUtils(AuroraPgDialect(), props)) -def test_get_host_role(mock_provider_service, mock_conn, mock_cursor, props): - mock_cursor.fetchone.return_value = (True,) - topology_utils = AuroraTopologyUtils(AuroraPgDialect(), props) - provider = RdsHostListProvider(mock_provider_service, mock_provider_service, props, topology_utils) - - assert HostRole.READER == provider.get_host_role(mock_conn) - - mock_cursor.fetchone.return_value = None - with pytest.raises(AwsWrapperError): - provider.get_host_role(mock_conn) - - mock_cursor.execute.side_effect = TimeoutError() - with pytest.raises(QueryTimeoutError): - provider.get_host_role(mock_conn) - - def test_cluster_id_setting(mock_provider_service): props = Properties({"host": "my-cluster.cluster-xyz.us-east-2.rds.amazonaws.com", WrapperProperties.CLUSTER_ID.name: "my-cluster-id"})