Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
266 changes: 221 additions & 45 deletions cassandra/cluster.py

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions cassandra/datastax/insights/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def initialize_registry(insights_registry):
DCAwareRoundRobinPolicy,
TokenAwarePolicy,
WhiteListRoundRobinPolicy,
DynamicWhiteListRoundRobinPolicy,
HostFilterPolicy,
ConstantReconnectionPolicy,
ExponentialReconnectionPolicy,
Expand Down Expand Up @@ -80,6 +81,13 @@ def whitelist_round_robin_policy_insights_serializer(policy):
'options': {'allowed_hosts': policy._allowed_hosts}
}

@insights_registry.register_serializer_for(DynamicWhiteListRoundRobinPolicy)
def dynamic_whitelist_round_robin_policy_insights_serializer(policy):
return {'type': policy.__class__.__name__,
'namespace': namespace(policy.__class__),
'options': {'allowed_host_ids': tuple(str(host_id) for host_id in policy._allowed_host_ids)}
}

@insights_registry.register_serializer_for(HostFilterPolicy)
def host_filter_policy_insights_serializer(policy):
return {
Expand Down
71 changes: 71 additions & 0 deletions cassandra/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,16 @@ def check_supported(self):
"""
pass

def on_control_connection_host(self, host):
"""
Called when the control connection resolves the metadata host behind
the endpoint it is currently using.

Policies that maintain a dynamic host allowlist can override this to
update their internal view of the cluster.
"""
pass


class RoundRobinPolicy(LoadBalancingPolicy):
"""
Expand Down Expand Up @@ -540,6 +550,9 @@ def on_add(self, *args, **kwargs):
def on_remove(self, *args, **kwargs):
return self._child_policy.on_remove(*args, **kwargs)

def on_control_connection_host(self, host):
return self._child_policy.on_control_connection_host(host)


class WhiteListRoundRobinPolicy(RoundRobinPolicy):
"""
Expand Down Expand Up @@ -594,6 +607,58 @@ def on_add(self, host):
RoundRobinPolicy.on_add(self, host)


class DynamicWhiteListRoundRobinPolicy(RoundRobinPolicy):
"""
A :class:`.RoundRobinPolicy` variant whose allowlist is updated from the
control connection.

This is intended for proxy deployments where the driver can only reach the
host currently behind the control connection endpoint. The policy keeps
every other discovered node at :attr:`~.HostDistance.IGNORED` until the
control connection resolves a different host.
"""

def __init__(self):
self._allowed_host_ids = frozenset(())
self._cluster = None
RoundRobinPolicy.__init__(self)

def _host_is_allowed(self, host):
return getattr(host, "host_id", None) in self._allowed_host_ids

def _refresh_live_hosts(self, hosts):
self._live_hosts = frozenset(
host for host in hosts
if self._host_is_allowed(host) and host.is_up is not False
)

def populate(self, cluster, hosts):
self._cluster = cluster
self._refresh_live_hosts(hosts)
if len(self._live_hosts) > 1:
self._position = randint(0, len(self._live_hosts) - 1)
else:
self._position = 0

def distance(self, host):
return HostDistance.LOCAL if self._host_is_allowed(host) else HostDistance.IGNORED

def on_up(self, host):
if self._host_is_allowed(host):
RoundRobinPolicy.on_up(self, host)

def on_add(self, host):
if self._host_is_allowed(host):
RoundRobinPolicy.on_add(self, host)

def on_control_connection_host(self, host):
with self._hosts_lock:
allowed_host_id = getattr(host, "host_id", None)
self._allowed_host_ids = frozenset((allowed_host_id,)) if allowed_host_id is not None else frozenset(())
if self._cluster is not None:
self._refresh_live_hosts(self._cluster.metadata.all_hosts())


class HostFilterPolicy(LoadBalancingPolicy):
"""
A :class:`.LoadBalancingPolicy` subclass configured with a child policy,
Expand Down Expand Up @@ -654,6 +719,9 @@ def on_add(self, host, *args, **kwargs):
def on_remove(self, host, *args, **kwargs):
return self._child_policy.on_remove(host, *args, **kwargs)

def on_control_connection_host(self, host):
return self._child_policy.on_control_connection_host(host)

@property
def predicate(self):
"""
Expand Down Expand Up @@ -1322,6 +1390,9 @@ def on_add(self, *args, **kwargs):
def on_remove(self, *args, **kwargs):
return self._child_policy.on_remove(*args, **kwargs)

def on_control_connection_host(self, host):
return self._child_policy.on_control_connection_host(host)


class DefaultLoadBalancingPolicy(WrapperPolicy):
"""
Expand Down
15 changes: 7 additions & 8 deletions cassandra/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def __init__(self, host, host_distance, session):

if self._keyspace:
first_connection.set_keyspace_blocking(self._keyspace)
if first_connection.features.sharding_info and not self._session.cluster.shard_aware_options.disable:
if first_connection.features.sharding_info and not self._session.is_shard_aware_disabled():
self.host.sharding_info = first_connection.features.sharding_info
self._open_connections_for_all_shards(first_connection.features.shard_id)
self.tablets_routing_v1 = first_connection.features.tablets_routing_v1
Expand All @@ -451,7 +451,7 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table
raise NoConnectionsAvailable()

shard_id = None
if not self._session.cluster.shard_aware_options.disable and self.host.sharding_info and routing_key:
if not self._session.is_shard_aware_disabled() and self.host.sharding_info and routing_key:
t = self._session.cluster.metadata.token_map.token_class.from_key(routing_key)

shard_id = None
Expand Down Expand Up @@ -554,15 +554,15 @@ def return_connection(self, connection, stream_was_orphaned=False):
if not connection.signaled_error:
log.debug("Defunct or closed connection (%s) returned to pool, potentially "
"marking host %s as down", id(connection), self.host)
is_down = self.host.signal_connection_failure(connection.last_error)
is_down = self._session._signal_connection_failure(self.host, connection.last_error)
connection.signaled_error = True

if self.shutdown_on_error and not is_down:
is_down = True

if is_down:
self.shutdown()
self._session.cluster.on_down(self.host, is_host_addition=False)
self._session._handle_pool_down(self.host, is_host_addition=False)
else:
connection.close()
with self._lock:
Expand Down Expand Up @@ -603,7 +603,7 @@ def _replace(self, connection):
try:
if connection.features.shard_id in self._connections:
del self._connections[connection.features.shard_id]
if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable:
if self.host.sharding_info and not self._session.is_shard_aware_disabled():
self._connecting.add(connection.features.shard_id)
self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id)
else:
Expand Down Expand Up @@ -678,7 +678,8 @@ def disable_advanced_shard_aware(self, secs):

