diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 9eace8810d..420d94200a 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -23,7 +23,7 @@ from binascii import hexlify from collections import defaultdict from collections.abc import Mapping -from concurrent.futures import ThreadPoolExecutor, FIRST_COMPLETED, wait as wait_futures +from concurrent.futures import ThreadPoolExecutor, FIRST_COMPLETED, Future, wait as wait_futures from copy import copy from functools import partial, reduce, wraps from itertools import groupby, count, chain @@ -1746,6 +1746,45 @@ def connect(self, keyspace=None, wait_for_all_pools=False): established or attempted. Default is `False`, which means it will return when the first successful connection is established. Remaining pools are added asynchronously. """ + self._ensure_core_connections_setup() + + session = self._new_session(keyspace) + if wait_for_all_pools: + wait_futures(session._initial_connect_futures) + + self._set_default_dbaas_consistency(session) + + return session + + def connect_to_control_host(self, keyspace=None, wait_for_all_pools=False): + """ + Creates and returns a new :class:`~.Session` pinned to the current + control-connection host. + + This is intended for proxy deployments where the control connection + reaches a node through an address that is different from the node's + broadcast RPC address. The returned session creates a pool only for + that control-connection node and routes all requests through the + control connection endpoint instead of the node's broadcast endpoint. + If the active control connection is using a shared load-balancer + address, this requires the cluster to expose host-specific endpoints + (for example through SNI or client routes); otherwise this method + raises :class:`~cassandra.DriverException`. + + `keyspace` and `wait_for_all_pools` behave the same way as in + :meth:`connect`. + """ + self._ensure_core_connections_setup() + + session = self._new_control_host_session(keyspace) + if wait_for_all_pools: + wait_futures(session._initial_connect_futures) + + self._set_default_dbaas_consistency(session) + + return session + + def _ensure_core_connections_setup(self): with self._lock: if self.is_shutdown: raise DriverException("Cluster is already shut down") @@ -1777,14 +1816,6 @@ def connect(self, keyspace=None, wait_for_all_pools=False): ) self._is_setup = True - session = self._new_session(keyspace) - if wait_for_all_pools: - wait_futures(session._initial_connect_futures) - - self._set_default_dbaas_consistency(session) - - return session - def _set_default_dbaas_consistency(self, session): if session.cluster.metadata.dbaas: for profile in self.profile_manager.profiles.values(): @@ -1805,14 +1836,18 @@ def get_all_pools(self): pools.extend(s.get_pools()) return pools + def _get_shard_aware_pools(self): + return [pool for pool in self.get_all_pools() if pool.host.sharding_info is not None] + def is_shard_aware(self): - return bool(self.get_all_pools()[0].host.sharding_info) + return bool(self._get_shard_aware_pools()) def shard_aware_stats(self): - if self.is_shard_aware(): + shard_aware_pools = self._get_shard_aware_pools() + if shard_aware_pools: return {str(pool.host.endpoint): {'shards_count': pool.host.sharding_info.shards_count, 'connected': len(pool._connections.keys())} - for pool in self.get_all_pools()} + for pool in shard_aware_pools} def shutdown(self): """ @@ -1857,6 +1892,101 @@ def _new_session(self, keyspace): self.sessions.add(session) return session + def _new_control_host_session(self, keyspace): + connection = self.control_connection._connection + if connection is None: + raise DriverException("Control connection is not established") + + control_host = self.get_control_connection_host() + if control_host is None: + raise DriverException( + "Unable to resolve the current control connection host from metadata") + + pinned_endpoint = self._get_control_host_session_endpoint(control_host, connection.endpoint) + pinned_host = self._clone_host_with_endpoint(control_host, pinned_endpoint, host_cls=_ControlHost) + session = _ControlHostSession(self, control_host, pinned_host, keyspace) + self._session_register_user_types(session) + self.sessions.add(session) + return session + + def _control_host_session_can_reuse_endpoint(self, host, endpoint): + if endpoint is None: + return False + if endpoint == host.endpoint: + return True + if not isinstance(endpoint, DefaultEndPoint): + return True + return self._control_host_session_default_endpoint_targets_host( + host, endpoint) + + def _control_host_session_default_endpoint_targets_host( + self, host, endpoint, attempts=3): + for _ in range(attempts): + connected_host_id = self._get_host_id_for_endpoint(endpoint) + if connected_host_id != host.host_id: + return False + return True + + def _get_host_id_for_endpoint(self, endpoint): + connection = None + try: + connection = self.connection_factory(endpoint) + response = connection.wait_for_response( + QueryMessage( + query="SELECT host_id FROM system.local WHERE key='local'", + consistency_level=ConsistencyLevel.ONE), + timeout=self.connect_timeout) + rows = dict_factory(response.column_names, response.parsed_rows) + if not rows: + return None + return rows[0].get("host_id") + except Exception: + log.debug( + "Failed verifying control-host endpoint %s", endpoint, + exc_info=True) + return None + finally: + if connection: + connection.close() + + def _get_control_host_session_endpoint(self, control_host, connection_endpoint): + if connection_endpoint is not None and ( + connection_endpoint == control_host.endpoint or + not isinstance(connection_endpoint, DefaultEndPoint)): + return connection_endpoint + + host_endpoint = control_host.endpoint + if host_endpoint is not None and not isinstance(host_endpoint, DefaultEndPoint): + return host_endpoint + + if self._control_host_session_can_reuse_endpoint(control_host, connection_endpoint): + return connection_endpoint + + raise DriverException( + "connect_to_control_host() requires a host-specific control endpoint for %s; " + "the active control connection is using shared endpoint %s" + % (control_host, connection_endpoint)) + + def _clone_host_with_endpoint(self, host, endpoint, host_cls=Host): + cloned_host = host_cls( + endpoint, + self.conviction_policy_factory, + datacenter=host.datacenter, + rack=host.rack, + host_id=host.host_id) + cloned_host.is_up = host.is_up + cloned_host.broadcast_address = host.broadcast_address + cloned_host.broadcast_port = host.broadcast_port + cloned_host.broadcast_rpc_address = endpoint.address + cloned_host.broadcast_rpc_port = endpoint.port + cloned_host.listen_address = host.listen_address + cloned_host.listen_port = host.listen_port + cloned_host.release_version = host.release_version + cloned_host.dse_version = host.dse_version + cloned_host.dse_workload = host.dse_workload + cloned_host.dse_workloads = host.dse_workloads + return cloned_host + def _session_register_user_types(self, session): for keyspace, type_map in self._user_types.items(): for udt_name, klass in type_map.items(): @@ -1954,12 +2084,16 @@ def on_up(self, host): futures_lock = Lock() futures_results = [] callback = partial(self._on_up_future_completed, host, futures, futures_results, futures_lock) + callback_futures = [] for session in tuple(self.sessions): future = session.add_or_renew_pool(host, is_host_addition=False) if future is not None: have_future = True - future.add_done_callback(callback) futures.add(future) + callback_futures.append(future) + + for future in callback_futures: + future.add_done_callback(callback) except Exception: log.exception("Unexpected failure handling node %s being marked up:", host) for future in futures: @@ -2030,8 +2164,7 @@ def on_down(self, host, is_host_addition, expect_host_to_be_down=False): if self._discount_down_events and self.profile_manager.distance(host) != HostDistance.IGNORED: connected = False for session in tuple(self.sessions): - pool_states = session.get_pool_state() - pool_state = pool_states.get(host) + pool_state = session.get_pool_state_for_host(host) if pool_state: connected |= pool_state['open_count'] > 0 if connected: @@ -2220,7 +2353,18 @@ def get_control_connection_host(self): """ connection = self.control_connection._connection endpoint = connection.endpoint if connection else None - return self.metadata.get_host(endpoint) if endpoint else None + if not endpoint: + return None + + host = self.metadata.get_host(endpoint) + if host is not None: + return host + + host_id = self.control_connection._current_host_id + if host_id is None: + return None + + return self.metadata.get_host_by_host_id(host_id) def refresh_schema_metadata(self, max_schema_agreement_wait=None): """ @@ -2383,6 +2527,39 @@ def _prepare_all_queries(self, host): if connection: connection.close() + def _prepare_query_on_all_hosts(self, query, excluded_host, keyspace=None): + excluded_host_id = getattr(excluded_host, "host_id", None) + for host in tuple(self.metadata.all_hosts()): + if host == excluded_host or ( + excluded_host_id is not None and host.host_id == excluded_host_id): + continue + if not host.is_up: + continue + + connection = None + try: + connection = self.connection_factory(host.endpoint) + response = connection.wait_for_response( + PrepareMessage(query=query, keyspace=keyspace), + timeout=5.0) + if getattr(response, "query_id", None) is None: + log.debug( + "Got unexpected response when preparing query on host %s: %r", + host, response) + except OperationTimedOut as timeout: + log.warning( + "Timed out trying to prepare query on host %s: %s", + host, timeout) + except (ConnectionException, socket.error) as exc: + log.warning( + "Error trying to prepare query on host %s: %r", + host, exc) + except Exception: + log.exception("Error trying to prepare query on host %s", host) + finally: + if connection: + connection.close() + def add_prepared(self, query_id, prepared_statement): with self._prepared_statement_lock: self._prepared_statements[query_id] = prepared_statement @@ -2924,8 +3101,11 @@ def _on_analytics_master_result(self, response, master_future, query_future): delimiter_index = addr.rfind(':') # assumes : - not robust, but that's what is being provided if delimiter_index > 0: addr = addr[:delimiter_index] - targeted_query = HostTargetingStatement(query_future.query, addr) - query_future.query_plan = query_future._load_balancer.make_query_plan(self.keyspace, targeted_query) + if query_future._host is not None: + query_future.query_plan = iter([query_future._host]) + else: + targeted_query = HostTargetingStatement(query_future.query, addr) + query_future.query_plan = query_future._load_balancer.make_query_plan(self.keyspace, targeted_query) except Exception: log.debug("Failed querying analytics master (request might not be routed optimally). " "Make sure the session is connecting to a graph analytics datacenter.", exc_info=True) @@ -3232,11 +3412,7 @@ def __del__(self): # when cluster.shutdown() is called explicitly. pass - def add_or_renew_pool(self, host, is_host_addition): - """ - For internal use only. - """ - distance = self._profile_manager.distance(host) + def _add_or_renew_pool_for_distance(self, host, distance, is_host_addition): if distance == HostDistance.IGNORED: return None @@ -3245,15 +3421,17 @@ def run_add_or_renew_pool(): new_pool = HostConnection(host, distance, self) except AuthenticationFailed as auth_exc: conn_exc = ConnectionException(str(auth_exc), endpoint=host) - self.cluster.signal_connection_failure(host, conn_exc, is_host_addition) + if self._signal_connection_failure(host, conn_exc): + self._handle_pool_down(host, is_host_addition) return False except Exception as conn_exc: log.warning("Failed to create connection pool for new host %s:", host, exc_info=conn_exc) # the host itself will still be marked down, so we need to pass # a special flag to make sure the reconnector is created - self.cluster.signal_connection_failure( - host, conn_exc, is_host_addition, expect_host_to_be_down=True) + if self._signal_connection_failure(host, conn_exc): + self._handle_pool_down( + host, is_host_addition, expect_host_to_be_down=True) return False previous = self._pools.get(host) @@ -3271,7 +3449,7 @@ def callback(pool, errors): set_keyspace_event.wait(self.cluster.connect_timeout) if not set_keyspace_event.is_set() or errors_returned: log.warning("Failed setting keyspace for pool after keyspace changed during connect: %s", errors_returned) - self.cluster.on_down(host, is_host_addition) + self._handle_pool_down(host, is_host_addition) new_pool.shutdown() self._lock.acquire() return False @@ -3286,6 +3464,13 @@ def callback(pool, errors): return self.submit(run_add_or_renew_pool) + def add_or_renew_pool(self, host, is_host_addition): + """ + For internal use only. + """ + distance = self._profile_manager.distance(host) + return self._add_or_renew_pool_for_distance(host, distance, is_host_addition) + def remove_pool(self, host): pool = self._pools.pop(host, None) if pool: @@ -3410,9 +3595,21 @@ def submit(self, fn, *args, **kwargs): def get_pool_state(self): return dict((host, pool.get_state()) for host, pool in tuple(self._pools.items())) + def get_pool_state_for_host(self, host): + return self.get_pool_state().get(host) + def get_pools(self): return self._pools.values() + def _signal_connection_failure(self, host, connection_exc): + return host.signal_connection_failure(connection_exc) + + def _handle_pool_down(self, host, is_host_addition, expect_host_to_be_down=False): + self.cluster.on_down(host, is_host_addition, expect_host_to_be_down) + + def is_shard_aware_disabled(self): + return self.cluster.shard_aware_options.disable + def _validate_set_legacy_config(self, attr_name, value): if self.cluster._config_mode == _ConfigMode.PROFILES: raise ValueError("Cannot set Session.%s while using Configuration Profiles. Set this in a profile instead." % (attr_name,)) @@ -3420,6 +3617,377 @@ def _validate_set_legacy_config(self, attr_name, value): self.cluster._config_mode = _ConfigMode.LEGACY +class _ControlHost(Host): + """ + Host clone used by control-host sessions. + + Connection failures on the proxy endpoint should only affect the pinned + session pool. The source host remains responsible for global cluster + liveness, load-balancing membership, and reconnection handling. + """ + + +class _ControlHostReconnectionHandler(_ReconnectionHandler): + """ + Reconnects a control-host session pool without notifying the shared + cluster state for the source host. + """ + + def __init__(self, session, host, *args, **kwargs): + _ReconnectionHandler.__init__(self, *args, **kwargs) + self.session = weakref.proxy(session) + self.host = host + + def try_reconnect(self): + return self.session.cluster.connection_factory(self.host.endpoint) + + def on_reconnection(self, connection): + self.session._on_control_host_reconnection() + + def on_exception(self, exc, next_delay): + if isinstance(exc, AuthenticationFailed): + return False + log.warning("Error attempting to reconnect control-host session to %s, scheduling retry in %s seconds: %s", + self.host, next_delay, exc) + log.debug("Control-host session reconnection error details", exc_info=True) + return True + + +class _ControlHostSession(Session): + """ + Session variant pinned to a single host and endpoint. + + This is used for proxy deployments where the control connection reaches a + node through an endpoint that differs from the node metadata advertised by + the cluster. All requests are explicitly targeted to the pinned host so the + normal load-balancing query plan does not reintroduce other hosts. + """ + + def __init__(self, cluster, source_host, pinned_host, keyspace=None): + self._source_host = source_host + self._pinned_host = pinned_host + self._pinned_host_rebind_error = None + super(_ControlHostSession, self).__init__(cluster, [pinned_host], keyspace) + + def _ensure_valid_pinned_host(self): + if self._pinned_host_rebind_error is not None: + raise DriverException(self._pinned_host_rebind_error) + + def _replace_pinned_host(self, source_host, endpoint): + stale_pool = None + old_pinned_host = None + new_pinned_host = self.cluster._clone_host_with_endpoint( + source_host, endpoint, host_cls=_ControlHost) + + with self._lock: + if self._pinned_host.host_id == source_host.host_id and self._pinned_host.endpoint == endpoint: + self._pinned_host_rebind_error = None + return + + old_pinned_host = self._pinned_host + stale_pool = self._pools.pop(old_pinned_host, None) + self._pinned_host = new_pinned_host + self._pinned_host_rebind_error = None + + self._cancel_control_host_reconnector(old_pinned_host) + if stale_pool is not None: + self.submit(stale_pool.shutdown) + + def _invalidate_pinned_host(self, error_message): + stale_pool = None + with self._lock: + self._pinned_host_rebind_error = error_message + stale_pool = self._pools.pop(self._pinned_host, None) + + self._cancel_control_host_reconnector() + if stale_pool is not None: + self.submit(stale_pool.shutdown) + + def _resolve_pinned_host(self, host): + if host is None: + return self._pinned_host + if host == self._pinned_host: + return self._pinned_host + if isinstance(host, Host) and host.host_id == self._pinned_host.host_id: + return self._pinned_host + raise ValueError( + "This session is pinned to control host %s; explicit host %s is not supported" + % (self._pinned_host, host)) + + def _resolve_source_host(self, host): + control_host = self.cluster.get_control_connection_host() + if host is None or host == self._pinned_host: + return control_host or self._source_host + if host == self._source_host: + return control_host or self._source_host + if isinstance(host, Host): + if control_host is not None and host.host_id == control_host.host_id: + return control_host + if host.host_id == self._source_host.host_id: + return control_host or host + raise ValueError( + "This session is pinned to control host %s; source host %s is not supported" + % (self._source_host, host)) + + def _refresh_current_control_endpoint(self, source_host): + connection = self.cluster.control_connection._connection + endpoint = connection.endpoint if connection else None + if not endpoint: + return source_host + if self._pinned_host.host_id == source_host.host_id and \ + self._pinned_host.endpoint == endpoint: + return source_host + + try: + endpoint = self.cluster._get_control_host_session_endpoint(source_host, endpoint) + except DriverException as exc: + self._invalidate_pinned_host(str(exc)) + return None + + self._replace_pinned_host(source_host, endpoint) + + return source_host + + def _sync_from_source_host(self, host): + if host is not None and not isinstance(host, Host): + return None + try: + source_host = self._resolve_source_host(host) + except ValueError: + return None + + source_host = self._refresh_current_control_endpoint(source_host) + if source_host is None: + return None + self._source_host = source_host + self._pinned_host.set_location_info(source_host.datacenter, source_host.rack) + self._pinned_host.release_version = source_host.release_version + self._pinned_host.dse_version = source_host.dse_version + self._pinned_host.dse_workload = source_host.dse_workload + self._pinned_host.dse_workloads = source_host.dse_workloads + if source_host.is_up: + self._pinned_host.set_up() + elif source_host.is_up is False: + self._pinned_host.set_down() + else: + self._pinned_host.is_up = None + return source_host + + def _pinned_host_distance(self): + distance = self._profile_manager.distance(self._pinned_host) + if distance != HostDistance.IGNORED: + return distance + return self._profile_manager.distance(self._source_host) + + @staticmethod + def _completed_future(result): + future = Future() + future.set_result(result) + return future + + def _completed_future_after(self, future): + if future is None: + return self._completed_future(True) + + completed_future = Future() + + def callback(inner_future): + try: + inner_future.result() + except Exception: + log.warning("Unexpected failure while refreshing control-host session pool for %s", + self._pinned_host, exc_info=True) + completed_future.set_result(True) + + future.add_done_callback(callback) + return completed_future + + def _is_current_pinned_host(self, host): + return host is self._pinned_host + + def _is_stale_pinned_host(self, host): + return isinstance(host, _ControlHost) and not self._is_current_pinned_host(host) + + def _refresh_pinned_pool_for_source_host(self, is_host_addition): + distance = self._pinned_host_distance() + pool = self._pools.get(self._pinned_host) + future = None + + if not pool or pool.is_shutdown: + if distance != HostDistance.IGNORED: + future = self._add_or_renew_pool_for_distance( + self._pinned_host, distance, is_host_addition) + elif distance != pool.host_distance: + if distance == HostDistance.IGNORED: + future = super(_ControlHostSession, self).remove_pool(self._pinned_host) + else: + pool.host_distance = distance + + return self._completed_future_after(future) + + def get_pool_state_for_host(self, host): + if not self._is_current_pinned_host(host): + return None + return super(_ControlHostSession, self).get_pool_state_for_host(self._pinned_host) + + def _create_response_future(self, query, parameters, trace, custom_payload, + timeout, execution_profile=EXEC_PROFILE_DEFAULT, + paging_state=None, host=None): + self._ensure_valid_pinned_host() + pinned_host = self._resolve_pinned_host(host) + return super(_ControlHostSession, self)._create_response_future( + query, parameters, trace, custom_payload, timeout, + execution_profile=execution_profile, + paging_state=paging_state, host=pinned_host) + + def prepare(self, query, custom_payload=None, keyspace=None): + self._ensure_valid_pinned_host() + message = PrepareMessage(query=query, keyspace=keyspace) + future = ResponseFuture( + self, message, query=None, timeout=self.default_timeout, + host=self._pinned_host) + try: + future.send_request() + response = future.result().one() + except Exception: + log.exception("Error preparing query:") + raise + + prepared_keyspace = keyspace if keyspace else None + prepared_statement = PreparedStatement.from_message( + response.query_id, response.bind_metadata, response.pk_indexes, + self.cluster.metadata, query, prepared_keyspace, + self._protocol_version, response.column_metadata, + response.result_metadata_id, response.is_lwt, + self.cluster.column_encryption_policy) + prepared_statement.custom_payload = future.custom_payload + + self.cluster.add_prepared(response.query_id, prepared_statement) + # Control-host sessions are intentionally limited to the pinned + # endpoint. Fan-out through cluster metadata can reintroduce + # unreachable private node addresses in proxy-only deployments. + + return prepared_statement + + def add_or_renew_pool(self, host, is_host_addition): + tracked_source_host_id = getattr(self._source_host, "host_id", None) + is_pinned_host = self._is_current_pinned_host(host) + if not self._sync_from_source_host(host): + return None + if not is_pinned_host: + if getattr(host, "host_id", None) == tracked_source_host_id: + # Source-host UP/ADD handling should not be gated on the + # control-host pool outcome, but the pinned pool still needs to + # be reopened if it was previously removed. + return self._refresh_pinned_pool_for_source_host(is_host_addition) + return self._completed_future(True) + return self._add_or_renew_pool_for_distance( + self._pinned_host, self._pinned_host_distance(), is_host_addition) + + def remove_pool(self, host): + if not self._sync_from_source_host(host): + return None + if not self._is_current_pinned_host(host): + return None + return super(_ControlHostSession, self).remove_pool(self._pinned_host) + + def update_created_pools(self): + futures = set() + source_host = self._sync_from_source_host(None) + if not source_host: + return futures + distance = self._pinned_host_distance() + pool = self._pools.get(self._pinned_host) + future = None + if not pool or pool.is_shutdown: + if distance != HostDistance.IGNORED and source_host.is_up in (True, None): + future = self._add_or_renew_pool_for_distance( + self._pinned_host, distance, False) + elif distance != pool.host_distance: + if distance == HostDistance.IGNORED: + future = super(_ControlHostSession, self).remove_pool(self._pinned_host) + else: + pool.host_distance = distance + if future: + futures.add(future) + return futures + + def _signal_connection_failure(self, host, connection_exc): + if self._is_stale_pinned_host(host): + return False + if not self._is_current_pinned_host(host): + return super(_ControlHostSession, self)._signal_connection_failure(host, connection_exc) + return self._pinned_host.conviction_policy.add_failure(connection_exc) + + def _handle_pool_down(self, host, is_host_addition, expect_host_to_be_down=False): + if self._is_stale_pinned_host(host): + return + if not self._is_current_pinned_host(host): + super(_ControlHostSession, self)._handle_pool_down( + host, is_host_addition, expect_host_to_be_down) + return + + with self._pinned_host.lock: + was_up = self._pinned_host.is_up + self._pinned_host.set_down() + if (not was_up and not expect_host_to_be_down) or self._pinned_host.is_currently_reconnecting(): + return + + future = super(_ControlHostSession, self).remove_pool(self._pinned_host) + if future: + future.add_done_callback(lambda f: self._start_control_host_reconnector()) + else: + self._start_control_host_reconnector() + + def _start_control_host_reconnector(self): + if self.is_shutdown or self._source_host.is_up is False: + return + + schedule = self.cluster.reconnection_policy.new_schedule() + reconnector = _ControlHostReconnectionHandler( + self, self._pinned_host, + self.cluster.scheduler, schedule, lambda: None) + + old_reconnector = self._pinned_host.get_and_set_reconnection_handler(reconnector) + if old_reconnector: + log.debug("Old control-host reconnector found for %s, cancelling", self._pinned_host) + old_reconnector.cancel() + + log.debug("Starting control-host reconnector for %s", self._pinned_host) + reconnector.start() + + def _cancel_control_host_reconnector(self, host=None): + target_host = self._pinned_host if host is None else host + reconnector = target_host.get_and_set_reconnection_handler(None) + if reconnector: + reconnector.cancel() + + def _on_control_host_reconnection(self): + self._cancel_control_host_reconnector() + self._sync_from_source_host(None) + self.update_created_pools() + + def on_down(self, host): + if not self._sync_from_source_host(host): + return + self._cancel_control_host_reconnector() + future = super(_ControlHostSession, self).remove_pool(self._pinned_host) + if future: + future.add_done_callback(lambda f: self.update_created_pools()) + else: + self.update_created_pools() + + def on_remove(self, host): + self.on_down(host) + + def shutdown(self): + self._cancel_control_host_reconnector() + super(_ControlHostSession, self).shutdown() + + def is_shard_aware_disabled(self): + return True + + class UserTypeDoesNotExist(Exception): """ An attempt was made to use a user-defined type that does not exist. @@ -3545,6 +4113,7 @@ def __init__(self, cluster, timeout, self._reconnection_handler = None self._reconnection_lock = RLock() + self._current_host_id = None self._event_schedule_times = {} @@ -3569,6 +4138,10 @@ def _set_new_connection(self, conn): log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn) old.close() + for session in tuple(getattr(self._cluster, "sessions", ())): + if isinstance(session, _ControlHostSession): + session.update_created_pools() + def _try_connect_to_hosts(self): errors = {} @@ -3770,6 +4343,7 @@ def shutdown(self): if self._connection: self._connection.close() self._connection = None + self._current_host_id = None def refresh_schema(self, force=False, **kwargs): try: @@ -3849,6 +4423,7 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, found_host_ids = set() found_endpoints = set() + local_host_id = None if local_result.parsed_rows: local_rows = dict_factory(local_result.column_names, local_result.parsed_rows) local_row = local_rows[0] @@ -3857,6 +4432,7 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, partitioner = local_row.get("partitioner") tokens = local_row.get("tokens", None) + local_host_id = local_row.get("host_id") peers_result.insert(0, local_row) @@ -3928,6 +4504,7 @@ def _refresh_node_list_and_token_map(self, connection, preloaded_results=None, self._cluster.metadata.remove_host_by_host_id(old_host_id, old_host.endpoint) log.debug("[control connection] Finished fetching ring info") + self._current_host_id = local_host_id if local_host_id in found_host_ids else None if partitioner and should_rebuild_token_map: log.debug("[control connection] Rebuilding token map due to topology changes") self._cluster.metadata.rebuild_token_map(partitioner, token_map) @@ -4222,25 +4799,50 @@ def _signal_error(self): # try just signaling the cluster, as this will trigger a reconnect # as part of marking the host down if self._connection and self._connection.is_defunct: - host = self._cluster.metadata.get_host(self._connection.endpoint) + host = self._try_get_cluster_host() # host may be None if it's already been removed, but that indicates # that errors have already been reported, so we're fine if host: - self._cluster.signal_connection_failure( + is_down = self._cluster.signal_connection_failure( host, self._connection.last_error, is_host_addition=False) - return + if is_down: + return # if the connection is not defunct or the host already left, reconnect # manually self.reconnect() + def _try_get_cluster_host(self): + conn = self._connection + endpoint = conn.endpoint if conn else None + if not endpoint: + return None + + host = self._cluster.metadata.get_host(endpoint) + if host is not None: + return host + + host_id = self._current_host_id + if host_id is None: + return None + + return self._cluster.metadata.get_host_by_host_id(host_id) + def on_up(self, host): pass - def on_down(self, host): - + def _is_current_host(self, host): conn = self._connection - if conn and conn.endpoint == host.endpoint and \ + if conn is None or host is None: + return False + + if conn.endpoint == host.endpoint: + return True + + return self._current_host_id is not None and getattr(host, 'host_id', None) == self._current_host_id + + def on_down(self, host): + if self._is_current_host(host) and \ self._reconnection_handler is None: log.debug("[control connection] Control connection host (%s) is " "considered down, starting reconnection", host) @@ -4252,8 +4854,7 @@ def on_add(self, host, refresh_nodes=True): self.refresh_node_list_and_token_map(force_token_rebuild=True) def on_remove(self, host): - c = self._connection - if c and c.endpoint == host.endpoint: + if self._is_current_host(host): log.debug("[control connection] Control connection host (%s) is being removed. Reconnecting", host) # refresh will be done on reconnect self.reconnect() diff --git a/cassandra/pool.py b/cassandra/pool.py index 9e949c342c..097824be06 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -435,7 +435,7 @@ def __init__(self, host, host_distance, session): if self._keyspace: first_connection.set_keyspace_blocking(self._keyspace) - if first_connection.features.sharding_info and not self._session.cluster.shard_aware_options.disable: + if first_connection.features.sharding_info and not self._session.is_shard_aware_disabled(): self.host.sharding_info = first_connection.features.sharding_info self._open_connections_for_all_shards(first_connection.features.shard_id) self.tablets_routing_v1 = first_connection.features.tablets_routing_v1 @@ -451,7 +451,7 @@ def _get_connection_for_routing_key(self, routing_key=None, keyspace=None, table raise NoConnectionsAvailable() shard_id = None - if not self._session.cluster.shard_aware_options.disable and self.host.sharding_info and routing_key: + if not self._session.is_shard_aware_disabled() and self.host.sharding_info and routing_key: t = self._session.cluster.metadata.token_map.token_class.from_key(routing_key) shard_id = None @@ -554,7 +554,7 @@ def return_connection(self, connection, stream_was_orphaned=False): if not connection.signaled_error: log.debug("Defunct or closed connection (%s) returned to pool, potentially " "marking host %s as down", id(connection), self.host) - is_down = self.host.signal_connection_failure(connection.last_error) + is_down = self._session._signal_connection_failure(self.host, connection.last_error) connection.signaled_error = True if self.shutdown_on_error and not is_down: @@ -562,7 +562,7 @@ def return_connection(self, connection, stream_was_orphaned=False): if is_down: self.shutdown() - self._session.cluster.on_down(self.host, is_host_addition=False) + self._session._handle_pool_down(self.host, is_host_addition=False) else: connection.close() with self._lock: @@ -603,7 +603,7 @@ def _replace(self, connection): try: if connection.features.shard_id in self._connections: del self._connections[connection.features.shard_id] - if self.host.sharding_info and not self._session.cluster.shard_aware_options.disable: + if self.host.sharding_info and not self._session.is_shard_aware_disabled(): self._connecting.add(connection.features.shard_id) self._session.submit(self._open_connection_to_missing_shard, connection.features.shard_id) else: @@ -678,7 +678,8 @@ def disable_advanced_shard_aware(self, secs): def _get_shard_aware_endpoint(self): if (self.advanced_shardaware_block_until and self.advanced_shardaware_block_until > time.time()) or \ - self._session.cluster.shard_aware_options.disable_shardaware_port: + self._session.cluster.shard_aware_options.disable_shardaware_port or \ + self._session.is_shard_aware_disabled(): return None endpoint = None @@ -920,5 +921,3 @@ def open_count(self): @property def _excess_connection_limit(self): return self.host.sharding_info.shards_count * self.max_excess_connections_per_shard_multiplier - - diff --git a/docs/api/cassandra/cluster.rst b/docs/api/cassandra/cluster.rst index 51f03f3d97..1ae4f8ef34 100644 --- a/docs/api/cassandra/cluster.rst +++ b/docs/api/cassandra/cluster.rst @@ -78,6 +78,8 @@ Clusters and Sessions .. automethod:: connect + .. automethod:: connect_to_control_host + .. automethod:: shutdown .. automethod:: register_user_type diff --git a/tests/integration/standard/test_client_routes.py b/tests/integration/standard/test_client_routes.py index 5a20421276..ac424e8128 100644 --- a/tests/integration/standard/test_client_routes.py +++ b/tests/integration/standard/test_client_routes.py @@ -38,9 +38,9 @@ import json as _json import urllib.request -from cassandra.cluster import Cluster +from cassandra.cluster import Cluster, NoHostAvailable from cassandra.client_routes import ClientRoutesConfig, ClientRouteProxy -from cassandra.connection import ClientRoutesEndPoint +from cassandra.connection import ClientRoutesEndPoint, ConnectionException from cassandra.policies import RoundRobinPolicy from tests.integration import ( TestCluster, @@ -54,6 +54,28 @@ log = logging.getLogger(__name__) + +class ProxyOnlyReachableConnection(Cluster.connection_class): + """ + Simulates a private-link client that can reach only the proxy endpoint. + + The CCM node addresses are reachable from the local test runner, which means + the existing client-routes tests cannot reproduce bugs that only appear when + direct node IPs are private. This connection class rejects those direct node + addresses while still allowing the NLB address. + """ + + @classmethod + def factory(cls, endpoint, timeout, host_conn=None, *args, **kwargs): + address, _ = endpoint.resolve() + if address.startswith("127.0.0."): + raise ConnectionException( + "Simulated private node address %s is unreachable from the client" % address, + endpoint=endpoint, + ) + return super().factory(endpoint, timeout, host_conn=host_conn, *args, **kwargs) + + class TcpProxy: """ A simple TCP proxy that forwards connections from a local listen port @@ -535,6 +557,64 @@ def teardown_module(): else: os.environ['SCYLLA_EXT_OPTS'] = _saved_scylla_ext_opts + +class TestProxyConnectivityWithoutClientRoutes(unittest.TestCase): + """ + Reproducer for connecting through a generic proxy when node addresses are + not reachable from the client. + + The initial control connection can reach the cluster through the proxy, but + the driver later tries to open pools to the discovered node addresses + directly. In a proxy-only environment that makes connect/query fail. + """ + + @classmethod + def setUpClass(cls): + cls.node_addrs = { + 1: "127.0.0.1", + 2: "127.0.0.2", + 3: "127.0.0.3", + } + cls.proxy_node_id = 1 + cls.nlb = NLBEmulator() + cls.nlb.start(cls.node_addrs) + + @classmethod + def tearDownClass(cls): + cls.nlb.stop() + + def _make_proxy_cluster(self): + return Cluster( + contact_points=[NLBEmulator.LISTEN_HOST], + port=self.nlb.node_port(self.proxy_node_id), + connection_class=ProxyOnlyReachableConnection, + load_balancing_policy=RoundRobinPolicy(), + ) + + def test_default_session_fails_when_only_proxy_is_reachable(self): + cluster = self._make_proxy_cluster() + self.addCleanup(cluster.shutdown) + + with self.assertRaises(NoHostAvailable): + cluster.connect() + + def test_control_host_session_succeeds_when_only_proxy_is_reachable(self): + cluster = self._make_proxy_cluster() + self.addCleanup(cluster.shutdown) + + session = cluster.connect_to_control_host() + row = session.execute( + "SELECT release_version FROM system.local WHERE key='local'" + ).one() + + self.assertIsNotNone(row) + pool_state = session.get_pool_state() + self.assertEqual(len(pool_state), 1) + + session_host = next(iter(pool_state)) + self.assertEqual(session_host.endpoint.address, NLBEmulator.LISTEN_HOST) + self.assertEqual(session_host.endpoint.port, self.nlb.node_port(self.proxy_node_id)) + @skip_scylla_version_lt(reason='scylladb/scylladb#26992 - system.client_routes is not yet supported', scylla_version="2026.1.0") class TestGetHostPortMapping(unittest.TestCase): diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index 872d133b28..ed47c3a60f 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -15,16 +15,19 @@ import logging import socket - +from concurrent.futures import Future +from threading import RLock from unittest.mock import patch, Mock import uuid from cassandra import ConsistencyLevel, DriverException, Timeout, Unavailable, RequestExecutionException, ReadTimeout, WriteTimeout, CoordinationFailure, ReadFailure, WriteFailure, FunctionFailure, AlreadyExists,\ InvalidRequest, Unauthorized, AuthenticationFailed, OperationTimedOut, UnsupportedOperation, RequestValidationException, ConfigurationException, ProtocolVersion from cassandra.cluster import _Scheduler, Session, Cluster, default_lbp_factory, \ - ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT -from cassandra.pool import Host -from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy + ExecutionProfile, _ConfigMode, EXEC_PROFILE_DEFAULT, _ControlHostSession, ResponseFuture, NoHostAvailable +from cassandra.connection import DefaultEndPoint, EndPoint +from cassandra.pool import Host, HostConnection +from cassandra.policies import HostDistance, RetryPolicy, RoundRobinPolicy, DowngradingConsistencyRetryPolicy, SimpleConvictionPolicy, WhiteListRoundRobinPolicy +from cassandra.protocol import QueryMessage from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory from tests.unit.utils import mock_session_pools from tests import connection_class @@ -33,6 +36,39 @@ log = logging.getLogger(__name__) + +class _HostAwareProxyEndPoint(EndPoint): + def __init__(self, address, affinity_key, port=9042): + self._address = address + self._affinity_key = affinity_key + self._port = port + + @property + def address(self): + return self._address + + @property + def port(self): + return self._port + + def resolve(self): + return self._address, self._port + + def __eq__(self, other): + return isinstance(other, _HostAwareProxyEndPoint) and \ + self.address == other.address and self.port == other.port and \ + self._affinity_key == other._affinity_key + + def __hash__(self): + return hash((self.address, self.port, self._affinity_key)) + + def __lt__(self, other): + if not isinstance(other, _HostAwareProxyEndPoint): + return NotImplemented + return (self.address, self.port, str(self._affinity_key)) < \ + (other.address, other.port, str(other._affinity_key)) + + class ExceptionTypeTest(unittest.TestCase): def test_exception_types(self): @@ -171,6 +207,62 @@ def test_connection_factory_passes_compression_kwarg(self): assert factory.call_args.kwargs['compression'] == expected assert cluster.compression == expected + def test_get_control_connection_host_falls_back_to_host_id(self): + cluster = Cluster(contact_points=['127.0.0.1']) + host = Host(DefaultEndPoint('192.168.1.10'), SimpleConvictionPolicy, host_id=uuid.uuid4()) + + metadata = Mock() + metadata.get_host.return_value = None + metadata.get_host_by_host_id.return_value = host + cluster.metadata = metadata + + connection = Mock(endpoint=DefaultEndPoint('127.254.254.101', 9042)) + cluster.control_connection = Mock(_connection=connection, _current_host_id=host.host_id) + + assert cluster.get_control_connection_host() is host + metadata.get_host.assert_called_once_with(connection.endpoint) + metadata.get_host_by_host_id.assert_called_once_with(host.host_id) + + def test_is_shard_aware_ignores_non_shard_aware_pools(self): + cluster = Cluster(contact_points=['127.0.0.1']) + + shard_pool = Mock() + shard_pool.host = Mock( + endpoint=DefaultEndPoint("127.0.0.1"), + sharding_info=Mock(shards_count=8)) + shard_pool._connections = {0: Mock(), 1: Mock()} + + control_pool = Mock() + control_pool.host = Mock( + endpoint=DefaultEndPoint("127.254.254.101"), + sharding_info=None) + control_pool._connections = {0: Mock()} + + cluster.get_all_pools = Mock(return_value=[control_pool, shard_pool]) + + assert cluster.is_shard_aware() is True + + def test_shard_aware_stats_ignores_non_shard_aware_pools(self): + cluster = Cluster(contact_points=['127.0.0.1']) + + shard_pool = Mock() + shard_pool.host = Mock( + endpoint=DefaultEndPoint("127.0.0.1"), + sharding_info=Mock(shards_count=8)) + shard_pool._connections = {0: Mock(), 1: Mock()} + + control_pool = Mock() + control_pool.host = Mock( + endpoint=DefaultEndPoint("127.254.254.101"), + sharding_info=None) + control_pool._connections = {0: Mock()} + + cluster.get_all_pools = Mock(return_value=[shard_pool, control_pool]) + + assert cluster.shard_aware_stats() == { + "127.0.0.1:9042": {"shards_count": 8, "connected": 2} + } + class SchedulerTest(unittest.TestCase): # TODO: this suite could be expanded; for now just adding a test covering a ticket @@ -194,6 +286,25 @@ def setUp(self): raise unittest.SkipTest('libev does not appear to be installed correctly') connection_class.initialize_reactor() + @staticmethod + def _completed_future(result): + future = Future() + future.set_result(result) + return future + + @staticmethod + def _proxy_endpoint(address, affinity_key, port=9042): + return _HostAwareProxyEndPoint(address, affinity_key, port) + + def _make_control_host_session(self, cluster, source_host, endpoint="127.254.254.101"): + endpoint = endpoint if isinstance(endpoint, EndPoint) else self._proxy_endpoint(endpoint, source_host.host_id) + cluster.control_connection = Mock( + _connection=Mock(endpoint=endpoint)) + cluster.get_control_connection_host = Mock(return_value=source_host) + with patch.object(Session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)): + return cluster._new_control_host_session(None) + # TODO: this suite could be expanded; for now just adding a test covering a PR @mock_session_pools def test_default_serial_consistency_level_ep(self, *_): @@ -281,6 +392,883 @@ def test_set_keyspace_escapes_quotes(self, *_): assert query == 'USE simple_ks', ( "Simple keyspace names should not be quoted, got: %r" % query) + def test_control_host_session_prepare_targets_pinned_host(self): + cluster = Cluster( + load_balancing_policy=RoundRobinPolicy(), + protocol_version=4, + prepare_on_all_hosts=False) + self.addCleanup(cluster.shutdown) + + host_id = uuid.uuid4() + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=host_id) + source_host.set_up() + pinned_host = Host(DefaultEndPoint("127.254.254.101"), SimpleConvictionPolicy, host_id=host_id) + pinned_host.set_up() + + with patch.object(Session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)): + session = _ControlHostSession(cluster, source_host, pinned_host) + + prepare_response = Mock( + query_id=b"prepared-id", + bind_metadata=[], + pk_indexes=[], + column_metadata=[], + result_metadata_id=None, + is_lwt=False) + response_rows = Mock() + response_rows.one.return_value = prepare_response + + response_future = Mock() + response_future.result.return_value = response_rows + response_future.custom_payload = {"prepared-by": b"control-host"} + + prepared_statement = Mock() + with patch.object(cluster, "add_prepared") as add_prepared, \ + patch("cassandra.cluster.ResponseFuture", return_value=response_future) as response_future_cls, \ + patch("cassandra.cluster.PreparedStatement.from_message", + return_value=prepared_statement): + prepared = session.prepare("SELECT release_version FROM system.local") + + assert prepared is prepared_statement + response_future_cls.assert_called_once() + assert response_future_cls.call_args.kwargs["host"] is pinned_host + add_prepared.assert_called_once_with(prepare_response.query_id, prepared_statement) + assert prepared_statement.custom_payload == response_future.custom_payload + + def test_control_host_session_prepare_skips_raw_host_fanout_on_all_hosts(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + source_host = Host( + DefaultEndPoint("127.0.0.1"), + SimpleConvictionPolicy, + host_id=uuid.uuid4()) + source_host.set_up() + pinned_host = Host( + DefaultEndPoint("127.254.254.101"), + SimpleConvictionPolicy, + host_id=source_host.host_id) + pinned_host.set_up() + + with patch.object(Session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)): + session = _ControlHostSession(cluster, source_host, pinned_host) + + prepare_response = Mock( + query_id=b"prepared-id", + bind_metadata=[], + pk_indexes=[], + column_metadata=[], + result_metadata_id=None, + is_lwt=False) + response_rows = Mock() + response_rows.one.return_value = prepare_response + + response_future = Mock() + response_future.result.return_value = response_rows + response_future.custom_payload = {"prepared-by": b"control-host"} + response_future._current_host = pinned_host + + prepared_statement = Mock( + query_string="SELECT release_version FROM system.local") + + with patch.object(cluster, "add_prepared") as add_prepared, \ + patch.object(cluster, "_prepare_query_on_all_hosts") as prepare_all, \ + patch("cassandra.cluster.ResponseFuture", + return_value=response_future) as response_future_cls, \ + patch("cassandra.cluster.PreparedStatement.from_message", + return_value=prepared_statement): + prepared = session.prepare("SELECT release_version FROM system.local") + + assert prepared is prepared_statement + response_future_cls.assert_called_once() + assert response_future_cls.call_args.kwargs["host"] is pinned_host + add_prepared.assert_called_once_with(prepare_response.query_id, prepared_statement) + prepare_all.assert_not_called() + + def test_connect_to_control_host_rejects_shared_default_endpoint(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + source_host.set_up() + + cluster.control_connection = Mock( + _connection=Mock(endpoint=DefaultEndPoint("127.254.254.101"))) + cluster.get_control_connection_host = Mock(return_value=source_host) + + verification_connection = Mock() + verification_connection.wait_for_response.return_value = Mock( + column_names=["host_id"], + parsed_rows=[(uuid.uuid4(),)]) + + with patch.object(cluster, "connection_factory", + return_value=verification_connection), \ + pytest.raises(DriverException, match="host-specific control endpoint"): + cluster._new_control_host_session(None) + + assert verification_connection.close.call_count == 1 + + def test_connect_to_control_host_accepts_host_specific_default_endpoint(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + source_host = Host( + DefaultEndPoint("127.0.0.1"), + SimpleConvictionPolicy, + host_id=uuid.uuid4()) + source_host.set_up() + + connection_endpoint = DefaultEndPoint("proxy.control.example") + cluster.control_connection = Mock( + _connection=Mock(endpoint=connection_endpoint)) + cluster.get_control_connection_host = Mock(return_value=source_host) + + verification_connection = Mock() + verification_connection.wait_for_response.return_value = Mock( + column_names=["host_id"], + parsed_rows=[(source_host.host_id,)]) + + with patch.object(cluster, "connection_factory", + return_value=verification_connection) as connection_factory, \ + patch.object(Session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)): + session = cluster._new_control_host_session(None) + + assert session._pinned_host.endpoint == connection_endpoint + assert connection_factory.call_count == 3 + assert verification_connection.close.call_count == 3 + + def test_control_host_session_uses_host_aware_metadata_endpoint_when_control_connection_is_shared(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host_id = uuid.uuid4() + source_endpoint = self._proxy_endpoint("proxy.control.example", host_id) + source_host = Host(source_endpoint, SimpleConvictionPolicy, host_id=host_id) + source_host.set_up() + + cluster.control_connection = Mock( + _connection=Mock(endpoint=DefaultEndPoint("127.254.254.101"))) + cluster.get_control_connection_host = Mock(return_value=source_host) + + with patch.object(Session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)): + session = cluster._new_control_host_session(None) + + assert session._pinned_host.endpoint == source_endpoint + + def test_control_host_session_analytics_master_lookup_keeps_pinned_host(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host_id = uuid.uuid4() + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=host_id) + source_host.set_up() + pinned_host = Host(self._proxy_endpoint("proxy.control.example", host_id), + SimpleConvictionPolicy, host_id=host_id) + pinned_host.set_up() + + with patch.object(Session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)): + session = _ControlHostSession(cluster, source_host, pinned_host) + + master_future = Mock() + master_future.result.return_value = [({'location': '127.0.0.99:8182'},)] + + query_future = Mock() + query_future._host = pinned_host + query_future.query = SimpleStatement("g.V()") + query_future._load_balancer = Mock() + query_future.send_request = Mock() + query_future.query_plan = iter(()) + + with patch.object(session, "submit", + side_effect=lambda fn, *args, **kwargs: self._completed_future(fn(*args, **kwargs))): + session._on_analytics_master_result(None, master_future, query_future) + + assert list(query_future.query_plan) == [pinned_host] + query_future._load_balancer.make_query_plan.assert_not_called() + query_future.send_request.assert_called_once_with() + + def test_control_host_session_preserves_source_host_distance(self): + cluster = Cluster( + load_balancing_policy=WhiteListRoundRobinPolicy(["127.0.0.1"]), + protocol_version=4) + self.addCleanup(cluster.shutdown) + + host_id = uuid.uuid4() + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=host_id) + source_host.set_up() + pinned_host = Host(DefaultEndPoint("127.254.254.101"), SimpleConvictionPolicy, host_id=host_id) + pinned_host.set_up() + + with patch.object(Session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)) as add_pool: + _ControlHostSession(cluster, source_host, pinned_host) + + add_pool.assert_called_once_with(pinned_host, HostDistance.LOCAL, False) + + def test_control_host_session_prefers_pinned_host_distance_for_proxy_whitelist(self): + cluster = Cluster( + load_balancing_policy=WhiteListRoundRobinPolicy(["127.254.254.101"]), + protocol_version=4) + self.addCleanup(cluster.shutdown) + + host_id = uuid.uuid4() + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=host_id) + source_host.set_up() + pinned_host = Host(DefaultEndPoint("127.254.254.101"), SimpleConvictionPolicy, host_id=host_id) + pinned_host.set_up() + + with patch.object(Session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)) as add_pool: + _ControlHostSession(cluster, source_host, pinned_host) + + add_pool.assert_called_once_with(pinned_host, HostDistance.LOCAL, False) + + def test_control_host_session_reopens_pinned_pool_during_source_host_lifecycle(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host_id = uuid.uuid4() + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=host_id) + source_host.set_up() + pinned_host = Host(DefaultEndPoint("127.254.254.101"), SimpleConvictionPolicy, host_id=host_id) + pinned_host.set_up() + + with patch.object(Session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)): + session = _ControlHostSession(cluster, source_host, pinned_host) + + with patch.object(session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(False)) as add_pool: + future = session.add_or_renew_pool(source_host, is_host_addition=False) + + assert future.result() is True + add_pool.assert_called_once_with(session._pinned_host, HostDistance.LOCAL, False) + + def test_control_host_session_waits_for_all_on_up_futures_before_marking_source_host_up(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host_id = uuid.uuid4() + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=host_id) + source_host.set_down() + + control_session = self._make_control_host_session(cluster, source_host) + + cluster._prepare_all_queries = Mock() + cluster._start_reconnector = Mock() + cluster.control_connection.on_up = Mock() + cluster.control_connection.on_down = Mock() + + regular_session = Mock(spec=Session) + regular_session.remove_pool.return_value = None + regular_session.add_or_renew_pool.return_value = self._completed_future(False) + regular_session.update_created_pools.return_value = set() + + cluster.sessions = [control_session, regular_session] + + with patch.object(control_session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)) as add_pool: + cluster.on_up(source_host) + + assert source_host.is_up is False + cluster.control_connection.on_up.assert_called_once_with(source_host) + cluster.control_connection.on_down.assert_called_once_with(source_host) + regular_session.add_or_renew_pool.assert_called_once_with(source_host, is_host_addition=False) + regular_session.update_created_pools.assert_not_called() + cluster._start_reconnector.assert_called_once_with(source_host, is_host_addition=False) + add_pool.assert_called_once_with(control_session._pinned_host, HostDistance.LOCAL, False) + + @mock_session_pools + def test_session_preserves_down_event_discounting_after_endpoint_update(self, *_): + class _DeterministicHashEndPoint(EndPoint): + def __init__(self, address, hash_value, port=9042): + self._address = address + self._hash_value = hash_value + self._port = port + + @property + def address(self): + return self._address + + @property + def port(self): + return self._port + + def resolve(self): + return self._address, self._port + + def __eq__(self, other): + return isinstance(other, _DeterministicHashEndPoint) and \ + self.address == other.address and self.port == other.port + + def __hash__(self): + return self._hash_value + + def __lt__(self, other): + return (self.address, self.port) < (other.address, other.port) + + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + host.set_up() + + session = Session(cluster, [host]) + cluster.sessions.add(session) + + pool = Mock() + pool.get_state.return_value = {"open_count": 1} + + host.endpoint = _DeterministicHashEndPoint("127.0.0.1", 1) + session._pools = {host: pool} + host.endpoint = _DeterministicHashEndPoint("127.0.0.2", 2) + + cluster.on_down_potentially_blocking = Mock() + + cluster.on_down(host, is_host_addition=False) + + assert host.is_up is True + cluster.on_down_potentially_blocking.assert_not_called() + assert session.get_pool_state_for_host(host) == {"open_count": 1} + + def test_control_host_session_does_not_mask_source_host_down_events(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host_id = uuid.uuid4() + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=host_id) + source_host.set_up() + + control_session = self._make_control_host_session(cluster, source_host) + pool = Mock() + pool.get_state.return_value = {"open_count": 1} + control_session._pools[control_session._pinned_host] = pool + + regular_session = Mock(spec=Session) + regular_session.get_pool_state_for_host.return_value = None + cluster.sessions = {control_session, regular_session} + + cluster.on_down_potentially_blocking = Mock() + + cluster.on_down(source_host, is_host_addition=False) + + assert source_host.is_up is False + cluster.on_down_potentially_blocking.assert_called_once_with(source_host, False) + assert control_session.get_pool_state_for_host(source_host) is None + + def test_control_host_session_does_not_mask_source_host_down_events_with_shared_host_aware_endpoint(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host_id = uuid.uuid4() + source_endpoint = self._proxy_endpoint("proxy.control.example", host_id) + source_host = Host(source_endpoint, SimpleConvictionPolicy, host_id=host_id) + source_host.set_up() + + cluster.control_connection = Mock( + _connection=Mock(endpoint=source_endpoint)) + cluster.get_control_connection_host = Mock(return_value=source_host) + with patch.object(Session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)): + control_session = cluster._new_control_host_session(None) + + pool = Mock() + pool.get_state.return_value = {"open_count": 1} + control_session._pools[control_session._pinned_host] = pool + + regular_session = Mock(spec=Session) + regular_session.get_pool_state_for_host.return_value = None + cluster.sessions = {control_session, regular_session} + + cluster.on_down_potentially_blocking = Mock() + + cluster.on_down(source_host, is_host_addition=False) + + assert source_host.is_up is False + cluster.on_down_potentially_blocking.assert_called_once_with(source_host, False) + assert control_session.get_pool_state_for_host(source_host) is None + + def test_control_host_session_proxy_failures_schedule_backoff_reconnects(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host_id = uuid.uuid4() + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=host_id) + source_host.set_up() + + session = self._make_control_host_session(cluster, source_host) + pool = Mock(is_shutdown=False) + session._pools[session._pinned_host] = pool + + cluster.on_down = Mock() + cluster.scheduler = Mock() + cluster.reconnection_policy = Mock(new_schedule=Mock(return_value=iter([1.25]))) + + def submit_sync(fn, *args, **kwargs): + return self._completed_future(fn(*args, **kwargs)) + + with patch.object(session, "submit", side_effect=submit_sync): + session._handle_pool_down(session._pinned_host, is_host_addition=False) + + pool.shutdown.assert_called_once_with() + cluster.on_down.assert_not_called() + cluster.scheduler.schedule.assert_called_once() + assert cluster.scheduler.schedule.call_args.args[0] == 1.25 + assert session._pinned_host.is_currently_reconnecting() is True + + def test_control_host_session_pool_failures_do_not_block_source_host_up(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host_id = uuid.uuid4() + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=host_id) + source_host.set_down() + + control_session = self._make_control_host_session(cluster, source_host) + + cluster.control_connection = Mock(on_up=Mock(), on_down=Mock()) + cluster._prepare_all_queries = Mock() + + regular_session = Mock(spec=Session) + regular_session.remove_pool.return_value = None + regular_session.add_or_renew_pool.return_value = self._completed_future(True) + regular_session.update_created_pools.return_value = set() + + cluster.sessions = {control_session, regular_session} + + with patch.object(control_session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(False)) as add_pool: + cluster.on_up(source_host) + + assert source_host.is_up is True + cluster.control_connection.on_up.assert_called_once_with(source_host) + cluster.control_connection.on_down.assert_not_called() + regular_session.add_or_renew_pool.assert_called_once_with(source_host, is_host_addition=False) + regular_session.update_created_pools.assert_called_once_with() + assert add_pool.call_count == 2 + assert add_pool.call_args_list[0].args == (control_session._pinned_host, HostDistance.LOCAL, False) + assert add_pool.call_args_list[1].args == (control_session._pinned_host, HostDistance.LOCAL, False) + + def test_control_host_session_pool_failures_do_not_block_source_host_up_with_shared_host_aware_endpoint(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host_id = uuid.uuid4() + source_endpoint = self._proxy_endpoint("proxy.control.example", host_id) + source_host = Host(source_endpoint, SimpleConvictionPolicy, host_id=host_id) + source_host.set_down() + + cluster.control_connection = Mock( + _connection=Mock(endpoint=source_endpoint), + on_up=Mock(), + on_down=Mock()) + cluster.get_control_connection_host = Mock(return_value=source_host) + with patch.object(Session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)): + control_session = cluster._new_control_host_session(None) + + cluster._prepare_all_queries = Mock() + + regular_session = Mock(spec=Session) + regular_session.remove_pool.return_value = None + regular_session.add_or_renew_pool.return_value = self._completed_future(True) + regular_session.update_created_pools.return_value = set() + + cluster.sessions = {control_session, regular_session} + + with patch.object(control_session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(False)) as add_pool: + cluster.on_up(source_host) + + assert source_host.is_up is True + cluster.control_connection.on_up.assert_called_once_with(source_host) + cluster.control_connection.on_down.assert_not_called() + regular_session.add_or_renew_pool.assert_called_once_with(source_host, is_host_addition=False) + regular_session.update_created_pools.assert_called_once_with() + assert add_pool.call_count == 2 + assert add_pool.call_args_list[0].args == (control_session._pinned_host, HostDistance.LOCAL, False) + assert add_pool.call_args_list[1].args == (control_session._pinned_host, HostDistance.LOCAL, False) + + def test_control_host_session_disables_shard_aware_fanout(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host_id = uuid.uuid4() + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=host_id) + source_host.set_up() + proxy_endpoint = self._proxy_endpoint("127.254.254.101", host_id) + + cluster.control_connection = Mock( + _connection=Mock(endpoint=proxy_endpoint)) + cluster.get_control_connection_host = Mock(return_value=source_host) + + first_connection = Mock(is_defunct=False, is_closed=False) + first_connection.features = Mock( + shard_id=0, + sharding_info=Mock(shard_aware_port=19042, shard_aware_port_ssl=None), + tablets_routing_v1=False) + cluster.connection_factory = Mock(return_value=first_connection) + + with patch.object(HostConnection, "_open_connections_for_all_shards") as open_shards: + session = cluster._new_control_host_session(None) + + open_shards.assert_not_called() + assert session.is_shard_aware_disabled() is True + assert session._pinned_host.sharding_info is None + + def test_control_host_session_update_created_pools_resyncs_source_host_state(self): + cluster = Cluster( + load_balancing_policy=WhiteListRoundRobinPolicy(["127.0.0.1"]), + protocol_version=4) + self.addCleanup(cluster.shutdown) + + host_id = uuid.uuid4() + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=host_id) + source_host.set_up() + pinned_host = Host(DefaultEndPoint("127.254.254.101"), SimpleConvictionPolicy, host_id=host_id) + pinned_host.set_down() + + with patch.object(Session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)): + session = _ControlHostSession(cluster, source_host, pinned_host) + + session._pinned_host.set_down() + source_host.set_up() + + with patch.object(session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)) as add_pool: + session.update_created_pools() + + assert session._pinned_host.is_up is True + add_pool.assert_called_once_with(session._pinned_host, HostDistance.LOCAL, False) + + def test_control_host_session_accepts_replacement_control_host_during_on_up(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + source_host.set_up() + replacement_host = Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + replacement_host.set_up() + + session = self._make_control_host_session(cluster, source_host, endpoint="127.254.254.101") + replacement_endpoint = self._proxy_endpoint("127.254.254.102", replacement_host.host_id) + cluster.control_connection._connection.endpoint = replacement_endpoint + cluster.get_control_connection_host.return_value = replacement_host + + with patch.object(session, "_add_or_renew_pool_for_distance") as add_pool: + future = session.add_or_renew_pool(replacement_host, is_host_addition=False) + + assert future.result() is True + assert session._source_host is replacement_host + assert session._pinned_host.endpoint == replacement_endpoint + add_pool.assert_not_called() + + def test_control_host_session_on_down_rebinds_source_host_after_failover(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + source_host.set_up() + replacement_host = Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + replacement_host.set_up() + + old_endpoint = self._proxy_endpoint("127.254.254.101", source_host.host_id) + new_endpoint = self._proxy_endpoint("127.254.254.102", replacement_host.host_id) + cluster.control_connection = Mock(_connection=Mock(endpoint=old_endpoint)) + cluster.get_control_connection_host = Mock( + side_effect=lambda: source_host + if cluster.control_connection._connection.endpoint == old_endpoint + else replacement_host) + + with patch.object(Session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)): + session = cluster._new_control_host_session(None) + + pool = Mock(is_shutdown=False, host_distance=HostDistance.LOCAL) + session._pools[session._pinned_host] = pool + cluster.control_connection._connection.endpoint = new_endpoint + + with patch.object(session, "submit", + side_effect=lambda fn, *args, **kwargs: self._completed_future(fn(*args, **kwargs))), \ + patch.object(session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)) as add_pool: + session.on_down(source_host) + + assert session._source_host is replacement_host + assert session._pinned_host.endpoint == new_endpoint + assert session._pinned_host.host_id == replacement_host.host_id + assert session._pools.get(session._pinned_host) is None + pool.shutdown.assert_called_once_with() + add_pool.assert_called_once_with(session._pinned_host, HostDistance.LOCAL, False) + + def test_control_host_session_update_created_pools_rebinds_pinned_endpoint(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host_id = uuid.uuid4() + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=host_id) + source_host.set_up() + + session = self._make_control_host_session(cluster, source_host, endpoint="127.254.254.101") + pool = Mock(is_shutdown=False, host_distance=HostDistance.LOCAL) + pool.get_state.return_value = {"open_count": 1} + session._pools[session._pinned_host] = pool + + new_endpoint = self._proxy_endpoint("127.254.254.102", host_id) + cluster.control_connection._connection.endpoint = new_endpoint + + with patch.object(session, "submit", + side_effect=lambda fn, *args, **kwargs: self._completed_future(fn(*args, **kwargs))), \ + patch.object(session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)) as add_pool: + session.update_created_pools() + + assert session._pinned_host.endpoint == new_endpoint + assert session._pools.get(session._pinned_host) is None + assert session.get_pool_state_for_host(session._pinned_host) is None + pool.shutdown.assert_called_once_with() + add_pool.assert_called_once_with(session._pinned_host, HostDistance.LOCAL, False) + + def test_control_host_session_update_created_pools_rebinds_source_host_after_failover(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + source_host.set_up() + replacement_host = Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + replacement_host.set_up() + + session = self._make_control_host_session(cluster, source_host, endpoint="127.254.254.101") + pool = Mock(is_shutdown=False, host_distance=HostDistance.LOCAL) + session._pools[session._pinned_host] = pool + + replacement_endpoint = self._proxy_endpoint("127.254.254.102", replacement_host.host_id) + cluster.control_connection._connection.endpoint = replacement_endpoint + cluster.get_control_connection_host.return_value = replacement_host + + with patch.object(session, "submit", + side_effect=lambda fn, *args, **kwargs: self._completed_future(fn(*args, **kwargs))), \ + patch.object(session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)) as add_pool: + session.update_created_pools() + + assert session._source_host is replacement_host + assert session._pinned_host.endpoint == replacement_endpoint + assert session._pinned_host.host_id == replacement_host.host_id + assert session._pools.get(session._pinned_host) is None + pool.shutdown.assert_called_once_with() + add_pool.assert_called_once_with(session._pinned_host, HostDistance.LOCAL, False) + + def test_control_host_session_rebind_rejects_shared_default_endpoint_after_failover(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + source_host.set_up() + replacement_host = Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + replacement_host.set_up() + + session = self._make_control_host_session(cluster, source_host, endpoint="127.254.254.101") + original_pinned_host = session._pinned_host + pool = Mock(is_shutdown=False, host_distance=HostDistance.LOCAL) + session._pools[original_pinned_host] = pool + + cluster.control_connection._connection.endpoint = DefaultEndPoint("127.254.254.200") + cluster.get_control_connection_host.return_value = replacement_host + + with patch.object(session, "submit", + side_effect=lambda fn, *args, **kwargs: self._completed_future(fn(*args, **kwargs))), \ + patch.object(session, "_add_or_renew_pool_for_distance") as add_pool: + assert session.update_created_pools() == set() + + assert session._pinned_host is original_pinned_host + assert session._pools.get(original_pinned_host) is None + pool.shutdown.assert_called_once_with() + add_pool.assert_not_called() + + with pytest.raises(DriverException, match="host-specific control endpoint"): + session.prepare("SELECT release_version FROM system.local") + + def test_control_host_session_rebind_replaces_pinned_host_without_retargeting_inflight_requests(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + source_host.set_up() + replacement_host = Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + replacement_host.set_up() + + session = self._make_control_host_session(cluster, source_host, endpoint="127.254.254.101") + original_pinned_host = session._pinned_host + + original_pool = Mock(is_shutdown=False, host_distance=HostDistance.LOCAL) + original_connection = Mock() + original_connection._requests = {} + original_connection.lock = RLock() + original_connection.orphaned_request_ids = set() + original_connection.orphaned_threshold = 10 + original_connection.orphaned_threshold_reached = False + original_pool.borrow_connection.return_value = (original_connection, 1) + session._pools[original_pinned_host] = original_pool + + query = SimpleStatement("SELECT release_version FROM system.local") + timeout_message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) + timeout_future = ResponseFuture( + session, timeout_message, query, None, host=original_pinned_host) + timeout_future.send_request() + original_connection._requests[timeout_future._req_id] = object() + + replacement_endpoint = self._proxy_endpoint("127.254.254.102", replacement_host.host_id) + cluster.control_connection._connection.endpoint = replacement_endpoint + cluster.get_control_connection_host.return_value = replacement_host + + with patch.object(session, "submit", + side_effect=lambda fn, *args, **kwargs: self._completed_future(fn(*args, **kwargs))), \ + patch.object(session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)) as add_pool: + session.update_created_pools() + + assert session._pinned_host is not original_pinned_host + assert session._pinned_host.endpoint == replacement_endpoint + assert session._pinned_host.host_id == replacement_host.host_id + assert session._pools.get(original_pinned_host) is None + original_pool.shutdown.assert_called_once_with() + add_pool.assert_called_once_with(session._pinned_host, HostDistance.LOCAL, False) + + replacement_pool = Mock(is_shutdown=False) + session._pools[session._pinned_host] = replacement_pool + + timeout_future._on_timeout() + replacement_pool.return_connection.assert_not_called() + with pytest.raises(OperationTimedOut): + timeout_future.result() + + retry_message = QueryMessage(query=query, consistency_level=ConsistencyLevel.ONE) + retry_future = ResponseFuture( + session, retry_message, query, None, host=original_pinned_host) + retry_future._retry_task(True, original_pinned_host) + + replacement_pool.borrow_connection.assert_not_called() + assert isinstance(retry_future._final_exception, NoHostAvailable) + + def test_control_host_session_rebind_ignores_stale_pinned_host_failures(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + source_host.set_up() + + session = self._make_control_host_session(cluster, source_host, endpoint="127.254.254.101") + original_pinned_host = session._pinned_host + + replacement_endpoint = self._proxy_endpoint("127.254.254.102", source_host.host_id) + cluster.control_connection._connection.endpoint = replacement_endpoint + + with patch.object(session, "submit", + side_effect=lambda fn, *args, **kwargs: self._completed_future(fn(*args, **kwargs))), \ + patch.object(session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)): + session.update_created_pools() + + assert session._pinned_host is not original_pinned_host + + with patch.object(Session, "_signal_connection_failure", return_value=True) as base_signal_failure: + assert session._signal_connection_failure(original_pinned_host, Exception("stale failure")) is False + + with patch.object(Session, "_handle_pool_down") as base_handle_pool_down, \ + patch.object(session, "_start_control_host_reconnector") as start_reconnector: + session._handle_pool_down(original_pinned_host, is_host_addition=False) + + base_signal_failure.assert_not_called() + base_handle_pool_down.assert_not_called() + start_reconnector.assert_not_called() + + def test_control_connection_failover_resyncs_control_host_sessions(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + source_host.set_up() + replacement_host = Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy, host_id=uuid.uuid4()) + replacement_host.set_up() + + old_endpoint = self._proxy_endpoint("127.254.254.101", source_host.host_id) + new_endpoint = self._proxy_endpoint("127.254.254.102", replacement_host.host_id) + old_connection = Mock(endpoint=old_endpoint) + new_connection = Mock(endpoint=new_endpoint) + cluster.control_connection._connection = old_connection + cluster.get_control_connection_host = Mock( + side_effect=lambda: source_host + if cluster.control_connection._connection.endpoint == old_endpoint + else replacement_host) + + with patch.object(Session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)): + session = cluster._new_control_host_session(None) + + pool = Mock(is_shutdown=False, host_distance=HostDistance.LOCAL) + session._pools[session._pinned_host] = pool + + with patch.object(session, "submit", + side_effect=lambda fn, *args, **kwargs: self._completed_future(fn(*args, **kwargs))), \ + patch.object(session, "_add_or_renew_pool_for_distance", + return_value=self._completed_future(True)) as add_pool: + cluster.control_connection._set_new_connection(new_connection) + + old_connection.close.assert_called_once_with() + assert session._source_host is replacement_host + assert session._pinned_host.endpoint == new_endpoint + assert session._pinned_host.host_id == replacement_host.host_id + assert session._pools.get(session._pinned_host) is None + pool.shutdown.assert_called_once_with() + add_pool.assert_called_once_with(session._pinned_host, HostDistance.LOCAL, False) + + def test_control_host_session_source_host_up_failure_keeps_reopened_pinned_pool(self): + cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), protocol_version=4) + self.addCleanup(cluster.shutdown) + + host_id = uuid.uuid4() + source_host = Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy, host_id=host_id) + source_host.set_down() + + control_session = self._make_control_host_session(cluster, source_host) + + cluster.control_connection = Mock( + on_up=Mock(), + on_down=Mock(), + _connection=cluster.control_connection._connection) + cluster._prepare_all_queries = Mock() + + regular_session = Mock(spec=Session) + regular_session.remove_pool.return_value = None + regular_session.add_or_renew_pool.return_value = self._completed_future(False) + regular_session.update_created_pools.return_value = set() + + cluster.sessions = {control_session, regular_session} + + reopened_pool = Mock(is_shutdown=False, host_distance=HostDistance.LOCAL) + + def reopen_pool(host, distance, is_host_addition): + control_session._pools[host] = reopened_pool + return self._completed_future(True) + + with patch.object(control_session, "_add_or_renew_pool_for_distance", + side_effect=reopen_pool) as add_pool: + cluster.on_up(source_host) + + assert source_host.is_up is False + assert control_session._pools.get(control_session._pinned_host) is reopened_pool + cluster.control_connection.on_up.assert_called_once_with(source_host) + cluster.control_connection.on_down.assert_called_once_with(source_host) + regular_session.add_or_renew_pool.assert_called_once_with(source_host, is_host_addition=False) + add_pool.assert_called_once_with(control_session._pinned_host, HostDistance.LOCAL, False) + class ProtocolVersionTests(unittest.TestCase): def test_protocol_downgrade_test(self): diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 037d4a8888..7d8ecf59ff 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -301,6 +301,60 @@ def test_wait_for_schema_agreement_none_timeout(self): cc._time = self.time assert cc.wait_for_schema_agreement() + def test_on_down_reconnects_when_current_host_matches_by_host_id(self): + self.control_connection._connection.endpoint = DefaultEndPoint("127.254.254.101") + self.control_connection._current_host_id = "uuid1" + self.control_connection.reconnect = Mock() + + self.control_connection.on_down(self.cluster.metadata.get_host_by_host_id("uuid1")) + + self.control_connection.reconnect.assert_called_once_with() + + def test_on_remove_reconnects_when_current_host_matches_by_host_id(self): + self.control_connection._connection.endpoint = DefaultEndPoint("127.254.254.101") + self.control_connection._current_host_id = "uuid1" + self.control_connection.reconnect = Mock() + self.control_connection.refresh_node_list_and_token_map = Mock() + + self.control_connection.on_remove(self.cluster.metadata.get_host_by_host_id("uuid1")) + + self.control_connection.reconnect.assert_called_once_with() + self.control_connection.refresh_node_list_and_token_map.assert_not_called() + + def test_signal_error_marks_current_host_down_when_current_host_matches_by_host_id(self): + host = self.cluster.metadata.get_host_by_host_id("uuid1") + error = RuntimeError("defunct") + + self.connection.endpoint = DefaultEndPoint("127.254.254.101") + self.connection.is_defunct = True + self.connection.last_error = error + self.control_connection._current_host_id = host.host_id + self.cluster.signal_connection_failure = Mock() + self.control_connection.reconnect = Mock() + + self.control_connection._signal_error() + + self.cluster.signal_connection_failure.assert_called_once_with( + host, error, is_host_addition=False) + self.control_connection.reconnect.assert_not_called() + + def test_signal_error_reconnects_when_current_host_conviction_is_deferred(self): + host = self.cluster.metadata.get_host_by_host_id("uuid1") + error = RuntimeError("defunct") + + self.connection.endpoint = DefaultEndPoint("127.254.254.101") + self.connection.is_defunct = True + self.connection.last_error = error + self.control_connection._current_host_id = host.host_id + self.cluster.signal_connection_failure = Mock(return_value=False) + self.control_connection.reconnect = Mock() + + self.control_connection._signal_error() + + self.cluster.signal_connection_failure.assert_called_once_with( + host, error, is_host_addition=False) + self.control_connection.reconnect.assert_called_once_with() + def test_refresh_nodes_and_tokens(self): self.control_connection.refresh_node_list_and_token_map() meta = self.cluster.metadata diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index f92bb53785..b2109565ea 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -21,7 +21,7 @@ import unittest from threading import Thread, Event, Lock -from unittest.mock import Mock, NonCallableMagicMock, MagicMock +from unittest.mock import Mock, NonCallableMagicMock, MagicMock, patch from cassandra.cluster import Session, ShardAwareOptions from cassandra.connection import Connection @@ -42,6 +42,8 @@ class _PoolTests(unittest.TestCase): def make_session(self): session = NonCallableMagicMock(spec=Session, keyspace='foobarkeyspace', _trash=[]) + session._signal_connection_failure.return_value = False + session.is_shard_aware_disabled.return_value = False return session def test_borrow_and_return(self): @@ -143,8 +145,7 @@ def test_return_defunct_connection(self): pool.borrow_connection(timeout=0.01) conn.is_defunct = True - session.cluster.signal_connection_failure.return_value = False - host.signal_connection_failure.return_value = False + session._signal_connection_failure.return_value = False pool.return_connection(conn) # the connection should be closed a new creation scheduled @@ -165,19 +166,14 @@ def test_return_defunct_connection_on_down_host(self): pool.borrow_connection(timeout=0.01) conn.is_defunct = True - session.cluster.signal_connection_failure.return_value = True - host.signal_connection_failure.return_value = True + session._signal_connection_failure.return_value = True pool.return_connection(conn) - # the connection should be closed a new creation scheduled + # the connection should be closed and the pool should delegate down + # handling back to the session. assert conn.close.call_args - if self.PoolImpl is HostConnection: - # on shard aware implementation we use submit function regardless - assert host.signal_connection_failure.call_args - assert session.submit.called - else: - assert not session.submit.called - assert session.cluster.signal_connection_failure.call_args + session._signal_connection_failure.assert_called_once_with(host, conn.last_error) + session._handle_pool_down.assert_called_once_with(host, is_host_addition=False) assert pool.is_shutdown def test_return_closed_connection(self): @@ -192,8 +188,7 @@ def test_return_closed_connection(self): pool.borrow_connection(timeout=0.01) conn.is_closed = True - session.cluster.signal_connection_failure.return_value = False - host.signal_connection_failure.return_value = False + session._signal_connection_failure.return_value = False pool.return_connection(conn) # a new creation should be scheduled @@ -231,6 +226,29 @@ class HostConnectionTests(_PoolTests): PoolImpl = HostConnection uses_single_connection = True + def test_session_level_shard_aware_disable_skips_fanout(self): + host = Mock(spec=Host, address='ip1') + host.sharding_info = None + session = self.make_session() + session.is_shard_aware_disabled.return_value = True + + connection = HashableMock(spec=Connection, in_flight=0, is_defunct=False, + is_closed=False, max_request_id=100) + connection.features = ProtocolFeatures( + shard_id=0, + sharding_info=_ShardingInfo( + shard_id=0, shards_count=4, partitioner="", + sharding_algorithm="", sharding_ignore_msb=0, + shard_aware_port=19042, shard_aware_port_ssl=""), + tablets_routing_v1=False) + session.cluster.connection_factory.return_value = connection + + with patch.object(HostConnection, "_open_connections_for_all_shards") as open_shards: + pool = HostConnection(host, HostDistance.LOCAL, session) + + open_shards.assert_not_called() + assert pool.host.sharding_info is None + def test_fast_shutdown(self): class MockSession(MagicMock): is_shutdown = False