diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 9eace8810d..d601146aff 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -549,6 +549,10 @@ def on_remove(self, host): for p in self.profiles.values(): p.load_balancing_policy.on_remove(host) + def on_control_connection_host(self, host): + for p in self.profiles.values(): + p.load_balancing_policy.on_control_connection_host(host) + @property def default(self): """ @@ -1674,6 +1678,7 @@ def add_execution_profile(self, name, profile, pool_wait_timeout=5): self.profile_manager.profiles[name] = profile profile.load_balancing_policy.populate(self, self.metadata.all_hosts()) + profile.load_balancing_policy.on_control_connection_host(self.get_control_connection_host()) # on_up after populate allows things like DCA LBP to choose default local dc for host in filter(lambda h: h.is_up, self.metadata.all_hosts()): profile.load_balancing_policy.on_up(host) @@ -1746,6 +1751,17 @@ def connect(self, keyspace=None, wait_for_all_pools=False): established or attempted. Default is `False`, which means it will return when the first successful connection is established. Remaining pools are added asynchronously. """ + self._ensure_core_connections_setup() + + session = self._new_session(keyspace) + if wait_for_all_pools: + wait_futures(session._initial_connect_futures) + + self._set_default_dbaas_consistency(session) + + return session + + def _ensure_core_connections_setup(self): with self._lock: if self.is_shutdown: raise DriverException("Cluster is already shut down") @@ -1777,14 +1793,6 @@ def connect(self, keyspace=None, wait_for_all_pools=False): ) self._is_setup = True - session = self._new_session(keyspace) - if wait_for_all_pools: - wait_futures(session._initial_connect_futures) - - self._set_default_dbaas_consistency(session) - - return session - def _set_default_dbaas_consistency(self, session): if session.cluster.metadata.dbaas: for profile in self.profile_manager.profiles.values(): @@ -1805,14 +1813,18 @@ def get_all_pools(self): pools.extend(s.get_pools()) return pools + def _get_shard_aware_pools(self): + return [pool for pool in self.get_all_pools() if pool.host.sharding_info is not None] + def is_shard_aware(self): - return bool(self.get_all_pools()[0].host.sharding_info) + return bool(self._get_shard_aware_pools()) def shard_aware_stats(self): - if self.is_shard_aware(): + shard_aware_pools = self._get_shard_aware_pools() + if shard_aware_pools: return {str(pool.host.endpoint): {'shards_count': pool.host.sharding_info.shards_count, 'connected': len(pool._connections.keys())} - for pool in self.get_all_pools()} + for pool in shard_aware_pools} def shutdown(self): """ @@ -1857,11 +1869,88 @@ def _new_session(self, keyspace): self.sessions.add(session) return session + def _default_control_connection_endpoint_targets_host( + self, host, endpoint, attempts=3): + for _ in range(attempts): + connected_host_id = self._get_host_id_for_endpoint(endpoint) + if connected_host_id != host.host_id: + return False + return True + + def _get_host_id_for_endpoint(self, endpoint): + connection = None + try: + connection = self.connection_factory(endpoint) + response = connection.wait_for_response( + QueryMessage( + query="SELECT host_id FROM system.local WHERE key='local'", + consistency_level=ConsistencyLevel.ONE), + timeout=self.connect_timeout) + rows = dict_factory(response.column_names, response.parsed_rows) + if not rows: + return None + return rows[0].get("host_id") + except Exception: + log.debug( + "Failed verifying control connection endpoint %s", endpoint, + exc_info=True) + return None + finally: + if connection: + connection.close() + + def _get_control_connection_host_endpoint(self, control_host, connection_endpoint): + if connection_endpoint is not None and ( + connection_endpoint == control_host.endpoint or + not isinstance(connection_endpoint, DefaultEndPoint)): + return connection_endpoint + + host_endpoint = control_host.endpoint + if host_endpoint is not None and not isinstance(host_endpoint, DefaultEndPoint): + return host_endpoint + + if connection_endpoint is not None and self._default_control_connection_endpoint_targets_host( + control_host, connection_endpoint): + return connection_endpoint + + if connection_endpoint is not None: + return connection_endpoint + + return host_endpoint + def _session_register_user_types(self, session): for keyspace, type_map in self._user_types.items(): for udt_name, klass in type_map.items(): session.user_type_registered(keyspace, udt_name, klass) + def _update_host_endpoint(self, host, endpoint): + if host.endpoint == endpoint: + return + + was_up = host.is_up + reconnector = host.get_and_set_reconnection_handler(None) + if reconnector: + reconnector.cancel() + + self.profile_manager.on_down(host) + for session in tuple(self.sessions): + session.remove_pool(host) + if was_up: + for listener in self.listeners: + listener.on_down(host) + + old_endpoint = host.endpoint + host.endpoint = endpoint + self.metadata.update_host(host, old_endpoint) + if was_up: + self.profile_manager.on_up(host) + for session in tuple(self.sessions): + session.add_or_renew_pool(host, is_host_addition=False) + for listener in self.listeners: + listener.on_up(host) + else: + self._start_reconnector(host, is_host_addition=False) + def _cleanup_failed_on_up_handling(self, host): self.profile_manager.on_down(host) self.control_connection.on_down(host) @@ -1954,12 +2043,16 @@ def on_up(self, host): futures_lock = Lock() futures_results = [] callback = partial(self._on_up_future_completed, host, futures, futures_results, futures_lock) + callback_futures = [] for session in tuple(self.sessions): future = session.add_or_renew_pool(host, is_host_addition=False) if future is not None: have_future = True - future.add_done_callback(callback) futures.add(future) + callback_futures.append(future) + + for future in callback_futures: + future.add_done_callback(callback) except Exception: log.exception("Unexpected failure handling node %s being marked up:", host) for future in futures: @@ -2030,8 +2123,7 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): if self._discount_down_events and self.profile_manager.distance(host) != HostDistance.IGNORED: connected = False for session in tuple(self.sessions): - pool_states = session.get_pool_state() - pool_state = pool_states.get(host) + pool_state = session.get_pool_state_for_host(host) if pool_state: connected |= pool_state['open_count'] > 0 if connected: @@ -2220,7 +2312,18 @@ def get_control_connection_host(self): """ connection = self.control_connection._connection endpoint = connection.endpoint if connection else None - return self.metadata.get_host(endpoint) if endpoint else None + if not endpoint: + return None + + host = self.metadata.get_host(endpoint) + if host is not None: + return host + + host_id = self.control_connection._current_host_id + if host_id is None: + return None + + return self.metadata.get_host_by_host_id(host_id) def refresh_schema_metadata(self, max_schema_agreement_wait=None): """ @@ -2924,8 +3027,11 @@ def _on_analytics_master_result(self, response, master_future, query_future): delimiter_index = addr.rfind(':') # assumes : - not robust, but that's what is being provided if delimiter_index > 0: addr = addr[:delimiter_index] - targeted_query = HostTargetingStatement(query_future.query, addr) - query_future.query_plan = query_future._load_balancer.make_query_plan(self.keyspace, targeted_query) + if query_future._host is not None: + query_future.query_plan = iter([query_future._host]) + else: + targeted_query = HostTargetingStatement(query_future.query, addr) + query_future.query_plan = query_future._load_balancer.make_query_plan(self.keyspace, targeted_query) except Exception: log.debug("Failed querying analytics master (request might not be routed optimally). " "Make sure the session is connecting to a graph analytics datacenter.", exc_info=True) @@ -3245,15 +3351,17 @@ def run_add_or_renew_pool(): new_pool = HostConnection(host, distance, self) except AuthenticationFailed as auth_exc: conn_exc = ConnectionException(str(auth_exc), endpoint=host) - self.cluster.signal_connection_failure(host, conn_exc, is_host_addition) + if self._signal_connection_failure(host, conn_exc): + self._handle_pool_down(host, is_host_addition) return False except Exception as conn_exc: log.warning("Failed to create connection pool for new host %s:", host, exc_info=conn_exc) # the host itself will still be marked down, so we need to pass # a special flag to make sure the reconnector is created - self.cluster.signal_connection_failure( - host, conn_exc, is_host_addition, expect_host_to_be_down=True) + if self._signal_connection_failure(host, conn_exc): + self._handle_pool_down( + host, is_host_addition, expect_host_to_be_down=True) return False previous = self._pools.get(host) @@ -3271,7 +3379,7 @@ def callback(pool, errors): set_keyspace_event.wait(self.cluster.connect_timeout) if not set_keyspace_event.is_set() or errors_returned: log.warning("Failed setting keyspace for pool after keyspace changed during connect: %s", errors_returned) - self.cluster.on_down(host, is_host_addition) + self._handle_pool_down(host, is_host_addition) new_pool.shutdown() self._lock.acquire() return False @@ -3410,9 +3518,21 @@ def submit(self, fn, *args, **kwargs): def get_pool_state(self): return dict((host, pool.get_state()) for host, pool in tuple(self._pools.items())) + def get_pool_state_for_host(self, host): + return self.get_pool_state().get(host) + def get_pools(self): return self._pools.values() + def _signal_connection_failure(self, host, connection_exc): + return host.signal_connection_failure(connection_exc) + + def _handle_pool_down(self, host, is_host_addition, expect_host_to_be_down=False): + self.cluster.on_down(host, is_host_addition, expect_host_to_be_down) + + def is_shard_aware_disabled(self): + return self.cluster.shard_aware_options.disable + def _validate_set_legacy_config(self, attr_name, value): if self.cluster._config_mode == _ConfigMode.PROFILES: raise ValueError("Cannot set Session.%s while using Configuration Profiles. Set this in a profile instead." % (attr_name,)) @@ -3545,6 +3665,7 @@ def __init__(self, cluster, timeout, self._reconnection_handler = None self._reconnection_lock = RLock() + self._current_host_id = None self._event_schedule_times = {} @@ -3569,6 +3690,9 @@ def _set_new_connection(self, conn): log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn) old.close() + for session in tuple(getattr(self._cluster, "sessions", ())): + session.update_created_pools() + def _try_connect_to_hosts(self): errors = {} @@ -3770,6 +3894,11 @@ def shutdown(self): if self._connection: self._connection.close() self._connection = None + self._current_host_id = None + try: + self._cluster.profile_manager.on_control_connection_host(None) + except ReferenceError: + pass def refresh_schema(self, force=False, **kwargs): try: @@ -3849,6 +3978,7 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, found_host_ids = set() found_endpoints = set() + local_host_id = None if local_result.parsed_rows: local_rows = dict_factory(local_result.column_names, local_result.parsed_rows) local_row = local_rows[0] @@ -3857,6 +3987,7 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, partitioner = local_row.get("partitioner") tokens = local_row.get("tokens", None) + local_host_id = local_row.get("host_id") peers_result.insert(0, local_row) @@ -3887,17 +4018,17 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, if host is None: host = self._cluster.metadata.get_host_by_host_id(host_id) - if host and host.endpoint != endpoint: - log.debug("[control connection] Updating host ip from %s to %s for (%s)", host.endpoint, endpoint, host_id) - reconnector = host.get_and_set_reconnection_handler(None) - if reconnector: - reconnector.cancel() - self._cluster.on_down(host, is_host_addition=False, expect_host_to_be_down=True) - - old_endpoint = host.endpoint - host.endpoint = endpoint - self._cluster.metadata.update_host(host, old_endpoint) - self._cluster.on_up(host) + if host: + target_endpoint = endpoint + if host_id == local_host_id: + target_endpoint = self._cluster._get_control_connection_host_endpoint(host, connection.endpoint) + if target_endpoint is None: + target_endpoint = endpoint + + if host.endpoint != target_endpoint: + log.debug("[control connection] Updating host endpoint from %s to %s for (%s)", + host.endpoint, target_endpoint, host_id) + self._cluster._update_host_endpoint(host, target_endpoint) if host is None: log.debug("[control connection] Found new host to connect to: %s", endpoint) @@ -3928,6 +4059,15 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, self._cluster.metadata.remove_host_by_host_id(old_host_id, old_host.endpoint) log.debug("[control connection] Finished fetching ring info") + current_host = None + if local_host_id in found_host_ids: + current_host = self._cluster.metadata.get_host_by_host_id(local_host_id) + if current_host is not None: + self._maybe_rebind_control_connection_host_endpoint(current_host, connection.endpoint) + current_host = self._cluster.metadata.get_host_by_host_id(local_host_id) + + self._current_host_id = local_host_id if current_host is not None else None + self._cluster.profile_manager.on_control_connection_host(current_host) if partitioner and should_rebuild_token_map: log.debug("[control connection] Rebuilding token map due to topology changes") self._cluster.metadata.rebuild_token_map(partitioner, token_map) @@ -3981,6 +4121,15 @@ def _update_location_info(self, host, datacenter, rack): self._cluster.profile_manager.on_up(host) return True + def _maybe_rebind_control_connection_host_endpoint(self, host, connection_endpoint): + target_endpoint = self._cluster._get_control_connection_host_endpoint(host, connection_endpoint) + if target_endpoint is None or target_endpoint == host.endpoint: + return + + log.debug("[control connection] Rebasing current host %s from %s to %s", + host.host_id, host.endpoint, target_endpoint) + self._cluster._update_host_endpoint(host, target_endpoint) + def _delay_for_event_type(self, event_type, delay_window): # this serves to order processing correlated events (received within the window) # the window and randomization still have the desired effect of skew across client instances @@ -4222,25 +4371,50 @@ def _signal_error(self): # try just signaling the cluster, as this will trigger a reconnect # as part of marking the host down if self._connection and self._connection.is_defunct: - host = self._cluster.metadata.get_host(self._connection.endpoint) + host = self._try_get_cluster_host() # host may be None if it's already been removed, but that indicates # that errors have already been reported, so we're fine if host: - self._cluster.signal_connection_failure( + is_down = self._cluster.signal_connection_failure( host, self._connection.last_error, is_host_addition=False) - return + if is_down: + return # if the connection is not defunct or the host already left, reconnect # manually self.reconnect() + def _try_get_cluster_host(self): + conn = self._connection + endpoint = conn.endpoint if conn else None + if not endpoint: + return None + + host = self._cluster.metadata.get_host(endpoint) + if host is not None: + return host + + host_id = self._current_host_id + if host_id is None: + return None + + return self._cluster.metadata.get_host_by_host_id(host_id) + def on_up(self, host): pass - def on_down(self, host): - + def _is_current_host(self, host): conn = self._connection - if conn and conn.endpoint == host.endpoint and \ + if conn is None or host is None: + return False + + if conn.endpoint == host.endpoint: + return True + + return self._current_host_id is not None and getattr(host, 'host_id', None) == self._current_host_id + + def on_down(self, host): + if self._is_current_host(host) and \ self._reconnection_handler is None: log.debug("[control connection] Control connection host (%s) is " "considered down, starting reconnection", host) @@ -4252,8 +4426,7 @@ def on_add(self, host, refresh_nodes=True): self.refresh_node_list_and_token_map(force_token_rebuild=True) def on_remove(self, host): - c = self._connection - if c and c.endpoint == host.endpoint: + if self._is_current_host(host): log.debug("[control connection] Control connection host (%s) is being removed. Reconnecting", host) # refresh will be done on reconnect self.reconnect() diff --git a/cassandra/datastax/insights/serializers.py b/cassandra/datastax/insights/serializers.py index 289c165e8a..270b5360a3 100644 --- a/cassandra/datastax/insights/serializers.py +++ b/cassandra/datastax/insights/serializers.py @@ -37,6 +37,7 @@ def initialize_registry(insights_registry): DCAwareRoundRobinPolicy, TokenAwarePolicy, WhiteListRoundRobinPolicy, + DynamicWhiteListRoundRobinPolicy, HostFilterPolicy, ConstantReconnectionPolicy, ExponentialReconnectionPolicy, @@ -80,6 +81,13 @@ def whitelist_round_robin_policy_insights_serializer(policy): 'options': {'allowed_hosts': policy._allowed_hosts} } + @insights_registry.register_serializer_for(DynamicWhiteListRoundRobinPolicy) + def dynamic_whitelist_round_robin_policy_insights_serializer(policy): + return {'type': policy.__class__.__name__, + 'namespace': namespace(policy.__class__), + 'options': {'allowed_host_ids': tuple(str(host_id) for host_id in policy._allowed_host_ids)} + } + @insights_registry.register_serializer_for(HostFilterPolicy) def host_filter_policy_insights_serializer(policy): return { diff --git a/cassandra/policies.py b/cassandra/policies.py index ceb5ebdc45..268a260812 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -166,6 +166,16 @@ def check_supported(self): """ pass + def on_control_connection_host(self, host): + """ + Called when the control connection resolves the metadata host behind + the endpoint it is currently using. + + Policies that maintain a dynamic host allowlist can override this to + update their internal view of the cluster. + """ + pass + class RoundRobinPolicy(LoadBalancingPolicy): """ @@ -540,6 +550,9 @@ def on_add(self, *args, **kwargs): def on_remove(self, *args, **kwargs): return self._child_policy.on_remove(*args, **kwargs) + def on_control_connection_host(self, host): + return self._child_policy.on_control_connection_host(host) + class WhiteListRoundRobinPolicy(RoundRobinPolicy): """ @@ -594,6 +607,58 @@ def on_add(self, host): RoundRobinPolicy.on_add(self, host) +class DynamicWhiteListRoundRobinPolicy(RoundRobinPolicy): + """ + A :class:`.RoundRobinPolicy` variant whose allowlist is updated from the + control connection. + + This is intended for proxy deployments where the driver can only reach the + host currently behind the control connection endpoint. The policy keeps + every other discovered node at :attr:`~.HostDistance.IGNORED` until the + control connection resolves a different host. + """ + + def __init__(self): + self._allowed_host_ids = frozenset(()) + self._cluster = None + RoundRobinPolicy.__init__(self) + + def _host_is_allowed(self, host): + return getattr(host, "host_id", None) in self._allowed_host_ids + + def _refresh_live_hosts(self, hosts): + self._live_hosts = frozenset( + host for host in hosts + if self._host_is_allowed(host) and host.is_up is not False + ) + + def populate(self, cluster, hosts): + self._cluster = cluster + self._refresh_live_hosts(hosts) + if len(self._live_hosts) > 1: + self._position = randint(0, len(self._live_hosts) - 1) + else: + self._position = 0 + + def distance(self, host): + return HostDistance.LOCAL if self._host_is_allowed(host) else HostDistance.IGNORED + + def on_up(self, host): + if self._host_is_allowed(host): + RoundRobinPolicy.on_up(self, host) + + def on_add(self, host): + if self._host_is_allowed(host): + RoundRobinPolicy.on_add(self, host) + + def on_control_connection_host(self, host): + with self._hosts_lock: + allowed_host_id = getattr(host, "host_id", None) + self._allowed_host_ids = frozenset((allowed_host_id,)) if allowed_host_id is not None else frozenset(()) + if self._cluster is not None: + self._refresh_live_hosts(self._cluster.metadata.all_hosts()) + + class HostFilterPolicy(LoadBalancingPolicy): """ A :class:`.LoadBalancingPolicy` subclass configured with a child policy, @@ -654,6 +719,9 @@ def on_add(self, host, *args, **kwargs): def on_remove(self, host, *args, **kwargs): return self._child_policy.on_remove(host, *args, **kwargs) + def on_control_connection_host(self, host): + return self._child_policy.on_control_connection_host(host) + @property def predicate(self): """ @@ -1322,6 +1390,9 @@ def on_add(self, *args, **kwargs): def on_remove(self, *args, **kwargs): return self._child_policy.on_remove(*args, **kwargs) + def on_control_connection_host(self, host): + return self._child_policy.on_control_connection_host(host) + class DefaultLoadBalancingPolicy(WrapperPolicy): """ diff --git a/cassandra/pool.py b/cassandra/pool.py index 9e949c342c..097824be06 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -435,7 +435,7 @@ def __init__(self, host, host_distance, session): if self._keyspace: first_connection.set_keyspace_blocking(self._keyspace) - if first_connection.features.sharding_info and not self._session.cluster.shard_aware_options.disable: + if first_connection.features.sharding_info and not self._session.is_shard_aware_disabled(): self.host.sharding_info = first_connection.features.sharding_info self._open_connections_for_all_shards(first_connection.features.shard_id) self.tablets_routing_v1 = first_connection.features.tablets_routing_v1 @@ -451,7 +451,7 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table raise NoConnectionsAvailable() shard_id = None - if not self._session.cluster.shard_aware_options.disable and self.host.sharding_info and routing_key: + if not self._session.is_shard_aware_disabled() and self.host.sharding_info and routing_key: t = self._session.cluster.metadata.token_map.token_class.from_key(routing_key) shard_id = None @@ -554,7 +554,7 @@ def return_connection(self, connection, stream_was_orphaned=False): if not connection.signaled_error: log.debug("Defunct or closed connection (%s) returned to pool, potentially " "marking host %s as down", id(connection), self.host) - is_down = self.host.signal_connection_failure(connection.last_error) + is_down = self._session._signal_connection_failure(self.host, connection.last_error) connection.signaled_error = True if self.shutdown_on_error and not is_down: @@ -562,7 +562,7 @@ def return_connection(self, connection, stream_was_orphaned=False): if is_down: self.shutdown() - self._session.cluster.on_down(self.host, is_host_addition=False) + self._session._handle_pool_down(self.host, is_host_addition=False) else: connection.close() with self._lock: @@ -603,7 +603,7 @@ def _replace(self, connection): try: if connection.features.shard_id in self._connections: del self._connections[connection.features.shard_id] - if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable: + if self.host.sharding_info and not self._session.is_shard_aware_disabled(): self._connecting.add(connection.features.shard_id) self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id) else: @@ -678,7 +678,8 @@ def disable_advanced_shard_aware(self, secs): def _get_shard_aware_endpoint(self): if (self.advanced_shardaware_block_until and self.advanced_shardaware_block_until > time.time()) or \ - self._session.cluster.shard_aware_options.disable_shardaware_port: + self._session.cluster.shard_aware_options.disable_shardaware_port or \ + self._session.is_shard_aware_disabled(): return None endpoint = None @@ -920,5 +921,3 @@ def open_count(self): @property def _excess_connection_limit(self): return self.host.sharding_info.shards_count * self.max_excess_connections_per_shard_multiplier - - diff --git a/docs/api/cassandra/policies.rst b/docs/api/cassandra/policies.rst index 84d5575a40..2a24af8f9f 100644 --- a/docs/api/cassandra/policies.rst +++ b/docs/api/cassandra/policies.rst @@ -26,6 +26,9 @@ Load Balancing .. autoclass:: WhiteListRoundRobinPolicy :members: +.. autoclass:: DynamicWhiteListRoundRobinPolicy + :members: + .. autoclass:: TokenAwarePolicy :members: diff --git a/tests/integration/standard/test_client_routes.py b/tests/integration/standard/test_client_routes.py index 5a20421276..a9efcd65b4 100644 --- a/tests/integration/standard/test_client_routes.py +++ b/tests/integration/standard/test_client_routes.py @@ -40,8 +40,8 @@ from cassandra.cluster import Cluster from cassandra.client_routes import ClientRoutesConfig, ClientRouteProxy -from cassandra.connection import ClientRoutesEndPoint -from cassandra.policies import RoundRobinPolicy +from cassandra.connection import ClientRoutesEndPoint, ConnectionException +from cassandra.policies import DynamicWhiteListRoundRobinPolicy, RoundRobinPolicy from tests.integration import ( TestCluster, get_cluster, @@ -54,6 +54,28 @@ log = logging.getLogger(__name__) + +class ProxyOnlyReachableConnection(Cluster.connection_class): + """ + Simulates a private-link client that can reach only the proxy endpoint. + + The CCM node addresses are reachable from the local test runner, which means + the existing client-routes tests cannot reproduce bugs that only appear when + direct node IPs are private. This connection class rejects those direct node + addresses while still allowing the NLB address. + """ + + @classmethod + def factory(cls, endpoint, timeout, host_conn=None, *args, **kwargs): + address, _ = endpoint.resolve() + if address.startswith("127.0.0."): + raise ConnectionException( + "Simulated private node address %s is unreachable from the client" % address, + endpoint=endpoint, + ) + return super().factory(endpoint, timeout, host_conn=host_conn, *args, **kwargs) + + class TcpProxy: """ A simple TCP proxy that forwards connections from a local listen port @@ -535,6 +557,57 @@ def teardown_module(): else: os.environ['SCYLLA_EXT_OPTS'] = _saved_scylla_ext_opts + +class TestProxyConnectivityWithoutClientRoutes(unittest.TestCase): + """ + Reproducer for connecting through a generic proxy when node addresses are + not reachable from the client. + + The initial control connection can reach the cluster through the proxy, but + the driver later tries to open pools to the discovered node addresses + directly. In a proxy-only environment that makes connect/query fail. + """ + + @classmethod + def setUpClass(cls): + cls.node_addrs = { + 1: "127.0.0.1", + 2: "127.0.0.2", + 3: "127.0.0.3", + } + cls.proxy_node_id = 1 + cls.nlb = NLBEmulator() + cls.nlb.start(cls.node_addrs) + + @classmethod + def tearDownClass(cls): + cls.nlb.stop() + + def _make_proxy_cluster(self): + return Cluster( + contact_points=[NLBEmulator.LISTEN_HOST], + port=self.nlb.node_port(self.proxy_node_id), + connection_class=ProxyOnlyReachableConnection, + load_balancing_policy=DynamicWhiteListRoundRobinPolicy(), + ) + + def test_dynamic_whitelist_session_succeeds_when_only_proxy_is_reachable(self): + cluster = self._make_proxy_cluster() + self.addCleanup(cluster.shutdown) + + session = cluster.connect() + row = session.execute( + "SELECT release_version FROM system.local WHERE key='local'" + ).one() + + self.assertIsNotNone(row) + pool_state = session.get_pool_state() + self.assertEqual(len(pool_state), 1) + + session_host = next(iter(pool_state)) + self.assertEqual(session_host.endpoint.address, NLBEmulator.LISTEN_HOST) + self.assertEqual(session_host.endpoint.port, self.nlb.node_port(self.proxy_node_id)) + @skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', scylla_version="2026.1.0") class TestGetHostPortMapping(unittest.TestCase): diff --git a/tests/unit/advanced/test_insights.py b/tests/unit/advanced/test_insights.py index ec9b918866..55ea05e8f4 100644 --- a/tests/unit/advanced/test_insights.py +++ b/tests/unit/advanced/test_insights.py @@ -17,6 +17,7 @@ import logging import sys +import uuid from unittest.mock import sentinel from cassandra import ConsistencyLevel @@ -37,6 +38,7 @@ DCAwareRoundRobinPolicy, TokenAwarePolicy, WhiteListRoundRobinPolicy, + DynamicWhiteListRoundRobinPolicy, HostFilterPolicy, ConstantReconnectionPolicy, ExponentialReconnectionPolicy, @@ -203,6 +205,14 @@ def test_whitelist_round_robin_policy(self): 'options': {'allowed_hosts': ('127.0.0.3',)}, 'type': 'WhiteListRoundRobinPolicy'} + def test_dynamic_whitelist_round_robin_policy(self): + policy = DynamicWhiteListRoundRobinPolicy() + host_id = uuid.uuid4() + policy._allowed_host_ids = (host_id,) + assert insights_registry.serialize(policy) == {'namespace': 'cassandra.policies', + 'options': {'allowed_host_ids': (str(host_id),)}, + 'type': 'DynamicWhiteListRoundRobinPolicy'} + def test_host_filter_policy(self): def my_predicate(s): return False diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 872d133b28..25c58b2407 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -15,7 +15,7 @@ import logging import socket - +from concurrent.futures import Future from unittest.mock import patch, Mock import uuid @@ -23,8 +23,10 @@ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \ ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT +from cassandra.connection import DefaultEndPoint, EndPoint from cassandra.pool import Host -from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy +from cassandra.policies import RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, \ + DynamicWhiteListRoundRobinPolicy, HostStateListener, SimpleConvictionPolicy from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory from tests.unit.utils import mock_session_pools from tests import connection_class @@ -33,6 +35,57 @@ log = logging.getLogger(__name__) + +class _HostAwareProxyEndPoint(EndPoint): + def __init__(self, address, affinity_key, port=9042): + self._address = address + self._affinity_key = affinity_key + self._port = port + + @property + def address(self): + return self._address + + @property + def port(self): + return self._port + + def resolve(self): + return self._address, self._port + + def __eq__(self, other): + return isinstance(other, _HostAwareProxyEndPoint) and \ + self.address == other.address and self.port == other.port and \ + self._affinity_key == other._affinity_key + + def __hash__(self): + return hash((self.address, self.port, self._affinity_key)) + + def __lt__(self, other): + if not isinstance(other, _HostAwareProxyEndPoint): + return NotImplemented + return (self.address, self.port, str(self._affinity_key)) < \ + (other.address, other.port, str(other._affinity_key)) + + +class _RecordingHostStateListener(HostStateListener): + + def __init__(self): + self.events = [] + + def on_up(self, host): + self.events.append(("up", host.address)) + + def on_down(self, host): + self.events.append(("down", host.address)) + + def on_add(self, host): + self.events.append(("add", host.address)) + + def on_remove(self, host): + self.events.append(("remove", host.address)) + + class ExceptionTypeTest(unittest.TestCase): def test_exception_types(self): @@ -171,6 +224,134 @@ def test_connection_factory_passes_compression_kwarg(self): assert factory.call_args.kwargs['compression'] == expected assert cluster.compression == expected + def test_get_control_connection_host_falls_back_to_host_id(self): + cluster = Cluster(contact_points=['127.0.0.1']) + host = Host(DefaultEndPoint('192.168.1.10'), SimpleConvictionPolicy, host_id=uuid.uuid4()) + + metadata = Mock() + metadata.get_host.return_value = None + metadata.get_host_by_host_id.return_value = host + cluster.metadata = metadata + + connection = Mock(endpoint=DefaultEndPoint('127.254.254.101', 9042)) + cluster.control_connection = Mock(_connection=connection, _current_host_id=host.host_id) + + assert cluster.get_control_connection_host() is host + metadata.get_host.assert_called_once_with(connection.endpoint) + metadata.get_host_by_host_id.assert_called_once_with(host.host_id) + + def test_update_host_endpoint_recreates_session_pools(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_up() + cluster.metadata.add_or_return_host(host) + + remove_future = Future() + remove_future.set_result(None) + add_future = Future() + add_future.set_result(True) + + session = Mock() + session.remove_pool.return_value = remove_future + session.add_or_renew_pool.return_value = add_future + cluster.sessions.add(session) + + new_endpoint = DefaultEndPoint("127.0.0.2") + cluster._update_host_endpoint(host, new_endpoint) + + assert host.endpoint == new_endpoint + assert cluster.metadata.get_host(new_endpoint) is host + session.remove_pool.assert_called_once_with(host) + session.add_or_renew_pool.assert_called_once_with(host, is_host_addition=False) + + def test_update_host_endpoint_restarts_reconnector_for_down_host(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_down() + cluster.metadata.add_or_return_host(host) + + previous_reconnector = Mock() + host.get_and_set_reconnection_handler(previous_reconnector) + + session = Mock() + cluster.sessions.add(session) + cluster._start_reconnector = Mock() + + new_endpoint = DefaultEndPoint("127.0.0.2") + cluster._update_host_endpoint(host, new_endpoint) + + assert host.endpoint == new_endpoint + assert cluster.metadata.get_host(new_endpoint) is host + previous_reconnector.cancel.assert_called_once_with() + session.remove_pool.assert_called_once_with(host) + session.add_or_renew_pool.assert_not_called() + cluster._start_reconnector.assert_called_once_with(host, is_host_addition=False) + + def test_update_host_endpoint_notifies_listeners_for_live_host(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_up() + cluster.metadata.add_or_return_host(host) + + session = Mock() + cluster.sessions.add(session) + + listener = _RecordingHostStateListener() + cluster.register_listener(listener) + + new_endpoint = DefaultEndPoint("127.0.0.2") + cluster._update_host_endpoint(host, new_endpoint) + + assert listener.events == [("down", "127.0.0.1"), ("up", "127.0.0.2")] + session.remove_pool.assert_called_once_with(host) + session.add_or_renew_pool.assert_called_once_with(host, is_host_addition=False) + + def test_is_shard_aware_ignores_non_shard_aware_pools(self): + cluster = Cluster(contact_points=['127.0.0.1']) + + shard_pool = Mock() + shard_pool.host = Mock( + endpoint=DefaultEndPoint("127.0.0.1"), + sharding_info=Mock(shards_count=8)) + shard_pool._connections = {0: Mock(), 1: Mock()} + + control_pool = Mock() + control_pool.host = Mock( + endpoint=DefaultEndPoint("127.254.254.101"), + sharding_info=None) + control_pool._connections = {0: Mock()} + + cluster.get_all_pools = Mock(return_value=[control_pool, shard_pool]) + + assert cluster.is_shard_aware() is True + + def test_shard_aware_stats_ignores_non_shard_aware_pools(self): + cluster = Cluster(contact_points=['127.0.0.1']) + + shard_pool = Mock() + shard_pool.host = Mock( + endpoint=DefaultEndPoint("127.0.0.1"), + sharding_info=Mock(shards_count=8)) + shard_pool._connections = {0: Mock(), 1: Mock()} + + control_pool = Mock() + control_pool.host = Mock( + endpoint=DefaultEndPoint("127.254.254.101"), + sharding_info=None) + control_pool._connections = {0: Mock()} + + cluster.get_all_pools = Mock(return_value=[shard_pool, control_pool]) + + assert cluster.shard_aware_stats() == { + "127.0.0.1:9042": {"shards_count": 8, "connected": 2} + } + class SchedulerTest(unittest.TestCase): # TODO: this suite could be expanded; for now just adding a test covering a ticket @@ -194,6 +375,16 @@ def setUp(self): raise unittest.SkipTest('libev does not appear to be installed correctly') connection_class.initialize_reactor() + @staticmethod + def _completed_future(result): + future = Future() + future.set_result(result) + return future + + @staticmethod + def _proxy_endpoint(address, affinity_key, port=9042): + return _HostAwareProxyEndPoint(address, affinity_key, port) + # TODO: this suite could be expanded; for now just adding a test covering a PR @mock_session_pools def test_default_serial_consistency_level_ep(self, *_): @@ -281,6 +472,163 @@ def test_set_keyspace_escapes_quotes(self, *_): assert query == 'USE simple_ks', ( "Simple keyspace names should not be quoted, got: %r" % query) + def test_get_control_connection_host_endpoint_reuses_matching_default_endpoint(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + source_host.set_up() + connection_endpoint = DefaultEndPoint("127.254.254.101") + + verification_connection = Mock() + verification_connection.wait_for_response.return_value = Mock( + column_names=["host_id"], + parsed_rows=[(source_host.host_id,)]) + + with patch.object(cluster, "connection_factory", + return_value=verification_connection) as connection_factory: + endpoint = cluster._get_control_connection_host_endpoint(source_host, connection_endpoint) + + assert endpoint == connection_endpoint + assert connection_factory.call_count == 3 + assert verification_connection.close.call_count == 3 + + def test_get_control_connection_host_endpoint_prefers_host_aware_metadata_endpoint(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host_id = uuid.uuid4() + source_endpoint = self._proxy_endpoint("proxy.control.example", host_id) + source_host = Host(source_endpoint, SimpleConvictionPolicy, host_id=host_id) + source_host.set_up() + + endpoint = cluster._get_control_connection_host_endpoint( + source_host, DefaultEndPoint("127.254.254.101")) + + assert endpoint == source_endpoint + + def test_get_control_connection_host_endpoint_keeps_control_endpoint_when_verification_mismatches(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + source_host.set_up() + connection_endpoint = DefaultEndPoint("127.254.254.101") + + verification_connection = Mock() + verification_connection.wait_for_response.return_value = Mock( + column_names=["host_id"], + parsed_rows=[(uuid.uuid4(),)]) + + with patch.object(cluster, "connection_factory", + return_value=verification_connection) as connection_factory: + endpoint = cluster._get_control_connection_host_endpoint(source_host, connection_endpoint) + + assert endpoint == connection_endpoint + assert connection_factory.call_count == 1 + assert verification_connection.close.call_count == 1 + + def test_get_control_connection_host_endpoint_keeps_control_endpoint_when_verification_fails(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + source_host.set_up() + connection_endpoint = DefaultEndPoint("127.254.254.101") + + verification_connection = Mock() + verification_connection.wait_for_response.side_effect = RuntimeError("verification failed") + + with patch.object(cluster, "connection_factory", + return_value=verification_connection) as connection_factory: + endpoint = cluster._get_control_connection_host_endpoint(source_host, connection_endpoint) + + assert endpoint == connection_endpoint + assert connection_factory.call_count == 1 + assert verification_connection.close.call_count == 1 + + def test_analytics_master_lookup_keeps_explicit_host(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + target_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + target_host.set_up() + + with patch.object(Session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)): + session = Session(cluster, [target_host]) + + master_future = Mock() + master_future.result.return_value = [({'location': '127.0.0.99:8182'},)] + + query_future = Mock() + query_future._host = target_host + query_future.query = SimpleStatement("g.V()") + query_future._load_balancer = Mock() + query_future.send_request = Mock() + query_future.query_plan = iter(()) + + with patch.object(session, "submit", + side_effect=lambda fn, *args, **kwargs: self._completed_future(fn(*args, **kwargs))): + session._on_analytics_master_result(None, master_future, query_future) + + assert list(query_future.query_plan) == [target_host] + query_future._load_balancer.make_query_plan.assert_not_called() + query_future.send_request.assert_called_once_with() + + @mock_session_pools + def test_session_preserves_down_event_discounting_after_endpoint_update(self, *_): + class _DeterministicHashEndPoint(EndPoint): + def __init__(self, address, hash_value, port=9042): + self._address = address + self._hash_value = hash_value + self._port = port + + @property + def address(self): + return self._address + + @property + def port(self): + return self._port + + def resolve(self): + return self._address, self._port + + def __eq__(self, other): + return isinstance(other, _DeterministicHashEndPoint) and \ + self.address == other.address and self.port == other.port + + def __hash__(self): + return self._hash_value + + def __lt__(self, other): + return (self.address, self.port) < (other.address, other.port) + + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_up() + + session = Session(cluster, [host]) + cluster.sessions.add(session) + + pool = Mock() + pool.get_state.return_value = {"open_count": 1} + + host.endpoint = _DeterministicHashEndPoint("127.0.0.1", 1) + session._pools = {host: pool} + host.endpoint = _DeterministicHashEndPoint("127.0.0.2", 2) + + cluster.on_down_potentially_blocking = Mock() + + cluster.on_down(host, is_host_addition=False) + + assert host.is_up is True + cluster.on_down_potentially_blocking.assert_not_called() + assert session.get_pool_state_for_host(host) == {"open_count": 1} + class ProtocolVersionTests(unittest.TestCase): def test_protocol_downgrade_test(self): @@ -537,6 +885,26 @@ def test_no_profiles_same_name(self): with pytest.raises(ValueError): cluster.add_execution_profile('two', ExecutionProfile()) + def test_add_execution_profile_seeds_current_control_host(self): + cluster = Cluster(protocol_version=4) + self.addCleanup(cluster.shutdown) + + hosts = [ + Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()), + Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy, host_id=uuid.uuid4()), + ] + for host in hosts: + host.set_up() + cluster.metadata.add_or_return_host(host) + + cluster.control_connection._connection = Mock(endpoint=hosts[1].endpoint) + cluster.control_connection._current_host_id = hosts[1].host_id + + profile = ExecutionProfile(load_balancing_policy=DynamicWhiteListRoundRobinPolicy()) + cluster.add_execution_profile('proxy', profile) + + assert list(profile.load_balancing_policy.make_query_plan()) == [hosts[1]] + def test_warning_on_no_lbp_with_contact_points_legacy_mode(self): """ Test that users are warned when they instantiate a Cluster object in diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 037d4a8888..99ae4282a6 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -102,6 +102,7 @@ class MockCluster(object): down_host = None contact_points = [] is_shutdown = False + sessions = () def __init__(self): self.metadata = MockMetadata() @@ -118,6 +119,14 @@ def add_host(self, endpoint, datacenter, rack, signal=False, refresh_nodes=True, self.added_hosts.append(host) return host, True + def _update_host_endpoint(self, host, endpoint): + old_endpoint = host.endpoint + host.endpoint = endpoint + self.metadata.update_host(host, old_endpoint) + + def _get_control_connection_host_endpoint(self, host, connection_endpoint): + return connection_endpoint + def remove_host(self, host): pass @@ -301,6 +310,60 @@ def test_wait_for_schema_agreement_none_timeout(self): cc._time = self.time assert cc.wait_for_schema_agreement() + def test_on_down_reconnects_when_current_host_matches_by_host_id(self): + self.control_connection._connection.endpoint = DefaultEndPoint("127.254.254.101") + self.control_connection._current_host_id = "uuid1" + self.control_connection.reconnect = Mock() + + self.control_connection.on_down(self.cluster.metadata.get_host_by_host_id("uuid1")) + + self.control_connection.reconnect.assert_called_once_with() + + def test_on_remove_reconnects_when_current_host_matches_by_host_id(self): + self.control_connection._connection.endpoint = DefaultEndPoint("127.254.254.101") + self.control_connection._current_host_id = "uuid1" + self.control_connection.reconnect = Mock() + self.control_connection.refresh_node_list_and_token_map = Mock() + + self.control_connection.on_remove(self.cluster.metadata.get_host_by_host_id("uuid1")) + + self.control_connection.reconnect.assert_called_once_with() + self.control_connection.refresh_node_list_and_token_map.assert_not_called() + + def test_signal_error_marks_current_host_down_when_current_host_matches_by_host_id(self): + host = self.cluster.metadata.get_host_by_host_id("uuid1") + error = RuntimeError("defunct") + + self.connection.endpoint = DefaultEndPoint("127.254.254.101") + self.connection.is_defunct = True + self.connection.last_error = error + self.control_connection._current_host_id = host.host_id + self.cluster.signal_connection_failure = Mock() + self.control_connection.reconnect = Mock() + + self.control_connection._signal_error() + + self.cluster.signal_connection_failure.assert_called_once_with( + host, error, is_host_addition=False) + self.control_connection.reconnect.assert_not_called() + + def test_signal_error_reconnects_when_current_host_conviction_is_deferred(self): + host = self.cluster.metadata.get_host_by_host_id("uuid1") + error = RuntimeError("defunct") + + self.connection.endpoint = DefaultEndPoint("127.254.254.101") + self.connection.is_defunct = True + self.connection.last_error = error + self.control_connection._current_host_id = host.host_id + self.cluster.signal_connection_failure = Mock(return_value=False) + self.control_connection.reconnect = Mock() + + self.control_connection._signal_error() + + self.cluster.signal_connection_failure.assert_called_once_with( + host, error, is_host_addition=False) + self.control_connection.reconnect.assert_called_once_with() + def test_refresh_nodes_and_tokens(self): self.control_connection.refresh_node_list_and_token_map() meta = self.cluster.metadata @@ -319,6 +382,32 @@ def test_refresh_nodes_and_tokens(self): assert self.connection.wait_for_responses.call_count == 1 + def test_refresh_nodes_and_tokens_rebinds_current_host_to_control_endpoint(self): + proxy_endpoint = DefaultEndPoint("127.254.254.101") + self.connection.endpoint = proxy_endpoint + self.connection.original_endpoint = proxy_endpoint + self.cluster.profile_manager.on_control_connection_host = Mock() + + self.control_connection.refresh_node_list_and_token_map() + + current_host = self.cluster.metadata.get_host_by_host_id("uuid1") + assert current_host.endpoint == proxy_endpoint + assert self.cluster.metadata.get_host(proxy_endpoint) is current_host + assert self.control_connection._current_host_id == "uuid1" + self.cluster.profile_manager.on_control_connection_host.assert_called_once_with(current_host) + + def test_refresh_nodes_and_tokens_skips_intermediate_endpoint_for_current_host(self): + proxy_endpoint = DefaultEndPoint("127.254.254.101") + self.connection.endpoint = proxy_endpoint + self.connection.original_endpoint = proxy_endpoint + self.control_connection.refresh_node_list_and_token_map() + + self.cluster._update_host_endpoint = Mock(wraps=self.cluster._update_host_endpoint) + + self.control_connection.refresh_node_list_and_token_map() + + assert self.cluster._update_host_endpoint.call_args_list == [] + def test_refresh_nodes_and_tokens_with_invalid_peers(self): def refresh_and_validate_added_hosts(): self.connection.wait_for_responses = Mock(return_value=_node_meta_results( diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index f92bb53785..b2109565ea 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -21,7 +21,7 @@ import unittest from threading import Thread, Event, Lock -from unittest.mock import Mock, NonCallableMagicMock, MagicMock +from unittest.mock import Mock, NonCallableMagicMock, MagicMock, patch from cassandra.cluster import Session, ShardAwareOptions from cassandra.connection import Connection @@ -42,6 +42,8 @@ class _PoolTests(unittest.TestCase): def make_session(self): session = NonCallableMagicMock(spec=Session, keyspace='foobarkeyspace', _trash=[]) + session._signal_connection_failure.return_value = False + session.is_shard_aware_disabled.return_value = False return session def test_borrow_and_return(self): @@ -143,8 +145,7 @@ def test_return_defunct_connection(self): pool.borrow_connection(timeout=0.01) conn.is_defunct = True - session.cluster.signal_connection_failure.return_value = False - host.signal_connection_failure.return_value = False + session._signal_connection_failure.return_value = False pool.return_connection(conn) # the connection should be closed a new creation scheduled @@ -165,19 +166,14 @@ def test_return_defunct_connection_on_down_host(self): pool.borrow_connection(timeout=0.01) conn.is_defunct = True - session.cluster.signal_connection_failure.return_value = True - host.signal_connection_failure.return_value = True + session._signal_connection_failure.return_value = True pool.return_connection(conn) - # the connection should be closed a new creation scheduled + # the connection should be closed and the pool should delegate down + # handling back to the session. assert conn.close.call_args - if self.PoolImpl is HostConnection: - # on shard aware implementation we use submit function regardless - assert host.signal_connection_failure.call_args - assert session.submit.called - else: - assert not session.submit.called - assert session.cluster.signal_connection_failure.call_args + session._signal_connection_failure.assert_called_once_with(host, conn.last_error) + session._handle_pool_down.assert_called_once_with(host, is_host_addition=False) assert pool.is_shutdown def test_return_closed_connection(self): @@ -192,8 +188,7 @@ def test_return_closed_connection(self): pool.borrow_connection(timeout=0.01) conn.is_closed = True - session.cluster.signal_connection_failure.return_value = False - host.signal_connection_failure.return_value = False + session._signal_connection_failure.return_value = False pool.return_connection(conn) # a new creation should be scheduled @@ -231,6 +226,29 @@ class HostConnectionTests(_PoolTests): PoolImpl = HostConnection uses_single_connection = True + def test_session_level_shard_aware_disable_skips_fanout(self): + host = Mock(spec=Host, address='ip1') + host.sharding_info = None + session = self.make_session() + session.is_shard_aware_disabled.return_value = True + + connection = HashableMock(spec=Connection, in_flight=0, is_defunct=False, + is_closed=False, max_request_id=100) + connection.features = ProtocolFeatures( + shard_id=0, + sharding_info=_ShardingInfo( + shard_id=0, shards_count=4, partitioner="", + sharding_algorithm="", sharding_ignore_msb=0, + shard_aware_port=19042, shard_aware_port_ssl=""), + tablets_routing_v1=False) + session.cluster.connection_factory.return_value = connection + + with patch.object(HostConnection, "_open_connections_for_all_shards") as open_shards: + pool = HostConnection(host, HostDistance.LOCAL, session) + + open_shards.assert_not_called() + assert pool.host.sharding_info is None + def test_fast_shutdown(self): class MockSession(MagicMock): is_shutdown = False diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index 6142af1aa1..073612178b 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -27,7 +27,8 @@ from cassandra import ConsistencyLevel from cassandra.cluster import Cluster, ControlConnection from cassandra.metadata import Metadata -from cassandra.policies import (RackAwareRoundRobinPolicy, RoundRobinPolicy, WhiteListRoundRobinPolicy, DCAwareRoundRobinPolicy, +from cassandra.policies import (RackAwareRoundRobinPolicy, RoundRobinPolicy, WhiteListRoundRobinPolicy, + DynamicWhiteListRoundRobinPolicy, DCAwareRoundRobinPolicy, TokenAwarePolicy, SimpleConvictionPolicy, HostDistance, ExponentialReconnectionPolicy, RetryPolicy, WriteType, @@ -1421,6 +1422,42 @@ def test_hosts_with_hostname(self): assert policy.distance(host) == HostDistance.LOCAL + +class DynamicWhiteListRoundRobinPolicyTest(unittest.TestCase): + + def test_control_connection_host_updates_allowed_host(self): + hosts = [ + Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()), + Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy, host_id=uuid.uuid4()), + Host(DefaultEndPoint("127.0.0.3"), SimpleConvictionPolicy, host_id=uuid.uuid4()), + ] + for host in hosts: + host.set_up() + + cluster = Mock() + cluster.metadata.all_hosts.return_value = hosts + + policy = DynamicWhiteListRoundRobinPolicy() + policy.populate(cluster, hosts) + + assert list(policy.make_query_plan()) == [] + assert policy.distance(hosts[0]) == HostDistance.IGNORED + + policy.on_control_connection_host(hosts[1]) + + assert list(policy.make_query_plan()) == [hosts[1]] + assert policy.distance(hosts[0]) == HostDistance.IGNORED + assert policy.distance(hosts[1]) == HostDistance.LOCAL + + policy.on_down(hosts[1]) + assert list(policy.make_query_plan()) == [] + + policy.on_up(hosts[1]) + assert list(policy.make_query_plan()) == [hosts[1]] + + policy.on_control_connection_host(hosts[2]) + assert list(policy.make_query_plan()) == [hosts[2]] + def test_hosts_with_socket_hostname(self): hosts = [UnixSocketEndPoint('/tmp/scylla-workdir/cql.m')] policy = WhiteListRoundRobinPolicy(hosts)