Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions aws_advanced_python_wrapper/read_write_splitting_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def _set_reader_connection(

def _initialize_writer_connection(self):
conn, writer_host = self._connection_handler.open_new_writer_connection(lambda x: self._plugin_service.connect(x, self._properties, self))

if conn is None:
self.log_and_raise_exception(
"ReadWriteSplittingPlugin.FailedToConnectToWriter"
Expand Down Expand Up @@ -280,13 +279,18 @@ def _switch_to_writer_connection(self):
# Already connected to the intended writer.
return

self._writer_host_info = self._connection_handler.get_writer_host_info()
self._in_read_write_split = True
if not self._is_connection_usable(self._writer_connection, driver_dialect):
self._initialize_writer_connection()
elif self._writer_connection is not None and self._writer_host_info is not None:
self._switch_current_connection_to(
self._writer_connection, self._writer_host_info
)
if self._connection_handler.can_host_be_used(self._writer_host_info):
self._switch_current_connection_to(
self._writer_connection, self._writer_host_info
)
else:
ReadWriteSplittingConnectionManager.log_and_raise_exception(
"ReadWriteSplittingPlugin.NoWriterFound")

if self._is_reader_conn_from_internal_pool:
self._close_connection_if_idle(self._reader_connection)
Expand Down Expand Up @@ -508,6 +512,10 @@ def refresh_and_store_host_list(
"""Refreshes the host list and then stores it."""
...

def get_writer_host_info(self) -> Optional[HostInfo]:
"""Get the current writer host info."""
...


class TopologyBasedConnectionHandler(ReadWriteConnectionHandler):
"""Topology based implementation of connection handling logic."""
Expand Down Expand Up @@ -538,7 +546,7 @@ def open_new_writer_connection(
self,
plugin_service_connect_func: Callable[[HostInfo], Connection],
) -> tuple[Optional[Connection], Optional[HostInfo]]:
writer_host = self._get_writer()
writer_host = self.get_writer_host_info()
if writer_host is None:
return None, None

Expand Down Expand Up @@ -621,7 +629,7 @@ def can_host_be_used(self, host_info: HostInfo) -> bool:

def has_no_readers(self) -> bool:
if len(self._hosts) == 1:
return self._get_writer() is not None
return self.get_writer_host_info() is not None
return False

def refresh_and_store_host_list(
Expand Down Expand Up @@ -657,14 +665,11 @@ def is_writer_host(self, current_host: HostInfo) -> bool:
def is_reader_host(self, current_host) -> bool:
return current_host.role == HostRole.READER

def _get_writer(self) -> Optional[HostInfo]:
def get_writer_host_info(self) -> Optional[HostInfo]:
for host in self._hosts:
if host.role == HostRole.WRITER:
return host

ReadWriteSplittingConnectionManager.log_and_raise_exception(
"ReadWriteSplittingPlugin.NoWriterFound"
)
return None


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ def is_reader_host(self, current_host: HostInfo) -> bool:
or current_host.url.casefold() == self._read_endpoint
)

def get_writer_host_info(self) -> Optional[HostInfo]:
return self._write_endpoint_host_info

def _create_host_info(self, endpoint: str, role: HostRole) -> HostInfo:
endpoint = endpoint.strip()
host = endpoint
Expand Down
7 changes: 6 additions & 1 deletion tests/integration/container/test_aws_secrets_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,12 @@ def test_failover_with_secrets_manager(
props.update({
"plugins": "failover,aws_secrets_manager",
"secrets_manager_secret_id": secret_name,
"secrets_manager_region": region
"secrets_manager_region": region,
"socket_timeout": 10,
"connect_timeout": 10,
"monitoring-connect_timeout": 5,
"monitoring-socket_timeout": 5,
"topology_refresh_ms": 10,
})

with AwsWrapperConnection.connect(
Expand Down
186 changes: 147 additions & 39 deletions tests/integration/container/test_custom_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from aws_advanced_python_wrapper import AwsWrapperConnection
from aws_advanced_python_wrapper.errors import (FailoverSuccessError,
ReadWriteSplittingError)
from aws_advanced_python_wrapper.hostinfo import HostRole
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.properties import (Properties,
WrapperProperties)
Expand Down Expand Up @@ -195,76 +196,183 @@ def test_custom_endpoint_failover(self, test_driver: TestDriver, conn_utils, pro

conn.close()

def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes(
def _setup_custom_endpoint_role(self, target_driver_connect, conn_kwargs, rds_utils, host_role: HostRole):
self.logger.debug("Setting up custom endpoint instance with role: " + host_role.name)
props = {'plugins': ''}
original_writer = rds_utils.get_cluster_writer_instance_id()
failover_target = None
with AwsWrapperConnection.connect(target_driver_connect, **conn_kwargs, **props) as conn:
endpoint_members = self.endpoint_info["StaticMembers"]
original_instance_id = rds_utils.query_instance_id(conn)
self.logger.debug("Original instance id: " + original_instance_id)
assert original_instance_id in endpoint_members

if host_role == HostRole.WRITER:
if original_instance_id == original_writer:
self.logger.debug("Role is already " + host_role.name + ", no failover needed.")
return # Do nothing, no need to failover.
failover_target = original_instance_id
self.logger.debug("Failing over to get writer role...")
elif host_role == HostRole.READER:
if original_instance_id != original_writer:
self.logger.debug("Role is already " + host_role.name + ", no failover needed.")
return # Do nothing, no need to failover.
self.logger.debug("Failing over to get reader role...")

rds_utils.failover_cluster_and_wait_until_writer_changed(target_id=failover_target)

self.logger.debug("Verifying that new connection has role: " + host_role.name)
# Verify that new connection is now the correct role
with AwsWrapperConnection.connect(target_driver_connect, **conn_kwargs, **props) as conn:
endpoint_members = self.endpoint_info["StaticMembers"]
original_instance_id = rds_utils.query_instance_id(conn)
assert original_instance_id in endpoint_members

new_role = rds_utils.query_host_role(conn, TestEnvironment.get_current().get_engine())
assert new_role == host_role
self.logger.debug("Custom endpoint instance successfully set to role: " + host_role.name)

def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__with_reader_as_init_conn(
self, test_driver: TestDriver, conn_utils, props, rds_utils):
'''
Will test for the following scenario:
1. Initially connect to a reader instance via the custom endpoint.
2. Attempt to switch to writer instance - should fail since the custom endpoint only has the reader instance.
3. Modify the custom endpoint to add the writer instance as a static member.
4. Switch to writer instance - should succeed.
5. Switch back to reader instance - should succeed.
6. Modify the custom endpoint to remove the writer instance as a static member.
7. Attempt to switch to writer instance - should fail since the custom endpoint no longer has the writer instance.
'''
target_driver_connect = DriverHelper.get_connect_func(test_driver)
kwargs = conn_utils.get_connect_params()
kwargs["host"] = self.endpoint_info["Endpoint"]
# This setting is not required for the test, but it allows us to also test re-creation of expired monitors since
# it takes more than 30 seconds to modify the cluster endpoint (usually around 140s).
props["custom_endpoint_idle_monitor_expiration_ms"] = 30_000
props["wait_for_custom_endpoint_info_timeout_ms"] = 30_000
conn = AwsWrapperConnection.connect(target_driver_connect, **kwargs, **props)

# Ensure that we are starting with a reader connection
self._setup_custom_endpoint_role(target_driver_connect, kwargs, rds_utils, HostRole.READER)

conn = AwsWrapperConnection.connect(target_driver_connect, **kwargs, **props)
endpoint_members = self.endpoint_info["StaticMembers"]
original_instance_id = rds_utils.query_instance_id(conn)
assert original_instance_id in endpoint_members
original_reader_id = rds_utils.query_instance_id(conn)
assert original_reader_id in endpoint_members

# Attempt to switch to an instance of the opposite role. This should fail since the custom endpoint consists
# only of the current host.
new_read_only_value = original_instance_id == rds_utils.get_cluster_writer_instance_id()
if new_read_only_value:
# We are connected to the writer. Attempting to switch to the reader will not work but will intentionally
# not throw an exception. In this scenario we log a warning and purposefully stick with the writer.
self.logger.debug("Initial connection is to the writer. Attempting to switch to reader...")
conn.read_only = new_read_only_value
new_instance_id = rds_utils.query_instance_id(conn)
assert new_instance_id == original_instance_id
else:
# We are connected to the reader. Attempting to switch to the writer will throw an exception.
self.logger.debug("Initial connection is to a reader. Attempting to switch to writer...")
with pytest.raises(ReadWriteSplittingError):
conn.read_only = new_read_only_value
self.logger.debug("Initial connection is to a reader. Attempting to switch to writer...")
with pytest.raises(ReadWriteSplittingError):
conn.read_only = False

instances = TestEnvironment.get_current().get_instances()
writer_id = rds_utils.get_cluster_writer_instance_id()
if original_instance_id == writer_id:
new_member = instances[1].get_instance_id()
else:
new_member = writer_id

rds_client = client('rds', region_name=TestEnvironment.get_current().get_aurora_region())
rds_client.modify_db_cluster_endpoint(
DBClusterEndpointIdentifier=self.endpoint_id,
StaticMembers=[original_instance_id, new_member]
StaticMembers=[original_reader_id, writer_id]
)

try:
self.wait_until_endpoint_has_members(rds_client, {original_instance_id, new_member})
self.wait_until_endpoint_has_members(rds_client, {original_reader_id, writer_id})

# We should now be able to switch to new_member.
conn.read_only = new_read_only_value
# We should now be able to switch to writer.
conn.read_only = False
new_instance_id = rds_utils.query_instance_id(conn)
assert new_instance_id == new_member
assert new_instance_id == writer_id

# Switch back to original instance
conn.read_only = not new_read_only_value
conn.read_only = True
new_instance_id = rds_utils.query_instance_id(conn)
assert new_instance_id == original_reader_id
finally:
# Remove the writer from the custom endpoint.
rds_client.modify_db_cluster_endpoint(
DBClusterEndpointIdentifier=self.endpoint_id,
StaticMembers=[original_instance_id])
self.wait_until_endpoint_has_members(rds_client, {original_instance_id})
StaticMembers=[original_reader_id])
self.wait_until_endpoint_has_members(rds_client, {original_reader_id})

# We should not be able to switch again because new_member was removed from the custom endpoint.
if new_read_only_value:
# We are connected to the writer. Attempting to switch to the reader will not work but will intentionally
# not throw an exception. In this scenario we log a warning and purposefully stick with the writer.
conn.read_only = new_read_only_value
# We are connected to the reader. Attempting to switch to the writer will throw an exception.
with pytest.raises(ReadWriteSplittingError):
conn.read_only = False

conn.close()

def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__with_writer_as_init_conn(
self, test_driver: TestDriver, conn_utils, props, rds_utils):
'''
Will test for the following scenario:
1. Iniitially connect to the writer instance via the custom endpoint.
2. Attempt to switch to reader instance - should succeed, but will still use writer instance as reader.
3. Modify the custom endpoint to add a reader instance as a static member.
4. Switch to reader instance - should succeed.
5. Switch back to writer instance - should succeed.
6. Modify the custom endpoint to remove the reader instance as a static member.
7. Attempt to switch to reader instance - should fail since the custom endpoint no longer has the reader instance.
'''

target_driver_connect = DriverHelper.get_connect_func(test_driver)
kwargs = conn_utils.get_connect_params()
kwargs["host"] = self.endpoint_info["Endpoint"]
# This setting is not required for the test, but it allows us to also test re-creation of expired monitors since
# it takes more than 30 seconds to modify the cluster endpoint (usually around 140s).
props["custom_endpoint_idle_monitor_expiration_ms"] = 30_000
props["wait_for_custom_endpoint_info_timeout_ms"] = 30_000

# Ensure that we are starting with a writer connection
self._setup_custom_endpoint_role(target_driver_connect, kwargs, rds_utils, HostRole.WRITER)
conn = AwsWrapperConnection.connect(target_driver_connect, **kwargs, **props)

endpoint_members = self.endpoint_info["StaticMembers"]
original_writer_id = str(rds_utils.query_instance_id(conn))
assert original_writer_id in endpoint_members

# We are connected to the writer. Attempting to switch to the reader will not work but will intentionally
# not throw an exception. In this scenario we log a warning and purposefully stick with the writer.
self.logger.debug("Initial connection is to the writer. Attempting to switch to reader...")
conn.read_only = True
new_instance_id = rds_utils.query_instance_id(conn)
assert new_instance_id == original_writer_id

instances = TestEnvironment.get_current().get_instances()
writer_id = str(rds_utils.get_cluster_writer_instance_id())

reader_id_to_add = ""
# Get any reader id
for instance in instances:
if instance.get_instance_id() != writer_id:
reader_id_to_add = instance.get_instance_id()
break

rds_client = client('rds', region_name=TestEnvironment.get_current().get_aurora_region())
rds_client.modify_db_cluster_endpoint(
DBClusterEndpointIdentifier=self.endpoint_id,
StaticMembers=[original_writer_id, reader_id_to_add]
)

try:
self.wait_until_endpoint_has_members(rds_client, {original_writer_id, reader_id_to_add})
# We should now be able to switch to new_member.
conn.read_only = True
new_instance_id = rds_utils.query_instance_id(conn)
assert new_instance_id == original_instance_id
else:
# We are connected to the reader. Attempting to switch to the writer will throw an exception.
with pytest.raises(ReadWriteSplittingError):
conn.read_only = new_read_only_value
assert new_instance_id == reader_id_to_add

# Switch back to original instance
conn.read_only = False
finally:
# Remove the reader from the custom endpoint.
rds_client.modify_db_cluster_endpoint(
DBClusterEndpointIdentifier=self.endpoint_id,
StaticMembers=[original_writer_id])
self.wait_until_endpoint_has_members(rds_client, {original_writer_id})

# We should not be able to switch again because new_member was removed from the custom endpoint.
# We are connected to the writer. Attempting to switch to the reader will not work but will intentionally
# not throw an exception. In this scenario we log a warning and fallback to the writer.
conn.read_only = True
new_instance_id = rds_utils.query_instance_id(conn)
assert new_instance_id == original_writer_id

conn.close()
20 changes: 20 additions & 0 deletions tests/integration/container/utils/rds_test_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from aws_advanced_python_wrapper.driver_info import DriverInfo
from aws_advanced_python_wrapper.errors import UnsupportedOperationError
from aws_advanced_python_wrapper.hostinfo import HostRole
from aws_advanced_python_wrapper.utils.log import Logger
from aws_advanced_python_wrapper.utils.messages import Messages
from .database_engine import DatabaseEngine
Expand Down Expand Up @@ -255,6 +256,25 @@ def query_instance_id(
raise RuntimeError(Messages.get_formatted(
"RdsTestUtility.MethodNotSupportedForDeployment", "query_instance_id", database_deployment))

def query_host_role(
self,
conn,
database_engine: DatabaseEngine) -> HostRole:
if database_engine == DatabaseEngine.MYSQL:
is_reader_query = "SELECT @@innodb_read_only"
elif database_engine == DatabaseEngine.PG:
is_reader_query = "SELECT pg_catalog.pg_is_in_recovery()"

with closing(conn.cursor()) as cursor:
cursor.execute(is_reader_query)
record = cursor.fetchone()
is_reader = record[0]

if is_reader in (1, True):
return HostRole.READER
else:
return HostRole.WRITER

def _query_aurora_instance_id(self, conn: Connection, engine: DatabaseEngine) -> str:
if engine == DatabaseEngine.MYSQL:
sql = "SELECT @@aurora_server_id"
Expand Down
Loading