Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aws_advanced_python_wrapper/blue_green_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ def _init_host_list_provider(self):
logger.warning("BlueGreenStatusMonitor.HostInfoNone")
return

host_list_provider_supplier = self._plugin_service.database_dialect.get_host_list_provider_supplier()
host_list_provider_supplier = self._plugin_service.database_dialect.get_host_list_provider_supplier(self._plugin_service)
host_list_provider_service: HostListProviderService = cast('HostListProviderService', self._plugin_service)
self._host_list_provider = host_list_provider_supplier(host_list_provider_service, props_copy)

Expand Down
537 changes: 537 additions & 0 deletions aws_advanced_python_wrapper/cluster_topology_monitor.py

Large diffs are not rendered by default.

54 changes: 45 additions & 9 deletions aws_advanced_python_wrapper/database_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
Protocol, Tuple, runtime_checkable)

from aws_advanced_python_wrapper.driver_info import DriverInfo
from aws_advanced_python_wrapper.failover_v2_plugin import FailoverV2Plugin
from aws_advanced_python_wrapper.host_list_provider import (
AuroraTopologyUtils, MultiAzTopologyUtils)
AuroraTopologyUtils, MonitoringRdsHostListProvider, MultiAzTopologyUtils)
from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType

if TYPE_CHECKING:
from aws_advanced_python_wrapper.pep249 import Connection
from .driver_dialect import DriverDialect
from .exception_handling import ExceptionHandler
from aws_advanced_python_wrapper.plugin_service import PluginService

from abc import ABC, abstractmethod
from concurrent.futures import TimeoutError
Expand Down Expand Up @@ -88,6 +90,7 @@ class TopologyAwareDatabaseDialect(Protocol):
_TOPOLOGY_QUERY: str
_HOST_ID_QUERY: str
_IS_READER_QUERY: str
_WRITER_HOST_QUERY: str

@property
def topology_query(self) -> str:
Expand All @@ -101,6 +104,10 @@ def host_id_query(self) -> str:
def is_reader_query(self) -> str:
return self._IS_READER_QUERY

@property
def writer_id_query(self) -> str:
return self._WRITER_HOST_QUERY


@runtime_checkable
class AuroraLimitlessDialect(Protocol):
Expand Down Expand Up @@ -147,7 +154,7 @@ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
...

@abstractmethod
def get_host_list_provider_supplier(self) -> Callable:
def get_host_list_provider_supplier(self, plugin_service: PluginService) -> Callable:
...

@abstractmethod
Expand Down Expand Up @@ -213,7 +220,7 @@ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:

return False

def get_host_list_provider_supplier(self) -> Callable:
def get_host_list_provider_supplier(self, plugin_service: PluginService) -> Callable:
return lambda provider_service, props: ConnectionStringHostListProvider(provider_service, props)

def prepare_conn_props(self, props: Properties):
Expand Down Expand Up @@ -261,7 +268,7 @@ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:

return False

def get_host_list_provider_supplier(self) -> Callable:
def get_host_list_provider_supplier(self, plugin_service: PluginService) -> Callable:
return lambda provider_service, props: ConnectionStringHostListProvider(provider_service, props)

def prepare_conn_props(self, props: Properties):
Expand Down Expand Up @@ -387,6 +394,9 @@ class AuroraMysqlDialect(MysqlDatabaseDialect, TopologyAwareDatabaseDialect, Blu
"OR SESSION_ID = 'MASTER_SESSION_ID' ")
_HOST_ID_QUERY = "SELECT @@aurora_server_id"
_IS_READER_QUERY = "SELECT @@innodb_read_only"
_WRITER_HOST_QUERY = \
("SELECT SERVER_ID FROM information_schema.replica_host_status "
"WHERE SESSION_ID = 'MASTER_SESSION_ID' AND SERVER_ID = @@aurora_server_id")

_BG_STATUS_QUERY = "SELECT version, endpoint, port, role, status FROM mysql.rds_topology"
_BG_STATUS_EXISTS_QUERY = \
Expand All @@ -410,7 +420,13 @@ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:

return False

def get_host_list_provider_supplier(self) -> Callable:
def get_host_list_provider_supplier(self, plugin_service: PluginService) -> Callable:
if plugin_service.is_plugin_in_use(FailoverV2Plugin):
return lambda provider_service, props: MonitoringRdsHostListProvider(
provider_service,
props, AuroraTopologyUtils(self, props),
plugin_service)

return lambda provider_service, props: RdsHostListProvider(provider_service, props, AuroraTopologyUtils(self, props))

