Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cassandra/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def disable_advanced_shard_aware(self, secs):
self.advanced_shardaware_block_until = max(time.time() + secs, self.advanced_shardaware_block_until)

def _get_shard_aware_endpoint(self):
if (self.advanced_shardaware_block_until and self.advanced_shardaware_block_until < time.time()) or \
if (self.advanced_shardaware_block_until and self.advanced_shardaware_block_until > time.time()) or \
self._session.cluster.shard_aware_options.disable_shardaware_port:
return None

Expand Down
138 changes: 89 additions & 49 deletions tests/unit/test_shard_aware.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest

import logging
import time
from unittest.mock import MagicMock
from concurrent.futures import ThreadPoolExecutor

Expand All @@ -27,6 +28,45 @@
LOGGER = logging.getLogger(__name__)


class MockSession(MagicMock):
is_shutdown = False
keyspace = "ks1"

def __init__(self, is_ssl=False, *args, **kwargs):
super(MockSession, self).__init__(*args, **kwargs)
self.cluster = MagicMock()
if is_ssl:
self.cluster.ssl_options = {'some_ssl_options': True}
else:
self.cluster.ssl_options = None
self.cluster.shard_aware_options = ShardAwareOptions()
self.cluster.executor = ThreadPoolExecutor(max_workers=2)
self.cluster.signal_connection_failure = lambda *args, **kwargs: False
self.cluster.connection_factory = self.mock_connection_factory
self.connection_counter = 0
self.futures = []

def submit(self, fn, *args, **kwargs):
logging.info("Scheduling %s with args: %s, kwargs: %s", fn, args, kwargs)
if not self.is_shutdown:
f = self.cluster.executor.submit(fn, *args, **kwargs)
self.futures += [f]
return f

def mock_connection_factory(self, *args, **kwargs):
connection = MagicMock()
connection.is_shutdown = False
connection.is_defunct = False
connection.is_closed = False
connection.orphaned_threshold_reached = False
connection.endpoint = args[0]
sharding_info = ShardingInfo(shard_id=1, shards_count=4, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, shard_aware_port_ssl=19045)
connection.features = ProtocolFeatures(shard_id=kwargs.get('shard_id', self.connection_counter), sharding_info=sharding_info)
self.connection_counter += 1

return connection


class TestShardAware(unittest.TestCase):
def test_parsing_and_calculating_shard_id(self):
"""
Expand Down Expand Up @@ -55,58 +95,58 @@ def test_advanced_shard_aware_port(self):
Test that on given a `shard_aware_port` on the OPTIONS message (ShardInfo class)
the next connections would be open using this port
"""
class MockSession(MagicMock):
is_shutdown = False
keyspace = "ks1"

def __init__(self, is_ssl=False, *args, **kwargs):
super(MockSession, self).__init__(*args, **kwargs)
self.cluster = MagicMock()
if is_ssl:
self.cluster.ssl_options = {'some_ssl_options': True}
else:
self.cluster.ssl_options = None
self.cluster.shard_aware_options = ShardAwareOptions()
self.cluster.executor = ThreadPoolExecutor(max_workers=2)
self.cluster.signal_connection_failure = lambda *args, **kwargs: False
self.cluster.connection_factory = self.mock_connection_factory
self.connection_counter = 0
self.futures = []

def submit(self, fn, *args, **kwargs):
logging.info("Scheduling %s with args: %s, kwargs: %s", fn, args, kwargs)
if not self.is_shutdown:
f = self.cluster.executor.submit(fn, *args, **kwargs)
self.futures += [f]
return f

def mock_connection_factory(self, *args, **kwargs):
connection = MagicMock()
connection.is_shutdown = False
connection.is_defunct = False
connection.is_closed = False
connection.orphaned_threshold_reached = False
connection.endpoint = args[0]
sharding_info = ShardingInfo(shard_id=1, shards_count=4, partitioner="", sharding_algorithm="", sharding_ignore_msb=0, shard_aware_port=19042, shard_aware_port_ssl=19045)
connection.features = ProtocolFeatures(shard_id=kwargs.get('shard_id', self.connection_counter), sharding_info=sharding_info)
self.connection_counter += 1

return connection

host = MagicMock()
host.endpoint = DefaultEndPoint("1.2.3.4")

for port, is_ssl in [(19042, False), (19045, True)]:
session = MockSession(is_ssl=is_ssl)
pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session)
for f in session.futures:
f.result()
assert len(pool._connections) == 4
for shard_id, connection in pool._connections.items():
assert connection.features.shard_id == shard_id
if shard_id == 0:
assert connection.endpoint == DefaultEndPoint("1.2.3.4")
else:
assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port)

session.cluster.executor.shutdown(wait=True)
try:
for f in session.futures:
f.result()
assert len(pool._connections) == 4
for shard_id, connection in pool._connections.items():
assert connection.features.shard_id == shard_id
if shard_id == 0:
assert connection.endpoint == DefaultEndPoint("1.2.3.4")
else:
assert connection.endpoint == DefaultEndPoint("1.2.3.4", port=port)
finally:
session.cluster.executor.shutdown(wait=True)

def test_advanced_shard_aware_cooldown(self):
"""
`disable_advanced_shard_aware` must suppress the shard-aware endpoint for
the duration of the cool-down window, then automatically restore it once
the deadline has passed. The hard-disable flag must suppress the endpoint
unconditionally.
"""
host = MagicMock()
host.endpoint = DefaultEndPoint("1.2.3.4")
session = MockSession(is_ssl=False)

pool = HostConnection(host=host, host_distance=HostDistance.REMOTE, session=session)
for f in session.futures:
f.result()

try:
# Baseline: shard-aware port is returned.
endpoint = pool._get_shard_aware_endpoint()
assert endpoint is not None
assert endpoint.port == 19042

# During the cool-down window `_get_shard_aware_endpoint` must return None.
pool.disable_advanced_shard_aware(600)
assert pool._get_shard_aware_endpoint() is None

# Once the deadline has passed, the shard-aware port must be used again.
pool.advanced_shardaware_block_until = time.time() - 1
endpoint = pool._get_shard_aware_endpoint()
assert endpoint is not None
assert endpoint.port == 19042

# The hard-disable flag must suppress the endpoint regardless of the timer.
session.cluster.shard_aware_options.disable_shardaware_port = True
assert pool._get_shard_aware_endpoint() is None
finally:
session.cluster.executor.shutdown(wait=True)
Loading