diff --git a/cassandra/pool.py b/cassandra/pool.py index 227e1b5315..9e949c342c 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -677,7 +677,7 @@ def disable_advanced_shard_aware(self, secs): self.advanced_shardaware_block_until = max(time.time() + secs, self.advanced_shardaware_block_until) def _get_shard_aware_endpoint(self): - if (self.advanced_shardaware_block_until and self.advanced_shardaware_block_until < time.time()) or \ + if (self.advanced_shardaware_block_until and self.advanced_shardaware_block_until > time.time()) or \ self._session.cluster.shard_aware_options.disable_shardaware_port: return None diff --git a/tests/unit/test_shard_aware.py b/tests/unit/test_shard_aware.py index e7d26ae207..4b4c2c138d 100644 --- a/tests/unit/test_shard_aware.py +++ b/tests/unit/test_shard_aware.py @@ -15,6 +15,7 @@ import unittest import logging +import time from unittest.mock import MagicMock from concurrent.futures import ThreadPoolExecutor @@ -27,6 +28,45 @@ LOGGER = logging.getLogger(__name__) +class MockSession(MagicMock): + is_shutdown = False + keyspace = "ks1" + + def __init__(self, is_ssl=False, *args, **kwargs): + super(MockSession, self).__init__(*args, **kwargs) + self.cluster = MagicMock() + if is_ssl: + self.cluster.ssl_options = {'some_ssl_options': True} + else: + self.cluster.ssl_options = None + self.cluster.shard_aware_options = ShardAwareOptions() + self.cluster.executor = ThreadPoolExecutor(max_workers=2) + self.cluster.signal_connection_failure = lambda *args, **kwargs: False + self.cluster.connection_factory = self.mock_connection_factory + self.connection_counter = 0 + self.futures = [] + + def submit(self, fn, *args, **kwargs): + logging.info("Scheduling %s with args: %s, kwargs: %s", fn, args, kwargs) + if not self.is_shutdown: + f = self.cluster.executor.submit(fn, *args, **kwargs) + self.futures += [f] + return f + + def mock_connection_factory(self, *args, **kwargs): + connection = MagicMock() + connection.is_shutdown = False + connection.is_defunct = False + connection.is_closed = False + connection.orphaned_threshold_reached = False + connection.endpoint = args[0] + sharding_info = ShardingInfo(shard_id=1, shards_count=4, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, shard_aware_port_ssl=19045) + connection.features = ProtocolFeatures(shard_id=kwargs.get('shard_id', self.connection_counter), sharding_info=sharding_info) + self.connection_counter += 1 + + return connection + + class TestShardAware(unittest.TestCase): def test_parsing_and_calculating_shard_id(self): """ @@ -55,58 +95,58 @@ def test_advanced_shard_aware_port(self): Test that on given a `shard_aware_port` on the OPTIONS message (ShardInfo class) the next connections would be open using this port """ - class MockSession(MagicMock): - is_shutdown = False - keyspace = "ks1" - - def __init__(self, is_ssl=False, *args, **kwargs): - super(MockSession, self).__init__(*args, **kwargs) - self.cluster = MagicMock() - if is_ssl: - self.cluster.ssl_options = {'some_ssl_options': True} - else: - self.cluster.ssl_options = None - self.cluster.shard_aware_options = ShardAwareOptions() - self.cluster.executor = ThreadPoolExecutor(max_workers=2) - self.cluster.signal_connection_failure = lambda *args, **kwargs: False - self.cluster.connection_factory = self.mock_connection_factory - self.connection_counter = 0 - self.futures = [] - - def submit(self, fn, *args, **kwargs): - logging.info("Scheduling %s with args: %s, kwargs: %s", fn, args, kwargs) - if not self.is_shutdown: - f = self.cluster.executor.submit(fn, *args, **kwargs) - self.futures += [f] - return f - - def mock_connection_factory(self, *args, **kwargs): - connection = MagicMock() - connection.is_shutdown = False - connection.is_defunct = False - connection.is_closed = False - connection.orphaned_threshold_reached = False - connection.endpoint = args[0] - sharding_info = ShardingInfo(shard_id=1, shards_count=4, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, shard_aware_port_ssl=19045) - connection.features = ProtocolFeatures(shard_id=kwargs.get('shard_id', self.connection_counter), sharding_info=sharding_info) - self.connection_counter += 1 - - return connection - host = MagicMock() host.endpoint = DefaultEndPoint("1.2.3.4") for port, is_ssl in [(19042, False), (19045, True)]: session = MockSession(is_ssl=is_ssl) pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session) - for f in session.futures: - f.result() - assert len(pool._connections) == 4 - for shard_id, connection in pool._connections.items(): - assert connection.features.shard_id == shard_id - if shard_id == 0: - assert connection.endpoint == DefaultEndPoint("1.2.3.4") - else: - assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port) - - session.cluster.executor.shutdown(wait=True) + try: + for f in session.futures: + f.result() + assert len(pool._connections) == 4 + for shard_id, connection in pool._connections.items(): + assert connection.features.shard_id == shard_id + if shard_id == 0: + assert connection.endpoint == DefaultEndPoint("1.2.3.4") + else: + assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port) + finally: + session.cluster.executor.shutdown(wait=True) + + def test_advanced_shard_aware_cooldown(self): + """ + `disable_advanced_shard_aware` must suppress the shard-aware endpoint for + the duration of the cool-down window, then automatically restore it once + the deadline has passed. The hard-disable flag must suppress the endpoint + unconditionally. + """ + host = MagicMock() + host.endpoint = DefaultEndPoint("1.2.3.4") + session = MockSession(is_ssl=False) + + pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session) + for f in session.futures: + f.result() + + try: + # Baseline: shard-aware port is returned. + endpoint = pool._get_shard_aware_endpoint() + assert endpoint is not None + assert endpoint.port == 19042 + + # During the cool-down window `_get_shard_aware_endpoint` must return None. + pool.disable_advanced_shard_aware(600) + assert pool._get_shard_aware_endpoint() is None + + # Once the deadline has passed, the shard-aware port must be used again. + pool.advanced_shardaware_block_until = time.time() - 1 + endpoint = pool._get_shard_aware_endpoint() + assert endpoint is not None + assert endpoint.port == 19042 + + # The hard-disable flag must suppress the endpoint regardless of the timer. + session.cluster.shard_aware_options.disable_shardaware_port = True + assert pool._get_shard_aware_endpoint() is None + finally: + session.cluster.executor.shutdown(wait=True)