@property
Expand Down Expand Up @@ -449,6 +465,10 @@ class AuroraPgDialect(PgDatabaseDialect, TopologyAwareDatabaseDialect, AuroraLim
_BG_STATUS_QUERY = (f"SELECT version, endpoint, port, role, status "
f"FROM pg_catalog.get_blue_green_fast_switchover_metadata('aws_advanced_python_wrapper-{DriverInfo.DRIVER_VERSION}')")
_BG_STATUS_EXISTS_QUERY = "SELECT 'pg_catalog.get_blue_green_fast_switchover_metadata'::regproc"
_WRITER_HOST_QUERY = \
("SELECT SERVER_ID FROM pg_catalog.aurora_replica_status() "
"WHERE SESSION_ID OPERATOR(pg_catalog.=) 'MASTER_SESSION_ID' "
"AND SERVER_ID OPERATOR(pg_catalog.=) pg_catalog.aurora_db_instance_identifier()")

@property
def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
Expand Down Expand Up @@ -483,7 +503,13 @@ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:

return False

def get_host_list_provider_supplier(self) -> Callable:
def get_host_list_provider_supplier(self, plugin_service: PluginService) -> Callable:
if plugin_service.is_plugin_in_use(FailoverV2Plugin):
return lambda provider_service, props: MonitoringRdsHostListProvider(
provider_service,
props,
AuroraTopologyUtils(self, props), plugin_service)

return lambda provider_service, props: RdsHostListProvider(provider_service, props, AuroraTopologyUtils(self, props))

@property
Expand Down Expand Up @@ -533,7 +559,12 @@ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:

return False

def get_host_list_provider_supplier(self) -> Callable:
def get_host_list_provider_supplier(self, plugin_service: PluginService) -> Callable:
if plugin_service.is_plugin_in_use(FailoverV2Plugin):
return lambda provider_service, props: MonitoringRdsHostListProvider(
provider_service, props,
MultiAzTopologyUtils(self, props, self._WRITER_HOST_QUERY, self._WRITER_HOST_COLUMN_INDEX), plugin_service)

return lambda provider_service, props: RdsHostListProvider(
provider_service,
props,
Expand Down Expand Up @@ -588,7 +619,12 @@ def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:

return False

def get_host_list_provider_supplier(self) -> Callable:
def get_host_list_provider_supplier(self, plugin_service: PluginService) -> Callable:
if plugin_service.is_plugin_in_use(FailoverV2Plugin):
return lambda provider_service, props: MonitoringRdsHostListProvider(
provider_service, props,
MultiAzTopologyUtils(self, props, self._WRITER_HOST_QUERY), plugin_service)

return lambda provider_service, props: RdsHostListProvider(
provider_service,
props,
Expand Down Expand Up @@ -629,7 +665,7 @@ def exception_handler(self) -> Optional[ExceptionHandler]:
def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
return False

def get_host_list_provider_supplier(self) -> Callable:
def get_host_list_provider_supplier(self, plugin_service: PluginService) -> Callable:
return lambda provider_service, props: ConnectionStringHostListProvider(provider_service, props)

def prepare_conn_props(self, props: Properties):
Expand Down
16 changes: 16 additions & 0 deletions aws_advanced_python_wrapper/exception_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,15 @@ def is_login_exception(self, error: Optional[Exception] = None, sql_state: Optio
"""
pass

def is_read_only_connection_exception(self, error: Optional[Exception] = None, sql_state: Optional[str] = None) -> bool:
"""
Checks whether the given error is caused by failing to authenticate the user.
:param error: The error raised by the target driver.
:param sql_state: The SQL State associated with the error.
:return: True if the error is caused by a login issue, False otherwise.
"""
pass


class ExceptionManager:
custom_handler: Optional[ExceptionHandler] = None
Expand All @@ -67,6 +76,13 @@ def is_login_exception(self, dialect: Optional[DatabaseDialect], error: Optional
return handler.is_login_exception(error=error, sql_state=sql_state)
return False

def is_read_only_connection_exception(self, dialect: Optional[DatabaseDialect], error: Optional[Exception] = None,
sql_state: Optional[str] = None) -> bool:
handler = self._get_handler(dialect)
if handler is not None:
return handler.is_read_only_connection_exception(error=error, sql_state=sql_state)
return False

def _get_handler(self, dialect: Optional[DatabaseDialect]) -> Optional[ExceptionHandler]:
if dialect is None:
return None
Expand Down
13 changes: 5 additions & 8 deletions aws_advanced_python_wrapper/failover_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,7 @@ def _update_topology(self, force_update: bool):
if not self._is_failover_enabled():
return

conn = self._plugin_service.current_connection
driver_dialect = self._plugin_service.driver_dialect
if conn is None or (driver_dialect is not None and driver_dialect.is_closed(conn)):
if self._plugin_service.current_connection is None:
return

if force_update:
Expand Down Expand Up @@ -377,11 +375,10 @@ def _invalidate_current_connection(self):
except Exception:
pass

if not driver_dialect.is_closed(conn):
try:
return driver_dialect.execute(DbApiMethod.CONNECTION_CLOSE.method_name, lambda: conn.close())
except Exception:
pass
try:
return driver_dialect.execute(DbApiMethod.CONNECTION_CLOSE.method_name, lambda: conn.close())
except Exception:
pass

return None

Expand Down
Loading
Loading