Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aws_advanced_python_wrapper/cluster_topology_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,8 @@ def __call__(self) -> None:

if is_writer:
try:
if self._monitor._topology_utils.get_host_role(
connection, self._monitor._plugin_service.driver_dialect) != HostRole.WRITER:
if self._monitor._plugin_service.get_host_role(
connection) != HostRole.WRITER:
is_writer = False
except Exception as ex:
logger.debug("HostMonitor.InvalidWriterQuery", ex)
Expand Down
130 changes: 113 additions & 17 deletions aws_advanced_python_wrapper/database_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@
from enum import Enum, auto

from aws_advanced_python_wrapper.errors import (AwsWrapperError,
QueryTimeoutError)
from aws_advanced_python_wrapper.hostinfo import HostInfo
QueryTimeoutError,
UnsupportedOperationError)
from aws_advanced_python_wrapper.hostinfo import HostInfo, HostRole
from aws_advanced_python_wrapper.thread_pool_container import \
ThreadPoolContainer
from aws_advanced_python_wrapper.utils.decorators import \
Expand Down Expand Up @@ -90,7 +91,6 @@ class TargetDriverType(Enum):
class TopologyAwareDatabaseDialect(Protocol):
_TOPOLOGY_QUERY: str
_HOST_ID_QUERY: str
_IS_READER_QUERY: str
_WRITER_HOST_QUERY: str

@property
Expand All @@ -101,10 +101,6 @@ def topology_query(self) -> str:
def host_id_query(self) -> str:
return self._HOST_ID_QUERY

@property
def is_reader_query(self) -> str:
return self._IS_READER_QUERY

@property
def writer_id_query(self) -> str:
return self._WRITER_HOST_QUERY
Expand Down Expand Up @@ -148,6 +144,16 @@ def host_alias_query(self) -> str:
def server_version_query(self) -> str:
...

@property
@abstractmethod
def host_id_query(self) -> str:
...

@property
@abstractmethod
def is_reader_query(self) -> str:
...

@property
@abstractmethod
def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
Expand Down Expand Up @@ -189,19 +195,30 @@ class MysqlDatabaseDialect(DatabaseDialect):
_DIALECT_UPDATE_CANDIDATES: Tuple[DialectCode, ...] = (
DialectCode.AURORA_MYSQL, DialectCode.GLOBAL_AURORA_MYSQL, DialectCode.MULTI_AZ_CLUSTER_MYSQL, DialectCode.RDS_MYSQL)
_exception_handler: Optional[ExceptionHandler] = None
_HOST_ID_EXPRESSION = "CONCAT(@@hostname, ':', @@port)"
_HOST_ID_QUERY = f"SELECT @@hostname AS host, {_HOST_ID_EXPRESSION} AS host_id"
_IS_READER_QUERY = "SELECT @@read_only"

@property
def default_port(self) -> int:
return 3306

@property
def host_alias_query(self) -> str:
return "SELECT CONCAT(@@hostname, ':', @@port)"
return f"SELECT {self._HOST_ID_EXPRESSION}"

@property
def server_version_query(self) -> str:
return "SHOW VARIABLES LIKE 'version_comment'"

@property
def host_id_query(self) -> str:
return self._HOST_ID_QUERY

@property
def is_reader_query(self) -> str:
return self._IS_READER_QUERY

@property
def exception_handler(self) -> Optional[ExceptionHandler]:
if MysqlDatabaseDialect._exception_handler is None:
Expand Down Expand Up @@ -240,19 +257,30 @@ class PgDatabaseDialect(DatabaseDialect):
_DIALECT_UPDATE_CANDIDATES: Tuple[DialectCode, ...] = (
DialectCode.AURORA_PG, DialectCode.GLOBAL_AURORA_PG, DialectCode.MULTI_AZ_CLUSTER_PG, DialectCode.RDS_PG)
_exception_handler: Optional[ExceptionHandler] = None
_HOST_ID_EXPRESSION = "pg_catalog.CONCAT(pg_catalog.inet_server_addr(), ':', pg_catalog.inet_server_port())"
_HOST_ID_QUERY = f"SELECT pg_catalog.inet_server_addr() AS host, {_HOST_ID_EXPRESSION} AS host_id"
_IS_READER_QUERY = "SELECT pg_catalog.pg_is_in_recovery()"

@property
def default_port(self) -> int:
return 5432

@property
def host_alias_query(self) -> str:
return "SELECT pg_catalog.CONCAT(pg_catalog.inet_server_addr(), ':', pg_catalog.inet_server_port())"
return f"SELECT {self._HOST_ID_EXPRESSION}"

@property
def server_version_query(self) -> str:
return "SELECT 'version', pg_catalog.VERSION()"

@property
def host_id_query(self) -> str:
return self._HOST_ID_QUERY

