diff --git a/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py b/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py index 3a9254be..9a6dc23e 100644 --- a/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py +++ b/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py @@ -14,13 +14,20 @@ from __future__ import annotations +import threading +import time from threading import Thread -from typing import (TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Optional, - Set, Tuple) +from time import perf_counter_ns +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, FrozenSet, + Optional, Set) + +from aws_advanced_python_wrapper.utils.notifications import HostEvent +from aws_advanced_python_wrapper.utils.utils import Utils if TYPE_CHECKING: from aws_advanced_python_wrapper.driver_dialect import DriverDialect from aws_advanced_python_wrapper.plugin_service import PluginService + from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType @@ -29,7 +36,6 @@ from _weakrefset import WeakSet from aws_advanced_python_wrapper.errors import FailoverError -from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole from aws_advanced_python_wrapper.pep249_methods import DbApiMethod from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory from aws_advanced_python_wrapper.utils.log import Logger @@ -39,8 +45,51 @@ class OpenedConnectionTracker: - _opened_connections: Dict[str, WeakSet] = {} - _rds_utils = RdsUtils() + _opened_connections: ClassVar[Dict[str, WeakSet]] = {} + _rds_utils: ClassVar[RdsUtils] = RdsUtils() + _prune_thread: ClassVar[Optional[Thread]] = None + _shutdown_event: ClassVar[threading.Event] = threading.Event() + _safe_to_check_closed_classes: ClassVar[Set[str]] = {"psycopg"} + _default_sleep_time: ClassVar[int] = 30 + + @classmethod + def _start_prune_thread(cls): + if cls._prune_thread is None or not cls._prune_thread.is_alive(): + cls._prune_thread = Thread(daemon=True, target=cls._prune_connections_loop) + cls._prune_thread.start() + + @classmethod + def _prune_connections_loop(cls): + while not cls._shutdown_event.is_set(): + try: + cls._prune_connections() + time.sleep(cls._default_sleep_time) + except Exception: + pass + + @classmethod + def _prune_connections(cls): + for host, conn_set in list(cls._opened_connections.items()): + # Remove dead references and closed connections + to_remove = [] + for conn in list(conn_set): + if conn is None: + to_remove.append(conn) + else: + try: + # The following classes do not check connection validity via a DB server call + # so it is safe to check whether connection is already closed. + if any(safe_class in conn.__module__ for safe_class in cls._safe_to_check_closed_classes) and conn.is_closed(): + to_remove.append(conn) + except Exception: + pass + + for conn in to_remove: + conn_set.discard(conn) + + # Remove empty connection sets + if not conn_set: + del cls._opened_connections[host] def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection): """ @@ -56,8 +105,8 @@ def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection): self._track_connection(host_info.as_alias(), conn) return - instance_endpoint: Optional[str] = next((alias for alias in aliases if self._rds_utils.is_rds_instance(self._rds_utils.remove_port(alias))), - None) + instance_endpoint: Optional[str] = next( + (alias for alias in aliases if self._rds_utils.is_rds_instance(self._rds_utils.remove_port(alias))), None) if not instance_endpoint: logger.debug("OpenedConnectionTracker.UnableToPopulateOpenedConnectionSet") return @@ -73,7 +122,7 @@ def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host: """ if host_info: - self.invalidate_all_connections(host=frozenset(host_info.as_alias())) + self.invalidate_all_connections(host=frozenset([host_info.as_alias()])) self.invalidate_all_connections(host=host_info.as_aliases()) return @@ -94,21 +143,38 @@ def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host: self._log_connection_set(instance_endpoint, connection_set) self._invalidate_connections(connection_set) - def _track_connection(self, instance_endpoint: str, conn: Connection): - connection_set: Optional[WeakSet] = self._opened_connections.get(instance_endpoint) - if connection_set is None: - connection_set = WeakSet() - connection_set.add(conn) - self._opened_connections[instance_endpoint] = connection_set + def remove_connection_tracking(self, host_info: HostInfo, connection: Connection | None): + if not connection: + return + + if self._rds_utils.is_rds_instance(host_info.host): + host = host_info.as_alias() else: - connection_set.add(conn) + host = next((alias for alias in host_info.as_aliases() + if self._rds_utils.is_rds_instance(self._rds_utils.remove_port(alias))), "") + + if not host: + return + + connection_set = self._opened_connections.get(host) + if connection_set: + connection_set.discard(connection) + def _track_connection(self, instance_endpoint: str, conn: Connection): + connection_set = self._opened_connections.setdefault(instance_endpoint, WeakSet()) + connection_set.add(conn) + self._start_prune_thread() self.log_opened_connections() @staticmethod def _task(connection_set: WeakSet): - while connection_set is not None and len(connection_set) > 0: - conn_reference = connection_set.pop() + while connection_set is not None: + try: + conn_reference = connection_set.pop() + except KeyError: + # connection_set is empty + # use KeyError instead of len() to determine whether connection_set is empty to prevent data race + break if conn_reference is None: continue @@ -125,31 +191,28 @@ def _invalidate_connections(self, connection_set: WeakSet): invalidate_connection_thread.start() def log_opened_connections(self): - msg = "" + msg_parts = [] for key, conn_set in self._opened_connections.items(): - conn = "" - for item in list(conn_set): - conn += f"\n\t\t{item}" - - msg += f"\t[{key} : {conn}]" + conn_parts = [f"\n\t\t{item}" for item in list(conn_set)] + conn = "".join(conn_parts) + msg_parts.append(f"\t[{key} : {conn}]") + msg = "".join(msg_parts) return logger.debug("OpenedConnectionTracker.OpenedConnectionsTracked", msg) def _log_connection_set(self, host: str, conn_set: Optional[WeakSet]): if conn_set is None or len(conn_set) == 0: return - conn = "" - for item in list(conn_set): - conn += f"\n\t\t{item}" - + conn_parts = [f"\n\t\t{item}" for item in list(conn_set)] + conn = "".join(conn_parts) msg = host + f"[{conn}\n]" logger.debug("OpenedConnectionTracker.InvalidatingConnections", msg) class AuroraConnectionTrackerPlugin(Plugin): - _current_writer: Optional[HostInfo] = None - _need_update_current_writer: bool = False + _host_list_refresh_end_time_nano: ClassVar[int] = 0 + _TOPOLOGY_CHANGES_EXPECTED_TIME_NANO: ClassVar[int] = 3 * 60 * 1_000_000_000 # 3 minutes @property def subscribed_methods(self) -> Set[str]: @@ -164,6 +227,8 @@ def __init__(self, self._props = props self._rds_utils = rds_utils self._tracker = tracker + self._current_writer: Optional[HostInfo] = None + self._need_update_current_writer: bool = False self._subscribed_methods: Set[str] = {DbApiMethod.CONNECT.method_name, DbApiMethod.CONNECTION_CLOSE.method_name, DbApiMethod.CONNECT.method_name, @@ -192,26 +257,66 @@ def connect( return conn def execute(self, target: object, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any: + current_host = self._plugin_service.current_host_info if self._current_writer is None or self._need_update_current_writer: - self._current_writer = self._get_writer(self._plugin_service.all_hosts) + self._current_writer = Utils.get_writer(self._plugin_service.all_hosts) self._need_update_current_writer = False try: - return execute_func() + if not method_name == DbApiMethod.CONNECTION_CLOSE.method_name: + local_host_list_refresh_end_time_nano = AuroraConnectionTrackerPlugin._host_list_refresh_end_time_nano + need_refresh_host_lists = False + if local_host_list_refresh_end_time_nano > 0: + if local_host_list_refresh_end_time_nano > perf_counter_ns(): + # The time specified in hostListRefreshThresholdTimeNano isn't yet reached. + # Need to continue to refresh host list. + need_refresh_host_lists = True + else: + # The time specified in hostListRefreshThresholdTimeNano is reached, and we can stop further refreshes + # of host list. If hostListRefreshThresholdTimeNano has changed while this thread processes the code, + # we can't override a new value in hostListRefreshThresholdTimeNano. + if AuroraConnectionTrackerPlugin._host_list_refresh_end_time_nano == local_host_list_refresh_end_time_nano: + AuroraConnectionTrackerPlugin._host_list_refresh_end_time_nano = 0 + if self._need_update_current_writer or need_refresh_host_lists: + # Calling this method may effectively close/abort a current connection + self._check_writer_changed(need_refresh_host_lists) + + result = execute_func() + if method_name == DbApiMethod.CONNECTION_CLOSE.method_name: + self._tracker.remove_connection_tracking(current_host, self._plugin_service.current_connection) + return result except Exception as e: - # Check that e is a FailoverError and that the writer has changed - if isinstance(e, FailoverError) and self._get_writer(self._plugin_service.all_hosts) != self._current_writer: - self._tracker.invalidate_all_connections(host_info=self._current_writer) - self._tracker.log_opened_connections() - self._need_update_current_writer = True - raise e + if isinstance(e, FailoverError): + AuroraConnectionTrackerPlugin._host_list_refresh_end_time_nano = ( + perf_counter_ns() + AuroraConnectionTrackerPlugin._TOPOLOGY_CHANGES_EXPECTED_TIME_NANO) + # Calling this method may effectively close/abort a current connection + self._check_writer_changed(True) + raise + + def _check_writer_changed(self, need_refresh_host_lists: bool): + if need_refresh_host_lists: + self._plugin_service.refresh_host_list() + + host_info_after_failover = Utils.get_writer(self._plugin_service.all_hosts) + if host_info_after_failover is None: + return + + if self._current_writer is None: + self._current_writer = host_info_after_failover + self._need_update_current_writer = False + elif not self._current_writer.get_host_and_port() == host_info_after_failover.get_host_and_port(): + self._tracker.invalidate_all_connections(host_info=self._current_writer) + self._tracker.log_opened_connections() + self._current_writer = host_info_after_failover + self._need_update_current_writer = False - def _get_writer(self, hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]: - for host in hosts: - if host.role == HostRole.WRITER: - return host - return None + def notify_host_list_changed(self, changes: Dict[str, Set[HostEvent]]): + for node, node_changes in changes.items(): + if HostEvent.CONVERTED_TO_READER in node_changes: + self._tracker.invalidate_all_connections(host=frozenset([node])) + if HostEvent.CONVERTED_TO_WRITER in node_changes: + self._need_update_current_writer = True class AuroraConnectionTrackerPluginFactory(PluginFactory): diff --git a/aws_advanced_python_wrapper/failover_plugin.py b/aws_advanced_python_wrapper/failover_plugin.py index be47fc5f..6281c638 100644 --- a/aws_advanced_python_wrapper/failover_plugin.py +++ b/aws_advanced_python_wrapper/failover_plugin.py @@ -325,12 +325,12 @@ def _failover_writer(self): writer_host = self._get_writer(result.topology) allowed_hosts = self._plugin_service.hosts - allowed_hostnames = [host.host for host in allowed_hosts] - if writer_host.host not in allowed_hostnames: + allowed_hostnames = [host.get_host_and_port() for host in allowed_hosts] + if writer_host.get_host_and_port() not in allowed_hostnames: raise FailoverFailedError( Messages.get_formatted( "FailoverPlugin.NewWriterNotAllowed", - "" if writer_host is None else writer_host.host, + "" if writer_host is None else writer_host.get_host_and_port(), LogUtils.log_topology(allowed_hosts))) self._plugin_service.set_current_connection(result.new_connection, writer_host) diff --git a/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py b/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py index da963fcf..b4abd577 100644 --- a/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py +++ b/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py @@ -14,9 +14,9 @@ from __future__ import annotations +import time from copy import copy from dataclasses import dataclass -from datetime import datetime from threading import Event, Lock, Thread from time import sleep from typing import (TYPE_CHECKING, Callable, ClassVar, Dict, List, Optional, @@ -96,7 +96,7 @@ def get_host_info_by_strategy(self, role: HostRole, strategy: str, host_list: Op # Found a fastest host. Let's find it in the latest topology. for host in self._plugin_service.hosts: - if host == fastest_response_host: + if host.get_host_and_port() == fastest_response_host.get_host_and_port(): # found the fastest host in the topology return host # It seems that the fastest cached host isn't in the latest topology. @@ -196,7 +196,7 @@ def close(self): logger.debug("HostResponseTimeMonitor.Stopped", self._host_info.host) def _get_current_time(self): - return datetime.now().microsecond / 1000 # milliseconds + return time.perf_counter() * 1000 # milliseconds def run(self): context: TelemetryContext = self._telemetry_factory.open_telemetry_context( diff --git a/aws_advanced_python_wrapper/host_list_provider.py b/aws_advanced_python_wrapper/host_list_provider.py index 81409443..9bbf3abe 100644 --- a/aws_advanced_python_wrapper/host_list_provider.py +++ b/aws_advanced_python_wrapper/host_list_provider.py @@ -48,7 +48,7 @@ WrapperProperties) from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils -from aws_advanced_python_wrapper.utils.utils import LogUtils +from aws_advanced_python_wrapper.utils.utils import LogUtils, Utils logger = Logger(__name__) @@ -266,7 +266,6 @@ def _suggest_cluster_id(self, primary_cluster_id_hosts: Tuple[HostInfo, ...]): if not primary_cluster_id_hosts: return - primary_cluster_id_urls = {host.url for host in primary_cluster_id_hosts} for cluster_id, hosts in RdsHostListProvider._topology_cache.get_dict().items(): is_primary_cluster = RdsHostListProvider._is_primary_cluster_id_cache.get_with_default( cluster_id, False, self._suggested_cluster_id_refresh_ns) @@ -276,7 +275,7 @@ def _suggest_cluster_id(self, primary_cluster_id_hosts: Tuple[HostInfo, ...]): # The entry is non-primary for host in hosts: - if host.url in primary_cluster_id_urls: + if Utils.contains_host_and_port(primary_cluster_id_hosts, host.get_host_and_port()): # An instance URL in this topology cache entry matches an instance URL in the primary cluster entry. # The associated cluster ID should be updated to match the primary ID so that they can share # topology info. diff --git a/aws_advanced_python_wrapper/host_monitoring_plugin.py b/aws_advanced_python_wrapper/host_monitoring_plugin.py index 6b9271d7..69252d19 100644 --- a/aws_advanced_python_wrapper/host_monitoring_plugin.py +++ b/aws_advanced_python_wrapper/host_monitoring_plugin.py @@ -181,7 +181,7 @@ def _get_monitoring_host_info(self) -> HostInfo: if current_host_info is None: raise AwsWrapperError("HostMonitoringPlugin.HostInfoNone") self._monitoring_host_info = current_host_info - rds_type = self._rds_utils.identify_rds_type(self._monitoring_host_info.url) + rds_type = self._rds_utils.identify_rds_type(self._monitoring_host_info.host) try: if rds_type.is_rds_cluster: diff --git a/aws_advanced_python_wrapper/host_monitoring_v2_plugin.py b/aws_advanced_python_wrapper/host_monitoring_v2_plugin.py index b3dc881c..e20e9a7e 100644 --- a/aws_advanced_python_wrapper/host_monitoring_v2_plugin.py +++ b/aws_advanced_python_wrapper/host_monitoring_v2_plugin.py @@ -131,7 +131,7 @@ def _get_monitoring_host_info(self) -> HostInfo: if current_host_info is None: raise AwsWrapperError(Messages.get("HostMonitoringV2Plugin.HostInfoNone")) self._monitoring_host_info = current_host_info - rds_url_type = self._rds_utils.identify_rds_type(self._monitoring_host_info.url) + rds_url_type = self._rds_utils.identify_rds_type(self._monitoring_host_info.host) try: if not rds_url_type.is_rds_cluster: diff --git a/aws_advanced_python_wrapper/hostinfo.py b/aws_advanced_python_wrapper/hostinfo.py index 7f151350..e0224751 100644 --- a/aws_advanced_python_wrapper/hostinfo.py +++ b/aws_advanced_python_wrapper/hostinfo.py @@ -88,10 +88,7 @@ def __copy__(self): @property def url(self): - if self.is_port_specified(): - return f"{self.host}:{self.port}" - else: - return self.host + return f"{self.as_alias()}/" @property def aliases(self) -> FrozenSet[str]: @@ -101,9 +98,12 @@ def aliases(self) -> FrozenSet[str]: def all_aliases(self) -> FrozenSet[str]: return frozenset(self._all_aliases) - def as_alias(self) -> str: + def get_host_and_port(self): return f"{self.host}:{self.port}" if self.is_port_specified() else self.host + def as_alias(self) -> str: + return self.get_host_and_port() + def add_alias(self, *aliases: str): if not aliases: return diff --git a/aws_advanced_python_wrapper/limitless_plugin.py b/aws_advanced_python_wrapper/limitless_plugin.py index 85cb15da..1b8cea4b 100644 --- a/aws_advanced_python_wrapper/limitless_plugin.py +++ b/aws_advanced_python_wrapper/limitless_plugin.py @@ -37,7 +37,7 @@ SlidingExpirationCacheWithCleanupThread from aws_advanced_python_wrapper.utils.telemetry.telemetry import ( TelemetryContext, TelemetryFactory, TelemetryTraceLevel) -from aws_advanced_python_wrapper.utils.utils import LogUtils +from aws_advanced_python_wrapper.utils.utils import LogUtils, Utils if TYPE_CHECKING: from aws_advanced_python_wrapper.driver_dialect import DriverDialect @@ -343,7 +343,7 @@ def establish_connection(self, context: LimitlessContext) -> None: context.set_connection(context.get_connect_func()()) return - if context.get_host_info() in context.get_limitless_routers(): + if Utils.contains_host_and_port(tuple(context.get_limitless_routers()), context.get_host_info().get_host_and_port()): logger.debug(Messages.get_formatted("LimitlessRouterService.ConnectWithHost", context.get_host_info().host)) if context.get_connection() is None: try: diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index 789c364a..ca4e4af2 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -449,11 +449,11 @@ def current_host_info(self) -> HostInfo: next((host_info for host_info in all_hosts if host_info.role == HostRole.WRITER), None)) if self._current_host_info: allowed_hosts = self.hosts - if not Utils.contains_url(allowed_hosts, self._current_host_info.url): + if not Utils.contains_host_and_port(allowed_hosts, self._current_host_info.get_host_and_port()): raise AwsWrapperError( Messages.get_formatted( "PluginServiceImpl.CurrentHostNotAllowed", - self._current_host_info.url, LogUtils.log_topology(allowed_hosts))) + self._current_host_info.get_host_and_port(), LogUtils.log_topology(allowed_hosts))) else: allowed_hosts = self.hosts if len(allowed_hosts) > 0: diff --git a/aws_advanced_python_wrapper/read_write_splitting_plugin.py b/aws_advanced_python_wrapper/read_write_splitting_plugin.py index 1a23bea6..94e7c9b0 100644 --- a/aws_advanced_python_wrapper/read_write_splitting_plugin.py +++ b/aws_advanced_python_wrapper/read_write_splitting_plugin.py @@ -624,8 +624,8 @@ def get_verified_initial_connection( return current_conn def can_host_be_used(self, host_info: HostInfo) -> bool: - hostnames = [host_info.host for host_info in self._hosts] - return host_info.host in hostnames + hosts = [host_info.get_host_and_port() for host_info in self._hosts] + return host_info.get_host_and_port() in hosts def has_no_readers(self) -> bool: if len(self._hosts) == 1: diff --git a/aws_advanced_python_wrapper/stale_dns_plugin.py b/aws_advanced_python_wrapper/stale_dns_plugin.py index dc5efd83..c2d219a0 100644 --- a/aws_advanced_python_wrapper/stale_dns_plugin.py +++ b/aws_advanced_python_wrapper/stale_dns_plugin.py @@ -33,7 +33,7 @@ from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.notifications import HostEvent from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils -from aws_advanced_python_wrapper.utils.utils import LogUtils +from aws_advanced_python_wrapper.utils.utils import LogUtils, Utils logger = Logger(__name__) @@ -113,12 +113,12 @@ def get_verified_connection(self, is_initial_connection: bool, host_list_provide logger.debug("StaleDnsHelper.StaleDnsDetected", self._writer_host_info) allowed_hosts = self._plugin_service.hosts - allowed_hostnames = [host.host for host in allowed_hosts] - if self._writer_host_info.host not in allowed_hostnames: + + if not Utils.contains_host_and_port(tuple(allowed_hosts), self._writer_host_info.get_host_and_port()): raise AwsWrapperError( Messages.get_formatted( "StaleDnsHelper.CurrentWriterNotAllowed", - "" if self._writer_host_info is None else self._writer_host_info.host, + "" if self._writer_host_info is None else self._writer_host_info.get_host_and_port(), LogUtils.log_topology(allowed_hosts))) writer_conn: Connection = self._plugin_service.connect(self._writer_host_info, props) diff --git a/aws_advanced_python_wrapper/utils/utils.py b/aws_advanced_python_wrapper/utils/utils.py index ec599a54..e6690a41 100644 --- a/aws_advanced_python_wrapper/utils/utils.py +++ b/aws_advanced_python_wrapper/utils/utils.py @@ -20,6 +20,8 @@ from queue import Empty, Queue from typing import TYPE_CHECKING, Optional, Tuple +from aws_advanced_python_wrapper.hostinfo import HostRole + if TYPE_CHECKING: from aws_advanced_python_wrapper.hostinfo import HostInfo @@ -89,9 +91,16 @@ def initialize_class(full_class_name: str, *args): return None @staticmethod - def contains_url(hosts: Tuple[HostInfo, ...], url: str) -> bool: + def contains_host_and_port(hosts: Tuple[HostInfo, ...], host_and_port: str) -> bool: for host in hosts: - if host.url == url: + if host.get_host_and_port() == host_and_port: return True return False + + @staticmethod + def get_writer(hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]: + for host in hosts: + if host.role == HostRole.WRITER: + return host + return None diff --git a/aws_advanced_python_wrapper/writer_failover_handler.py b/aws_advanced_python_wrapper/writer_failover_handler.py index 7888df49..8e295883 100644 --- a/aws_advanced_python_wrapper/writer_failover_handler.py +++ b/aws_advanced_python_wrapper/writer_failover_handler.py @@ -300,7 +300,7 @@ def is_same(self, host_info: Optional[HostInfo], current_host_info: Optional[Hos if host_info is None or current_host_info is None: return False - return host_info.url == current_host_info.url + return host_info.get_host_and_port() == current_host_info.get_host_and_port() def connect_to_writer(self, writer_candidate: Optional[HostInfo]) -> bool: if self.is_same(writer_candidate, self._current_reader_host): diff --git a/tests/integration/container/test_blue_green_deployment.py b/tests/integration/container/test_blue_green_deployment.py index 8a732642..31d568e1 100644 --- a/tests/integration/container/test_blue_green_deployment.py +++ b/tests/integration/container/test_blue_green_deployment.py @@ -11,18 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. from __future__ import annotations diff --git a/tests/integration/container/utils/test_telemetry_info.py b/tests/integration/container/utils/test_telemetry_info.py index 24267dc8..e37941ef 100644 --- a/tests/integration/container/utils/test_telemetry_info.py +++ b/tests/integration/container/utils/test_telemetry_info.py @@ -11,18 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. + import typing from typing import Any, Dict diff --git a/tests/unit/test_aurora_connection_tracker.py b/tests/unit/test_aurora_connection_tracker.py index f9f30544..23551eb4 100644 --- a/tests/unit/test_aurora_connection_tracker.py +++ b/tests/unit/test_aurora_connection_tracker.py @@ -11,32 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. from __future__ import annotations from typing import TYPE_CHECKING import pytest +from _weakrefset import WeakSet from aws_advanced_python_wrapper.errors import FailoverError if TYPE_CHECKING: from aws_advanced_python_wrapper.pep249 import Connection -from aws_advanced_python_wrapper.aurora_connection_tracker_plugin import \ - AuroraConnectionTrackerPlugin +from aws_advanced_python_wrapper.aurora_connection_tracker_plugin import ( + AuroraConnectionTrackerPlugin, OpenedConnectionTracker) from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.utils.properties import Properties @@ -117,3 +106,41 @@ def test_invalidate_opened_connections( mock_tracker.invalidate_current_connection.assert_not_called() mock_tracker.invalidate_all_connections.assert_called_with(host_info=original_host) + + +def test_prune_connections_logs_none_connection(mocker, caplog): + """Test that pruning None connections logs the expected message.""" + + tracker = OpenedConnectionTracker() + mock_conn_set = mocker.MagicMock(spec=WeakSet) + mock_conn_set.__iter__ = mocker.MagicMock(return_value=iter([None])) + mock_conn_set.__len__ = mocker.MagicMock(return_value=1) + mock_conn_set.discard = mocker.MagicMock() + + OpenedConnectionTracker._opened_connections = {"test-host": mock_conn_set} + + with caplog.at_level("DEBUG"): + tracker._prune_connections() + + mock_conn_set.discard.assert_called_with(None) + + +def test_prune_connections_logs_closed_connection(mocker, caplog): + """Test that pruning closed connections logs the connection class name.""" + + mock_conn = mocker.MagicMock() + mock_conn.__module__ = "psycopg" + mock_conn.is_closed.return_value = True + + mock_conn_set = mocker.MagicMock(spec=WeakSet) + mock_conn_set.__iter__ = mocker.MagicMock(return_value=iter([mock_conn])) + mock_conn_set.__len__ = mocker.MagicMock(return_value=1) + mock_conn_set.discard = mocker.MagicMock() + + tracker = OpenedConnectionTracker() + OpenedConnectionTracker._opened_connections = {"test-host": mock_conn_set} + + with caplog.at_level("DEBUG"): + tracker._prune_connections() + + mock_conn_set.discard.assert_called_with(mock_conn) diff --git a/tests/unit/test_dialect.py b/tests/unit/test_dialect.py index fb7f658c..1f02b494 100644 --- a/tests/unit/test_dialect.py +++ b/tests/unit/test_dialect.py @@ -350,7 +350,7 @@ def test_query_for_dialect_no_update_candidates(mock_dialect, mock_conn, mock_dr assert mock_dialect == manager.query_for_dialect("url", HostInfo("host"), mock_conn, mock_driver_dialect) assert DialectCode.PG == manager._known_endpoint_dialects.get("url") - assert DialectCode.PG == manager._known_endpoint_dialects.get("host") + assert DialectCode.PG == manager._known_endpoint_dialects.get("host/") def test_query_for_dialect_pg(mock_conn, mock_cursor, mock_driver_dialect): @@ -364,7 +364,7 @@ def test_query_for_dialect_pg(mock_conn, mock_cursor, mock_driver_dialect): result = manager.query_for_dialect("url", HostInfo("host"), mock_conn, mock_driver_dialect) assert isinstance(result, AuroraPgDialect) assert DialectCode.AURORA_PG == manager._known_endpoint_dialects.get("url") - assert DialectCode.AURORA_PG == manager._known_endpoint_dialects.get("host") + assert DialectCode.AURORA_PG == manager._known_endpoint_dialects.get("host/") def test_query_for_dialect_mysql(mock_conn, mock_cursor, mock_driver_dialect): @@ -378,4 +378,4 @@ def test_query_for_dialect_mysql(mock_conn, mock_cursor, mock_driver_dialect): result = manager.query_for_dialect("url", HostInfo("host"), mock_conn, mock_driver_dialect) assert isinstance(result, AuroraMysqlDialect) assert DialectCode.AURORA_MYSQL == manager._known_endpoint_dialects.get("url") - assert DialectCode.AURORA_MYSQL == manager._known_endpoint_dialects.get("host") + assert DialectCode.AURORA_MYSQL == manager._known_endpoint_dialects.get("host/") diff --git a/tests/unit/test_host_monitoring_plugin.py b/tests/unit/test_host_monitoring_plugin.py index 4d47c121..faa8ffa2 100644 --- a/tests/unit/test_host_monitoring_plugin.py +++ b/tests/unit/test_host_monitoring_plugin.py @@ -173,7 +173,7 @@ def test_connect(mocker, plugin, host_info, props, mock_conn, mock_plugin_servic def test_notify_host_list_changed( mocker, plugin, host_info, props, mock_conn, mock_plugin_service, mock_monitor_service, host_events): mock_host_info = mocker.MagicMock() - mock_host_info.url = "instance-1.xyz.us-east-2.rds.amazonaws.com" + mock_host_info.host = "instance-1.xyz.us-east-2.rds.amazonaws.com" aliases = frozenset({"instance-1.xyz.us-east-2.rds.amazonaws.com", "alias1", "alias2"}) mock_host_info.all_aliases = aliases plugin._is_connection_initialized = True diff --git a/tests/unit/test_multi_az_rds_host_list_provider.py b/tests/unit/test_multi_az_rds_host_list_provider.py index 5e66efc3..d287f29c 100644 --- a/tests/unit/test_multi_az_rds_host_list_provider.py +++ b/tests/unit/test_multi_az_rds_host_list_provider.py @@ -365,7 +365,7 @@ def test_initialize__rds_proxy(mock_provider_service): props = Properties({"host": "my-cluster.proxy-xyz.us-east-2.rds.amazonaws.com", "port": 5432}) provider = create_provider(mock_provider_service, props) provider._initialize() - assert provider._cluster_id == "my-cluster.proxy-xyz.us-east-2.rds.amazonaws.com:5432" + assert provider._cluster_id == "my-cluster.proxy-xyz.us-east-2.rds.amazonaws.com:5432/" def test_query_for_topology__empty_writer_query_results( diff --git a/tests/unit/test_rds_host_list_provider.py b/tests/unit/test_rds_host_list_provider.py index 12fd69e9..7b697e8e 100644 --- a/tests/unit/test_rds_host_list_provider.py +++ b/tests/unit/test_rds_host_list_provider.py @@ -386,7 +386,7 @@ def test_initialize_rds_proxy(mock_provider_service): topology_utils = AuroraTopologyUtils(AuroraPgDialect(), props) provider = RdsHostListProvider(mock_provider_service, props, topology_utils) provider._initialize() - assert provider._cluster_id == "my-cluster.proxy-xyz.us-east-2.rds.amazonaws.com" + assert provider._cluster_id == "my-cluster.proxy-xyz.us-east-2.rds.amazonaws.com/" def test_get_topology_returns_last_writer(mocker, mock_provider_service, mock_conn, mock_cursor): diff --git a/tests/unit/test_secrets_manager_plugin.py b/tests/unit/test_secrets_manager_plugin.py index 87a15987..3ccefd1b 100644 --- a/tests/unit/test_secrets_manager_plugin.py +++ b/tests/unit/test_secrets_manager_plugin.py @@ -11,18 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. from __future__ import annotations