-
Notifications
You must be signed in to change notification settings - Fork 16
fix: update aurora connection tracker and fix writer host comparison #1081
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,13 +14,20 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import threading | ||
| import time | ||
| from threading import Thread | ||
| from typing import (TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Optional, | ||
| Set, Tuple) | ||
| from time import perf_counter_ns | ||
| from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, FrozenSet, | ||
| Optional, Set) | ||
|
|
||
| from aws_advanced_python_wrapper.utils.notifications import HostEvent | ||
| from aws_advanced_python_wrapper.utils.utils import Utils | ||
|
|
||
| if TYPE_CHECKING: | ||
| from aws_advanced_python_wrapper.driver_dialect import DriverDialect | ||
| from aws_advanced_python_wrapper.plugin_service import PluginService | ||
| from aws_advanced_python_wrapper.hostinfo import HostInfo | ||
| from aws_advanced_python_wrapper.pep249 import Connection | ||
|
|
||
| from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType | ||
|
|
@@ -29,7 +36,6 @@ | |
| from _weakrefset import WeakSet | ||
|
|
||
| from aws_advanced_python_wrapper.errors import FailoverError | ||
| from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole | ||
| from aws_advanced_python_wrapper.pep249_methods import DbApiMethod | ||
| from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory | ||
| from aws_advanced_python_wrapper.utils.log import Logger | ||
|
|
@@ -39,8 +45,51 @@ | |
|
|
||
|
|
||
| class OpenedConnectionTracker: | ||
| _opened_connections: Dict[str, WeakSet] = {} | ||
| _rds_utils = RdsUtils() | ||
| _opened_connections: ClassVar[Dict[str, WeakSet]] = {} | ||
| _rds_utils: ClassVar[RdsUtils] = RdsUtils() | ||
| _prune_thread: ClassVar[Optional[Thread]] = None | ||
| _shutdown_event: ClassVar[threading.Event] = threading.Event() | ||
| _safe_to_check_closed_classes: ClassVar[Set[str]] = {"psycopg"} | ||
| _default_sleep_time: ClassVar[int] = 30 | ||
|
|
||
| @classmethod | ||
| def _start_prune_thread(cls): | ||
| if cls._prune_thread is None or not cls._prune_thread.is_alive(): | ||
| cls._prune_thread = Thread(daemon=True, target=cls._prune_connections_loop) | ||
| cls._prune_thread.start() | ||
|
Comment on lines
+56
to
+59
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need a lock for this, it's more rare, but we could have 2 initial connections both calling this and starting 2 threads. |
||
|
|
||
| @classmethod | ||
| def _prune_connections_loop(cls): | ||
| while not cls._shutdown_event.is_set(): | ||
| try: | ||
| cls._prune_connections() | ||
| time.sleep(cls._default_sleep_time) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably do a wait on the event instead of a sleep: It's a daemon thread so it won't cause the program to hang, but just incase, this will help with making it exit gracefully. |
||
| except Exception: | ||
| pass | ||
|
|
||
| @classmethod | ||
| def _prune_connections(cls): | ||
| for host, conn_set in list(cls._opened_connections.items()): | ||
| # Remove dead references and closed connections | ||
| to_remove = [] | ||
| for conn in list(conn_set): | ||
| if conn is None: | ||
| to_remove.append(conn) | ||
| else: | ||
| try: | ||
| # The following classes do not check connection validity via a DB server call | ||
| # so it is safe to check whether connection is already closed. | ||
| if any(safe_class in conn.__module__ for safe_class in cls._safe_to_check_closed_classes) and conn.is_closed(): | ||
| to_remove.append(conn) | ||
| except Exception: | ||
| pass | ||
|
|
||
| for conn in to_remove: | ||
| conn_set.discard(conn) | ||
|
|
||
| # Remove empty connection sets | ||
| if not conn_set: | ||
| del cls._opened_connections[host] | ||
|
|
||
| def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection): | ||
| """ | ||
|
|
@@ -56,8 +105,8 @@ def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection): | |
| self._track_connection(host_info.as_alias(), conn) | ||
| return | ||
|
|
||
| instance_endpoint: Optional[str] = next((alias for alias in aliases if self._rds_utils.is_rds_instance(self._rds_utils.remove_port(alias))), | ||
| None) | ||
| instance_endpoint: Optional[str] = next( | ||
| (alias for alias in aliases if self._rds_utils.is_rds_instance(self._rds_utils.remove_port(alias))), None) | ||
| if not instance_endpoint: | ||
| logger.debug("OpenedConnectionTracker.UnableToPopulateOpenedConnectionSet") | ||
| return | ||
|
|
@@ -73,7 +122,7 @@ def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host: | |
| """ | ||
|
|
||
| if host_info: | ||
| self.invalidate_all_connections(host=frozenset(host_info.as_alias())) | ||
| self.invalidate_all_connections(host=frozenset([host_info.as_alias()])) | ||
| self.invalidate_all_connections(host=host_info.as_aliases()) | ||
| return | ||
|
|
||
|
|
@@ -94,21 +143,38 @@ def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host: | |
| self._log_connection_set(instance_endpoint, connection_set) | ||
| self._invalidate_connections(connection_set) | ||
|
|
||
| def _track_connection(self, instance_endpoint: str, conn: Connection): | ||
| connection_set: Optional[WeakSet] = self._opened_connections.get(instance_endpoint) | ||
| if connection_set is None: | ||
| connection_set = WeakSet() | ||
| connection_set.add(conn) | ||
| self._opened_connections[instance_endpoint] = connection_set | ||
| def remove_connection_tracking(self, host_info: HostInfo, connection: Connection | None): | ||
| if not connection: | ||
| return | ||
|
|
||
| if self._rds_utils.is_rds_instance(host_info.host): | ||
| host = host_info.as_alias() | ||
| else: | ||
| connection_set.add(conn) | ||
| host = next((alias for alias in host_info.as_aliases() | ||
| if self._rds_utils.is_rds_instance(self._rds_utils.remove_port(alias))), "") | ||
|
|
||
| if not host: | ||
| return | ||
|
|
||
| connection_set = self._opened_connections.get(host) | ||
| if connection_set: | ||
| connection_set.discard(connection) | ||
|
|
||
| def _track_connection(self, instance_endpoint: str, conn: Connection): | ||
| connection_set = self._opened_connections.setdefault(instance_endpoint, WeakSet()) | ||
| connection_set.add(conn) | ||
| self._start_prune_thread() | ||
| self.log_opened_connections() | ||
|
|
||
| @staticmethod | ||
| def _task(connection_set: WeakSet): | ||
| while connection_set is not None and len(connection_set) > 0: | ||
| conn_reference = connection_set.pop() | ||
| while connection_set is not None: | ||
| try: | ||
| conn_reference = connection_set.pop() | ||
| except KeyError: | ||
| # connection_set is empty | ||
| # use KeyError instead of len() to determine whether connection_set is empty to prevent data race | ||
| break | ||
|
|
||
| if conn_reference is None: | ||
| continue | ||
|
|
@@ -125,31 +191,28 @@ def _invalidate_connections(self, connection_set: WeakSet): | |
| invalidate_connection_thread.start() | ||
|
|
||
| def log_opened_connections(self): | ||
| msg = "" | ||
| msg_parts = [] | ||
| for key, conn_set in self._opened_connections.items(): | ||
| conn = "" | ||
| for item in list(conn_set): | ||
| conn += f"\n\t\t{item}" | ||
|
|
||
| msg += f"\t[{key} : {conn}]" | ||
| conn_parts = [f"\n\t\t{item}" for item in list(conn_set)] | ||
| conn = "".join(conn_parts) | ||
| msg_parts.append(f"\t[{key} : {conn}]") | ||
|
|
||
| msg = "".join(msg_parts) | ||
| return logger.debug("OpenedConnectionTracker.OpenedConnectionsTracked", msg) | ||
|
|
||
| def _log_connection_set(self, host: str, conn_set: Optional[WeakSet]): | ||
| if conn_set is None or len(conn_set) == 0: | ||
| return | ||
|
|
||
| conn = "" | ||
| for item in list(conn_set): | ||
| conn += f"\n\t\t{item}" | ||
|
|
||
| conn_parts = [f"\n\t\t{item}" for item in list(conn_set)] | ||
| conn = "".join(conn_parts) | ||
| msg = host + f"[{conn}\n]" | ||
| logger.debug("OpenedConnectionTracker.InvalidatingConnections", msg) | ||
|
|
||
|
|
||
| class AuroraConnectionTrackerPlugin(Plugin): | ||
| _current_writer: Optional[HostInfo] = None | ||
| _need_update_current_writer: bool = False | ||
| _host_list_refresh_end_time_nano: ClassVar[int] = 0 | ||
| _TOPOLOGY_CHANGES_EXPECTED_TIME_NANO: ClassVar[int] = 3 * 60 * 1_000_000_000 # 3 minutes | ||
|
|
||
| @property | ||
| def subscribed_methods(self) -> Set[str]: | ||
|
|
@@ -164,6 +227,8 @@ def __init__(self, | |
| self._props = props | ||
| self._rds_utils = rds_utils | ||
| self._tracker = tracker | ||
| self._current_writer: Optional[HostInfo] = None | ||
| self._need_update_current_writer: bool = False | ||
| self._subscribed_methods: Set[str] = {DbApiMethod.CONNECT.method_name, | ||
| DbApiMethod.CONNECTION_CLOSE.method_name, | ||
| DbApiMethod.CONNECT.method_name, | ||
|
|
@@ -192,26 +257,66 @@ def connect( | |
| return conn | ||
|
|
||
| def execute(self, target: object, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any: | ||
| current_host = self._plugin_service.current_host_info | ||
| if self._current_writer is None or self._need_update_current_writer: | ||
| self._current_writer = self._get_writer(self._plugin_service.all_hosts) | ||
| self._current_writer = Utils.get_writer(self._plugin_service.all_hosts) | ||
| self._need_update_current_writer = False | ||
|
|
||
| try: | ||
| return execute_func() | ||
| if not method_name == DbApiMethod.CONNECTION_CLOSE.method_name: | ||
| local_host_list_refresh_end_time_nano = AuroraConnectionTrackerPlugin._host_list_refresh_end_time_nano | ||
| need_refresh_host_lists = False | ||
| if local_host_list_refresh_end_time_nano > 0: | ||
| if local_host_list_refresh_end_time_nano > perf_counter_ns(): | ||
| # The time specified in hostListRefreshThresholdTimeNano isn't yet reached. | ||
| # Need to continue to refresh host list. | ||
| need_refresh_host_lists = True | ||
| else: | ||
| # The time specified in hostListRefreshThresholdTimeNano is reached, and we can stop further refreshes | ||
| # of host list. If hostListRefreshThresholdTimeNano has changed while this thread processes the code, | ||
| # we can't override a new value in hostListRefreshThresholdTimeNano. | ||
| if AuroraConnectionTrackerPlugin._host_list_refresh_end_time_nano == local_host_list_refresh_end_time_nano: | ||
| AuroraConnectionTrackerPlugin._host_list_refresh_end_time_nano = 0 | ||
| if self._need_update_current_writer or need_refresh_host_lists: | ||
| # Calling this method may effectively close/abort a current connection | ||
| self._check_writer_changed(need_refresh_host_lists) | ||
|
|
||
| result = execute_func() | ||
| if method_name == DbApiMethod.CONNECTION_CLOSE.method_name: | ||
| self._tracker.remove_connection_tracking(current_host, self._plugin_service.current_connection) | ||
| return result | ||
|
|
||
| except Exception as e: | ||
| # Check that e is a FailoverError and that the writer has changed | ||
| if isinstance(e, FailoverError) and self._get_writer(self._plugin_service.all_hosts) != self._current_writer: | ||
| self._tracker.invalidate_all_connections(host_info=self._current_writer) | ||
| self._tracker.log_opened_connections() | ||
| self._need_update_current_writer = True | ||
| raise e | ||
| if isinstance(e, FailoverError): | ||
| AuroraConnectionTrackerPlugin._host_list_refresh_end_time_nano = ( | ||
| perf_counter_ns() + AuroraConnectionTrackerPlugin._TOPOLOGY_CHANGES_EXPECTED_TIME_NANO) | ||
| # Calling this method may effectively close/abort a current connection | ||
| self._check_writer_changed(True) | ||
| raise | ||
|
|
||
| def _check_writer_changed(self, need_refresh_host_lists: bool): | ||
| if need_refresh_host_lists: | ||
| self._plugin_service.refresh_host_list() | ||
|
|
||
| host_info_after_failover = Utils.get_writer(self._plugin_service.all_hosts) | ||
| if host_info_after_failover is None: | ||
| return | ||
|
|
||
| if self._current_writer is None: | ||
| self._current_writer = host_info_after_failover | ||
| self._need_update_current_writer = False | ||
| elif not self._current_writer.get_host_and_port() == host_info_after_failover.get_host_and_port(): | ||
| self._tracker.invalidate_all_connections(host_info=self._current_writer) | ||
| self._tracker.log_opened_connections() | ||
| self._current_writer = host_info_after_failover | ||
| self._need_update_current_writer = False | ||
|
|
||
| def _get_writer(self, hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]: | ||
| for host in hosts: | ||
| if host.role == HostRole.WRITER: | ||
| return host | ||
| return None | ||
| def notify_host_list_changed(self, changes: Dict[str, Set[HostEvent]]): | ||
| for node, node_changes in changes.items(): | ||
| if HostEvent.CONVERTED_TO_READER in node_changes: | ||
| self._tracker.invalidate_all_connections(host=frozenset([node])) | ||
| if HostEvent.CONVERTED_TO_WRITER in node_changes: | ||
| self._need_update_current_writer = True | ||
|
|
||
|
|
||
| class AuroraConnectionTrackerPluginFactory(PluginFactory): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is never set so thread is never shut down. Maybe we should have this in the release_resources method in the wrapper.py and change this in a release_resources() method or something like that).