@property
def is_reader_query(self) -> str:
return self._IS_READER_QUERY

@property
def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
return PgDatabaseDialect._DIALECT_UPDATE_CANDIDATES
Expand Down Expand Up @@ -298,10 +326,17 @@ def is_blue_green_status_available(self, conn: Connection) -> bool:
class RdsMysqlDialect(MysqlDatabaseDialect, BlueGreenDialect):
_DIALECT_UPDATE_CANDIDATES = (DialectCode.AURORA_MYSQL, DialectCode.GLOBAL_AURORA_MYSQL, DialectCode.MULTI_AZ_CLUSTER_MYSQL)

_HOST_ID_QUERY = ("SELECT id, SUBSTRING_INDEX(endpoint, '.', 1) "
"FROM mysql.rds_topology "
"WHERE id = @@server_id")
_BG_STATUS_QUERY = "SELECT version, endpoint, port, role, status FROM mysql.rds_topology"
_BG_STATUS_EXISTS_QUERY = \
"SELECT 1 AS tmp FROM information_schema.tables WHERE table_schema = 'mysql' AND table_name = 'rds_topology'"

@property
def host_id_query(self) -> str:
return self._HOST_ID_QUERY

def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
initial_transaction_status: bool = driver_dialect.is_in_transaction(conn)
try:
Expand Down Expand Up @@ -352,10 +387,17 @@ class RdsPgDialect(PgDatabaseDialect, BlueGreenDialect):
"WHERE name OPERATOR(pg_catalog.=) 'rds.extensions'")
_DIALECT_UPDATE_CANDIDATES = (DialectCode.AURORA_PG, DialectCode.GLOBAL_AURORA_PG, DialectCode.MULTI_AZ_CLUSTER_PG)

_HOST_ID_QUERY = ("SELECT id, SUBSTRING(endpoint FROM 0 FOR POSITION('.' IN endpoint)) "
"FROM rds_tools.show_topology() "
"WHERE id OPERATOR(pg_catalog.=) rds_tools.dbi_resource_id()")
_BG_STATUS_QUERY = (f"SELECT version, endpoint, port, role, status "
f"FROM rds_tools.show_topology('aws_advanced_python_wrapper-{DriverInfo.DRIVER_VERSION}')")
_BG_STATUS_EXISTS_QUERY = "SELECT 'rds_tools.show_topology'::regproc"

@property
def host_id_query(self) -> str:
return self._HOST_ID_QUERY

