Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
from aws_advanced_python_wrapper.utils.rds_utils import RdsUtils

logger = Logger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from aws_advanced_python_wrapper.utils.properties import (Properties,
WrapperProperties)
from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
from aws_advanced_python_wrapper.utils.rds_utils import RdsUtils


class AuroraInitialConnectionStrategyPlugin(Plugin):
Expand All @@ -45,7 +45,7 @@ def subscribed_methods(self) -> Set[str]:

def __init__(self, plugin_service: PluginService):
super()
self._plugin_service = plugin_service
self._plugin_service: PluginService = plugin_service
self._rds_utils = RdsUtils()

def connect(self, target_driver_func: Callable, driver_dialect: DriverDialect, host_info: HostInfo, props: Properties,
Expand Down Expand Up @@ -207,6 +207,20 @@ def _get_reader(self, props: Properties) -> Optional[HostInfo]:
and strategy is not None
and self._plugin_service.accepts_strategy(HostRole.READER, strategy)):
try:
original_host = self._plugin_service.current_host_info
url_type = self._rds_utils.identify_rds_type(original_host.host) if original_host else None

if url_type and url_type.has_region:
aws_region = self._rds_utils.get_rds_region(original_host.host)
if aws_region:
hosts_in_region = []
for h in self._plugin_service.all_hosts:
h_region = self._rds_utils.get_rds_region(h.host)
if h_region and aws_region.lower() == h_region.lower():
hosts_in_region.append(h)
return self._plugin_service.get_host_info_by_strategy(
HostRole.READER, strategy, hosts_in_region)

return self._plugin_service.get_host_info_by_strategy(HostRole.READER, strategy)
except Exception:
# Host isn't found.
Expand Down
2 changes: 1 addition & 1 deletion aws_advanced_python_wrapper/blue_green_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.properties import (Properties,
WrapperProperties)
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
from aws_advanced_python_wrapper.utils.rds_utils import RdsUtils
from aws_advanced_python_wrapper.utils.telemetry.telemetry import \
TelemetryTraceLevel

Expand Down
57 changes: 52 additions & 5 deletions aws_advanced_python_wrapper/cluster_topology_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Dict, Optional

from aws_advanced_python_wrapper.errors import AwsWrapperError
from aws_advanced_python_wrapper.host_availability import HostAvailability
from aws_advanced_python_wrapper.hostinfo import HostInfo
from aws_advanced_python_wrapper.utils.atomic import AtomicReference
from aws_advanced_python_wrapper.utils.messages import Messages
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
from aws_advanced_python_wrapper.utils.rds_utils import RdsUtils
from aws_advanced_python_wrapper.utils.storage.storage_service import (
StorageService, Topology)
from aws_advanced_python_wrapper.utils.thread_safe_connection_holder import \
Expand All @@ -35,7 +36,7 @@
from aws_advanced_python_wrapper.pep249 import Connection
from aws_advanced_python_wrapper.plugin_service import PluginService
from aws_advanced_python_wrapper.utils.properties import Properties
from aws_advanced_python_wrapper.host_list_provider import TopologyUtils
from aws_advanced_python_wrapper.host_list_provider import TopologyUtils, GlobalAuroraTopologyUtils

from aws_advanced_python_wrapper.hostinfo import HostRole
from aws_advanced_python_wrapper.utils.log import Logger
Expand Down Expand Up @@ -316,9 +317,10 @@ def _open_any_connection_and_update_topology(self) -> Topology:
writer_host_info = self._initial_host_info
self._writer_host_info.set(writer_host_info)
else:
writer_host = self._instance_template.host.replace("?", writer_id)
port = self._instance_template.port \
if self._instance_template.is_port_specified() \
instance_template = self._get_instance_template(writer_id, conn)
writer_host = instance_template.host.replace("?", writer_id)
port = instance_template.port \
if instance_template.is_port_specified() \
else self._initial_host_info.port
writer_host_info = HostInfo(
writer_host,
Expand Down Expand Up @@ -438,6 +440,9 @@ def _query_for_topology(self, connection: Connection) -> Topology:
return hosts
return ()

def _get_instance_template(self, instance_id: str, connection: Connection) -> HostInfo:
return self._instance_template

def _update_topology_cache(self, hosts: Topology) -> None:
StorageService.set(self._cluster_id, hosts, Topology)
# Notify waiting threads
Expand Down Expand Up @@ -565,3 +570,45 @@ def _calculate_backoff_with_jitter(self, attempt: int) -> int:
backoff = ClusterTopologyMonitorImpl.INITIAL_BACKOFF_MS * (2 ** min(attempt, 6))
backoff = min(backoff, ClusterTopologyMonitorImpl.MAX_BACKOFF_MS)
return int(backoff * (0.5 + random.random() * 0.5))


class GlobalAuroraTopologyMonitor(ClusterTopologyMonitorImpl):
def __init__(
self,
plugin_service: PluginService,
topology_utils: GlobalAuroraTopologyUtils,
cluster_id: str,
initial_host_info: HostInfo,
props: Properties,
instance_template: HostInfo,
refresh_rate_ns: int,
high_refresh_rate_ns: int,
instance_templates_by_region: dict[str, HostInfo]
):
super().__init__(
plugin_service,
topology_utils,
cluster_id,
initial_host_info,
props,
instance_template,
refresh_rate_ns,
high_refresh_rate_ns
)
self._instance_templates_by_region = instance_templates_by_region
self._global_topology_utils = topology_utils

def _get_instance_template(self, instance_id: str, connection: Connection) -> HostInfo:
region = self._global_topology_utils.get_region(instance_id, connection)
if region:
instance_template = self._instance_templates_by_region.get(region)
if instance_template is None:
raise AwsWrapperError(
Messages.get_formatted("GlobalAuroraTopologyMonitor.cannotFindRegionTemplate", region))
return instance_template
return self._instance_template

def _query_for_topology(self, connection: Connection) -> Topology:
result = self._global_topology_utils.query_for_topology_with_regions(
connection, self._instance_templates_by_region)
return result if result is not None else ()
4 changes: 2 additions & 2 deletions aws_advanced_python_wrapper/custom_endpoint_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@

from enum import Enum

from boto3 import Session
from boto3 import Session # type: ignore

from aws_advanced_python_wrapper.pep249_methods import DbApiMethod
from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.properties import WrapperProperties
from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils
from aws_advanced_python_wrapper.utils.rds_utils import RdsUtils
from aws_advanced_python_wrapper.utils.sliding_expiration_cache_container import \
SlidingExpirationCacheContainer
from aws_advanced_python_wrapper.utils.telemetry.telemetry import (
Expand Down
Loading
Loading