Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 147 additions & 42 deletions aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,20 @@

from __future__ import annotations

import threading
import time
from threading import Thread
from typing import (TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Optional,
Set, Tuple)
from time import perf_counter_ns
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, FrozenSet,
Optional, Set)

from aws_advanced_python_wrapper.utils.notifications import HostEvent
from aws_advanced_python_wrapper.utils.utils import Utils

if TYPE_CHECKING:
from aws_advanced_python_wrapper.driver_dialect import DriverDialect
from aws_advanced_python_wrapper.plugin_service import PluginService
from aws_advanced_python_wrapper.hostinfo import HostInfo
from aws_advanced_python_wrapper.pep249 import Connection

from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType
Expand All @@ -29,7 +36,6 @@
from _weakrefset import WeakSet

from aws_advanced_python_wrapper.errors import FailoverError
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
from aws_advanced_python_wrapper.utils.log import Logger
Expand All @@ -39,8 +45,51 @@


class OpenedConnectionTracker:
_opened_connections: Dict[str, WeakSet] = {}
_rds_utils = RdsUtils()
_opened_connections: ClassVar[Dict[str, WeakSet]] = {}
_rds_utils: ClassVar[RdsUtils] = RdsUtils()
_prune_thread: ClassVar[Optional[Thread]] = None
_shutdown_event: ClassVar[threading.Event] = threading.Event()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is never set so thread is never shut down. Maybe we should have this in the release_resources method in the wrapper.py and change this in a release_resources() method or something like that).

_safe_to_check_closed_classes: ClassVar[Set[str]] = {"psycopg"}
_default_sleep_time: ClassVar[int] = 30

@classmethod
def _start_prune_thread(cls):
if cls._prune_thread is None or not cls._prune_thread.is_alive():
cls._prune_thread = Thread(daemon=True, target=cls._prune_connections_loop)
cls._prune_thread.start()
Comment on lines +56 to +59
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a lock for this, it's more rare, but we could have 2 initial connections both calling this and starting 2 threads.


@classmethod
def _prune_connections_loop(cls):
while not cls._shutdown_event.is_set():
try:
cls._prune_connections()
time.sleep(cls._default_sleep_time)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably do a wait on the event instead of a sleep:
e.g.

if cls._shutdown_event.wait(timeout=cls._default_sleep_time):
            break

It's a daemon thread so it won't cause the program to hang, but just incase, this will help with making it exit gracefully.

except Exception:
pass

@classmethod
def _prune_connections(cls):
for host, conn_set in list(cls._opened_connections.items()):
# Remove dead references and closed connections
to_remove = []
for conn in list(conn_set):
if conn is None:
to_remove.append(conn)
else:
try:
# The following classes do not check connection validity via a DB server call
# so it is safe to check whether connection is already closed.
if any(safe_class in conn.__module__ for safe_class in cls._safe_to_check_closed_classes) and conn.is_closed():
to_remove.append(conn)
except Exception:
pass

for conn in to_remove:
conn_set.discard(conn)

# Remove empty connection sets
if not conn_set:
del cls._opened_connections[host]

def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection):
"""
Expand All @@ -56,8 +105,8 @@ def populate_opened_connection_set(self, host_info: HostInfo, conn: Connection):
self._track_connection(host_info.as_alias(), conn)
return

instance_endpoint: Optional[str] = next((alias for alias in aliases if self._rds_utils.is_rds_instance(self._rds_utils.remove_port(alias))),
None)
instance_endpoint: Optional[str] = next(
(alias for alias in aliases if self._rds_utils.is_rds_instance(self._rds_utils.remove_port(alias))), None)
if not instance_endpoint:
logger.debug("OpenedConnectionTracker.UnableToPopulateOpenedConnectionSet")
return
Expand All @@ -73,7 +122,7 @@ def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host:
"""

if host_info:
self.invalidate_all_connections(host=frozenset(host_info.as_alias()))
self.invalidate_all_connections(host=frozenset([host_info.as_alias()]))
self.invalidate_all_connections(host=host_info.as_aliases())
return

