diff --git a/aws_advanced_python_wrapper/blue_green_plugin.py b/aws_advanced_python_wrapper/blue_green_plugin.py index f0f95ac6..7e3a655d 100644 --- a/aws_advanced_python_wrapper/blue_green_plugin.py +++ b/aws_advanced_python_wrapper/blue_green_plugin.py @@ -1093,7 +1093,7 @@ def _init_host_list_provider(self): logger.warning("BlueGreenStatusMonitor.HostInfoNone") return - host_list_provider_supplier = self._plugin_service.database_dialect.get_host_list_provider_supplier() + host_list_provider_supplier = self._plugin_service.database_dialect.get_host_list_provider_supplier(self._plugin_service) host_list_provider_service: HostListProviderService = cast('HostListProviderService', self._plugin_service) self._host_list_provider = host_list_provider_supplier(host_list_provider_service, props_copy) diff --git a/aws_advanced_python_wrapper/cluster_topology_monitor.py b/aws_advanced_python_wrapper/cluster_topology_monitor.py new file mode 100644 index 00000000..b1acf26f --- /dev/null +++ b/aws_advanced_python_wrapper/cluster_topology_monitor.py @@ -0,0 +1,537 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from __future__ import annotations + +import random +import threading +import time +from abc import ABC, abstractmethod +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING, Dict, Optional, Tuple + +from aws_advanced_python_wrapper.hostinfo import HostInfo +from aws_advanced_python_wrapper.utils.atomic import AtomicReference +from aws_advanced_python_wrapper.utils.cache_map import CacheMap +from aws_advanced_python_wrapper.utils.messages import Messages +from aws_advanced_python_wrapper.utils.utils import LogUtils + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.pep249 import Connection + from aws_advanced_python_wrapper.plugin_service import PluginService + from aws_advanced_python_wrapper.utils.properties import Properties + from aws_advanced_python_wrapper.host_list_provider import TopologyUtils + +from aws_advanced_python_wrapper.hostinfo import HostRole +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.properties import (PropertiesUtils, + WrapperProperties) + +logger = Logger(__name__) + + +class ClusterTopologyMonitor(ABC): + @abstractmethod + def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]: + pass + + @abstractmethod + def force_refresh_with_connection(self, connection: Connection, timeout_sec: int) -> Tuple[HostInfo, ...]: + pass + + @abstractmethod + def can_dispose(self) -> bool: + pass + + @abstractmethod + def close(self) -> None: + pass + + +class ClusterTopologyMonitorImpl(ClusterTopologyMonitor): + MONITOR_TERMINATION_TIMEOUT_SEC = 30 + CLOSE_CONNECTION_NETWORK_TIMEOUT_MS = 500 + DEFAULT_CONNECT_TIMEOUT_SEC = 5 + DEFAULT_SOCKET_TIMEOUT_SEC = 5 + TOPOLOGY_CACHE_EXPIRATION_NANO = 5 * 60 * 1_000_000_000 # 5 minutes in nanoseconds + + HIGH_REFRESH_PERIOD_AFTER_PANIC_NANO = 30 * 1_000_000_000 # 30 seconds in nanoseconds + IGNORE_TOPOLOGY_REQUEST_NANO = 10 * 1_000_000_000 # 10 seconds in nanoseconds + + INITIAL_BACKOFF_MS = 100 + MAX_BACKOFF_MS = 10000 + + _topology_map: CacheMap[str, Tuple[HostInfo, ...]] = CacheMap() + + def __init__(self, plugin_service: PluginService, topology_utils: TopologyUtils, cluster_id: str, + initial_host_info: HostInfo, properties: Properties, instance_template: HostInfo, + refresh_rate_nano: int, high_refresh_rate_nano: int): + self._plugin_service = plugin_service + self._topology_utils = topology_utils + self._cluster_id = cluster_id + self._initial_host_info: HostInfo = initial_host_info + self._properties = properties + self._instance_template = instance_template + self._refresh_rate_nano = refresh_rate_nano + self._high_refresh_rate_nano = high_refresh_rate_nano + + self._writer_host_info: AtomicReference[Optional[HostInfo]] = AtomicReference(None) + self._monitoring_connection: AtomicReference[Optional[Connection]] = AtomicReference(None) + + self._topology_updated = threading.Event() + self._request_to_update_topology = threading.Event() + self._ignore_new_topology_requests_end_time_nano = -1 + self._submitted_hosts: Dict[str, bool] = {} + + self._thread_pool_executor: AtomicReference[Optional[ThreadPoolExecutor]] = AtomicReference(None) + self._host_threads_stop = threading.Event() + self._host_threads_writer_connection: AtomicReference[Optional[Connection]] = AtomicReference(None) + self._host_threads_writer_host_info: AtomicReference[Optional[HostInfo]] = AtomicReference(None) + self._host_threads_reader_connection: AtomicReference[Optional[Connection]] = AtomicReference(None) + self._host_threads_latest_topology: AtomicReference[Optional[Tuple[HostInfo, ...]]] = AtomicReference(None) + + self._is_verified_writer_connection = False + self._high_refresh_rate_end_time_nano = 0 + self._stop = threading.Event() + self._monitor_thread: Optional[threading.Thread] = None + + self._monitoring_properties = PropertiesUtils.create_topology_monitoring_properties(properties) + if WrapperProperties.SOCKET_TIMEOUT_SEC.get(self._monitoring_properties) is None: + WrapperProperties.SOCKET_TIMEOUT_SEC.set(self._monitoring_properties, self.DEFAULT_SOCKET_TIMEOUT_SEC) + if WrapperProperties.CONNECT_TIMEOUT_SEC.get(self._monitoring_properties) is None: + WrapperProperties.CONNECT_TIMEOUT_SEC.set(self._monitoring_properties, self.DEFAULT_CONNECT_TIMEOUT_SEC) + + self._start_monitoring() + + def force_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]: + current_time_nano = time.time_ns() + if (self._ignore_new_topology_requests_end_time_nano > 0 and + current_time_nano < self._ignore_new_topology_requests_end_time_nano): + current_hosts = self._get_stored_hosts() + if current_hosts is not None: + logger.debug("ClusterTopologyMonitorImpl.IgnoringTopologyRequest", self._cluster_id, LogUtils.log_topology(current_hosts)) + return current_hosts + + if should_verify_writer: + self._close_connection_from_ref(self._monitoring_connection) + self._is_verified_writer_connection = False + + result = self._wait_till_topology_gets_updated(timeout_sec) + return result + + def force_refresh_with_connection(self, connection: Connection, timeout_sec: int) -> Tuple[HostInfo, ...]: + if self._is_verified_writer_connection: + return self._wait_till_topology_gets_updated(timeout_sec) + return self._fetch_topology_and_update_cache(connection) + + def _wait_till_topology_gets_updated(self, timeout_sec: int) -> Tuple[HostInfo, ...]: + current_hosts = self._get_stored_hosts() + + self._request_to_update_topology.set() + + if timeout_sec == 0: + logger.debug("ClusterTopologyMonitorImpl.TimeoutSetToZero", self._cluster_id, LogUtils.log_topology(current_hosts)) + return current_hosts + + end_time = time.time() + timeout_sec + while time.time() < end_time: + latest_hosts = self._get_stored_hosts() + if latest_hosts is not current_hosts: + return latest_hosts + + if self._topology_updated.wait(1.0): + self._topology_updated.clear() + + raise TimeoutError( + Messages.get_formatted( + "ClusterTopologyMonitorImpl.TopologyNotUpdated", + self._cluster_id, timeout_sec * 1000)) + + def _get_stored_hosts(self) -> Tuple[HostInfo, ...]: + hosts = ClusterTopologyMonitorImpl._topology_map.get(self._cluster_id) + if hosts is None: + return () + return hosts + + def can_dispose(self) -> bool: + return self._stop.is_set() + + def close(self) -> None: + logger.debug("ClusterTopologyMonitorImpl.ClosingMonitor", self._cluster_id) + self._stop.set() + self._request_to_update_topology.set() + + self._close_host_monitors() + + if self._monitor_thread and self._monitor_thread.is_alive(): + self._monitor_thread.join(self.MONITOR_TERMINATION_TIMEOUT_SEC) + + # Step 3: Now safe to close connections - no threads are using them + self._close_connection_from_ref(self._monitoring_connection) + self._close_connection_from_ref(self._host_threads_writer_connection) + self._close_connection_from_ref(self._host_threads_reader_connection) + + def _start_monitoring(self) -> None: + self._monitor_thread = threading.Thread(target=self._monitor, daemon=True) + self._monitor_thread.start() + + def _monitor(self) -> None: + try: + logger.debug("ClusterTopologyMonitor.StartMonitoringThread", self._cluster_id, self._initial_host_info.host) + + while not self._stop.is_set(): + if self._is_in_panic_mode(): + if not self._submitted_hosts: + self._close_host_monitors() + self._host_threads_stop.clear() + self._host_threads_writer_host_info.set(None) + self._host_threads_latest_topology.set(None) + + hosts = self._get_stored_hosts() + if not hosts: + hosts = self._open_any_connection_and_update_topology() + + if hosts and not self._is_verified_writer_connection: + logger.debug("ClusterTopologyMonitorImpl.StartingHostMonitoringThreads", self._cluster_id) + writer_host_info = self._writer_host_info.get() + for host_info in hosts: + if host_info.host not in self._submitted_hosts: + try: + worker = self._get_host_monitor(host_info, writer_host_info) + self._get_host_executor_service().submit(worker) + self._submitted_hosts[host_info.host] = True + except Exception as e: + logger.debug( + "ClusterTopologyMonitorImpl.ExceptionStartingHostMonitor", + self._cluster_id, host_info.host, e) + else: + # Check if writer has been detected + writer_host_info = self._host_threads_writer_host_info.get() + writer_connection = self._host_threads_writer_connection.get() + if (writer_connection is not None and writer_host_info is not None): + logger.debug("ClusterTopologyMonitorImpl.WriterPickedUpFromHostMonitors", self._cluster_id, writer_host_info.host) + self._close_connection_from_ref(self._monitoring_connection) + self._monitoring_connection.set(writer_connection) + self._writer_host_info.set(writer_host_info) + self._is_verified_writer_connection = True + self._high_refresh_rate_end_time_nano = ( + time.time_ns() + self.HIGH_REFRESH_PERIOD_AFTER_PANIC_NANO) + + if self._ignore_new_topology_requests_end_time_nano == -1: + self._ignore_new_topology_requests_end_time_nano = 0 + else: + self._ignore_new_topology_requests_end_time_nano = ( + time.time_ns() + self.IGNORE_TOPOLOGY_REQUEST_NANO) + + self._host_threads_stop.set() + self._close_host_monitors() + self._submitted_hosts.clear() + continue + + # Update host monitors with new topology + host_threads_topology = self._host_threads_latest_topology.get() + if host_threads_topology is not None and not self._host_threads_stop.is_set(): + for host_info in host_threads_topology: + if host_info.host not in self._submitted_hosts: + try: + worker = self._get_host_monitor(host_info, self._writer_host_info.get()) + self._get_host_executor_service().submit(worker) + self._submitted_hosts[host_info.host] = True + except Exception as e: + logger.debug( + "ClusterTopologyMonitorImpl.ExceptionStartingHostMonitor", + self._cluster_id, host_info.host, e) + + self._delay(True) + else: + # Regular mode + if self._submitted_hosts: + self._close_host_monitors() + self._submitted_hosts.clear() + + hosts = self._fetch_topology_and_update_cache(self._monitoring_connection.get()) + if not hosts: + self._close_connection_from_ref(self._monitoring_connection) + self._is_verified_writer_connection = False + self._writer_host_info.set(None) + continue + + current_time_nano = time.time_ns() + if (self._high_refresh_rate_end_time_nano > 0 and + current_time_nano > self._high_refresh_rate_end_time_nano): + self._high_refresh_rate_end_time_nano = 0 + + self._delay(False) + + if (self._ignore_new_topology_requests_end_time_nano > 0 and + time.time_ns() > self._ignore_new_topology_requests_end_time_nano): + self._ignore_new_topology_requests_end_time_nano = 0 + + except Exception as ex: + logger.info("ClusterTopologyMonitorImpl.ExceptionDuringMonitoringStop", self._cluster_id, ex) + finally: + self._stop.set() + self._close_host_monitors() + self._close_connection_from_ref(self._monitoring_connection) + logger.debug("ClusterTopologyMonitor.StopMonitoringThread", self._cluster_id, self._initial_host_info.host) + + def _is_in_panic_mode(self) -> bool: + return self._monitoring_connection.get() is None or not self._is_verified_writer_connection + + def _get_host_monitor(self, host_info: HostInfo, writer_host_info: Optional[HostInfo]): + return HostMonitor(self, host_info, writer_host_info) + + def _open_any_connection_and_update_topology(self) -> Tuple[HostInfo, ...]: + writer_verified_by_this_thread = False + if self._monitoring_connection.get() is None: + # Try to connect to the initial host first + try: + conn = self._plugin_service.force_connect(self._initial_host_info, self._monitoring_properties) + self._monitoring_connection.set(conn) + logger.debug("ClusterTopologyMonitorImpl.OpenedMonitoringConnection", self._cluster_id, self._initial_host_info.host) + + try: + writer_host = self._topology_utils.get_writer_host_if_connected( + conn, self._plugin_service.driver_dialect) + if writer_host: + self._is_verified_writer_connection = True + writer_verified_by_this_thread = True + self._writer_host_info.set(HostInfo(writer_host, self._initial_host_info.port)) + logger.debug("ClusterTopologyMonitorImpl.WriterMonitoringConnection", self._cluster_id, writer_host) + except Exception: + pass + except Exception: + return () + + hosts = self._fetch_topology_and_update_cache(self._monitoring_connection.get()) + if writer_verified_by_this_thread: + if self._ignore_new_topology_requests_end_time_nano == -1: + self._ignore_new_topology_requests_end_time_nano = 0 + else: + self._ignore_new_topology_requests_end_time_nano = ( + time.time_ns() + self.IGNORE_TOPOLOGY_REQUEST_NANO) + + if len(hosts) == 0: + self._close_connection_from_ref(self._monitoring_connection) + self._is_verified_writer_connection = False + self._writer_host_info.set(None) + + return hosts + + def _close_connection(self, connection: Optional[Connection]) -> None: + try: + if connection is not None: + connection.close() + except Exception: + pass + + def _close_connection_from_ref(self, connection: AtomicReference[Optional[Connection]]) -> None: + connection_to_close: Optional[Connection] = connection.get_and_set(None) + self._close_connection(connection_to_close) + + def _host_thread_connection_cleanup(self) -> None: + writer_connection = self._host_threads_writer_connection.get_and_set(None) + if self._monitoring_connection.get() != writer_connection: + self._close_connection(writer_connection) + + reader_connection = self._host_threads_reader_connection.get_and_set(None) + if self._monitoring_connection.get() != reader_connection: + self._close_connection(reader_connection) + + def _close_host_monitors(self) -> None: + self._host_threads_stop.set() + + thread_pool_executor = self._thread_pool_executor.get_and_set(None) + if thread_pool_executor is not None: + thread_pool_executor.shutdown(wait=True, cancel_futures=True) + self._host_thread_connection_cleanup() + + self._submitted_hosts.clear() + + def _get_host_executor_service(self) -> ThreadPoolExecutor: + if self._stop.is_set(): + raise RuntimeError(Messages.get_formatted( + "ClusterTopologyMonitorImpl.CannotCreateExecutorWhenStopped", self._cluster_id)) + thread_pool_executor = self._thread_pool_executor.get() + if thread_pool_executor is None: + thread_pool_executor = ThreadPoolExecutor(thread_name_prefix=self._cluster_id) + self._thread_pool_executor.compare_and_set(None, thread_pool_executor) + return thread_pool_executor + + def _delay(self, use_high_refresh_rate: bool) -> None: + current_time_nano = time.time_ns() + if (self._high_refresh_rate_end_time_nano > 0 and + current_time_nano < self._high_refresh_rate_end_time_nano): + use_high_refresh_rate = True + + if self._request_to_update_topology.is_set(): + use_high_refresh_rate = True + + refresh_rate = self._high_refresh_rate_nano if use_high_refresh_rate else self._refresh_rate_nano + delay_sec = refresh_rate / 1_000_000_000.0 + + start_time = time.time() + end_time = start_time + delay_sec + + while not self._request_to_update_topology.is_set() and time.time() < end_time and not self._stop.is_set(): + time.sleep(0.05) + + def _fetch_topology_and_update_cache(self, connection: Optional[Connection]) -> Tuple[HostInfo, ...]: + if connection is None: + return () + + try: + hosts = self._query_for_topology(connection) + if hosts: + self._update_topology_cache(hosts) + return hosts + return () + except Exception as ex: + logger.debug("ClusterTopologyMonitorImpl.ErrorFetchingTopology", self._cluster_id, ex) + return () + + def _query_for_topology(self, connection: Connection) -> Tuple[HostInfo, ...]: + hosts = self._topology_utils.query_for_topology(connection, self._plugin_service.driver_dialect) + if hosts is not None: + return hosts + return () + + def _update_topology_cache(self, hosts: Tuple[HostInfo, ...]) -> None: + ClusterTopologyMonitorImpl._topology_map.put( + self._cluster_id, hosts, ClusterTopologyMonitorImpl.TOPOLOGY_CACHE_EXPIRATION_NANO) + + # Notify waiting threads + self._request_to_update_topology.clear() + self._topology_updated.set() + + +class HostMonitor: + def __init__(self, monitor: ClusterTopologyMonitorImpl, host_info: HostInfo, + writer_host_info: Optional[HostInfo]): + self._monitor: ClusterTopologyMonitorImpl = monitor + self._host_info = host_info + self._writer_host_info = writer_host_info + self._writer_changed = False + self._connection_attempts = 0 + + def __call__(self) -> None: + connection = None + update_topology = False + start_time = time.time() + + try: + while not self._monitor._host_threads_stop.is_set(): + if self._monitor._host_threads_stop.is_set(): + return + + if connection is None: + try: + connection = self._monitor._plugin_service.force_connect( + self._host_info, self._monitor._monitoring_properties) + self._connection_attempts = 0 + except Exception as ex: + if self._monitor._host_threads_stop.is_set(): + return + + if self._monitor._plugin_service.is_network_exception(ex): + time.sleep(0.1) + continue + elif self._monitor._plugin_service.is_login_exception(ex): + raise RuntimeError(ex) + else: + backoff = self._calculate_backoff_with_jitter(self._connection_attempts) + self._connection_attempts += 1 + time.sleep(backoff / 1000.0) + continue + + if self._monitor._host_threads_stop.is_set(): + return + + if connection is not None: + is_writer = False + try: + is_writer = self._monitor._topology_utils.get_writer_host_if_connected( + connection, self._monitor._plugin_service.driver_dialect) is not None + except Exception: + self._monitor._close_connection(connection) + connection = None + continue + + if is_writer: + try: + if self._monitor._topology_utils.get_host_role( + connection, self._monitor._plugin_service.driver_dialect) != HostRole.WRITER: + is_writer = False + except Exception as ex: + logger.debug("HostMonitor.InvalidWriterQuery", ex) + continue + + if is_writer: + if self._monitor._host_threads_writer_connection.compare_and_set(None, connection): + self._monitor._fetch_topology_and_update_cache(connection) + self._monitor._host_threads_writer_host_info.set(self._host_info) + logger.debug("HostMonitor.DetectedWriter", self._host_info.host) + self._monitor._host_threads_stop.set() + connection = None # Prevent cleanup + return + else: + self._monitor._close_connection(connection) + connection = None + return + elif connection is not None: + # Reader connection + if self._monitor._host_threads_writer_connection.get() is None: + if update_topology: + self._reader_thread_fetch_topology(connection) + elif self._monitor._host_threads_reader_connection.compare_and_set(None, connection): + update_topology = True + self._reader_thread_fetch_topology(connection) + + time.sleep(0.1) + + except Exception as ex: + logger.debug("HostMonitor.Exception", self._host_info.host, ex) + finally: + self._monitor._close_connection(connection) + elapsed_time = (time.time() - start_time) * 1000 + logger.debug("HostMonitor.MonitorCompleted", self._host_info.host, elapsed_time) + + def _reader_thread_fetch_topology(self, connection: Connection) -> None: + if connection is None: + return + + try: + hosts = self._monitor._query_for_topology(connection) + if hosts is None: + return + except Exception: + return + + self._monitor._host_threads_latest_topology.set(hosts) + + if self._writer_changed: + self._monitor._update_topology_cache(hosts) + return + + latest_writer_host = next((host for host in hosts if host.role == HostRole.WRITER), None) + if (latest_writer_host is not None and self._writer_host_info is not None and + (latest_writer_host.host != self._writer_host_info.host or + latest_writer_host.port != self._writer_host_info.port)): + self._writer_changed = True + logger.debug("HostMonitor.WriterHostChanged", self._writer_host_info.host, latest_writer_host.host) + self._monitor._update_topology_cache(hosts) + + def _calculate_backoff_with_jitter(self, attempt: int) -> int: + backoff = ClusterTopologyMonitorImpl.INITIAL_BACKOFF_MS * (2 ** min(attempt, 6)) + backoff = min(backoff, ClusterTopologyMonitorImpl.MAX_BACKOFF_MS) + return int(backoff * (0.5 + random.random() * 0.5)) diff --git a/aws_advanced_python_wrapper/database_dialect.py b/aws_advanced_python_wrapper/database_dialect.py index e6f4b973..96e90af0 100644 --- a/aws_advanced_python_wrapper/database_dialect.py +++ b/aws_advanced_python_wrapper/database_dialect.py @@ -18,14 +18,16 @@ Protocol, Tuple, runtime_checkable) from aws_advanced_python_wrapper.driver_info import DriverInfo +from aws_advanced_python_wrapper.failover_v2_plugin import FailoverV2Plugin from aws_advanced_python_wrapper.host_list_provider import ( - AuroraTopologyUtils, MultiAzTopologyUtils) + AuroraTopologyUtils, MonitoringRdsHostListProvider, MultiAzTopologyUtils) from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType if TYPE_CHECKING: from aws_advanced_python_wrapper.pep249 import Connection from .driver_dialect import DriverDialect from .exception_handling import ExceptionHandler + from aws_advanced_python_wrapper.plugin_service import PluginService from abc import ABC, abstractmethod from concurrent.futures import TimeoutError @@ -88,6 +90,7 @@ class TopologyAwareDatabaseDialect(Protocol): _TOPOLOGY_QUERY: str _HOST_ID_QUERY: str _IS_READER_QUERY: str + _WRITER_HOST_QUERY: str @property def topology_query(self) -> str: @@ -101,6 +104,10 @@ def host_id_query(self) -> str: def is_reader_query(self) -> str: return self._IS_READER_QUERY + @property + def writer_id_query(self) -> str: + return self._WRITER_HOST_QUERY + @runtime_checkable class AuroraLimitlessDialect(Protocol): @@ -147,7 +154,7 @@ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool: ... @abstractmethod - def get_host_list_provider_supplier(self) -> Callable: + def get_host_list_provider_supplier(self, plugin_service: PluginService) -> Callable: ... @abstractmethod @@ -213,7 +220,7 @@ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool: return False - def get_host_list_provider_supplier(self) -> Callable: + def get_host_list_provider_supplier(self, plugin_service: PluginService) -> Callable: return lambda provider_service, props: ConnectionStringHostListProvider(provider_service, props) def prepare_conn_props(self, props: Properties): @@ -261,7 +268,7 @@ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool: return False - def get_host_list_provider_supplier(self) -> Callable: + def get_host_list_provider_supplier(self, plugin_service: PluginService) -> Callable: return lambda provider_service, props: ConnectionStringHostListProvider(provider_service, props) def prepare_conn_props(self, props: Properties): @@ -387,6 +394,9 @@ class AuroraMysqlDialect(MysqlDatabaseDialect, TopologyAwareDatabaseDialect, Blu "OR SESSION_ID = 'MASTER_SESSION_ID' ") _HOST_ID_QUERY = "SELECT @@aurora_server_id" _IS_READER_QUERY = "SELECT @@innodb_read_only" + _WRITER_HOST_QUERY = \ + ("SELECT SERVER_ID FROM information_schema.replica_host_status " + "WHERE SESSION_ID = 'MASTER_SESSION_ID' AND SERVER_ID = @@aurora_server_id") _BG_STATUS_QUERY = "SELECT version, endpoint, port, role, status FROM mysql.rds_topology" _BG_STATUS_EXISTS_QUERY = \ @@ -410,7 +420,13 @@ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool: return False - def get_host_list_provider_supplier(self) -> Callable: + def get_host_list_provider_supplier(self, plugin_service: PluginService) -> Callable: + if plugin_service.is_plugin_in_use(FailoverV2Plugin): + return lambda provider_service, props: MonitoringRdsHostListProvider( + provider_service, + props, AuroraTopologyUtils(self, props), + plugin_service) + return lambda provider_service, props: RdsHostListProvider(provider_service, props, AuroraTopologyUtils(self, props)) @property @@ -449,6 +465,10 @@ class AuroraPgDialect(PgDatabaseDialect, TopologyAwareDatabaseDialect, AuroraLim _BG_STATUS_QUERY = (f"SELECT version, endpoint, port, role, status " f"FROM pg_catalog.get_blue_green_fast_switchover_metadata('aws_advanced_python_wrapper-{DriverInfo.DRIVER_VERSION}')") _BG_STATUS_EXISTS_QUERY = "SELECT 'pg_catalog.get_blue_green_fast_switchover_metadata'::regproc" + _WRITER_HOST_QUERY = \ + ("SELECT SERVER_ID FROM pg_catalog.aurora_replica_status() " + "WHERE SESSION_ID OPERATOR(pg_catalog.=) 'MASTER_SESSION_ID' " + "AND SERVER_ID OPERATOR(pg_catalog.=) pg_catalog.aurora_db_instance_identifier()") @property def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]: @@ -483,7 +503,13 @@ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool: return False - def get_host_list_provider_supplier(self) -> Callable: + def get_host_list_provider_supplier(self, plugin_service: PluginService) -> Callable: + if plugin_service.is_plugin_in_use(FailoverV2Plugin): + return lambda provider_service, props: MonitoringRdsHostListProvider( + provider_service, + props, + AuroraTopologyUtils(self, props), plugin_service) + return lambda provider_service, props: RdsHostListProvider(provider_service, props, AuroraTopologyUtils(self, props)) @property @@ -533,7 +559,12 @@ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool: return False - def get_host_list_provider_supplier(self) -> Callable: + def get_host_list_provider_supplier(self, plugin_service: PluginService) -> Callable: + if plugin_service.is_plugin_in_use(FailoverV2Plugin): + return lambda provider_service, props: MonitoringRdsHostListProvider( + provider_service, props, + MultiAzTopologyUtils(self, props, self._WRITER_HOST_QUERY, self._WRITER_HOST_COLUMN_INDEX), plugin_service) + return lambda provider_service, props: RdsHostListProvider( provider_service, props, @@ -588,7 +619,12 @@ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool: return False - def get_host_list_provider_supplier(self) -> Callable: + def get_host_list_provider_supplier(self, plugin_service: PluginService) -> Callable: + if plugin_service.is_plugin_in_use(FailoverV2Plugin): + return lambda provider_service, props: MonitoringRdsHostListProvider( + provider_service, props, + MultiAzTopologyUtils(self, props, self._WRITER_HOST_QUERY), plugin_service) + return lambda provider_service, props: RdsHostListProvider( provider_service, props, @@ -629,7 +665,7 @@ def exception_handler(self) -> Optional[ExceptionHandler]: def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool: return False - def get_host_list_provider_supplier(self) -> Callable: + def get_host_list_provider_supplier(self, plugin_service: PluginService) -> Callable: return lambda provider_service, props: ConnectionStringHostListProvider(provider_service, props) def prepare_conn_props(self, props: Properties): diff --git a/aws_advanced_python_wrapper/exception_handling.py b/aws_advanced_python_wrapper/exception_handling.py index b072b349..69485571 100644 --- a/aws_advanced_python_wrapper/exception_handling.py +++ b/aws_advanced_python_wrapper/exception_handling.py @@ -41,6 +41,15 @@ def is_login_exception(self, error: Optional[Exception] = None, sql_state: Optio """ pass + def is_read_only_connection_exception(self, error: Optional[Exception] = None, sql_state: Optional[str] = None) -> bool: + """ + Checks whether the given error is caused by failing to authenticate the user. + :param error: The error raised by the target driver. + :param sql_state: The SQL State associated with the error. + :return: True if the error is caused by a login issue, False otherwise. + """ + pass + class ExceptionManager: custom_handler: Optional[ExceptionHandler] = None @@ -67,6 +76,13 @@ def is_login_exception(self, dialect: Optional[DatabaseDialect], error: Optional return handler.is_login_exception(error=error, sql_state=sql_state) return False + def is_read_only_connection_exception(self, dialect: Optional[DatabaseDialect], error: Optional[Exception] = None, + sql_state: Optional[str] = None) -> bool: + handler = self._get_handler(dialect) + if handler is not None: + return handler.is_read_only_connection_exception(error=error, sql_state=sql_state) + return False + def _get_handler(self, dialect: Optional[DatabaseDialect]) -> Optional[ExceptionHandler]: if dialect is None: return None diff --git a/aws_advanced_python_wrapper/failover_plugin.py b/aws_advanced_python_wrapper/failover_plugin.py index be47fc5f..b306c6d3 100644 --- a/aws_advanced_python_wrapper/failover_plugin.py +++ b/aws_advanced_python_wrapper/failover_plugin.py @@ -217,9 +217,7 @@ def _update_topology(self, force_update: bool): if not self._is_failover_enabled(): return - conn = self._plugin_service.current_connection - driver_dialect = self._plugin_service.driver_dialect - if conn is None or (driver_dialect is not None and driver_dialect.is_closed(conn)): + if self._plugin_service.current_connection is None: return if force_update: @@ -377,11 +375,10 @@ def _invalidate_current_connection(self): except Exception: pass - if not driver_dialect.is_closed(conn): - try: - return driver_dialect.execute(DbApiMethod.CONNECTION_CLOSE.method_name, lambda: conn.close()) - except Exception: - pass + try: + return driver_dialect.execute(DbApiMethod.CONNECTION_CLOSE.method_name, lambda: conn.close()) + except Exception: + pass return None diff --git a/aws_advanced_python_wrapper/failover_v2_plugin.py b/aws_advanced_python_wrapper/failover_v2_plugin.py new file mode 100644 index 00000000..4f2c3a61 --- /dev/null +++ b/aws_advanced_python_wrapper/failover_v2_plugin.py @@ -0,0 +1,451 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set + +from aws_advanced_python_wrapper.pep249_methods import DbApiMethod +from aws_advanced_python_wrapper.utils.utils import LogUtils + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.host_list_provider import HostListProviderService + from aws_advanced_python_wrapper.pep249 import Connection + from aws_advanced_python_wrapper.plugin_service import PluginService + from aws_advanced_python_wrapper.utils.notifications import HostEvent + from aws_advanced_python_wrapper.driver_dialect import DriverDialect + +from aws_advanced_python_wrapper.errors import ( + AwsWrapperError, FailoverFailedError, FailoverSuccessError, + TransactionResolutionUnknownError) +from aws_advanced_python_wrapper.host_availability import HostAvailability +from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole +from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory +from aws_advanced_python_wrapper.stale_dns_plugin import StaleDnsHelper +from aws_advanced_python_wrapper.utils.failover_mode import (FailoverMode, + get_failover_mode) +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.messages import Messages +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) +from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType +from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils +from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ + TelemetryTraceLevel + +logger = Logger(__name__) + + +class ReaderFailoverResult: + def __init__(self, connection: Connection, host_info: HostInfo): + self._connection = connection + self._host_info = host_info + + @property + def connection(self) -> Connection: + return self._connection + + @property + def host_info(self) -> HostInfo: + return self._host_info + + +class FailoverV2Plugin(Plugin): + """ + Failover Plugin v.2 + This plugin provides cluster-aware failover features. The plugin switches connections upon + detecting communication related exceptions and/or cluster topology changes. + """ + + _SUBSCRIBED_METHODS: Set[str] = {DbApiMethod.INIT_HOST_PROVIDER.method_name, + DbApiMethod.CONNECT.method_name, + DbApiMethod.NOTIFY_HOST_LIST_CHANGED.method_name} + + def __init__(self, plugin_service: PluginService, props: Properties): + self._plugin_service = plugin_service + self._properties = props + + self._failover_timeout_sec = WrapperProperties.FAILOVER_TIMEOUT_SEC.get_int(self._properties) + self._failover_mode: Optional[FailoverMode] = None + self._telemetry_failover_additional_top_trace = ( + WrapperProperties.TELEMETRY_FAILOVER_ADDITIONAL_TOP_TRACE.get_bool(self._properties)) + strategy = WrapperProperties.FAILOVER_READER_HOST_SELECTOR_STRATEGY.get(self._properties) + self._failover_reader_host_selector_strategy: str = strategy if strategy is not None else "" + self._enable_connect_failover = WrapperProperties.ENABLE_CONNECT_FAILOVER.get_bool(self._properties) + + self._closed_explicitly = False + self._is_closed = False + self._rds_helper = RdsUtils() + self._last_exception_dealt_with: Optional[Exception] = None + self._is_in_transaction = False + self._rds_url_type: Optional[RdsUrlType] = None + self._stale_dns_helper = StaleDnsHelper(plugin_service) + self._SUBSCRIBED_METHODS.update(self._plugin_service.network_bound_methods) + + @property + def subscribed_methods(self) -> Set[str]: + return self._SUBSCRIBED_METHODS + + def execute(self, target: type, method_name: str, execute_func: Callable, *args, **kwargs) -> Any: + self._is_in_transaction = self._plugin_service.is_in_transaction + + if self._can_direct_execute(method_name): + if method_name == DbApiMethod.CONNECTION_CLOSE.method_name: + self._closed_explicitly = True + return execute_func() + + if self._is_closed: + self._invalid_invocation_on_closed_connection() + + try: + result = execute_func() + except Exception as e: + logger.debug("FailoverPlugin.DetectedException", str(e)) + self._deal_with_original_exception(e) + + return result + + def init_host_provider( + self, + props: Properties, + host_list_provider_service: HostListProviderService, + init_host_provider_func: Callable): + self._host_list_provider_service: HostListProviderService = host_list_provider_service + init_host_provider_func() + + def connect( + self, + target_driver_func: Callable, + driver_dialect: DriverDialect, + host_info: HostInfo, + props: Properties, + is_initial_connection: bool, + connect_func: Callable) -> Connection: + if self._host_list_provider_service is None: + raise AwsWrapperError("Host list provider service not initialized") + + self._init_failover_mode() + + if not self._enable_connect_failover: + return self._stale_dns_helper.get_verified_connection( + is_initial_connection, self._host_list_provider_service, + host_info, props, connect_func) + + host_with_availability = None + for host in self._plugin_service.hosts: + if host.host == host_info.host and host.port == host_info.port: + host_with_availability = host + break + + conn = None + if (host_with_availability is None or host_with_availability.availability != HostAvailability.UNAVAILABLE): + try: + conn = self._stale_dns_helper.get_verified_connection( + is_initial_connection, self._host_list_provider_service, + host_info, props, connect_func) + except Exception as e: + if not self._should_exception_trigger_connection_switch(e): + raise e + + self._plugin_service.set_availability(host_info.as_aliases(), HostAvailability.UNAVAILABLE) + + try: + self._failover() + except FailoverSuccessError: + conn = self._plugin_service.current_connection + else: + try: + self._plugin_service.refresh_host_list() + self._failover() + except FailoverSuccessError: + conn = self._plugin_service.current_connection + + if conn is None: + raise AwsWrapperError("Unable to connect") + + if is_initial_connection: + self._plugin_service.refresh_host_list(conn) + + return conn + + def notify_host_list_changed(self, changes: Dict[str, Set[HostEvent]]) -> None: + self._stale_dns_helper.notify_host_list_changed(changes) + + def _is_failover_enabled(self) -> bool: + return (self._rds_url_type != RdsUrlType.RDS_PROXY and + len(self._plugin_service.all_hosts) > 0) + + def _invalid_invocation_on_closed_connection(self) -> None: + if not self._closed_explicitly: + self._is_closed = False + self._pick_new_connection() + logger.info(Messages.get("Failover.connectionChangedError")) + raise FailoverSuccessError() + else: + raise AwsWrapperError("No operations allowed after connection closed") + + def _deal_with_original_exception(self, original_exception: Exception) -> None: + if (self._last_exception_dealt_with != original_exception and + (self._should_exception_trigger_connection_switch(original_exception))): + self._invalidate_current_connection() + self._plugin_service.set_availability( + self._plugin_service.current_host_info.as_aliases(), HostAvailability.UNAVAILABLE) + self._pick_new_connection() + self._last_exception_dealt_with = original_exception + + raise AwsWrapperError(Messages.get_formatted("FailoverPlugin.DetectedException", str(original_exception))) \ + from original_exception + + def _failover(self) -> None: + if self._failover_mode == FailoverMode.STRICT_WRITER: + self._failover_writer() + else: + self._failover_reader() + + def _failover_reader(self) -> None: + logger.info("FailoverPlugin.StartReaderFailover") + telemetry_context = self._plugin_service.get_telemetry_factory().open_telemetry_context( + "failover to replica", TelemetryTraceLevel.NESTED) + + failover_start_time = time.time() + try: + if not self._plugin_service.force_monitoring_refresh_host_list(False, 0): + raise FailoverFailedError(Messages.get("FailoverPlugin.UnableToRefreshHostList")) + + try: + result = self._get_reader_failover_connection() + self._plugin_service.set_current_connection(result.connection, result.host_info) + except TimeoutError: + raise FailoverFailedError(Messages.get("FailoverPlugin.UnableToConnectToReader")) + + logger.info("FailoverPlugin.EstablishedConnection", self._plugin_service.current_host_info) + self._throw_failover_success_exception() + + except FailoverSuccessError as ex: + if telemetry_context: + telemetry_context.set_success(True) + telemetry_context.set_exception(ex) + raise ex + except Exception as ex: + if telemetry_context: + telemetry_context.set_success(False) + telemetry_context.set_exception(ex) + raise ex + finally: + elapsed_time = (time.time() - failover_start_time) * 1000 + logger.info("FailoverPlugin.ReaderFailoverTime", elapsed_time) + if telemetry_context: + telemetry_context.close_context() + if self._telemetry_failover_additional_top_trace: + self._plugin_service.get_telemetry_factory().post_copy( + telemetry_context, TelemetryTraceLevel.FORCE_TOP_LEVEL) + + def _get_reader_failover_connection(self) -> ReaderFailoverResult: + failover_end_time = time.time() + self._failover_timeout_sec + + hosts = self._plugin_service.hosts + reader_candidates = [host for host in hosts if host.role == HostRole.READER] + original_writer = next((host for host in hosts if host.role == HostRole.WRITER), None) + is_original_writer_still_writer = False + + while time.time() < failover_end_time: + # Try all original readers + remaining_readers = reader_candidates.copy() + while remaining_readers and time.time() < failover_end_time: + try: + reader_candidate: Optional[HostInfo] = self._plugin_service.get_host_info_by_strategy( + HostRole.READER, self._failover_reader_host_selector_strategy, remaining_readers) + except Exception: + break + + if reader_candidate is None: + break + + try: + candidate_conn = self._plugin_service.connect(reader_candidate, self._properties, self) + role = self._plugin_service.get_host_role(candidate_conn) + if role == HostRole.READER or self._failover_mode != FailoverMode.STRICT_READER: + updated_host_info = HostInfo(reader_candidate.host, reader_candidate.port, role) + return ReaderFailoverResult(candidate_conn, updated_host_info) + + remaining_readers.remove(reader_candidate) + self._plugin_service.driver_dialect.execute( + DbApiMethod.CONNECTION_CLOSE.method_name, lambda: candidate_conn.close()) + + if role == HostRole.WRITER: + reader_candidates.remove(reader_candidate) + except Exception: + remaining_readers.remove(reader_candidate) + + # Try original writer + if (original_writer is None or time.time() > failover_end_time or (self._failover_mode == FailoverMode.STRICT_READER + and is_original_writer_still_writer)): + continue + + try: + candidate_conn = self._plugin_service.connect(original_writer, self._properties, self) + role = self._plugin_service.get_host_role(candidate_conn) + if role == HostRole.READER or self._failover_mode != FailoverMode.STRICT_READER: + updated_host_info = HostInfo(original_writer.host, original_writer.port, role) + return ReaderFailoverResult(candidate_conn, updated_host_info) + + self._plugin_service.driver_dialect.execute( + DbApiMethod.CONNECTION_CLOSE.method_name, lambda: candidate_conn.close()) + if role == HostRole.WRITER: + is_original_writer_still_writer = True + except Exception: + pass + + raise TimeoutError("Failover reader timeout") + + def _throw_failover_success_exception(self) -> None: + if self._is_in_transaction or self._plugin_service.is_in_transaction: + self._plugin_service.update_in_transaction(False) + error_msg = "FailoverPlugin.TransactionResolutionUnknownError" + logger.warning(error_msg) + raise TransactionResolutionUnknownError(Messages.get(error_msg)) + else: + error_msg = "FailoverPlugin.ConnectionChangedError" + logger.error(error_msg) + raise FailoverSuccessError(Messages.get(error_msg)) + + def _failover_writer(self) -> None: + logger.info("FailoverPlugin.StartWriterFailover") + telemetry_context = self._plugin_service.get_telemetry_factory().open_telemetry_context( + "failover to writer host", TelemetryTraceLevel.NESTED) + + failover_start_time = time.time() + try: + if not self._plugin_service.force_monitoring_refresh_host_list(True, self._failover_timeout_sec): + raise FailoverFailedError(Messages.get("FailoverPlugin.UnableToRefreshHostList")) + + updated_hosts = self._plugin_service.all_hosts + writer_candidate = next((host for host in updated_hosts if host.role == HostRole.WRITER), None) + + if writer_candidate is None: + raise FailoverFailedError(Messages.get_formatted( + "FailoverPlugin.NoWriterHostInTopology", + LogUtils.log_topology(updated_hosts))) + + logger.info("FailoverPlugin.FoundWriterCandidate", writer_candidate) + allowed_hosts = self._plugin_service.hosts + if not any(host.host == writer_candidate.host and host.port == writer_candidate.port + for host in allowed_hosts): + raise FailoverFailedError( + Messages.get_formatted( + "FailoverPlugin.NewWriterNotAllowed", + "" if writer_candidate is None else writer_candidate.host, + LogUtils.log_topology(allowed_hosts))) + + try: + writer_candidate_conn = self._plugin_service.connect(writer_candidate, self._properties, self) + except Exception as e: + raise FailoverFailedError(Messages.get_formatted( + "FailoverPlugin.ExceptionConnectingToWriter", e)) + + role = self._plugin_service.get_host_role(writer_candidate_conn) + if role != HostRole.WRITER: + try: + self._plugin_service.driver_dialect.execute( + DbApiMethod.CONNECTION_CLOSE.method_name, lambda: writer_candidate_conn.close()) + except Exception: + pass + raise FailoverFailedError(Messages.get_formatted( + "FailoverPlugin.WriterFailoverConnectedToReader", + writer_candidate.host)) + + self._plugin_service.set_current_connection(writer_candidate_conn, writer_candidate) + logger.info("FailoverPlugin.EstablishedConnection", self._plugin_service.current_host_info) + self._throw_failover_success_exception() + + except FailoverSuccessError as ex: + if telemetry_context: + telemetry_context.set_success(True) + telemetry_context.set_exception(ex) + raise ex + except Exception as ex: + if telemetry_context: + telemetry_context.set_success(False) + telemetry_context.set_exception(ex) + raise ex + finally: + elapsed_time = (time.time() - failover_start_time) * 1000 + logger.info("FailoverPlugin.WriterFailoverTime", elapsed_time) + if telemetry_context: + telemetry_context.close_context() + if self._telemetry_failover_additional_top_trace: + self._plugin_service.get_telemetry_factory().post_copy( + telemetry_context, TelemetryTraceLevel.FORCE_TOP_LEVEL) + + def _invalidate_current_connection(self) -> None: + conn = self._plugin_service.current_connection + if conn is None: + return + + if self._is_in_transaction: + try: + conn.rollback() + except Exception: + pass + + try: + self._plugin_service.driver_dialect.execute( + DbApiMethod.CONNECTION_CLOSE.method_name, lambda: conn.close()) + except Exception: + pass + + def _pick_new_connection(self) -> None: + if self._is_closed and self._closed_explicitly: + logger.debug("FailoverPlugin.NoOperationsAfterConnectionClosed") + return + + self._failover() + + def _should_exception_trigger_connection_switch(self, exception: Exception) -> bool: + if not self._is_failover_enabled(): + logger.debug("FailoverPlugin.FailoverDisabled") + return False + + if self._plugin_service.is_network_exception(exception): + return True + + # For STRICT_WRITER failover mode when connection exception indicate that the connection's in read-only mode, + # initiate a failover by returning true. + return self._failover_mode == FailoverMode.STRICT_WRITER and \ + self._plugin_service.is_read_only_connection_exception(exception) + + def _can_direct_execute(self, method_name: str) -> bool: + return method_name == DbApiMethod.CONNECTION_CLOSE.method_name or \ + method_name == DbApiMethod.CONNECTION_IS_CLOSED.method_name or \ + method_name == DbApiMethod.CURSOR_CLOSE.method_name + + def _init_failover_mode(self) -> None: + if self._rds_url_type is None: + self._failover_mode = get_failover_mode(self._properties) + initial_host_spec = self._host_list_provider_service.initial_connection_host_info + if initial_host_spec is not None: + self._rds_url_type = self._rds_helper.identify_rds_type(initial_host_spec.host) + + if self._failover_mode is None: + self._failover_mode = (FailoverMode.READER_OR_WRITER + if self._rds_url_type is not None and self._rds_url_type == RdsUrlType.RDS_READER_CLUSTER + else FailoverMode.STRICT_WRITER) + + logger.debug("FailoverPlugin.ParameterValue", "FAILOVER_MODE", self._failover_mode) + + +class FailoverV2PluginFactory(PluginFactory): + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: + return FailoverV2Plugin(plugin_service, props) diff --git a/aws_advanced_python_wrapper/host_list_provider.py b/aws_advanced_python_wrapper/host_list_provider.py index 81409443..bd551568 100644 --- a/aws_advanced_python_wrapper/host_list_provider.py +++ b/aws_advanced_python_wrapper/host_list_provider.py @@ -24,11 +24,16 @@ from typing import (TYPE_CHECKING, ClassVar, List, Optional, Protocol, Tuple, runtime_checkable) +from aws_advanced_python_wrapper.cluster_topology_monitor import ( + ClusterTopologyMonitor, ClusterTopologyMonitorImpl) from aws_advanced_python_wrapper.utils.decorators import \ preserve_transaction_status_with_timeout +from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \ + SlidingExpirationCacheWithCleanupThread if TYPE_CHECKING: from aws_advanced_python_wrapper.driver_dialect import DriverDialect + from aws_advanced_python_wrapper.plugin_service import PluginService import aws_advanced_python_wrapper.database_dialect as db_dialect from aws_advanced_python_wrapper.errors import (AwsWrapperError, @@ -60,6 +65,9 @@ def refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, .. def force_refresh(self, connection: Optional[Connection] = None) -> Tuple[HostInfo, ...]: ... + def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]: + ... + def get_host_role(self, connection: Connection) -> HostRole: """ Evaluates the host role of the given connection - either a writer or a reader. @@ -245,7 +253,7 @@ def _get_topology(self, conn: Optional[Connection], force_update: bool = False) try: driver_dialect = self._host_list_provider_service.driver_dialect - hosts = self._topology_utils.query_for_topology(conn, driver_dialect) + hosts = self.query_for_topology(conn, driver_dialect) if hosts is not None and len(hosts) > 0: RdsHostListProvider._topology_cache.put(self._cluster_id, hosts, self._refresh_rate_ns) if self._is_primary_cluster_id and cached_hosts is None: @@ -262,6 +270,9 @@ def _get_topology(self, conn: Optional[Connection], force_update: bool = False) else: return RdsHostListProvider.FetchTopologyResult(self._initial_hosts, False) + def query_for_topology(self, conn, driver_dialect) -> Optional[Tuple[HostInfo, ...]]: + return self._topology_utils.query_for_topology(conn, driver_dialect) + def _suggest_cluster_id(self, primary_cluster_id_hosts: Tuple[HostInfo, ...]): if not primary_cluster_id_hosts: return @@ -317,6 +328,10 @@ def force_refresh(self, connection: Optional[Connection] = None) -> Tuple[HostIn self._hosts = topology.hosts return tuple(self._hosts) + def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]: + raise AwsWrapperError( + Messages.get_formatted("HostListProvider.ForceMonitoringRefreshUnsupported", "RdsHostListProvider")) + def get_host_role(self, connection: Connection) -> HostRole: driver_dialect = self._host_list_provider_service.driver_dialect @@ -406,6 +421,10 @@ def force_refresh(self, connection: Optional[Connection] = None) -> Tuple[HostIn self._initialize() return tuple(self._hosts) + def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]: + raise AwsWrapperError( + Messages.get_formatted("HostListProvider.ForceMonitoringRefreshUnsupported", "ConnectionStringHostListProvider")) + def get_host_role(self, connection: Connection) -> HostRole: raise UnsupportedOperationError( Messages.get_formatted("ConnectionStringHostListProvider.UnsupportedMethod", "get_host_role")) @@ -457,7 +476,7 @@ def __init__(self, dialect: db_dialect.TopologyAwareDatabaseDialect, props: Prop self._validate_host_pattern(instance_template.host) self.instance_template: HostInfo = instance_template - self._max_timeout = WrapperProperties.AUXILIARY_QUERY_TIMEOUT_SEC.get_int(props) + self._max_timeout_sec = WrapperProperties.AUXILIARY_QUERY_TIMEOUT_SEC.get_int(props) def _validate_host_pattern(self, host: str): if not self._rds_utils.is_dns_pattern_valid(host): @@ -489,8 +508,9 @@ def query_for_topology( an empty tuple will be returned. """ query_for_topology_func_with_timeout = preserve_transaction_status_with_timeout( - ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout, driver_dialect, conn)(self._query_for_topology) - return query_for_topology_func_with_timeout(conn) + ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout_sec, driver_dialect, conn)(self._query_for_topology) + x = query_for_topology_func_with_timeout(conn) + return x @abstractmethod def _query_for_topology(self, conn: Connection) -> Optional[Tuple[HostInfo, ...]]: @@ -551,7 +571,7 @@ def create_host( def get_host_role(self, connection: Connection, driver_dialect: DriverDialect) -> HostRole: try: cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout( - ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout, driver_dialect, connection)(self._get_host_role) + ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout_sec, driver_dialect, connection)(self._get_host_role) result = cursor_execute_func_with_timeout(connection) if result is not None: is_reader = result[0] @@ -574,7 +594,7 @@ def get_host_id(self, connection: Connection, driver_dialect: DriverDialect) -> """ cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout( - ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout, driver_dialect, connection)(self._get_host_id) + ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout_sec, driver_dialect, connection)(self._get_host_id) result = cursor_execute_func_with_timeout(connection) if result: host_id: str = result[0] @@ -586,6 +606,23 @@ def _get_host_id(self, conn: Connection): cursor.execute(self._dialect.host_id_query) return cursor.fetchone() + def get_writer_host_if_connected(self, connection: Connection, driver_dialect: DriverDialect) -> Optional[str]: + try: + cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout( + ThreadPoolContainer.get_thread_pool(self._executor_name), self._max_timeout_sec, driver_dialect, connection)(self._get_writer_id) + result = cursor_execute_func_with_timeout(connection) + if result: + host_id: str = result[0] + return host_id + return None + except Exception: + return None + + def _get_writer_id(self, conn: Connection): + with closing(conn.cursor()) as cursor: + cursor.execute(self._dialect.writer_id_query) + return cursor.fetchone() + class AuroraTopologyUtils(TopologyUtils): @@ -713,3 +750,61 @@ def _create_multi_az_host(self, record: Tuple, writer_id: str) -> HostInfo: host=host, port=port, role=role, availability=HostAvailability.AVAILABLE, weight=0, host_id=id) host_info.add_alias(host) return host_info + + +class MonitoringRdsHostListProvider(RdsHostListProvider): + _CACHE_CLEANUP_NANO = 1 * 60 * 1_000_000_000 # 1 minute + _MONITOR_CLEANUP_NANO = 15 * 60 * 1_000_000_000 # 15 minutes + + _monitors: ClassVar[SlidingExpirationCacheWithCleanupThread[str, ClusterTopologyMonitor]] = \ + SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_NANO, + should_dispose_func=lambda monitor: monitor.can_dispose(), + item_disposal_func=lambda monitor: monitor.close()) + + def __init__( + self, + host_list_provider_service: HostListProviderService, + props: Properties, + topology_utils: TopologyUtils, + plugin_service: PluginService + ): + super().__init__(host_list_provider_service, props, topology_utils) + self._plugin_service: PluginService = plugin_service + self._high_refresh_rate_ns = ( + WrapperProperties.CLUSTER_TOPOLOGY_HIGH_REFRESH_RATE_MS.get_int(self._props) * 1_000_000) + + def _get_monitor(self) -> Optional[ClusterTopologyMonitor]: + return self._monitors.compute_if_absent_with_disposal(self.get_cluster_id(), + lambda k: ClusterTopologyMonitorImpl( + self._plugin_service, + self._topology_utils, + self._cluster_id, + self._topology_utils.initial_host_info, + self._props, + self._topology_utils.instance_template, + self._refresh_rate_ns, + self._high_refresh_rate_ns + ), MonitoringRdsHostListProvider._MONITOR_CLEANUP_NANO) + + def query_for_topology(self, connection: Connection, driver_dialect) -> Optional[Tuple[HostInfo, ...]]: + monitor = self._get_monitor() + + if monitor is None: + return None + + try: + return monitor.force_refresh_with_connection(connection, self._topology_utils._max_timeout_sec) + except TimeoutError: + return None + + def force_monitoring_refresh(self, should_verify_writer: bool, timeout_sec: int) -> Tuple[HostInfo, ...]: + monitor = self._get_monitor() + + if monitor is None: + return () + + return monitor.force_refresh(should_verify_writer, timeout_sec) + + @staticmethod + def release_resources(): + MonitoringRdsHostListProvider._monitors.clear() diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index 789c364a..2f95d8a2 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -23,6 +23,8 @@ BlueGreenPluginFactory from aws_advanced_python_wrapper.custom_endpoint_plugin import \ CustomEndpointPluginFactory +from aws_advanced_python_wrapper.failover_v2_plugin import \ + FailoverV2PluginFactory from aws_advanced_python_wrapper.fastest_response_strategy_plugin import \ FastestResponseStrategyPluginFactory from aws_advanced_python_wrapper.federated_plugin import \ @@ -257,6 +259,9 @@ def refresh_host_list(self, connection: Optional[Connection] = None): def force_refresh_host_list(self, connection: Optional[Connection] = None): ... + def force_monitoring_refresh_host_list(self, should_verify_writer: bool, timeout_ms: int) -> bool: + ... + def connect(self, host_info: HostInfo, props: Properties, plugin_to_skip: Optional[Plugin] = None) -> Connection: """ Establishes a connection to the given host using the given driver protocol and properties. If a @@ -539,7 +544,7 @@ def update_dialect(self, connection: Optional[Connection] = None): self.driver_dialect) if original_dialect != self._database_dialect: - host_list_provider_init = self._database_dialect.get_host_list_provider_supplier() + host_list_provider_init = self._database_dialect.get_host_list_provider_supplier(self) self.host_list_provider = host_list_provider_init(self, self._props) self.refresh_host_list(connection) @@ -576,6 +581,18 @@ def force_refresh_host_list(self, connection: Optional[Connection] = None): self._update_host_availability(updated_host_list) self._update_hosts(updated_host_list) + def force_monitoring_refresh_host_list(self, should_verify_writer: bool, timeout_sec: int) -> bool: + try: + updated_host_list = self.host_list_provider.force_monitoring_refresh(should_verify_writer, timeout_sec) + if updated_host_list is not None: + self._update_host_availability(updated_host_list) + self._update_hosts(updated_host_list) + return True + except TimeoutError: + logger.debug(f"Force refresh timeout after {timeout_sec} sec") + + return False + def connect(self, host_info: HostInfo, props: Properties, plugin_to_skip: Optional[Plugin] = None) -> Connection: plugin_manager: PluginManager = self._container.plugin_manager return plugin_manager.connect( @@ -646,6 +663,10 @@ def is_login_exception(self, error: Optional[Exception] = None, sql_state: Optio return self._exception_manager.is_login_exception( dialect=self.database_dialect, error=error, sql_state=sql_state) + def is_read_only_connection_exception(self, error: Optional[Exception] = None, sql_state: Optional[str] = None) -> bool: + return self._exception_manager.is_read_only_connection_exception( + dialect=self.database_dialect, error=error, sql_state=sql_state) + def get_connection_provider_manager(self) -> ConnectionProviderManager: return self._container.plugin_manager.connection_provider_manager @@ -762,6 +783,7 @@ class PluginManager(CanReleaseResources): "host_monitoring": HostMonitoringPluginFactory, "host_monitoring_v2": HostMonitoringV2PluginFactory, "failover": FailoverPluginFactory, + "failover_v2": FailoverV2PluginFactory, "read_write_splitting": ReadWriteSplittingPluginFactory, "srw": SimpleReadWriteSplittingPluginFactory, "fastest_response_strategy": FastestResponseStrategyPluginFactory, @@ -790,6 +812,7 @@ class PluginManager(CanReleaseResources): ReadWriteSplittingPluginFactory: 300, SimpleReadWriteSplittingPluginFactory: 310, FailoverPluginFactory: 400, + FailoverV2PluginFactory: 410, HostMonitoringPluginFactory: 500, HostMonitoringV2PluginFactory: 510, BlueGreenPluginFactory: 550, diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index 1c1c0fbd..ecb5fcab 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -75,6 +75,21 @@ BlueGreenStatusProvider.UnsupportedDialect=[BlueGreenStatusProvider] [bgdId: '{} CloseConnectionExecuteRouting.InProgressConnectionClosed=[CloseConnectionExecuteRouting] Connection has been closed since Blue/Green switchover is in progress. +ClusterTopologyMonitor.StartMonitoringThread=[ClusterTopologyMonitor, clusterId: '{}'] Starting cluster topology monitoring thread for '{}'. +ClusterTopologyMonitor.StopMonitoringThread=[ClusterTopologyMonitor, clusterId: '{}'] Stop cluster topology monitoring thread for '{}'. +ClusterTopologyMonitorImpl.IgnoringTopologyRequest=[ClusterTopologyMonitor, clusterId: '{}'] A topology refresh was requested, but the topology was already updated recently. Returning cached hosts: +ClusterTopologyMonitorImpl.TopologyNotUpdated=[ClusterTopologyMonitor, clusterId: '{}'] Topology has not been updated after {} ms. +ClusterTopologyMonitorImpl.TimeoutSetToZero=[ClusterTopologyMonitor, clusterId: '{}'] A topology refresh was requested, but the given timeout for the request was 0ms. Returning cached hosts: +ClusterTopologyMonitorImpl.StartingHostMonitoringThreads=[ClusterTopologyMonitor, clusterId: '{}'] Starting host monitoring threads. +ClusterTopologyMonitorImpl.ExceptionStartingHostMonitor=[ClusterTopologyMonitor, clusterId: '{}'] Exception starting monitor for host '{}': '{}'. +ClusterTopologyMonitorImpl.WriterPickedUpFromHostMonitors=[ClusterTopologyMonitor, clusterId: '{}'] The writer host detected by the host monitors was picked up by the topology monitor: '{}'. +ClusterTopologyMonitorImpl.ExceptionDuringMonitoringStop=[ClusterTopologyMonitor, clusterId: '{}'] Stopping cluster topology monitoring after unhandled exception was thrown in monitoring thread '{}'. +ClusterTopologyMonitorImpl.ClosingMonitor=[ClusterTopologyMonitor, clusterId: '{}'] Closing monitor. +ClusterTopologyMonitorImpl.OpenedMonitoringConnection=[ClusterTopologyMonitor, clusterId: '{}'] Opened monitoring connection to host '{}'. +ClusterTopologyMonitorImpl.WriterMonitoringConnection=[ClusterTopologyMonitor, clusterId: '{}'] The monitoring connection is connected to a writer: '{}'. +ClusterTopologyMonitorImpl.ErrorFetchingTopology=[ClusterTopologyMonitor, clusterId: '{}'] An error occurred while querying for topology: {} +ClusterTopologyMonitorImpl.CannotCreateExecutorWhenStopped=[ClusterTopologyMonitor, clusterId: '{}'] Monitor is stopped, cannot create executor. + conftest.ExceptionWhileObtainingInstanceIDs=[conftest] An exception was thrown while attempting to obtain the cluster's instance IDs: '{}' ConnectTimePlugin.ConnectTime=[ConnectTimePlugin] Connected in {} nanos. @@ -145,6 +160,13 @@ FailoverPlugin.StartWriterFailover=[Failover] Starting writer failover procedure FailoverPlugin.TransactionResolutionUnknownError=[Failover] Transaction resolution unknown. Please re-configure session state if required and try restarting the transaction. FailoverPlugin.UnableToConnectToReader=[Failover] Unable to establish SQL connection to the reader instance. FailoverPlugin.UnableToConnectToWriter=[Failover] Unable to establish SQL connection to the writer instance. +FailoverPlugin.ExceptionConnectingToWriter=[Failover] Unable to establish SQL connection to the writer instance. Exception: {}. +FailoverPlugin.UnableToRefreshHostList=[Failover] Unable to refresh host list. +FailoverPlugin.FoundWriterCandidate=[Failover] Found writer candidate: {} +FailoverPlugin.NoWriterHostInTopology=[Failover] No writer host found in topology: {} +FailoverPlugin.WriterFailoverConnectedToReader=[Failover] Writer failover unexpectedly connected to a reader: {} +FailoverPlugin.WriterFailoverTime=[Failover] Writer failover elapsed: {}ms. +FailoverPlugin.ReaderFailoverTime=[Failover] Reader failover elapsed: {}ms. FastestResponseStrategyPlugin.RandomHostSelected=[FastestResponseStrategyPlugin] Fastest host not calculated. Random host selected instead. FastestResponseStrategyPlugin.UnsupportedHostSelectorStrategy=[FastestResponseStrategyPlugin] Unsupported host selector strategy: '{}'. To use the fastest response strategy plugin, please ensure the property reader_host_selector_strategy is set to fastest_response. @@ -165,6 +187,14 @@ FederatedAuthPluginFactory.UnsupportedIdp=[FederatedAuthPluginFactory] Unsupport HostAvailabilityStrategy.InvalidInitialBackoffTime=[HostAvailabilityStrategy] Invalid value of {} for configuration parameter `host_availability_strategy_initial_backoff_time`. It must be an integer greater than 1. HostAvailabilityStrategy.InvalidMaxRetries=[HostAvailabilityStrategy] Invalid value of {} for configuration parameter `host_availability_strategy_max_retries`. It must be an integer greater than 1. +HostListProvider.ForceMonitoringRefreshUnsupported=[{}] Force monitoring refresh is not supported. + +HostMonitor.DetectedWriter=[HostMonitor] Writer detected by host monitoring thread: {}. +HostMonitor.InvalidWriterQuery=[HostMonitor] The writer topology query is invalid: {} +HostMonitor.Exception=[HostMonitor] Host monitor for host {} is exiting due to an unknown exception: {} +HostMonitor.MonitorCompleted=[HostMonitor] Host monitor for {} completed in {} ms. +HostMonitor.WriterHostChanged=[HostMonitor] Writer host changed from {} to {}. + HostMonitoringPlugin.ActivatedMonitoring=[HostMonitoringPlugin] Executing method '{}', monitoring is activated. HostMonitoringPlugin.ClusterEndpointHostInfo=[HostMonitoringPlugin] The HostInfo to monitor is associated with a cluster endpoint. The plugin will attempt to identify the connected database instance. HostMonitoringPlugin.ErrorIdentifyingConnection=[HostMonitoringPlugin] An error occurred while identifying the connection database instance: '{}'. diff --git a/aws_advanced_python_wrapper/utils/atomic.py b/aws_advanced_python_wrapper/utils/atomic.py index 01e61c4e..ac7c079d 100644 --- a/aws_advanced_python_wrapper/utils/atomic.py +++ b/aws_advanced_python_wrapper/utils/atomic.py @@ -90,3 +90,16 @@ def get(self) -> T: def set(self, new_value: T) -> None: with self._lock: self._value = new_value + + def get_and_set(self, new_value: T) -> T: + with self._lock: + value = self._value + self._value = new_value + return value + + def compare_and_set(self, old_value: T, new_value: T) -> bool: + with self._lock: + if self._value == old_value: + self._value = new_value + return True + return False diff --git a/aws_advanced_python_wrapper/utils/mysql_exception_handler.py b/aws_advanced_python_wrapper/utils/mysql_exception_handler.py index 19d99e25..8712931f 100644 --- a/aws_advanced_python_wrapper/utils/mysql_exception_handler.py +++ b/aws_advanced_python_wrapper/utils/mysql_exception_handler.py @@ -24,6 +24,13 @@ class MySQLExceptionHandler(ExceptionHandler): _PAM_AUTHENTICATION_FAILED_MSG = "PAM authentication failed" _UNAVAILABLE_CONNECTION = "MySQL Connection not available" + _READ_ONLY_ERROR_MESSAGES: List[str] = [ + # ERROR 1290 (HY000): The MySQL server is running with the --read-only option so it cannot execute this statement + "running with the --read-only option so it cannot execute this statement", + # ERROR 1836 (HY000): Running in read-only mode + "Running in read-only mode" + ] + _NETWORK_ERRORS: List[int] = [ 2001, # Can't create UNIX socket 2002, # Can't connect to local MySQL server through socket @@ -71,3 +78,16 @@ def is_login_exception(self, error: Optional[Exception] = None, sql_state: Optio return True return False + + def is_read_only_connection_exception(self, error: Optional[Exception] = None, sql_state: Optional[str] = None) -> bool: + if hasattr(error, "errno"): + errno = getattr(error, "errno") + if errno == "1836": # ERROR 1836 (HY000): Running in read-only mode + return True + + if hasattr(error, "msg"): + error_msg = getattr(error, "msg") + if any(msg in error_msg for msg in self._READ_ONLY_ERROR_MESSAGES): + return True + + return False diff --git a/aws_advanced_python_wrapper/utils/pg_exception_handler.py b/aws_advanced_python_wrapper/utils/pg_exception_handler.py index 741caac3..02a79c54 100644 --- a/aws_advanced_python_wrapper/utils/pg_exception_handler.py +++ b/aws_advanced_python_wrapper/utils/pg_exception_handler.py @@ -14,9 +14,9 @@ from typing import List, Optional -from psycopg.errors import (ConnectionTimeout, +from psycopg.errors import (ConnectionTimeout, InternalError, InvalidAuthorizationSpecification, InvalidPassword, - OperationalError) + OperationalError, ReadOnlySqlTransaction) from aws_advanced_python_wrapper.errors import QueryTimeoutError from aws_advanced_python_wrapper.exception_handling import ExceptionHandler @@ -38,8 +38,11 @@ class PgExceptionHandler(ExceptionHandler): _PASSWORD_AUTHENTICATION_FAILED_MSG, _PAM_AUTHENTICATION_FAILED_MSG ] + # ERROR: cannot execute {} in a read-only transaction + _READ_ONLY_ERROR_MSG: str = "in a read-only transaction" _NETWORK_ERROR_CODES: List[str] _ACCESS_ERROR_CODES: List[str] + _READ_ONLY_ERROR_CODE: str = "25006" # read only sql transaction def is_network_exception(self, error: Optional[Exception] = None, sql_state: Optional[str] = None) -> bool: if isinstance(error, QueryTimeoutError) or isinstance(error, ConnectionTimeout): @@ -87,6 +90,28 @@ def is_login_exception(self, error: Optional[Exception] = None, sql_state: Optio return False + def is_read_only_connection_exception(self, error: Optional[Exception] = None, sql_state: Optional[str] = None) -> bool: + if error: + if isinstance(error, ReadOnlySqlTransaction): + return True + + if sql_state is None and hasattr(error, "sqlstate") and error.sqlstate is not None: + sql_state = error.sqlstate + + if sql_state is not None and sql_state == self._READ_ONLY_ERROR_CODE: + return True + + if isinstance(error, InternalError): + if len(error.args) == 0: + return False + + # Check the error message + error_msg: str = error.args[0] + if self._READ_ONLY_ERROR_MSG in error_msg: + return True + + return False + class SingleAzPgExceptionHandler(PgExceptionHandler): _NETWORK_ERROR_CODES: List[str] = [ diff --git a/aws_advanced_python_wrapper/utils/properties.py b/aws_advanced_python_wrapper/utils/properties.py index 9c48172f..e4dfcdf5 100644 --- a/aws_advanced_python_wrapper/utils/properties.py +++ b/aws_advanced_python_wrapper/utils/properties.py @@ -262,6 +262,22 @@ class WrapperProperties: 30, ) + # Failover2Plugin properties + FAILOVER_READER_HOST_SELECTOR_STRATEGY = WrapperProperty( + "failover_reader_host_selector_strategy", + "The strategy that should be used to select a new reader host while opening a new connection.", + "random") + ENABLE_CONNECT_FAILOVER = WrapperProperty( + "enable_connect_failover", + "Enable/disable cluster-aware failover if the initial connection to the database fails due to a network exception.", + False) + + # ClusterTopologyMonitor properties + CLUSTER_TOPOLOGY_HIGH_REFRESH_RATE_MS = WrapperProperty( + "cluster_topology_high_refresh_rate_ms", + "Cluster topology high refresh rate in milliseconds.", + 100) + # CustomEndpointPlugin CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS = WrapperProperty( "custom_endpoint_info_refresh_rate_ms", @@ -584,6 +600,7 @@ class WrapperProperties: class PropertiesUtils: _MONITORING_PROPERTY_PREFIX = "monitoring-" + _TOPOLOGY_MONITORING_PROPERTY_PREFIX = "topology-monitoring-" @staticmethod def parse_properties(conn_info: str, **kwargs: Any) -> Properties: @@ -727,7 +744,8 @@ def remove_wrapper_props(props: Properties): monitor_prop_keys = [ key for key in props - if key.startswith(PropertiesUtils._MONITORING_PROPERTY_PREFIX) + if key.startswith(PropertiesUtils._MONITORING_PROPERTY_PREFIX) or + key.startswith(PropertiesUtils._TOPOLOGY_MONITORING_PROPERTY_PREFIX) ] for key in monitor_prop_keys: props.pop(key, None) @@ -758,11 +776,21 @@ def mask_properties(props: Properties) -> Properties: return masked_properties @staticmethod - def create_monitoring_properties(props: Properties) -> Properties: + def create_filtered_properties(props: Properties, prefix: str) -> Properties: monitoring_properties = copy.deepcopy(props) for property_key in list(monitoring_properties.keys()): - if property_key.startswith(PropertiesUtils._MONITORING_PROPERTY_PREFIX): + if property_key.startswith(prefix): monitoring_properties[ - property_key[len(PropertiesUtils._MONITORING_PROPERTY_PREFIX):] + property_key[len(prefix):] ] = monitoring_properties.pop(property_key) return monitoring_properties + + @staticmethod + def create_monitoring_properties(props: Properties) -> Properties: + return PropertiesUtils.create_filtered_properties(props, + PropertiesUtils._MONITORING_PROPERTY_PREFIX) + + @staticmethod + def create_topology_monitoring_properties(props: Properties) -> Properties: + return PropertiesUtils.create_filtered_properties(props, + PropertiesUtils._TOPOLOGY_MONITORING_PROPERTY_PREFIX) diff --git a/aws_advanced_python_wrapper/utils/sliding_expiration_cache.py b/aws_advanced_python_wrapper/utils/sliding_expiration_cache.py index cd3900dc..8dd9c219 100644 --- a/aws_advanced_python_wrapper/utils/sliding_expiration_cache.py +++ b/aws_advanced_python_wrapper/utils/sliding_expiration_cache.py @@ -122,6 +122,30 @@ def __init__( self._cleanup_thread = Thread(target=self._cleanup_thread_internal, daemon=True) self._cleanup_thread.start() + def compute_if_absent_with_disposal(self, key: K, mapping_func: Callable, item_expiration_ns: int) -> Optional[V]: + self._remove_if_disposable(key) + cache_item = self._cdict.compute_if_absent( + key, lambda k: CacheItem(mapping_func(k), perf_counter_ns() + item_expiration_ns)) + return None if cache_item is None else cache_item.update_expiration(item_expiration_ns).item + + def _remove_if_disposable(self, key: K): + item = None + + def _remove_if_disposable_internal(_, cache_item): + if self._should_dispose_func is not None and self._should_dispose_func(cache_item.item): + nonlocal item + item = cache_item.item + return None + + return cache_item + + self._cdict.compute_if_present(key, _remove_if_disposable_internal) + + if item is None or self._item_disposal_func is None: + return + + self._item_disposal_func(item) + def _cleanup_thread_internal(self): while True: try: diff --git a/aws_advanced_python_wrapper/wrapper.py b/aws_advanced_python_wrapper/wrapper.py index 111c87e2..5a0c7b62 100644 --- a/aws_advanced_python_wrapper/wrapper.py +++ b/aws_advanced_python_wrapper/wrapper.py @@ -56,7 +56,7 @@ def __init__( self._plugin_service = plugin_service self._plugin_manager = plugin_manager - host_list_provider_init = plugin_service.database_dialect.get_host_list_provider_supplier() + host_list_provider_init = plugin_service.database_dialect.get_host_list_provider_supplier(plugin_service) plugin_service.host_list_provider = host_list_provider_init(host_list_provider_service, plugin_service.props) plugin_manager.init_host_provider(plugin_service.props, host_list_provider_service) diff --git a/tests/integration/container/conftest.py b/tests/integration/container/conftest.py index cc42d7cc..e3c7d2f9 100644 --- a/tests/integration/container/conftest.py +++ b/tests/integration/container/conftest.py @@ -27,7 +27,8 @@ from aws_advanced_python_wrapper.driver_dialect_manager import \ DriverDialectManager from aws_advanced_python_wrapper.exception_handling import ExceptionManager -from aws_advanced_python_wrapper.host_list_provider import RdsHostListProvider +from aws_advanced_python_wrapper.host_list_provider import ( + MonitoringRdsHostListProvider, RdsHostListProvider) from aws_advanced_python_wrapper.host_monitoring_plugin import \ MonitoringThreadContainer from aws_advanced_python_wrapper.plugin_service import PluginServiceImpl @@ -148,6 +149,7 @@ def pytest_runtest_setup(item): CustomEndpointMonitor._custom_endpoint_info_cache.clear() MonitoringThreadContainer.clean_up() ThreadPoolContainer.release_resources(wait=True) + MonitoringRdsHostListProvider._monitors.clear() ConnectionProviderManager.release_resources() ConnectionProviderManager.reset_provider() diff --git a/tests/integration/container/test_aurora_failover.py b/tests/integration/container/test_aurora_failover.py index 2634e456..8fffe062 100644 --- a/tests/integration/container/test_aurora_failover.py +++ b/tests/integration/container/test_aurora_failover.py @@ -18,10 +18,12 @@ from time import sleep from typing import TYPE_CHECKING, List -import pytest +import pytest # type: ignore from aws_advanced_python_wrapper.errors import ( FailoverSuccessError, TransactionResolutionUnknownError) +from aws_advanced_python_wrapper.host_list_provider import \ + MonitoringRdsHostListProvider from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) from .utils.conditions import (disable_on_features, enable_on_deployments, @@ -32,6 +34,7 @@ if TYPE_CHECKING: from .utils.test_instance_info import TestInstanceInfo from .utils.test_driver import TestDriver + from aws_advanced_python_wrapper import release_resources from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.wrapper import AwsWrapperConnection @@ -56,6 +59,9 @@ class TestAuroraFailover: def setup_method(self, request): self.logger.info(f"Starting test: {request.node.name}") yield + # Clean up global resources created by wrapper + release_resources() + MonitoringRdsHostListProvider.release_resources() self.logger.info(f"Ending test: {request.node.name}") release_resources() gc.collect() @@ -68,12 +74,10 @@ def aurora_utility(self): @pytest.fixture(scope='class') def props(self): p: Properties = Properties({ - "plugins": "failover", "socket_timeout": 10, "connect_timeout": 10, "monitoring-connect_timeout": 5, "monitoring-socket_timeout": 5, - "topology_refresh_ms": 10, "autocommit": True }) @@ -96,32 +100,35 @@ def proxied_props(self, props, conn_utils): WrapperProperties.CLUSTER_INSTANCE_HOST_PATTERN.set(props_copy, f"?.{endpoint_suffix}:{conn_utils.proxy_port}") return props_copy + @pytest.mark.parametrize("plugins", ["failover", "failover_v2"]) @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) def test_fail_from_writer_to_new_writer_fail_on_connection_invocation( - self, test_environment: TestEnvironment, test_driver: TestDriver, props, conn_utils, aurora_utility): + self, test_driver: TestDriver, props, conn_utils, aurora_utility, plugins): target_driver_connect = DriverHelper.get_connect_func(test_driver) initial_writer_id = aurora_utility.get_cluster_writer_instance_id() + props["plugins"] = plugins with AwsWrapperConnection.connect( target_driver_connect, **conn_utils.get_connect_params(), **props) as aws_conn: # crash instance1 and nominate a new writer aurora_utility.failover_cluster_and_wait_until_writer_changed() # failure occurs on Connection invocation - with pytest.raises(FailoverSuccessError): - aws_conn.commit() + aurora_utility.assert_first_query_throws(aws_conn, FailoverSuccessError) # assert that we are connected to the new writer after failover happens. current_connection_id = aurora_utility.query_instance_id(aws_conn) assert aurora_utility.is_db_instance_writer(current_connection_id) is True assert current_connection_id != initial_writer_id + @pytest.mark.parametrize("plugins", ["failover", "failover_v2"]) @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) def test_fail_from_writer_to_new_writer_fail_on_connection_bound_object_invocation( - self, test_environment: TestEnvironment, test_driver: TestDriver, props, conn_utils, aurora_utility): + self, test_driver: TestDriver, props, conn_utils, aurora_utility, plugins): target_driver_connect = DriverHelper.get_connect_func(test_driver) initial_writer_id = aurora_utility.get_cluster_writer_instance_id() + props["plugins"] = plugins with AwsWrapperConnection.connect( target_driver_connect, **conn_utils.get_connect_params(), **props) as aws_conn: # crash instance1 and nominate a new writer @@ -135,7 +142,8 @@ def test_fail_from_writer_to_new_writer_fail_on_connection_bound_object_invocati assert aurora_utility.is_db_instance_writer(current_connection_id) is True assert current_connection_id != initial_writer_id - @pytest.mark.parametrize("plugins", ["failover,host_monitoring", "failover,host_monitoring_v2"]) + @pytest.mark.parametrize("plugins", ["failover,host_monitoring", "failover,host_monitoring_v2", + "failover_v2,host_monitoring", "failover_v2,host_monitoring_v2"]) @enable_on_features([TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED, TestEnvironmentFeatures.ABORT_CONNECTION_SUPPORTED]) def test_fail_from_reader_to_writer( @@ -162,11 +170,14 @@ def test_fail_from_reader_to_writer( assert writer_id == current_connection_id assert aurora_utility.is_db_instance_writer(current_connection_id) is True + @pytest.mark.parametrize("plugins", ["failover", "failover_v2"]) @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) - def test_fail_from_writer_with_session_states_autocommit(self, test_driver: TestDriver, props, conn_utils, aurora_utility): + def test_fail_from_writer_with_session_states_autocommit(self, test_driver: TestDriver, props, conn_utils, aurora_utility, + plugins): target_driver_connect = DriverHelper.get_connect_func(test_driver) initial_writer_id = aurora_utility.get_cluster_writer_instance_id() + props["plugins"] = plugins with AwsWrapperConnection.connect(target_driver_connect, **conn_utils.get_connect_params(), **props) as conn: conn.autocommit = False @@ -200,11 +211,14 @@ def test_fail_from_writer_with_session_states_autocommit(self, test_driver: Test # Assert autocommit is still False after failover. assert conn.autocommit is False + @pytest.mark.parametrize("plugins", ["failover", "failover_v2"]) @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) - def test_fail_from_writer_with_session_states_readonly(self, test_driver: TestDriver, props, conn_utils, aurora_utility): + def test_fail_from_writer_with_session_states_readonly(self, test_driver: TestDriver, props, conn_utils, aurora_utility, + plugins): target_driver_connect = DriverHelper.get_connect_func(test_driver) initial_writer_id = aurora_utility.get_cluster_writer_instance_id() + props["plugins"] = plugins with AwsWrapperConnection.connect(target_driver_connect, **conn_utils.get_connect_params(), **props) as conn: assert conn.read_only is False conn.read_only = True @@ -225,12 +239,14 @@ def test_fail_from_writer_with_session_states_readonly(self, test_driver: TestDr # Assert readonly is still True after failover. assert conn.read_only is True + @pytest.mark.parametrize("plugins", ["failover", "failover_v2"]) @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) def test_writer_fail_within_transaction_set_autocommit_false( - self, test_driver: TestDriver, test_environment: TestEnvironment, props, conn_utils, aurora_utility): + self, test_driver: TestDriver, test_environment: TestEnvironment, props, conn_utils, aurora_utility, plugins): target_driver_connect = DriverHelper.get_connect_func(test_driver) initial_writer_id = test_environment.get_writer().get_instance_id() + props["plugins"] = plugins with AwsWrapperConnection.connect( target_driver_connect, **conn_utils.get_connect_params(), **props) as conn, conn.cursor() as cursor_1: cursor_1.execute("DROP TABLE IF EXISTS test3_2") @@ -266,12 +282,15 @@ def test_writer_fail_within_transaction_set_autocommit_false( cursor_3.execute("DROP TABLE IF EXISTS test3_2") conn.commit() + @pytest.mark.parametrize("plugins", ["failover", "failover_v2"]) @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) def test_writer_fail_within_transaction_start_transaction( - self, test_driver: TestDriver, test_environment: TestEnvironment, props, conn_utils, aurora_utility): + self, test_driver: TestDriver, test_environment: TestEnvironment, props, conn_utils, aurora_utility, + plugins): target_driver_connect = DriverHelper.get_connect_func(test_driver) initial_writer_id = test_environment.get_writer().get_instance_id() + props["plugins"] = plugins with AwsWrapperConnection.connect( target_driver_connect, **conn_utils.get_connect_params(), **props) as conn: with conn.cursor() as cursor_1: @@ -309,14 +328,15 @@ def test_writer_fail_within_transaction_start_transaction( cursor_3.execute("DROP TABLE IF EXISTS test3_3") conn.commit() + @pytest.mark.parametrize("plugins", ["aurora_connection_tracker,failover", "aurora_connection_tracker,failover_v2"]) @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) def test_writer_failover_in_idle_connections( - self, test_environment: TestEnvironment, test_driver: TestDriver, props, conn_utils, aurora_utility): + self, test_driver: TestDriver, props, conn_utils, aurora_utility, plugins): target_driver_connect = DriverHelper.get_connect_func(test_driver) current_writer_id = aurora_utility.get_cluster_writer_instance_id() idle_connections: List[AwsWrapperConnection] = [] - props["plugins"] = "aurora_connection_tracker,failover" + props["plugins"] = plugins for i in range(self.IDLE_CONNECTIONS_NUM): idle_connections.append( diff --git a/tests/integration/container/test_aws_secrets_manager.py b/tests/integration/container/test_aws_secrets_manager.py index f8cdbb53..1a26bcc8 100644 --- a/tests/integration/container/test_aws_secrets_manager.py +++ b/tests/integration/container/test_aws_secrets_manager.py @@ -199,13 +199,14 @@ def test_incorrect_region(self, test_driver, conn_utils, create_secret, props): ) as conn: conn.cursor() + @pytest.mark.parametrize("plugins", ["failover,aws_secrets_manager", "failover_v2,aws_secrets_manager"]) @enable_on_num_instances(min_instances=2) @disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.BLUE_GREEN_DEPLOYMENT, TestEnvironmentFeatures.PERFORMANCE]) @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED, TestEnvironmentFeatures.IAM]) def test_failover_with_secrets_manager( - self, test_driver, props, conn_utils, create_secret): + self, test_driver, props, conn_utils, create_secret, plugins): region = TestEnvironment.get_current().get_info().get_region() aurora_utility = RdsTestUtility(region) target_driver_connect = DriverHelper.get_connect_func(test_driver) @@ -213,7 +214,7 @@ def test_failover_with_secrets_manager( secret_name, _ = create_secret props.update({ - "plugins": "failover,aws_secrets_manager", + "plugins": plugins, "secrets_manager_secret_id": secret_name, "secrets_manager_region": region, "socket_timeout": 10, diff --git a/tests/integration/container/test_iam_authentication.py b/tests/integration/container/test_iam_authentication.py index 29c5eb0c..c83f70a2 100644 --- a/tests/integration/container/test_iam_authentication.py +++ b/tests/integration/container/test_iam_authentication.py @@ -130,6 +130,7 @@ def test_iam_valid_connection_properties_no_password( self.validate_connection(target_driver_connect, **params, **props) + @pytest.mark.parametrize("plugins", ["failover,iam", "failover_v2,iam"]) @enable_on_num_instances(min_instances=2) @enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER]) @disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, @@ -137,14 +138,14 @@ def test_iam_valid_connection_properties_no_password( TestEnvironmentFeatures.PERFORMANCE]) @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED, TestEnvironmentFeatures.IAM]) def test_failover_with_iam( - self, test_environment: TestEnvironment, test_driver: TestDriver, props, conn_utils): + self, test_driver: TestDriver, props, conn_utils, plugins): target_driver_connect = DriverHelper.get_connect_func(test_driver) region = TestEnvironment.get_current().get_info().get_region() aurora_utility = RdsTestUtility(region) initial_writer_id = aurora_utility.get_cluster_writer_instance_id() props.update({ - "plugins": "failover,iam", + "plugins": plugins, "socket_timeout": 10, "connect_timeout": 10, "monitoring-connect_timeout": 5, diff --git a/tests/unit/test_cluster_topology_monitor.py b/tests/unit/test_cluster_topology_monitor.py new file mode 100644 index 00000000..244578aa --- /dev/null +++ b/tests/unit/test_cluster_topology_monitor.py @@ -0,0 +1,359 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from unittest.mock import MagicMock, patch + +import pytest + +from aws_advanced_python_wrapper.cluster_topology_monitor import ( + ClusterTopologyMonitorImpl, HostMonitor) +from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) + + +@pytest.fixture +def plugin_service_mock(): + mock = MagicMock() + mock.force_connect.return_value = MagicMock() + mock.is_network_exception.return_value = False + mock.is_login_exception.return_value = False + mock.driver_dialect = MagicMock() + return mock + + +@pytest.fixture +def topology_utils_mock(): + mock = MagicMock() + mock.query_for_topology.return_value = ( + HostInfo("writer.com", 5432, HostRole.WRITER), + HostInfo("reader1.com", 5432, HostRole.READER), + HostInfo("reader2.com", 5432, HostRole.READER) + ) + mock.get_writer_host_if_connected.return_value = "writer.com" + mock.get_host_role.return_value = HostRole.WRITER + return mock + + +@pytest.fixture +def monitor_properties(): + props = Properties() + WrapperProperties.TOPOLOGY_REFRESH_MS.set(props, "1000") + WrapperProperties.CLUSTER_TOPOLOGY_HIGH_REFRESH_RATE_MS.set(props, "100") + return props + + +@pytest.fixture +def cluster_monitor(plugin_service_mock, topology_utils_mock, monitor_properties): + cluster_id = "test-cluster" + initial_host = HostInfo("writer.com", 5432, HostRole.WRITER) + instance_template = HostInfo("?.com", 5432) + refresh_rate_ns = 1000 * 1_000_000 + high_refresh_rate_ns = 100 * 1_000_000 + + with patch('threading.Thread'): + monitor = ClusterTopologyMonitorImpl( + plugin_service_mock, topology_utils_mock, cluster_id, + initial_host, monitor_properties, instance_template, + refresh_rate_ns, high_refresh_rate_ns + ) + monitor._stop.set() + return monitor + + +class TestClusterTopologyMonitorImpl: + def test_force_refresh_with_cached_hosts_ignoring_requests(self, cluster_monitor): + expected_hosts = (HostInfo("writer.com", 5432, HostRole.WRITER),) + cluster_monitor._get_stored_hosts = MagicMock(return_value=expected_hosts) + cluster_monitor._ignore_new_topology_requests_end_time_nano = time.time_ns() + 10_000_000_000 + + result = cluster_monitor.force_refresh(False, 5) + assert result == expected_hosts + + def test_force_refresh_with_writer_verification(self, cluster_monitor): + cluster_monitor._monitoring_connection.set(MagicMock()) + cluster_monitor._is_verified_writer_connection = True + cluster_monitor._wait_till_topology_gets_updated = MagicMock(return_value=()) + + cluster_monitor.force_refresh(True, 5) + + assert cluster_monitor._monitoring_connection.get() is None + assert not cluster_monitor._is_verified_writer_connection + + def test_force_refresh_without_writer_verification(self, cluster_monitor): + expected_hosts = (HostInfo("writer.com", 5432, HostRole.WRITER),) + cluster_monitor._is_verified_writer_connection = True + cluster_monitor._wait_till_topology_gets_updated = MagicMock(return_value=expected_hosts) + + result = cluster_monitor.force_refresh(False, 5) + + assert result == expected_hosts + cluster_monitor._wait_till_topology_gets_updated.assert_called_once_with(5) + assert cluster_monitor._is_verified_writer_connection + + def test_force_refresh_with_connection_verified_writer(self, cluster_monitor): + connection_mock = MagicMock() + cluster_monitor._is_verified_writer_connection = True + expected_hosts = (HostInfo("writer.com", 5432, HostRole.WRITER),) + cluster_monitor._wait_till_topology_gets_updated = MagicMock(return_value=expected_hosts) + + result = cluster_monitor.force_refresh_with_connection(connection_mock, 5) + + assert result == expected_hosts + cluster_monitor._wait_till_topology_gets_updated.assert_called_once_with(5) + + def test_force_refresh_with_connection_not_verified(self, cluster_monitor): + connection_mock = MagicMock() + cluster_monitor._is_verified_writer_connection = False + expected_hosts = (HostInfo("writer.com", 5432, HostRole.WRITER),) + cluster_monitor._fetch_topology_and_update_cache = MagicMock(return_value=expected_hosts) + + result = cluster_monitor.force_refresh_with_connection(connection_mock, 5) + + assert result == expected_hosts + cluster_monitor._fetch_topology_and_update_cache.assert_called_once_with(connection_mock) + + def test_wait_till_topology_gets_updated_timeout_zero(self, cluster_monitor): + expected_hosts = (HostInfo("writer.com", 5432, HostRole.WRITER),) + cluster_monitor._get_stored_hosts = MagicMock(return_value=expected_hosts) + + result = cluster_monitor._wait_till_topology_gets_updated(0) + + assert result == expected_hosts + assert cluster_monitor._request_to_update_topology.is_set() + + def test_wait_till_topology_gets_updated_success(self, cluster_monitor): + initial_hosts = (HostInfo("old-writer.com", 5432, HostRole.WRITER),) + updated_hosts = (HostInfo("new-writer.com", 5432, HostRole.WRITER),) + + cluster_monitor._get_stored_hosts = MagicMock(side_effect=[initial_hosts, updated_hosts]) + + result = cluster_monitor._wait_till_topology_gets_updated(1) + + assert result == updated_hosts + + def test_wait_till_topology_gets_updated_timeout(self, cluster_monitor): + hosts = (HostInfo("writer.com", 5432, HostRole.WRITER),) + cluster_monitor._get_stored_hosts = MagicMock(return_value=hosts) + + with pytest.raises(TimeoutError, match="Topology has not been updated"): + cluster_monitor._wait_till_topology_gets_updated(0.001) + + def test_close(self, cluster_monitor): + cluster_monitor._monitoring_connection.set(MagicMock()) + cluster_monitor._host_threads_writer_connection.set(MagicMock()) + cluster_monitor._host_threads_reader_connection.set(MagicMock()) + cluster_monitor._close_host_monitors = MagicMock() + cluster_monitor._monitor_thread = MagicMock() + cluster_monitor._monitor_thread.is_alive.return_value = False + + cluster_monitor.close() + + assert cluster_monitor._stop.is_set() + assert cluster_monitor._request_to_update_topology.is_set() + cluster_monitor._close_host_monitors.assert_called_once() + assert cluster_monitor._monitoring_connection.get() is None + assert cluster_monitor._host_threads_writer_connection.get() is None + assert cluster_monitor._host_threads_reader_connection.get() is None + + def test_is_in_panic_mode(self, cluster_monitor): + cluster_monitor._monitoring_connection.set(None) + assert cluster_monitor._is_in_panic_mode() + + cluster_monitor._monitoring_connection.set(MagicMock()) + cluster_monitor._is_verified_writer_connection = False + assert cluster_monitor._is_in_panic_mode() + + cluster_monitor._monitoring_connection.set(MagicMock()) + cluster_monitor._is_verified_writer_connection = True + assert not cluster_monitor._is_in_panic_mode() + + def test_open_any_connection_and_update_topology_success(self, cluster_monitor): + expected_hosts = (HostInfo("writer.com", 5432, HostRole.WRITER),) + cluster_monitor._plugin_service.force_connect.return_value = MagicMock() + cluster_monitor._fetch_topology_and_update_cache = MagicMock(return_value=expected_hosts) + + result = cluster_monitor._open_any_connection_and_update_topology() + + assert result == expected_hosts + assert cluster_monitor._monitoring_connection.get() is not None + + def test_open_any_connection_and_update_topology_connection_failure(self, cluster_monitor): + cluster_monitor._plugin_service.force_connect.side_effect = Exception("Connection failed") + + result = cluster_monitor._open_any_connection_and_update_topology() + + assert result == () + + def test_open_any_connection_and_update_topology_verifies_writer(self, cluster_monitor, topology_utils_mock): + expected_hosts = (HostInfo("writer", 5432, HostRole.WRITER),) + connection_mock = MagicMock() + cluster_monitor._writer_host_info.set(None) + cluster_monitor._plugin_service.force_connect.return_value = connection_mock + topology_utils_mock.get_writer_host_if_connected.return_value = "writer" + cluster_monitor._fetch_topology_and_update_cache = MagicMock(return_value=expected_hosts) + + result = cluster_monitor._open_any_connection_and_update_topology() + + assert result == expected_hosts + assert cluster_monitor._is_verified_writer_connection + assert cluster_monitor._writer_host_info.get() is not None + + +class TestHostMonitor: + @pytest.fixture + def monitor_impl_mock(self, plugin_service_mock, topology_utils_mock): + mock = MagicMock() + mock._plugin_service = plugin_service_mock + mock._topology_utils = topology_utils_mock + mock._monitoring_properties = Properties() + mock._host_threads_stop = MagicMock() + mock._host_threads_stop.is_set.return_value = False + mock._host_threads_writer_connection = MagicMock() + mock._host_threads_writer_connection.get.return_value = None + mock._host_threads_writer_connection.compare_and_set.return_value = True + mock._host_threads_reader_connection = MagicMock() + mock._host_threads_reader_connection.compare_and_set.return_value = True + mock._host_threads_latest_topology = MagicMock() + mock._close_connection = MagicMock() + mock._fetch_topology_and_update_cache = MagicMock() + mock._query_for_topology = MagicMock(return_value=(HostInfo("writer.com", 5432, HostRole.WRITER),)) + return mock + + def test_call_stop_signal_immediate(self, monitor_impl_mock): + host_info = HostInfo("reader.com", 5432, HostRole.READER) + monitor = HostMonitor(monitor_impl_mock, host_info, None) + monitor_impl_mock._host_threads_stop.is_set.return_value = True + + monitor() + + monitor_impl_mock._plugin_service.force_connect.assert_not_called() + + def test_call_connection_success_writer_detected(self, monitor_impl_mock, topology_utils_mock): + host_info = HostInfo("writer.com", 5432, HostRole.WRITER) + monitor = HostMonitor(monitor_impl_mock, host_info, None) + connection_mock = MagicMock() + monitor_impl_mock._plugin_service.force_connect.return_value = connection_mock + topology_utils_mock.get_writer_host_if_connected.return_value = "writer.com" + topology_utils_mock.get_host_role.return_value = HostRole.WRITER + + call_count = [0] + + def stop_after_checks(): + call_count[0] += 1 + # Stop after: initial check, connection attempt, writer check, role check + return call_count[0] > 4 + + monitor_impl_mock._host_threads_stop.is_set.side_effect = stop_after_checks + + with patch('time.sleep'): + monitor() + + monitor_impl_mock._host_threads_writer_connection.compare_and_set.assert_called_once_with(None, connection_mock) + monitor_impl_mock._fetch_topology_and_update_cache.assert_called_once_with(connection_mock) + monitor_impl_mock._host_threads_writer_host_info.set.assert_called_once_with(host_info) + + def test_call_connection_success_reader_detected(self, monitor_impl_mock, topology_utils_mock): + host_info = HostInfo("reader.com", 5432, HostRole.READER) + monitor = HostMonitor(monitor_impl_mock, host_info, None) + connection_mock = MagicMock() + monitor_impl_mock._plugin_service.force_connect.return_value = connection_mock + topology_utils_mock.get_writer_host_if_connected.return_value = None + + call_count = [0] + + def stop_after_iterations(): + call_count[0] += 1 + # Stop after: initial check, connection, writer check, reader logic, sleep + return call_count[0] > 5 + + monitor_impl_mock._host_threads_stop.is_set.side_effect = stop_after_iterations + + with patch('time.sleep'): + monitor() + + monitor_impl_mock._host_threads_reader_connection.compare_and_set.assert_called() + + def test_call_network_exception_retry(self, monitor_impl_mock): + host_info = HostInfo("reader.com", 5432, HostRole.READER) + monitor = HostMonitor(monitor_impl_mock, host_info, None) + monitor_impl_mock._plugin_service.force_connect.side_effect = Exception("Network error") + monitor_impl_mock._plugin_service.is_network_exception.return_value = True + + call_count = [0] + + def stop_after_retries(): + call_count[0] += 1 + # Allow multiple connection attempts + return call_count[0] > 10 + + monitor_impl_mock._host_threads_stop.is_set.side_effect = stop_after_retries + + with patch('time.sleep'): + monitor() + + assert monitor_impl_mock._plugin_service.force_connect.call_count >= 2 + + def test_call_login_exception_raises(self, monitor_impl_mock): + host_info = HostInfo("reader.com", 5432, HostRole.READER) + monitor = HostMonitor(monitor_impl_mock, host_info, None) + login_error = Exception("Login failed") + monitor_impl_mock._plugin_service.force_connect.side_effect = login_error + monitor_impl_mock._plugin_service.is_network_exception.return_value = False + monitor_impl_mock._plugin_service.is_login_exception.return_value = True + + # The login exception is caught and handled in the finally block, not raised + monitor() + + # Verify force_connect was called + monitor_impl_mock._plugin_service.force_connect.assert_called() + + def test_reader_thread_fetch_topology(self, monitor_impl_mock): + host_info = HostInfo("reader.com", 5432, HostRole.READER) + writer_info = HostInfo("writer.com", 5432, HostRole.WRITER) + monitor = HostMonitor(monitor_impl_mock, host_info, writer_info) + connection_mock = MagicMock() + expected_hosts = (HostInfo("writer.com", 5432, HostRole.WRITER),) + monitor_impl_mock._query_for_topology.return_value = expected_hosts + + monitor._reader_thread_fetch_topology(connection_mock) + + monitor_impl_mock._host_threads_latest_topology.set.assert_called_once_with(expected_hosts) + + def test_reader_thread_fetch_topology_writer_changed(self, monitor_impl_mock): + host_info = HostInfo("reader.com", 5432, HostRole.READER) + old_writer = HostInfo("old-writer.com", 5432, HostRole.WRITER) + monitor = HostMonitor(monitor_impl_mock, host_info, old_writer) + connection_mock = MagicMock() + new_writer = HostInfo("new-writer.com", 5432, HostRole.WRITER) + expected_hosts = (new_writer, HostInfo("reader.com", 5432, HostRole.READER)) + monitor_impl_mock._query_for_topology.return_value = expected_hosts + + monitor._reader_thread_fetch_topology(connection_mock) + + assert monitor._writer_changed + monitor_impl_mock._update_topology_cache.assert_called_once_with(expected_hosts) + + def test_reader_thread_fetch_topology_exception(self, monitor_impl_mock): + host_info = HostInfo("reader.com", 5432, HostRole.READER) + monitor = HostMonitor(monitor_impl_mock, host_info, None) + connection_mock = MagicMock() + monitor_impl_mock._query_for_topology.side_effect = Exception("Query failed") + + monitor._reader_thread_fetch_topology(connection_mock) + + monitor_impl_mock._host_threads_latest_topology.set.assert_not_called() diff --git a/tests/unit/test_connection_string_host_list_provider.py b/tests/unit/test_connection_string_host_list_provider.py index a8c59382..dbe049a5 100644 --- a/tests/unit/test_connection_string_host_list_provider.py +++ b/tests/unit/test_connection_string_host_list_provider.py @@ -62,3 +62,10 @@ def test_refresh(mock_provider_service, props): hosts = provider.refresh() assert 1 == len(hosts) + + +def test_force_monitoring_refresh(mock_provider_service, props): + provider = ConnectionStringHostListProvider(mock_provider_service, props) + + with pytest.raises(AwsWrapperError): + provider.force_monitoring_refresh(False, 10) diff --git a/tests/unit/test_failover_plugin.py b/tests/unit/test_failover_plugin.py index 3c6f4f0e..d3aa5428 100644 --- a/tests/unit/test_failover_plugin.py +++ b/tests/unit/test_failover_plugin.py @@ -147,8 +147,6 @@ def test_notify_host_list_changed_with_failover_disabled(plugin_service_mock, ho def test_notify_host_list_changed_with_valid_connection_not_in_topology(plugin_service_mock, - host_list_provider_service_mock, - init_host_provider_func_mock, host_mock): host_mock.url = "cluster-url/" aliases: FrozenSet[str] = frozenset("instance") @@ -174,8 +172,6 @@ def test_notify_host_list_changed_with_valid_connection_not_in_topology(plugin_s def test_update_topology( plugin_service_mock, - host_list_provider_service_mock, - init_host_provider_func_mock, driver_dialect_mock): properties = Properties() WrapperProperties.ENABLE_FAILOVER.set(properties, "False") @@ -214,7 +210,7 @@ def test_update_topology( force_refresh_mock.assert_not_called() -def test_failover_reader(plugin_service_mock, host_list_provider_service_mock, init_host_provider_func_mock): +def test_failover_reader(plugin_service_mock): type(plugin_service_mock).is_in_transaction = PropertyMock(return_value=False) plugin = FailoverPlugin(plugin_service_mock, Properties()) plugin._failover_mode = FailoverMode.READER_OR_WRITER @@ -226,7 +222,7 @@ def test_failover_reader(plugin_service_mock, host_list_provider_service_mock, i failover_reader_mock.assert_called_once_with(host_mock) -def test_failover_writer(plugin_service_mock, host_list_provider_service_mock, init_host_provider_func_mock): +def test_failover_writer(plugin_service_mock): type(plugin_service_mock).is_in_transaction = PropertyMock(return_value=True) plugin = FailoverPlugin(plugin_service_mock, Properties()) plugin._failover_mode = FailoverMode.STRICT_WRITER @@ -237,8 +233,7 @@ def test_failover_writer(plugin_service_mock, host_list_provider_service_mock, i failover_writer_mock.assert_called_once_with() -def test_failover_reader_with_valid_failed_host(plugin_service_mock, host_list_provider_service_mock, - init_host_provider_func_mock, conn_mock, reader_failover_handler_mock): +def test_failover_reader_with_valid_failed_host(plugin_service_mock, conn_mock, reader_failover_handler_mock): host: HostInfo = HostInfo("host") host.availability = HostAvailability.AVAILABLE host._aliases = ["alias1", "alias2"] @@ -259,8 +254,7 @@ def test_failover_reader_with_valid_failed_host(plugin_service_mock, host_list_p set_current_connection_mock.assert_called_with(conn_mock, host) -def test_failover_reader_with_no_failed_host(plugin_service_mock, host_list_provider_service_mock, - init_host_provider_func_mock, reader_failover_handler_mock): +def test_failover_reader_with_no_failed_host(plugin_service_mock, reader_failover_handler_mock): host: HostInfo = HostInfo("host") host.availability = HostAvailability.AVAILABLE host._aliases = ["alias1", "alias2"] @@ -282,8 +276,7 @@ def test_failover_reader_with_no_failed_host(plugin_service_mock, host_list_prov failover_reader_mock.assert_called_with(hosts, None) -def test_failover_writer_failed_failover_raises_error(plugin_service_mock, host_list_provider_service_mock, - init_host_provider_func_mock): +def test_failover_writer_failed_failover_raises_error(plugin_service_mock): writer_failover_handler_mock: WriterFailoverHandler = MagicMock() host: HostInfo = HostInfo("host") host._aliases = ["alias1", "alias2"] @@ -304,8 +297,7 @@ def test_failover_writer_failed_failover_raises_error(plugin_service_mock, host_ failover_writer_mock.assert_called_with(hosts) -def test_failover_writer_failed_failover_with_no_result(plugin_service_mock, host_list_provider_service_mock, - init_host_provider_func_mock, writer_failover_handler_mock): +def test_failover_writer_failed_failover_with_no_result(plugin_service_mock, writer_failover_handler_mock): host: HostInfo = HostInfo("host") host._aliases = ["alias1", "alias2"] hosts: Tuple[HostInfo, ...] = (host, ) @@ -335,8 +327,7 @@ def test_failover_writer_failed_failover_with_no_result(plugin_service_mock, hos get_topology_mock.assert_not_called() -def test_failover_writer_success(plugin_service_mock, host_list_provider_service_mock, init_host_provider_func_mock, - writer_failover_handler_mock): +def test_failover_writer_success(plugin_service_mock, writer_failover_handler_mock): host: HostInfo = HostInfo("host") host._aliases = ["alias1", "alias2"] hosts: Tuple[HostInfo, ...] = (host, ) @@ -356,8 +347,7 @@ def test_failover_writer_success(plugin_service_mock, host_list_provider_service failover_writer_mock.assert_called_with(hosts) -def test_invalid_current_connection_with_no_connection(plugin_service_mock, host_list_provider_service_mock, - init_host_provider_func_mock): +def test_invalid_current_connection_with_no_connection(plugin_service_mock): type(plugin_service_mock).current_connection = PropertyMock(return_value=None) is_in_transaction_mock = PropertyMock() @@ -369,8 +359,7 @@ def test_invalid_current_connection_with_no_connection(plugin_service_mock, host is_in_transaction_mock.assert_not_called() -def test_invalid_current_connection_in_transaction(plugin_service_mock, host_list_provider_service_mock, - init_host_provider_func_mock, conn_mock): +def test_invalid_current_connection_in_transaction(plugin_service_mock, conn_mock): type(plugin_service_mock).is_in_transaction = PropertyMock(return_value=True) type(plugin_service_mock).current_connection = PropertyMock(return_value=conn_mock) @@ -388,8 +377,7 @@ def test_invalid_current_connection_in_transaction(plugin_service_mock, host_lis pytest.fail("_invalidate_current_connection() raised unexpected error") -def test_invalidate_current_connection_not_in_transaction(plugin_service_mock, host_list_provider_service_mock, - init_host_provider_func_mock, conn_mock): +def test_invalidate_current_connection_not_in_transaction(plugin_service_mock, conn_mock): is_in_transaction_mock = PropertyMock(return_value=False) type(plugin_service_mock).is_in_transaction = is_in_transaction_mock type(plugin_service_mock).current_connection = PropertyMock(return_value=conn_mock) @@ -399,8 +387,7 @@ def test_invalidate_current_connection_not_in_transaction(plugin_service_mock, h is_in_transaction_mock.assert_called() -def test_invalidate_current_connection_with_open_connection(plugin_service_mock, host_list_provider_service_mock, - init_host_provider_func_mock, conn_mock, +def test_invalidate_current_connection_with_open_connection(plugin_service_mock, conn_mock, driver_dialect_mock): is_in_transaction_mock = PropertyMock(return_value=False) @@ -411,26 +398,21 @@ def test_invalidate_current_connection_with_open_connection(plugin_service_mock, plugin = FailoverPlugin(plugin_service_mock, Properties()) with mock.patch.object(driver_dialect_mock, "execute") as close_mock: - with mock.patch.object(driver_dialect_mock, "is_closed") as is_closed_mock: - driver_dialect_mock.is_closed.return_value = False - - plugin._invalidate_current_connection() - close_mock.assert_called_once() + plugin._invalidate_current_connection() + close_mock.assert_called_once() - conn_mock.close.side_effect = Error("test error") + conn_mock.close.side_effect = Error("test error") - try: - plugin._invalidate_current_connection() - except Error: - pytest.fail("_invalidate_current_connection() raised unexpected error") + try: + plugin._invalidate_current_connection() + except Error: + pytest.fail("_invalidate_current_connection() raised unexpected error") - is_closed_mock.assert_called() - assert is_closed_mock.call_count == 2 close_mock.assert_called() assert close_mock.call_count == 2 -def test_execute(plugin_service_mock, host_list_provider_service_mock, init_host_provider_func_mock, mock_sql_method): +def test_execute(plugin_service_mock, mock_sql_method): properties = Properties() WrapperProperties.ENABLE_FAILOVER.set(properties, "False") plugin = FailoverPlugin(plugin_service_mock, properties) diff --git a/tests/unit/test_failover_v2_plugin.py b/tests/unit/test_failover_v2_plugin.py new file mode 100644 index 00000000..e3614010 --- /dev/null +++ b/tests/unit/test_failover_v2_plugin.py @@ -0,0 +1,311 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Dict, Set +from unittest import mock +from unittest.mock import MagicMock + +import psycopg +import pytest + +from aws_advanced_python_wrapper.errors import ( + AwsWrapperError, FailoverFailedError, FailoverSuccessError, + TransactionResolutionUnknownError) +from aws_advanced_python_wrapper.failover_v2_plugin import ( + FailoverV2Plugin, ReaderFailoverResult) +from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole +from aws_advanced_python_wrapper.pep249 import Error +from aws_advanced_python_wrapper.utils.failover_mode import FailoverMode +from aws_advanced_python_wrapper.utils.notifications import HostEvent +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) +from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType + + +@pytest.fixture +def plugin_service_mock(): + mock = MagicMock() + mock.network_bound_methods = {"*"} + mock.current_host_info = HostInfo("writer.com", 5432, HostRole.WRITER) + mock.current_connection = MagicMock(spec=psycopg.Connection) + mock.driver_dialect.network_bound_methods = {"Connection.execute", "Connection.commit"} + mock.driver_dialect.is_closed.return_value = False + mock.is_network_exception.return_value = True + mock.is_in_transaction = False + mock.hosts = [ + HostInfo("writer.com", 5432, HostRole.WRITER), + HostInfo("reader1.com", 5432, HostRole.READER), + HostInfo("reader2.com", 5432, HostRole.READER) + ] + mock.all_hosts = mock.hosts + mock.get_telemetry_factory.return_value.open_telemetry_context.return_value = None + return mock + + +@pytest.fixture +def connection_mock(): + return MagicMock(spec=psycopg.Connection) + + +@pytest.fixture +def mock_sql_method(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def properties(): + props = Properties() + WrapperProperties.FAILOVER_TIMEOUT_SEC.set(props, "60") + WrapperProperties.TELEMETRY_FAILOVER_ADDITIONAL_TOP_TRACE.set(props, "false") + WrapperProperties.ENABLE_CONNECT_FAILOVER.set(props, "true") + return props + + +@pytest.fixture +def failover_v2_plugin(plugin_service_mock, properties): + return FailoverV2Plugin(plugin_service_mock, properties) + + +class TestFailoverV2Plugin: + def test_execute_direct_close_method(self, failover_v2_plugin, connection_mock, mock_sql_method): + failover_v2_plugin.execute(connection_mock, "Connection.close", mock_sql_method) + assert failover_v2_plugin._closed_explicitly + mock_sql_method.assert_called_once() + + def test_execute_with_exception_triggers_failover(self, failover_v2_plugin, connection_mock): + network_exception = Exception("Network error") + failover_v2_plugin._plugin_service.is_network_exception.return_value = True + failover_v2_plugin._deal_with_original_exception = MagicMock(side_effect=AwsWrapperError("test")) + + def failing_func(): + raise network_exception + + with pytest.raises(AwsWrapperError): + failover_v2_plugin.execute(connection_mock, "Connection.execute", failing_func) + + failover_v2_plugin._deal_with_original_exception.assert_called_once_with(network_exception) + + def test_init_host_provider(self, failover_v2_plugin, mock_sql_method, properties): + host_list_provider_service = MagicMock() + + failover_v2_plugin.init_host_provider(properties, host_list_provider_service, mock_sql_method) + + assert failover_v2_plugin._host_list_provider_service == host_list_provider_service + mock_sql_method.assert_called_once() + + def test_notify_host_list_changed(self, failover_v2_plugin): + changes: Dict[str, Set[HostEvent]] = {"host1": {HostEvent.HOST_DELETED}} + failover_v2_plugin._stale_dns_helper.notify_host_list_changed = MagicMock() + + failover_v2_plugin.notify_host_list_changed(changes) + + failover_v2_plugin._stale_dns_helper.notify_host_list_changed.assert_called_once_with(changes) + + def test_is_failover_enabled(self, failover_v2_plugin): + failover_v2_plugin._rds_url_type = RdsUrlType.RDS_WRITER_CLUSTER + failover_v2_plugin._plugin_service.all_hosts = [HostInfo("host1", 5432)] + assert failover_v2_plugin._is_failover_enabled() + + failover_v2_plugin._plugin_service.all_hosts = [] + assert not failover_v2_plugin._is_failover_enabled() + + failover_v2_plugin._rds_url_type = RdsUrlType.RDS_PROXY + assert not failover_v2_plugin._is_failover_enabled() + + def test_invalid_invocation_on_closed_connection_not_explicit(self, failover_v2_plugin): + failover_v2_plugin._closed_explicitly = False + failover_v2_plugin._pick_new_connection = MagicMock(side_effect=FailoverSuccessError()) + + with pytest.raises(FailoverSuccessError): + failover_v2_plugin._invalid_invocation_on_closed_connection() + + assert not failover_v2_plugin._is_closed + + def test_invalid_invocation_on_closed_connection_explicit(self, failover_v2_plugin): + failover_v2_plugin._closed_explicitly = True + + with pytest.raises(AwsWrapperError, match="No operations allowed after connection closed"): + failover_v2_plugin._invalid_invocation_on_closed_connection() + + def test_deal_with_original_exception_network_error(self, failover_v2_plugin): + network_exception = Exception("Network error") + failover_v2_plugin._plugin_service.is_network_exception.return_value = True + failover_v2_plugin._invalidate_current_connection = MagicMock() + failover_v2_plugin._pick_new_connection = MagicMock() + + with pytest.raises(AwsWrapperError): + failover_v2_plugin._deal_with_original_exception(network_exception) + + failover_v2_plugin._invalidate_current_connection.assert_called_once() + failover_v2_plugin._pick_new_connection.assert_called_once() + + def test_deal_with_original_exception_non_network_error(self, failover_v2_plugin): + non_network_exception = ValueError("Not a network error") + failover_v2_plugin._plugin_service.is_network_exception.return_value = False + failover_v2_plugin._invalidate_current_connection = MagicMock() + + with pytest.raises(AwsWrapperError): + failover_v2_plugin._deal_with_original_exception(non_network_exception) + + failover_v2_plugin._invalidate_current_connection.assert_not_called() + + def test_failover_writer_mode(self, failover_v2_plugin): + failover_v2_plugin._failover_mode = FailoverMode.STRICT_WRITER + failover_v2_plugin._failover_writer = MagicMock() + + failover_v2_plugin._failover() + + failover_v2_plugin._failover_writer.assert_called_once() + + def test_failover_reader_mode(self, failover_v2_plugin): + failover_v2_plugin._failover_mode = FailoverMode.READER_OR_WRITER + failover_v2_plugin._failover_reader = MagicMock() + + failover_v2_plugin._failover() + + failover_v2_plugin._failover_reader.assert_called_once() + + def test_failover_reader_success(self, failover_v2_plugin): + failover_v2_plugin._plugin_service.force_monitoring_refresh_host_list.return_value = True + reader_result = ReaderFailoverResult(MagicMock(), HostInfo("reader.com", 5432, HostRole.READER)) + failover_v2_plugin._get_reader_failover_connection = MagicMock(return_value=reader_result) + failover_v2_plugin._throw_failover_success_exception = MagicMock(side_effect=FailoverSuccessError()) + + with pytest.raises(FailoverSuccessError): + failover_v2_plugin._failover_reader() + + failover_v2_plugin._plugin_service.set_current_connection.assert_called_once_with( + reader_result.connection, reader_result.host_info) + + def test_failover_reader_refresh_failed(self, failover_v2_plugin): + failover_v2_plugin._plugin_service.force_monitoring_refresh_host_list.return_value = False + + with pytest.raises(FailoverFailedError): + failover_v2_plugin._failover_reader() + + def test_get_reader_failover_connection_timeout(self, failover_v2_plugin): + failover_v2_plugin._failover_timeout_sec = 0.001 + failover_v2_plugin._plugin_service.hosts = [] + + with pytest.raises(TimeoutError, match="Failover reader timeout"): + failover_v2_plugin._get_reader_failover_connection() + + def test_throw_failover_success_exception_in_transaction(self, failover_v2_plugin): + failover_v2_plugin._is_in_transaction = True + + with pytest.raises(TransactionResolutionUnknownError): + failover_v2_plugin._throw_failover_success_exception() + + failover_v2_plugin._plugin_service.update_in_transaction.assert_called_once_with(False) + + def test_failover_writer_success(self, failover_v2_plugin): + writer_host = HostInfo("new-writer.com", 5432, HostRole.WRITER) + failover_v2_plugin._plugin_service.force_monitoring_refresh_host_list.return_value = True + failover_v2_plugin._plugin_service.all_hosts = [writer_host] + failover_v2_plugin._plugin_service.hosts = [writer_host] + failover_v2_plugin._plugin_service.connect.return_value = MagicMock() + failover_v2_plugin._plugin_service.get_host_role.return_value = HostRole.WRITER + failover_v2_plugin._throw_failover_success_exception = MagicMock(side_effect=FailoverSuccessError()) + + with pytest.raises(FailoverSuccessError): + failover_v2_plugin._failover_writer() + + failover_v2_plugin._plugin_service.set_current_connection.assert_called_once() + + def test_failover_writer_no_writer_found(self, failover_v2_plugin): + reader_host = HostInfo("reader.com", 5432, HostRole.READER) + failover_v2_plugin._plugin_service.force_monitoring_refresh_host_list.return_value = True + failover_v2_plugin._plugin_service.all_hosts = [reader_host] + + with pytest.raises(FailoverFailedError): + failover_v2_plugin._failover_writer() + + def test_failover_writer_refresh_failed(self, failover_v2_plugin): + failover_v2_plugin._plugin_service.force_monitoring_refresh_host_list.return_value = False + + with pytest.raises(FailoverFailedError): + failover_v2_plugin._failover_writer() + + def test_failover_writer_connection_failed(self, failover_v2_plugin): + writer_host = HostInfo("new-writer.com", 5432, HostRole.WRITER) + failover_v2_plugin._plugin_service.force_monitoring_refresh_host_list.return_value = True + failover_v2_plugin._plugin_service.all_hosts = [writer_host] + failover_v2_plugin._plugin_service.hosts = [writer_host] + failover_v2_plugin._plugin_service.connect.side_effect = Exception("Connection failed") + + with pytest.raises(FailoverFailedError): + failover_v2_plugin._failover_writer() + + def test_failover_writer_connected_to_reader(self, failover_v2_plugin): + writer_host = HostInfo("new-writer.com", 5432, HostRole.WRITER) + failover_v2_plugin._plugin_service.force_monitoring_refresh_host_list.return_value = True + failover_v2_plugin._plugin_service.all_hosts = [writer_host] + failover_v2_plugin._plugin_service.hosts = [writer_host] + failover_v2_plugin._plugin_service.connect.return_value = MagicMock() + failover_v2_plugin._plugin_service.get_host_role.return_value = HostRole.READER + + with pytest.raises(FailoverFailedError): + failover_v2_plugin._failover_writer() + + def test_invalidate_current_connection_with_transaction(self, failover_v2_plugin, connection_mock): + failover_v2_plugin._plugin_service.current_connection = connection_mock + failover_v2_plugin._is_in_transaction = True + failover_v2_plugin._plugin_service.driver_dialect.is_closed.return_value = True + + with mock.patch.object(failover_v2_plugin._plugin_service.driver_dialect, "execute") as close_mock: + failover_v2_plugin._invalidate_current_connection() + connection_mock.rollback.assert_called_once() + close_mock.assert_called_once() + + connection_mock.close.side_effect = Error("test error") + + try: + failover_v2_plugin._invalidate_current_connection() + except Error: + pytest.fail("_invalidate_current_connection() raised unexpected error") + + close_mock.assert_called() + assert close_mock.call_count == 2 + + def test_pick_new_connection_proceeds_with_failover(self, failover_v2_plugin): + failover_v2_plugin._is_closed = False + failover_v2_plugin._closed_explicitly = False + failover_v2_plugin._failover = MagicMock() + + failover_v2_plugin._pick_new_connection() + + failover_v2_plugin._failover.assert_called_once() + + def test_should_exception_trigger_connection_switch_network_exception(self, failover_v2_plugin): + failover_v2_plugin._rds_url_type = RdsUrlType.RDS_WRITER_CLUSTER + failover_v2_plugin._plugin_service.all_hosts = [HostInfo("host", 5432)] + failover_v2_plugin._plugin_service.is_network_exception.return_value = True + + assert failover_v2_plugin._should_exception_trigger_connection_switch(Exception("network error")) + + def test_should_exception_trigger_connection_switch_read_only_strict_writer(self, failover_v2_plugin): + failover_v2_plugin._rds_url_type = RdsUrlType.RDS_WRITER_CLUSTER + failover_v2_plugin._plugin_service.all_hosts = [HostInfo("host", 5432)] + failover_v2_plugin._failover_mode = FailoverMode.STRICT_WRITER + failover_v2_plugin._plugin_service.is_network_exception.return_value = False + failover_v2_plugin._plugin_service.is_read_only_connection_exception.return_value = True + + assert failover_v2_plugin._should_exception_trigger_connection_switch(Exception("read only")) + + def test_should_exception_trigger_connection_switch_failover_disabled(self, failover_v2_plugin): + failover_v2_plugin._rds_url_type = RdsUrlType.RDS_PROXY + + assert not failover_v2_plugin._should_exception_trigger_connection_switch(Exception("any error")) diff --git a/tests/unit/test_rds_host_list_provider.py b/tests/unit/test_rds_host_list_provider.py index 12fd69e9..da814bcd 100644 --- a/tests/unit/test_rds_host_list_provider.py +++ b/tests/unit/test_rds_host_list_provider.py @@ -407,3 +407,11 @@ def test_get_topology_returns_last_writer(mocker, mock_provider_service, mock_co result = provider._get_topology(mock_conn, True) assert result.hosts[0].host == "expected_writer_host.xyz.us-east-2.rds.amazonaws.com" spy.assert_called_once() + + +def test_force_monitoring_refresh(mock_provider_service, props): + topology_utils = AuroraTopologyUtils(AuroraPgDialect(), props) + provider = RdsHostListProvider(mock_provider_service, props, topology_utils) + + with pytest.raises(AwsWrapperError): + provider.force_monitoring_refresh(True, 5)