def is_dialect(self, conn: Connection, driver_dialect: DriverDialect) -> bool:
initial_transaction_status: bool = driver_dialect.is_in_transaction(conn)
if not super().is_dialect(conn, driver_dialect):
Expand Down Expand Up @@ -401,7 +443,7 @@ class AuroraMysqlDialect(MysqlDatabaseDialect, TopologyAwareDatabaseDialect, Blu
"FROM information_schema.replica_host_status "
"WHERE time_to_sec(timediff(now(), LAST_UPDATE_TIMESTAMP)) <= 300 "
"OR SESSION_ID = 'MASTER_SESSION_ID' ")
_HOST_ID_QUERY = "SELECT @@aurora_server_id"
_HOST_ID_QUERY = "SELECT @@aurora_server_id, @@aurora_server_id"
_IS_READER_QUERY = "SELECT @@innodb_read_only"
_WRITER_HOST_QUERY = \
("SELECT SERVER_ID FROM information_schema.replica_host_status "
Expand All @@ -411,6 +453,10 @@ class AuroraMysqlDialect(MysqlDatabaseDialect, TopologyAwareDatabaseDialect, Blu
_BG_STATUS_EXISTS_QUERY = \
"SELECT 1 AS tmp FROM information_schema.tables WHERE table_schema = 'mysql' AND table_name = 'rds_topology'"

@property
def is_reader_query(self) -> str:
return self._IS_READER_QUERY

@property
def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
return AuroraMysqlDialect._DIALECT_UPDATE_CANDIDATES
Expand Down Expand Up @@ -465,8 +511,7 @@ class AuroraPgDialect(PgDatabaseDialect, TopologyAwareDatabaseDialect, AuroraLim
"OR SESSION_ID OPERATOR(pg_catalog.=) 'MASTER_SESSION_ID' "
"OR LAST_UPDATE_TIMESTAMP IS NULL")

_HOST_ID_QUERY = "SELECT pg_catalog.aurora_db_instance_identifier()"
_IS_READER_QUERY = "SELECT pg_catalog.pg_is_in_recovery()"
_HOST_ID_QUERY = "SELECT pg_catalog.aurora_db_instance_identifier(), pg_catalog.aurora_db_instance_identifier()"
_LIMITLESS_ROUTER_ENDPOINT_QUERY = "SELECT router_endpoint, load FROM pg_catalog.aurora_limitless_router_endpoints()"

_BG_STATUS_QUERY = (f"SELECT version, endpoint, port, role, status "
Expand Down Expand Up @@ -644,8 +689,9 @@ class MultiAzClusterMysqlDialect(MysqlDatabaseDialect, TopologyAwareDatabaseDial
_TOPOLOGY_QUERY = "SELECT id, endpoint, port FROM mysql.rds_topology"
_WRITER_HOST_QUERY = "SHOW REPLICA STATUS"
_WRITER_HOST_COLUMN_INDEX = 39
_HOST_ID_QUERY = "SELECT @@server_id"
_IS_READER_QUERY = "SELECT @@read_only"
_HOST_ID_QUERY = ("SELECT id, SUBSTRING_INDEX(endpoint, '.', 1) "
"FROM mysql.rds_topology "
"WHERE id = @@server_id")

@property
def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
Expand Down Expand Up @@ -701,8 +747,9 @@ class MultiAzClusterPgDialect(PgDatabaseDialect, TopologyAwareDatabaseDialect):
f"SELECT id, endpoint, port FROM rds_tools.show_topology('aws_python_driver-{DriverInfo.DRIVER_VERSION}')"
_WRITER_HOST_QUERY = \
"SELECT multi_az_db_cluster_source_dbi_resource_id FROM rds_tools.multi_az_db_cluster_source_dbi_resource_id()"
_HOST_ID_QUERY = "SELECT dbi_resource_id FROM rds_tools.dbi_resource_id()"
_IS_READER_QUERY = "SELECT pg_catalog.pg_is_in_recovery()"
_HOST_ID_QUERY = ("SELECT id, SUBSTRING(endpoint FROM 0 FOR POSITION('.' IN endpoint)) "
"FROM rds_tools.show_topology() "
"WHERE id OPERATOR(pg_catalog.=) rds_tools.dbi_resource_id()")
_exception_handler: Optional[ExceptionHandler] = None

@property
Expand Down Expand Up @@ -761,6 +808,16 @@ def host_alias_query(self) -> str:
def server_version_query(self) -> str:
return ""

@property
def host_id_query(self) -> str:
raise UnsupportedOperationError(
Messages.get_formatted("UnknownDialect.UnsupportedMethod", "host_id_query"))

@property
def is_reader_query(self) -> str:
raise UnsupportedOperationError(
Messages.get_formatted("UnknownDialect.UnsupportedMethod", "is_reader_query"))

@property
def dialect_update_candidates(self) -> Optional[Tuple[DialectCode, ...]]:
return UnknownDatabaseDialect._DIALECT_UPDATE_CANDIDATES
Expand Down Expand Up @@ -984,4 +1041,43 @@ def check_existence_queries(conn: Connection, existence_queries: Tuple[str, ...]
return False

return True
# not do we need to add the transaction try catch here or is it better to surround the calling method

@staticmethod
def get_host_role(conn: Connection, driver_dialect: DriverDialect, is_reader_query: str,
thread_pool, timeout_sec: float) -> HostRole:
try:
cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout(
thread_pool, timeout_sec, driver_dialect, conn)(DialectUtils._execute_is_reader_query)
result = cursor_execute_func_with_timeout(conn, is_reader_query)
if result is not None:
is_reader = bool(result[0])
return HostRole.READER if is_reader else HostRole.WRITER
except TimeoutError as e:
raise QueryTimeoutError(Messages.get("DialectUtils.GetHostRoleTimeout")) from e
except Exception as e:
raise AwsWrapperError(Messages.get("DialectUtils.ErrorGettingHostRole")) from e

raise AwsWrapperError(Messages.get("DialectUtils.ErrorGettingHostRole"))

@staticmethod
def _execute_is_reader_query(conn: Connection, is_reader_query: str):
with closing(conn.cursor()) as cursor:
cursor.execute(is_reader_query)
return cursor.fetchone()

@staticmethod
def get_instance_id(conn: Connection, driver_dialect: DriverDialect, instance_id_query: str,
thread_pool, timeout_sec: float) -> Optional[Tuple[str, str]]:
cursor_execute_func_with_timeout = preserve_transaction_status_with_timeout(
thread_pool, timeout_sec, driver_dialect, conn)(DialectUtils._execute_instance_id_query)
result = cursor_execute_func_with_timeout(conn, instance_id_query)
if result is not None and len(result) >= 2:
return (str(result[0]), str(result[1]))

return None

@staticmethod
def _execute_instance_id_query(conn: Connection, instance_id_query: str):
with closing(conn.cursor()) as cursor:
cursor.execute(instance_id_query)
return cursor.fetchone()
Loading
Loading