Expand All @@ -94,21 +143,38 @@ def invalidate_all_connections(self, host_info: Optional[HostInfo] = None, host:
self._log_connection_set(instance_endpoint, connection_set)
self._invalidate_connections(connection_set)

def _track_connection(self, instance_endpoint: str, conn: Connection):
connection_set: Optional[WeakSet] = self._opened_connections.get(instance_endpoint)
if connection_set is None:
connection_set = WeakSet()
connection_set.add(conn)
self._opened_connections[instance_endpoint] = connection_set
def remove_connection_tracking(self, host_info: HostInfo, connection: Connection | None):
if not connection:
return

if self._rds_utils.is_rds_instance(host_info.host):
host = host_info.as_alias()
else:
connection_set.add(conn)
host = next((alias for alias in host_info.as_aliases()
if self._rds_utils.is_rds_instance(self._rds_utils.remove_port(alias))), "")

if not host:
return

connection_set = self._opened_connections.get(host)
if connection_set:
connection_set.discard(connection)

def _track_connection(self, instance_endpoint: str, conn: Connection):
connection_set = self._opened_connections.setdefault(instance_endpoint, WeakSet())
connection_set.add(conn)
self._start_prune_thread()
self.log_opened_connections()

@staticmethod
def _task(connection_set: WeakSet):
while connection_set is not None and len(connection_set) > 0:
conn_reference = connection_set.pop()
while connection_set is not None:
try:
conn_reference = connection_set.pop()
except KeyError:
# connection_set is empty
# use KeyError instead of len() to determine whether connection_set is empty to prevent data race
break

if conn_reference is None:
continue
Expand All @@ -125,31 +191,28 @@ def _invalidate_connections(self, connection_set: WeakSet):
invalidate_connection_thread.start()

def log_opened_connections(self):
msg = ""
msg_parts = []
for key, conn_set in self._opened_connections.items():
conn = ""
for item in list(conn_set):
conn += f"\n\t\t{item}"

msg += f"\t[{key} : {conn}]"
conn_parts = [f"\n\t\t{item}" for item in list(conn_set)]
conn = "".join(conn_parts)
msg_parts.append(f"\t[{key} : {conn}]")

msg = "".join(msg_parts)
return logger.debug("OpenedConnectionTracker.OpenedConnectionsTracked", msg)

def _log_connection_set(self, host: str, conn_set: Optional[WeakSet]):
if conn_set is None or len(conn_set) == 0:
return

conn = ""
for item in list(conn_set):
conn += f"\n\t\t{item}"

conn_parts = [f"\n\t\t{item}" for item in list(conn_set)]
conn = "".join(conn_parts)
msg = host + f"[{conn}\n]"
logger.debug("OpenedConnectionTracker.InvalidatingConnections", msg)


class AuroraConnectionTrackerPlugin(Plugin):
_current_writer: Optional[HostInfo] = None
_need_update_current_writer: bool = False
_host_list_refresh_end_time_nano: ClassVar[int] = 0
_TOPOLOGY_CHANGES_EXPECTED_TIME_NANO: ClassVar[int] = 3 * 60 * 1_000_000_000 # 3 minutes

@property
def subscribed_methods(self) -> Set[str]:
Expand All @@ -164,6 +227,8 @@ def __init__(self,
self._props = props
self._rds_utils = rds_utils
self._tracker = tracker
self._current_writer: Optional[HostInfo] = None
self._need_update_current_writer: bool = False
self._subscribed_methods: Set[str] = {DbApiMethod.CONNECT.method_name,
DbApiMethod.CONNECTION_CLOSE.method_name,
DbApiMethod.CONNECT.method_name,
Expand Down Expand Up @@ -192,26 +257,66 @@ def connect(
return conn

def execute(self, target: object, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any:
current_host = self._plugin_service.current_host_info
if self._current_writer is None or self._need_update_current_writer:
self._current_writer = self._get_writer(self._plugin_service.all_hosts)
self._current_writer = Utils.get_writer(self._plugin_service.all_hosts)
self._need_update_current_writer = False

try:
return execute_func()
if not method_name == DbApiMethod.CONNECTION_CLOSE.method_name:
local_host_list_refresh_end_time_nano = AuroraConnectionTrackerPlugin._host_list_refresh_end_time_nano
need_refresh_host_lists = False
if local_host_list_refresh_end_time_nano > 0:
if local_host_list_refresh_end_time_nano > perf_counter_ns():
# The time specified in hostListRefreshThresholdTimeNano isn't yet reached.
# Need to continue to refresh host list.
need_refresh_host_lists = True
else:
# The time specified in hostListRefreshThresholdTimeNano is reached, and we can stop further refreshes
# of host list. If hostListRefreshThresholdTimeNano has changed while this thread processes the code,
# we can't override a new value in hostListRefreshThresholdTimeNano.
if AuroraConnectionTrackerPlugin._host_list_refresh_end_time_nano == local_host_list_refresh_end_time_nano:
AuroraConnectionTrackerPlugin._host_list_refresh_end_time_nano = 0
if self._need_update_current_writer or need_refresh_host_lists:
# Calling this method may effectively close/abort a current connection
self._check_writer_changed(need_refresh_host_lists)

result = execute_func()
if method_name == DbApiMethod.CONNECTION_CLOSE.method_name:
self._tracker.remove_connection_tracking(current_host, self._plugin_service.current_connection)
return result

except Exception as e:
# Check that e is a FailoverError and that the writer has changed
if isinstance(e, FailoverError) and self._get_writer(self._plugin_service.all_hosts) != self._current_writer:
self._tracker.invalidate_all_connections(host_info=self._current_writer)
self._tracker.log_opened_connections()
self._need_update_current_writer = True
raise e
if isinstance(e, FailoverError):
AuroraConnectionTrackerPlugin._host_list_refresh_end_time_nano = (
perf_counter_ns() + AuroraConnectionTrackerPlugin._TOPOLOGY_CHANGES_EXPECTED_TIME_NANO)
# Calling this method may effectively close/abort a current connection
self._check_writer_changed(True)
raise

def _check_writer_changed(self, need_refresh_host_lists: bool):
if need_refresh_host_lists:
self._plugin_service.refresh_host_list()

host_info_after_failover = Utils.get_writer(self._plugin_service.all_hosts)
if host_info_after_failover is None:
return

if self._current_writer is None:
self._current_writer = host_info_after_failover
self._need_update_current_writer = False
elif not self._current_writer.get_host_and_port() == host_info_after_failover.get_host_and_port():
self._tracker.invalidate_all_connections(host_info=self._current_writer)
self._tracker.log_opened_connections()
self._current_writer = host_info_after_failover
self._need_update_current_writer = False

def _get_writer(self, hosts: Tuple[HostInfo, ...]) -> Optional[HostInfo]:
for host in hosts:
if host.role == HostRole.WRITER:
return host
return None
def notify_host_list_changed(self, changes: Dict[str, Set[HostEvent]]):
for node, node_changes in changes.items():
if HostEvent.CONVERTED_TO_READER in node_changes:
self._tracker.invalidate_all_connections(host=frozenset([node]))
if HostEvent.CONVERTED_TO_WRITER in node_changes:
self._need_update_current_writer = True


class AuroraConnectionTrackerPluginFactory(PluginFactory):
Expand Down
6 changes: 3 additions & 3 deletions aws_advanced_python_wrapper/failover_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,12 +325,12 @@ def _failover_writer(self):

writer_host = self._get_writer(result.topology)
allowed_hosts = self._plugin_service.hosts
allowed_hostnames = [host.host for host in allowed_hosts]
if writer_host.host not in allowed_hostnames:
allowed_hostnames = [host.get_host_and_port() for host in allowed_hosts]
if writer_host.get_host_and_port() not in allowed_hostnames:
raise FailoverFailedError(
Messages.get_formatted(
"FailoverPlugin.NewWriterNotAllowed",
"<null>" if writer_host is None else writer_host.host,
"<null>" if writer_host is None else writer_host.get_host_and_port(),
LogUtils.log_topology(allowed_hosts)))

self._plugin_service.set_current_connection(result.new_connection, writer_host)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

from __future__ import annotations

import time
from copy import copy
from dataclasses import dataclass
from datetime import datetime
from threading import Event, Lock, Thread
from time import sleep
from typing import (TYPE_CHECKING, Callable, ClassVar, Dict, List, Optional,
Expand Down Expand Up @@ -96,7 +96,7 @@ def get_host_info_by_strategy(self, role: HostRole, strategy: str, host_list: Op

# Found a fastest host. Let's find it in the latest topology.
for host in self._plugin_service.hosts:
if host == fastest_response_host:
if host.get_host_and_port() == fastest_response_host.get_host_and_port():
# found the fastest host in the topology
return host
# It seems that the fastest cached host isn't in the latest topology.
Expand Down Expand Up @@ -196,7 +196,7 @@ def close(self):
logger.debug("HostResponseTimeMonitor.Stopped", self._host_info.host)

def _get_current_time(self):
return datetime.now().microsecond / 1000 # milliseconds
return time.perf_counter() * 1000 # milliseconds

def run(self):
context: TelemetryContext = self._telemetry_factory.open_telemetry_context(
Expand Down
5 changes: 2 additions & 3 deletions aws_advanced_python_wrapper/host_list_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
WrapperProperties)
from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
from aws_advanced_python_wrapper.utils.utils import LogUtils
from aws_advanced_python_wrapper.utils.utils import LogUtils, Utils

logger = Logger(__name__)

Expand Down Expand Up @@ -266,7 +266,6 @@ def _suggest_cluster_id(self, primary_cluster_id_hosts: Tuple[HostInfo, ...]):
if not primary_cluster_id_hosts:
return

primary_cluster_id_urls = {host.url for host in primary_cluster_id_hosts}
for cluster_id, hosts in RdsHostListProvider._topology_cache.get_dict().items():
is_primary_cluster = RdsHostListProvider._is_primary_cluster_id_cache.get_with_default(
cluster_id, False, self._suggested_cluster_id_refresh_ns)
Expand All @@ -276,7 +275,7 @@ def _suggest_cluster_id(self, primary_cluster_id_hosts: Tuple[HostInfo, ...]):

# The entry is non-primary
for host in hosts:
if host.url in primary_cluster_id_urls:
if Utils.contains_host_and_port(primary_cluster_id_hosts, host.get_host_and_port()):
# An instance URL in this topology cache entry matches an instance URL in the primary cluster entry.
# The associated cluster ID should be updated to match the primary ID so that they can share
# topology info.
Expand Down
2 changes: 1 addition & 1 deletion aws_advanced_python_wrapper/host_monitoring_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _get_monitoring_host_info(self) -> HostInfo:
if current_host_info is None:
raise AwsWrapperError("HostMonitoringPlugin.HostInfoNone")
self._monitoring_host_info = current_host_info
rds_type = self._rds_utils.identify_rds_type(self._monitoring_host_info.url)
rds_type = self._rds_utils.identify_rds_type(self._monitoring_host_info.host)

try:
if rds_type.is_rds_cluster:
Expand Down
2 changes: 1 addition & 1 deletion aws_advanced_python_wrapper/host_monitoring_v2_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _get_monitoring_host_info(self) -> HostInfo:
if current_host_info is None:
raise AwsWrapperError(Messages.get("HostMonitoringV2Plugin.HostInfoNone"))
self._monitoring_host_info = current_host_info
rds_url_type = self._rds_utils.identify_rds_type(self._monitoring_host_info.url)
rds_url_type = self._rds_utils.identify_rds_type(self._monitoring_host_info.host)

try:
if not rds_url_type.is_rds_cluster:
Expand Down
10 changes: 5 additions & 5 deletions aws_advanced_python_wrapper/hostinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,7 @@ def __copy__(self):

@property
def url(self):
if self.is_port_specified():
return f"{self.host}:{self.port}"
else:
return self.host
return f"{self.as_alias()}/"

@property
def aliases(self) -> FrozenSet[str]:
Expand All @@ -101,9 +98,12 @@ def aliases(self) -> FrozenSet[str]:
def all_aliases(self) -> FrozenSet[str]:
return frozenset(self._all_aliases)

def as_alias(self) -> str:
def get_host_and_port(self):
return f"{self.host}:{self.port}" if self.is_port_specified() else self.host

def as_alias(self) -> str:
return self.get_host_and_port()

def add_alias(self, *aliases: str):
if not aliases:
return
Expand Down
Loading