def _get_shard_aware_endpoint(self):
if (self.advanced_shardaware_block_until and self.advanced_shardaware_block_until > time.time()) or \
self._session.cluster.shard_aware_options.disable_shardaware_port:
self._session.cluster.shard_aware_options.disable_shardaware_port or \
self._session.is_shard_aware_disabled():
return None

endpoint = None
Expand Down Expand Up @@ -920,5 +921,3 @@ def open_count(self):
@property
def _excess_connection_limit(self):
return self.host.sharding_info.shards_count * self.max_excess_connections_per_shard_multiplier


3 changes: 3 additions & 0 deletions docs/api/cassandra/policies.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ Load Balancing
.. autoclass:: WhiteListRoundRobinPolicy
:members:

.. autoclass:: DynamicWhiteListRoundRobinPolicy
:members:

.. autoclass:: TokenAwarePolicy
:members:

Expand Down
77 changes: 75 additions & 2 deletions tests/integration/standard/test_client_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@

from cassandra.cluster import Cluster
from cassandra.client_routes import ClientRoutesConfig, ClientRouteProxy
from cassandra.connection import ClientRoutesEndPoint
from cassandra.policies import RoundRobinPolicy
from cassandra.connection import ClientRoutesEndPoint, ConnectionException
from cassandra.policies import DynamicWhiteListRoundRobinPolicy, RoundRobinPolicy
from tests.integration import (
TestCluster,
get_cluster,
Expand All @@ -54,6 +54,28 @@

log = logging.getLogger(__name__)


class ProxyOnlyReachableConnection(Cluster.connection_class):
"""
Simulates a private-link client that can reach only the proxy endpoint.

The CCM node addresses are reachable from the local test runner, which means
the existing client-routes tests cannot reproduce bugs that only appear when
direct node IPs are private. This connection class rejects those direct node
addresses while still allowing the NLB address.
"""

@classmethod
def factory(cls, endpoint, timeout, host_conn=None, *args, **kwargs):
address, _ = endpoint.resolve()
if address.startswith("127.0.0."):
raise ConnectionException(
"Simulated private node address %s is unreachable from the client" % address,
endpoint=endpoint,
)
return super().factory(endpoint, timeout, host_conn=host_conn, *args, **kwargs)


class TcpProxy:
"""
A simple TCP proxy that forwards connections from a local listen port
Expand Down Expand Up @@ -535,6 +557,57 @@ def teardown_module():
else:
os.environ['SCYLLA_EXT_OPTS'] = _saved_scylla_ext_opts


class TestProxyConnectivityWithoutClientRoutes(unittest.TestCase):
"""
Reproducer for connecting through a generic proxy when node addresses are
not reachable from the client.

The initial control connection can reach the cluster through the proxy, but
the driver later tries to open pools to the discovered node addresses
directly. In a proxy-only environment that makes connect/query fail.
"""

@classmethod
def setUpClass(cls):
cls.node_addrs = {
1: "127.0.0.1",
2: "127.0.0.2",
3: "127.0.0.3",
}
cls.proxy_node_id = 1
cls.nlb = NLBEmulator()
cls.nlb.start(cls.node_addrs)

@classmethod
def tearDownClass(cls):
cls.nlb.stop()

def _make_proxy_cluster(self):
return Cluster(
contact_points=[NLBEmulator.LISTEN_HOST],
port=self.nlb.node_port(self.proxy_node_id),
connection_class=ProxyOnlyReachableConnection,
load_balancing_policy=DynamicWhiteListRoundRobinPolicy(),
)

def test_dynamic_whitelist_session_succeeds_when_only_proxy_is_reachable(self):
cluster = self._make_proxy_cluster()
self.addCleanup(cluster.shutdown)

session = cluster.connect()
row = session.execute(
"SELECT release_version FROM system.local WHERE key='local'"
).one()

self.assertIsNotNone(row)
pool_state = session.get_pool_state()
self.assertEqual(len(pool_state), 1)

session_host = next(iter(pool_state))
self.assertEqual(session_host.endpoint.address, NLBEmulator.LISTEN_HOST)
self.assertEqual(session_host.endpoint.port, self.nlb.node_port(self.proxy_node_id))

@skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported',
scylla_version="2026.1.0")
class TestGetHostPortMapping(unittest.TestCase):
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/advanced/test_insights.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import logging
import sys
import uuid
from unittest.mock import sentinel

from cassandra import ConsistencyLevel
Expand All @@ -37,6 +38,7 @@
DCAwareRoundRobinPolicy,
TokenAwarePolicy,
WhiteListRoundRobinPolicy,
DynamicWhiteListRoundRobinPolicy,
HostFilterPolicy,
ConstantReconnectionPolicy,
ExponentialReconnectionPolicy,
Expand Down Expand Up @@ -203,6 +205,14 @@ def test_whitelist_round_robin_policy(self):
'options': {'allowed_hosts': ('127.0.0.3',)},
'type': 'WhiteListRoundRobinPolicy'}

def test_dynamic_whitelist_round_robin_policy(self):
policy = DynamicWhiteListRoundRobinPolicy()
host_id = uuid.uuid4()
policy._allowed_host_ids = (host_id,)
assert insights_registry.serialize(policy) == {'namespace': 'cassandra.policies',
'options': {'allowed_host_ids': (str(host_id),)},
'type': 'DynamicWhiteListRoundRobinPolicy'}

def test_host_filter_policy(self):
def my_predicate(s):
return False
Expand Down
Loading
Loading