From cae11cd2800e33a31f0a1c7df60c6e3c4b3eff59 Mon Sep 17 00:00:00 2001 From: pranavdev022 Date: Sat, 18 Apr 2026 02:42:57 +0200 Subject: [PATCH 1/2] SPARK-56538: Add per-RPC gRPC deadlines to Spark Connect client Introduce RpcDeadlines configuration for Scala and Python clients with defaults per SPARK-56538. Apply deadlines on blocking unary RPCs and reattachable execute stream segments; omit deadline on non-reattachable ExecutePlan. Treat DEADLINE_EXCEEDED as non-retryable in the default retry policy; reattachable iterator recovers via RetryException. Add user-facing hints when deadlines fire on unary RPCs. Include JVM and Python tests. --- python/pyspark/sql/connect/client/artifact.py | 18 +- python/pyspark/sql/connect/client/core.py | 196 ++++++++++++- python/pyspark/sql/connect/client/reattach.py | 30 +- python/pyspark/sql/connect/client/retries.py | 5 + .../sql/tests/connect/client/test_client.py | 170 ++++++++++- .../connect/client/test_client_retries.py | 90 ++++++ .../sql/connect/client/ArtifactSuite.scala | 2 +- .../SparkConnectClientRetriesSuite.scala | 34 +++ .../client/SparkConnectClientSuite.scala | 272 +++++++++++++++++- .../CustomSparkConnectBlockingStub.scala | 36 ++- .../client/CustomSparkConnectStub.scala | 14 +- ...cutePlanResponseReattachableIterator.scala | 43 ++- .../client/GrpcExceptionConverter.scala | 38 ++- .../sql/connect/client/RetryPolicy.scala | 5 + .../sql/connect/client/RpcDeadlines.scala | 83 ++++++ .../connect/client/SparkConnectClient.scala | 8 +- .../client/SparkConnectStubState.scala | 17 +- .../sql/connect/SparkConnectServerTest.scala | 4 +- 18 files changed, 1007 insertions(+), 58 deletions(-) create mode 100644 sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RpcDeadlines.scala diff --git a/python/pyspark/sql/connect/client/artifact.py b/python/pyspark/sql/connect/client/artifact.py index 277c313eef621..94879171b41f0 100644 --- a/python/pyspark/sql/connect/client/artifact.py +++ b/python/pyspark/sql/connect/client/artifact.py @@ -169,6 +169,8 @@ def __init__( session_id: str, channel: grpc.Channel, metadata: Iterable[Tuple[str, str]], + add_artifacts_timeout: Optional[float] = None, + artifact_status_timeout: Optional[float] = None, ): self._user_context = proto.UserContext() if user_id is not None: @@ -176,6 +178,8 @@ def __init__( self._stub = grpc_lib.SparkConnectServiceStub(channel) self._session_id = session_id self._metadata = metadata + self._add_artifacts_timeout = add_artifacts_timeout + self._artifact_status_timeout = artifact_status_timeout def _parse_artifacts( self, path_or_uri: str, pyfile: bool, archive: bool, file: bool @@ -288,7 +292,11 @@ def _retrieve_responses( self, requests: Iterator[proto.AddArtifactsRequest] ) -> proto.AddArtifactsResponse: """Separated for the testing purpose.""" - return self._stub.AddArtifacts(requests, metadata=self._metadata) + return self._stub.AddArtifacts( + requests, + metadata=self._metadata, + timeout=self._add_artifacts_timeout, + ) def _request_add_artifacts(self, requests: Iterator[proto.AddArtifactsRequest]) -> None: response: proto.AddArtifactsResponse = self._retrieve_responses(requests) @@ -428,7 +436,9 @@ def is_cached_artifact(self, hash: str) -> bool: user_context=self._user_context, session_id=self._session_id, names=[artifactName] ) resp: proto.ArtifactStatusesResponse = self._stub.ArtifactStatus( - request, metadata=self._metadata + request, + metadata=self._metadata, + timeout=self._artifact_status_timeout, ) status = resp.statuses.get(artifactName) return status.exists if status is not None else False @@ -446,7 +456,9 @@ def get_cached_artifacts(self, hashes: list[str]) -> set[str]: user_context=self._user_context, session_id=self._session_id, names=artifact_names ) resp: proto.ArtifactStatusesResponse = self._stub.ArtifactStatus( - request, metadata=self._metadata + request, + metadata=self._metadata, + timeout=self._artifact_status_timeout, ) cached = set() diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 1a36699b78815..7c62d7af63cd7 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -17,10 +17,13 @@ __all__ = [ "ChannelBuilder", "DefaultChannelBuilder", + "RpcDeadlines", "SparkConnectClient", ] import atexit +import math +from dataclasses import dataclass, fields import pyspark from pyspark.sql.connect.proto.base_pb2 import FetchErrorDetailsResponse @@ -107,7 +110,12 @@ from pyspark.sql.types import DataType, StructType from pyspark.util import PythonEvalType from pyspark.storagelevel import StorageLevel -from pyspark.errors import PySparkValueError, PySparkAssertionError, PySparkNotImplementedError +from pyspark.errors import ( + PySparkAssertionError, + PySparkNotImplementedError, + PySparkTypeError, + PySparkValueError, +) from pyspark.sql.connect.shell.progress import Progress, ProgressHandler, from_proto if TYPE_CHECKING: @@ -121,6 +129,76 @@ PYSPARK_ROOT = os.path.dirname(pyspark.__file__) +@dataclass +class RpcDeadlines: + """Per-RPC timeout configuration for :class:`SparkConnectClient`. + + Each field controls the timeout (in seconds as a float) for one gRPC call type. Set a field to + ``None`` to disable the per-RPC timeout for that call. Use :meth:`RpcDeadlines.disabled` to + create an instance with all timeouts disabled. + + Note on ``reattachable_execute_plan`` and ``reattach_execute``: these timeouts apply to each + individual gRPC stream segment, not to the overall query execution lifetime. When a deadline + fires, the server-side operation continues running; the client opens a new ReattachExecute + stream to resume receiving results. Non-reattachable ExecutePlan has no deadline because a + timeout there would kill the execution with no recovery path. + """ + + reattachable_execute_plan: Optional[float] = 10 * 60 # 10 min + reattach_execute: Optional[float] = 10 * 60 # 10 min + analyze_plan: Optional[float] = 60 * 60 # 1 hour + add_artifacts: Optional[float] = 60 * 60 # 1 hour + config: Optional[float] = 10 * 60 # 10 min + interrupt: Optional[float] = 10 * 60 # 10 min + release_session: Optional[float] = 10 * 60 # 10 min + artifact_status: Optional[float] = 10 * 60 # 10 min + clone_session: Optional[float] = 10 * 60 # 10 min + get_status: Optional[float] = 10 * 60 # 10 min + fetch_error_details: Optional[float] = 10 * 60 # 10 min + + def __post_init__(self) -> None: + for field in fields(self): + value = getattr(self, field.name) + if value is not None: + if not isinstance(value, (int, float)): + raise PySparkTypeError( + errorClass="NOT_EXPECTED_TYPE", + messageParameters={ + "arg_name": f"RpcDeadlines.{field.name}", + "expected_type": "int, float, or None", + "arg_type": type(value).__name__, + }, + ) + fv = float(value) + if not math.isfinite(fv) or fv <= 0: + raise PySparkValueError( + message=( + f"RpcDeadlines.{field.name} must be a finite positive number or None, " + f"got {value!r}" + ), + ) + + @classmethod + def disabled(cls) -> "RpcDeadlines": + """Create an :class:`RpcDeadlines` with all per-RPC timeouts disabled. + + Use this when you want to rely solely on server-side or network-layer timeouts. + """ + return cls( + reattachable_execute_plan=None, + reattach_execute=None, + analyze_plan=None, + add_artifacts=None, + config=None, + interrupt=None, + release_session=None, + artifact_status=None, + clone_session=None, + get_status=None, + fetch_error_details=None, + ) + + def _import_zstandard_if_available() -> Optional[Any]: """ Import zstandard if available, otherwise return None. @@ -647,6 +725,7 @@ def __init__( session_hooks: Optional[list["SparkSession.Hook"]] = None, allow_arrow_batch_chunking: bool = True, preferred_arrow_chunk_size: Optional[int] = None, + rpc_deadlines: Optional[RpcDeadlines] = None, ): """ Creates a new SparkSession for the Spark Connect interface. @@ -694,6 +773,9 @@ def __init__( results. The server will attempt to use this size if it is set and within the valid range ([1KB, max batch size on server]). Otherwise, the server's maximum batch size is used. + rpc_deadlines : RpcDeadlines, optional + Per-RPC gRPC call timeouts in seconds. Defaults follow SPARK-56538; use + :meth:`RpcDeadlines.disabled` to turn off all per-RPC deadlines. """ self.thread_local = threading.local() @@ -729,8 +811,39 @@ def __init__( self._channel = self._builder.toChannel() self._closed = False self._internal_stub = grpc_lib.SparkConnectServiceStub(self._channel) + self._rpc_deadlines: RpcDeadlines = ( + rpc_deadlines if rpc_deadlines is not None else RpcDeadlines() + ) + d = self._rpc_deadlines + configured = [ + (name, val) + for name, val in [ + ("reattachableExecutePlan", d.reattachable_execute_plan), + ("reattachExecute", d.reattach_execute), + ("analyzePlan", d.analyze_plan), + ("addArtifacts", d.add_artifacts), + ("config", d.config), + ("interrupt", d.interrupt), + ("releaseSession", d.release_session), + ("artifactStatus", d.artifact_status), + ("cloneSession", d.clone_session), + ("getStatus", d.get_status), + ("fetchErrorDetails", d.fetch_error_details), + ] + if val is not None + ] + if configured: + logger.info( + "Spark Connect RPC deadlines: " + + ", ".join(f"{name}: {val}s" for name, val in configured) + ) self._artifact_manager = ArtifactManager( - self._user_id, self._session_id, self._channel, self._builder.metadata() + self._user_id, + self._session_id, + self._channel, + self._builder.metadata(), + add_artifacts_timeout=self._rpc_deadlines.add_artifacts, + artifact_status_timeout=self._rpc_deadlines.artifact_status, ) self._use_reattachable_execute = use_reattachable_execute self._allow_arrow_batch_chunking = allow_arrow_batch_chunking @@ -1450,7 +1563,11 @@ def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult: try: for attempt in self._retrying(): with attempt: - resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata()) + resp = self._stub.AnalyzePlan( + req, + metadata=self._builder.metadata(), + timeout=self._rpc_deadlines.analyze_plan, + ) self._verify_response_integrity(resp) return AnalyzeResult.fromProto(resp) raise SparkConnectException("Invalid state during retry exception handling.") @@ -1479,7 +1596,12 @@ def handle_response(b: pb2.ExecutePlanResponse) -> None: if self._use_reattachable_execute: # Don't use retryHandler - own retry handling is inside. generator = ExecutePlanResponseReattachableIterator( - req, self._stub, self._retrying, self._builder.metadata() + req, + self._stub, + self._retrying, + self._builder.metadata(), + reattachable_execute_plan_timeout=self._rpc_deadlines.reattachable_execute_plan, + reattach_execute_timeout=self._rpc_deadlines.reattach_execute, ) try: for b in generator: @@ -1689,7 +1811,12 @@ def handle_response( if self._use_reattachable_execute: # Don't use retryHandler - own retry handling is inside. generator = ExecutePlanResponseReattachableIterator( - req, self._stub, self._retrying, self._builder.metadata() + req, + self._stub, + self._retrying, + self._builder.metadata(), + reattachable_execute_plan_timeout=self._rpc_deadlines.reattachable_execute_plan, + reattach_execute_timeout=self._rpc_deadlines.reattach_execute, ) try: for b in generator: @@ -1841,7 +1968,11 @@ def config(self, operation: pb2.ConfigRequest.Operation) -> ConfigResult: try: for attempt in self._retrying(): with attempt: - resp = self._stub.Config(req, metadata=self._builder.metadata()) + resp = self._stub.Config( + req, + metadata=self._builder.metadata(), + timeout=self._rpc_deadlines.config, + ) self._verify_response_integrity(resp) return ConfigResult.fromProto(resp) raise SparkConnectException("Invalid state during retry exception handling.") @@ -1883,7 +2014,11 @@ def interrupt_all(self) -> Optional[List[str]]: try: for attempt in self._retrying(): with attempt: - resp = self._stub.Interrupt(req, metadata=self._builder.metadata()) + resp = self._stub.Interrupt( + req, + metadata=self._builder.metadata(), + timeout=self._rpc_deadlines.interrupt, + ) self._verify_response_integrity(resp) return list(resp.interrupted_ids) raise SparkConnectException("Invalid state during retry exception handling.") @@ -1895,7 +2030,11 @@ def interrupt_tag(self, tag: str) -> Optional[List[str]]: try: for attempt in self._retrying(): with attempt: - resp = self._stub.Interrupt(req, metadata=self._builder.metadata()) + resp = self._stub.Interrupt( + req, + metadata=self._builder.metadata(), + timeout=self._rpc_deadlines.interrupt, + ) self._verify_response_integrity(resp) return list(resp.interrupted_ids) raise SparkConnectException("Invalid state during retry exception handling.") @@ -1907,7 +2046,11 @@ def interrupt_operation(self, op_id: str) -> Optional[List[str]]: try: for attempt in self._retrying(): with attempt: - resp = self._stub.Interrupt(req, metadata=self._builder.metadata()) + resp = self._stub.Interrupt( + req, + metadata=self._builder.metadata(), + timeout=self._rpc_deadlines.interrupt, + ) self._verify_response_integrity(resp) return list(resp.interrupted_ids) raise SparkConnectException("Invalid state during retry exception handling.") @@ -1923,7 +2066,11 @@ def release_session(self) -> None: try: for attempt in self._retrying(): with attempt: - resp = self._stub.ReleaseSession(req, metadata=self._builder.metadata()) + resp = self._stub.ReleaseSession( + req, + metadata=self._builder.metadata(), + timeout=self._rpc_deadlines.release_session, + ) self._verify_response_integrity(resp) return raise SparkConnectException("Invalid state during retry exception handling.") @@ -1975,7 +2122,11 @@ def _get_operation_statuses( try: for attempt in self._retrying(): with attempt: - resp = self._stub.GetStatus(req, metadata=self._builder.metadata()) + resp = self._stub.GetStatus( + req, + metadata=self._builder.metadata(), + timeout=self._rpc_deadlines.get_status, + ) self._verify_response_integrity(resp) return resp raise SparkConnectException("Invalid state during retry exception handling.") @@ -2100,7 +2251,11 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet req.user_context.user_id = self._user_id self._update_request_with_user_context_extensions(req) try: - return self._stub.FetchErrorDetails(req, metadata=self._builder.metadata()) + return self._stub.FetchErrorDetails( + req, + metadata=self._builder.metadata(), + timeout=self._rpc_deadlines.fetch_error_details, + ) except grpc.RpcError: return None @@ -2142,6 +2297,17 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn: # https://grpc.github.io/grpc/python/grpc.html#grpc.UnaryUnaryMultiCallable.__call__ error: grpc.Call = cast(grpc.Call, rpc_error) status_code: grpc.StatusCode = error.code() + if status_code == grpc.StatusCode.DEADLINE_EXCEEDED: + raise SparkConnectGrpcException( + message=( + f"{error}: RPC deadline exceeded. " + "The client applies per-RPC timeouts to prevent silent hangs " + "on broken connections. Deadlines can be configured via the " + "rpc_deadlines parameter of SparkConnectClient. To disable all " + "deadlines: SparkConnectClient(url, rpc_deadlines=RpcDeadlines.disabled())." + ), + grpc_status_code=status_code, + ) from None status: Optional[Status] = rpc_status.from_call(error) if status: for d in status.details: @@ -2466,7 +2632,9 @@ def clone(self, new_session_id: Optional[str] = None) -> "SparkConnectClient": for attempt in self._retrying(): with attempt: response: pb2.CloneSessionResponse = self._stub.CloneSession( - request, metadata=self._builder.metadata() + request, + metadata=self._builder.metadata(), + timeout=self._rpc_deadlines.clone_session, ) # Assert that the returned session ID matches the requested ID if one was provided @@ -2485,6 +2653,8 @@ def clone(self, new_session_id: Optional[str] = None) -> "SparkConnectClient": connection=new_connection, user_id=self._user_id, use_reattachable_execute=self._use_reattachable_execute, + session_hooks=self._session_hooks, + rpc_deadlines=self._rpc_deadlines, ) # Ensure the session ID is correctly set from the response new_client._session_id = response.new_session_id diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index 5750dd045a919..6babe7a8e4289 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -72,6 +72,8 @@ def __init__( stub: grpc_lib.SparkConnectServiceStub, retrying: Callable[[], Retrying], metadata: Iterable[Tuple[str, str]], + reattachable_execute_plan_timeout: Optional[float] = None, + reattach_execute_timeout: Optional[float] = None, ): self._request = request self._retrying = retrying @@ -86,6 +88,8 @@ def __init__( self._operation_id = str(uuid.uuid4()) self._stub = stub + self._reattachable_execute_plan_timeout = reattachable_execute_plan_timeout + self._reattach_execute_timeout = reattach_execute_timeout request.request_options.append( pb2.ExecutePlanRequest.RequestOption( reattach_options=pb2.ReattachOptions(reattachable=True) @@ -108,7 +112,11 @@ def __init__( self._metadata = metadata with disable_gc(): self._iterator: Optional[Iterator[pb2.ExecutePlanResponse]] = iter( - self._stub.ExecutePlan(self._initial_request, metadata=metadata) + self._stub.ExecutePlan( + self._initial_request, + metadata=metadata, + timeout=self._reattachable_execute_plan_timeout, + ) ) # Current item from this iterator. @@ -256,7 +264,9 @@ def _call_iter(self, iter_fun: Callable) -> Any: # we get a new iterator with ReattachExecute if it was unset. self._iterator = iter( self._stub.ReattachExecute( - self._create_reattach_execute_request(), metadata=self._metadata + self._create_reattach_execute_request(), + metadata=self._metadata, + timeout=self._reattach_execute_timeout, ) ) @@ -283,9 +293,23 @@ def _call_iter(self, iter_fun: Callable) -> Any: ) # Try a new ExecutePlan, and throw upstream for retry. self._iterator = iter( - self._stub.ExecutePlan(self._initial_request, metadata=self._metadata) + self._stub.ExecutePlan( + self._initial_request, + metadata=self._metadata, + timeout=self._reattachable_execute_plan_timeout, + ) ) raise RetryException() + elif e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + # The per-RPC deadline fired. The server-side operation is still alive; clear + # the iterator and raise RetryException so the retry loop opens a fresh + # ReattachExecute stream with a new deadline countdown to resume results. + logger.debug( + f"Deadline exceeded on stream for operation {self._operation_id}; " + f"will reattach. (last response: {self._last_returned_response_id})" + ) + self._iterator = None + raise RetryException() from e else: # Remove the iterator, so that a new one will be created after retry. self._iterator = None diff --git a/python/pyspark/sql/connect/client/retries.py b/python/pyspark/sql/connect/client/retries.py index 898d976f2628e..86ed004c40fe9 100644 --- a/python/pyspark/sql/connect/client/retries.py +++ b/python/pyspark/sql/connect/client/retries.py @@ -368,6 +368,11 @@ def can_retry(self, e: BaseException) -> bool: # All errors messages containing `RetryInfo` should be retried. return True + # DEADLINE_EXCEEDED on the reattachable execute path is handled directly in + # ExecutePlanResponseReattachableIterator, which converts it to RetryException so the + # server-side operation continues and a fresh ReattachExecute is issued. We do not retry + # other RPCs on deadline: those are non-idempotent or the deadline signals a genuine + # timeout that retrying won't fix. Keep this in sync with RetryPolicy.scala. return False diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 85fbafe227284..3c580fdf54b2b 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -32,6 +32,7 @@ import pandas as pd import pyarrow as pa from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder + from pyspark.sql.connect.client.core import RpcDeadlines from pyspark.sql.connect.client.retries import ( Retrying, DefaultPolicy, @@ -125,7 +126,6 @@ def ReleaseExecute(self, req: proto.ReleaseExecuteRequest, *args, **kwargs): if req.HasField("release_all"): self.release_calls += 1 elif req.HasField("release_until"): - print("increment") self.release_until_calls += 1 class MockService: @@ -154,7 +154,7 @@ def __init__(self, session_id: str, operation_statuses=None): operation_statuses = self.DEFAULT_OPERATION_STATUSES self._operation_statuses = {s.operation_id: s for s in operation_statuses} - def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): + def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata, timeout=None): self.req = req self.client_user_context_extensions = list(req.user_context.extensions) resp = proto.ExecutePlanResponse() @@ -175,14 +175,14 @@ def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): resp.arrow_batch.row_count = 2 return [resp] - def Interrupt(self, req: proto.InterruptRequest, metadata): + def Interrupt(self, req: proto.InterruptRequest, metadata, timeout=None): self.req = req self.client_user_context_extensions = list(req.user_context.extensions) resp = proto.InterruptResponse() resp.session_id = self._session_id return resp - def Config(self, req: proto.ConfigRequest, metadata): + def Config(self, req: proto.ConfigRequest, metadata, timeout=None): self.req = req self.client_user_context_extensions = list(req.user_context.extensions) resp = proto.ConfigResponse() @@ -197,7 +197,7 @@ def Config(self, req: proto.ConfigRequest, metadata): pair.value = req.operation.get_with_default.pairs[0].value or "true" return resp - def AnalyzePlan(self, req: proto.AnalyzePlanRequest, metadata): + def AnalyzePlan(self, req: proto.AnalyzePlanRequest, metadata, timeout=None): self.req = req self.client_user_context_extensions = list(req.user_context.extensions) resp = proto.AnalyzePlanResponse() @@ -206,7 +206,7 @@ def AnalyzePlan(self, req: proto.AnalyzePlanRequest, metadata): resp.semantic_hash.result = 12345 return resp - def GetStatus(self, req: proto.GetStatusRequest, metadata): + def GetStatus(self, req: proto.GetStatusRequest, metadata, timeout=None): self.req = req self.client_user_context_extensions = list(req.user_context.extensions) self.received_custom_server_session_id = req.client_observed_server_side_session_id @@ -706,6 +706,111 @@ def test_get_operations_statuses_with_request_extensions(self): resp.extensions[0].Unpack(resp_echoed) self.assertEqual(resp_echoed.value, "request_extension") + def test_analyze_plan_short_deadline_fires_then_succeeds_after_disabling(self): + """With a short deadline the call fails; after disabling deadlines it succeeds.""" + + class CapturingMock(MockService): + """Captures the timeout passed by the client; raises DEADLINE_EXCEEDED if set.""" + + def __init__(self, session_id): + super().__init__(session_id) + self.captured_timeout = "not_called" + + def AnalyzePlan(self, req, metadata, timeout=None): + self.captured_timeout = timeout + if timeout is not None: + raise TestException("deadline exceeded", grpc.StatusCode.DEADLINE_EXCEEDED) + return super().AnalyzePlan(req, metadata, timeout=timeout) + + client_with_deadline = SparkConnectClient( + "sc://foo/", + use_reattachable_execute=False, + rpc_deadlines=RpcDeadlines(analyze_plan=0.050), + retry_policy=dict(max_retries=0), + ) + mock_with_deadline = CapturingMock(session_id=client_with_deadline._session_id) + client_with_deadline._stub = mock_with_deadline + with self.assertRaises(SparkConnectGrpcException) as cm: + client_with_deadline._analyze("schema", plan=proto.Plan()) + self.assertEqual(cm.exception.getGrpcStatusCode(), grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertEqual(mock_with_deadline.captured_timeout, 0.050) + + client_disabled = SparkConnectClient( + "sc://foo/", + use_reattachable_execute=False, + rpc_deadlines=RpcDeadlines.disabled(), + retry_policy=dict(max_retries=0), + ) + mock_disabled = CapturingMock(session_id=client_disabled._session_id) + client_disabled._stub = mock_disabled + client_disabled._analyze("schema", plan=proto.Plan()) + self.assertIsNone(mock_disabled.captured_timeout) + + def test_each_rpc_receives_configured_deadline(self): + """Every RPC that accepts a deadline should forward it as timeout to the stub.""" + + class TimeoutCapturingMock(MockService): + """Records the timeout kwarg for each RPC call.""" + + def __init__(self, session_id): + super().__init__(session_id) + self.captured_timeouts = {} + + def AnalyzePlan(self, req, metadata, timeout=None): + self.captured_timeouts["AnalyzePlan"] = timeout + return super().AnalyzePlan(req, metadata, timeout=timeout) + + def Config(self, req, metadata, timeout=None): + self.captured_timeouts["Config"] = timeout + return super().Config(req, metadata, timeout=timeout) + + def Interrupt(self, req, metadata, timeout=None): + self.captured_timeouts["Interrupt"] = timeout + return super().Interrupt(req, metadata, timeout=timeout) + + def ReleaseSession(self, req, metadata, timeout=None): + self.captured_timeouts["ReleaseSession"] = timeout + resp = proto.ReleaseSessionResponse() + resp.session_id = self._session_id + return resp + + def GetStatus(self, req, metadata, timeout=None): + self.captured_timeouts["GetStatus"] = timeout + return super().GetStatus(req, metadata, timeout=timeout) + + deadlines = RpcDeadlines( + analyze_plan=11.0, + config=22.0, + interrupt=33.0, + release_session=44.0, + get_status=55.0, + ) + client = SparkConnectClient( + "sc://foo/", + use_reattachable_execute=False, + rpc_deadlines=deadlines, + retry_policy=dict(max_retries=0), + ) + mock = TimeoutCapturingMock(session_id=client._session_id) + client._stub = mock + + client._analyze("schema", plan=proto.Plan()) + self.assertEqual(mock.captured_timeouts["AnalyzePlan"], 11.0) + + op = proto.ConfigRequest.Operation() + op.get.keys.append("spark.sql.shuffle.partitions") + client.config(op) + self.assertEqual(mock.captured_timeouts["Config"], 22.0) + + client.interrupt_all() + self.assertEqual(mock.captured_timeouts["Interrupt"], 33.0) + + client.release_session() + self.assertEqual(mock.captured_timeouts["ReleaseSession"], 44.0) + + client._get_operation_statuses() + self.assertEqual(mock.captured_timeouts["GetStatus"], 55.0) + @unittest.skipIf(not should_test_connect, connect_requirement_message) class SparkConnectClientReattachTestCase(unittest.TestCase): @@ -881,6 +986,59 @@ def test_observed_session_id(self): reattach = ite._create_reattach_execute_request() self.assertEqual(reattach.client_observed_server_side_session_id, session_id) + def test_deadline_exceeded_triggers_reattach(self): + """DEADLINE_EXCEEDED mid-stream on ExecutePlan should trigger a ReattachExecute.""" + + def deadline_exceeded(): + raise TestException("deadline", grpc.StatusCode.DEADLINE_EXCEEDED) + + stub = self._stub_with( + [self.response, deadline_exceeded], + [self.response, self.finished], + ) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) + for _ in ite: + pass + + def check(): + self.assertEqual(1, stub.execute_calls) + self.assertEqual(1, stub.attach_calls) + self.assertEqual(1, stub.release_calls) + + eventually(timeout=1, catch_assertions=True)(check)() + + def test_deadline_exceeded_mid_stream_completes_successfully(self): + """After a mid-stream DEADLINE_EXCEEDED, reattach resumes and all responses are collected.""" + + response2 = proto.ExecutePlanResponse(response_id="2") + response3 = proto.ExecutePlanResponse(response_id="3") + + def deadline_exceeded(): + raise TestException("deadline", grpc.StatusCode.DEADLINE_EXCEEDED) + + finished = proto.ExecutePlanResponse( + result_complete=proto.ExecutePlanResponse.ResultComplete(), + response_id="final", + ) + stub = self._stub_with( + [self.response, response2, deadline_exceeded], + [response3, finished], + ) + collected = [] + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) + for r in ite: + if not r.HasField("result_complete"): + collected.append(r.response_id) + + self.assertEqual(collected, ["1", "2", "3"]) + + def check(): + self.assertEqual(1, stub.execute_calls) + self.assertEqual(1, stub.attach_calls) + self.assertEqual(1, stub.release_calls) + + eventually(timeout=1, catch_assertions=True)(check)() + def test_server_unreachable(self): # DNS resolution should fail for "foo". This error is a retriable UNAVAILABLE error. client = SparkConnectClient( diff --git a/python/pyspark/sql/tests/connect/client/test_client_retries.py b/python/pyspark/sql/tests/connect/client/test_client_retries.py index c509b4db26543..5e212f39902f9 100644 --- a/python/pyspark/sql/tests/connect/client/test_client_retries.py +++ b/python/pyspark/sql/tests/connect/client/test_client_retries.py @@ -27,6 +27,7 @@ from google.rpc import status_pb2 from google.rpc import error_details_pb2 from pyspark.sql.connect.client import SparkConnectClient + from pyspark.sql.connect.client.core import RpcDeadlines from pyspark.sql.connect.client.retries import ( Retrying, DefaultPolicy, @@ -235,6 +236,95 @@ def test_return_to_exponential_backoff(self): ] self.assertListsAlmostEqual(sleep_tracker.times, expected_times, delta=policy.jitter) + def test_deadline_exceeded_not_retryable_by_default_policy(self): + client = SparkConnectClient("sc://foo/;token=bar") + policy = get_client_policies_map(client).get(DefaultPolicy) + self.assertIsNotNone(policy) + self.assertFalse( + policy.can_retry(TestException("deadline", code=grpc.StatusCode.DEADLINE_EXCEEDED)) + ) + + def test_deadline_exceeded_not_retried_by_retry_handler(self): + client = SparkConnectClient("sc://foo/;token=bar") + sleep_tracker = SleepTimeTracker() + tries = 0 + with self.assertRaises(TestException): + for attempt in Retrying(client._retry_policies, sleep=sleep_tracker.sleep): + with attempt: + tries += 1 + raise TestException("d", code=grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertEqual(tries, 1) + self.assertEqual(len(sleep_tracker.times), 0) + + def test_deadline_exceeded_is_not_retried_for_non_retryable_codes(self): + # Sanity check: ABORTED is not retried by DefaultPolicy (unless matching cluster message) + policy = DefaultPolicy() + self.assertFalse( + policy.can_retry(TestException("some aborted error", code=grpc.StatusCode.ABORTED)) + ) + + def test_deadline_exceeded_exception_message_contains_configuration_hint(self): + """DEADLINE_EXCEEDED exceptions should carry a hint about how to adjust timeouts.""" + from pyspark.errors.exceptions.connect import SparkConnectGrpcException + + client = SparkConnectClient("sc://foo/;token=bar", retry_policy=dict(max_retries=0)) + err = TestException("deadline exceeded", code=grpc.StatusCode.DEADLINE_EXCEEDED) + with self.assertRaises(SparkConnectGrpcException) as cm: + client._handle_rpc_error(err) + self.assertIn("RPC deadline exceeded", str(cm.exception)) + self.assertIn("RpcDeadlines.disabled()", str(cm.exception)) + self.assertEqual(cm.exception.getGrpcStatusCode(), grpc.StatusCode.DEADLINE_EXCEEDED) + + def test_rpc_deadlines_disabled_sets_all_fields_to_none(self): + """RpcDeadlines.disabled() should create an instance with every field set to None.""" + d = RpcDeadlines.disabled() + self.assertIsNone(d.reattachable_execute_plan) + self.assertIsNone(d.reattach_execute) + self.assertIsNone(d.analyze_plan) + self.assertIsNone(d.add_artifacts) + self.assertIsNone(d.config) + self.assertIsNone(d.interrupt) + self.assertIsNone(d.release_session) + self.assertIsNone(d.artifact_status) + self.assertIsNone(d.clone_session) + self.assertIsNone(d.get_status) + self.assertIsNone(d.fetch_error_details) + + def test_rpc_deadlines_defaults_are_set(self): + """RpcDeadlines() default instance should have documented timeout values.""" + d = RpcDeadlines() + self.assertEqual(d.reattachable_execute_plan, 10 * 60) + self.assertEqual(d.reattach_execute, 10 * 60) + self.assertEqual(d.analyze_plan, 60 * 60) + self.assertEqual(d.add_artifacts, 60 * 60) + self.assertEqual(d.config, 10 * 60) + self.assertEqual(d.interrupt, 10 * 60) + self.assertEqual(d.release_session, 10 * 60) + self.assertEqual(d.artifact_status, 10 * 60) + self.assertEqual(d.clone_session, 10 * 60) + self.assertEqual(d.get_status, 10 * 60) + self.assertEqual(d.fetch_error_details, 10 * 60) + + def test_rpc_deadlines_rejects_non_positive_values(self): + """RpcDeadlines should raise ValueError for any non-None field that is <= 0.""" + from dataclasses import replace + + with self.assertRaises(ValueError): + replace(RpcDeadlines(), analyze_plan=-1.0) + with self.assertRaises(ValueError): + replace(RpcDeadlines(), config=0) + with self.assertRaises(ValueError): + replace(RpcDeadlines(), reattachable_execute_plan=-0.001) + with self.assertRaises(ValueError): + replace(RpcDeadlines(), config=float("nan")) + with self.assertRaises(ValueError): + replace(RpcDeadlines(), config=float("inf")) + with self.assertRaises(ValueError): + replace(RpcDeadlines(), config=float("-inf")) + # None (disabled) and positive values should be accepted without error. + replace(RpcDeadlines(), analyze_plan=None) + replace(RpcDeadlines(), analyze_plan=0.001) + if __name__ == "__main__": from pyspark.testing import main diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala index a6ccf39886e22..bbe3fd16a9693 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala @@ -57,7 +57,7 @@ class ArtifactSuite extends ConnectFunSuite { private def createArtifactManager(): Unit = { channel = InProcessChannelBuilder.forName(getClass.getName).directExecutor().build() - state = new SparkConnectStubState(channel, RetryPolicy.defaultPolicies()) + state = new SparkConnectStubState(channel, Configuration()) bstub = new CustomSparkConnectBlockingStub(channel, state) stub = new CustomSparkConnectStub(channel, state) artifactManager = new ArtifactManager(Configuration(), "", bstub, stub) diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientRetriesSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientRetriesSuite.scala index 7ea01e34ec88a..595f48b9f5485 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientRetriesSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientRetriesSuite.scala @@ -275,4 +275,38 @@ class SparkConnectClientRetriesSuite extends ConnectFunSuite with Eventually { policy.initialBackoff.toMillis * math.pow(policy.backoffMultiplier, i + 2).toLong) assertLongSequencesAlmostEqual(st.times, expectedSleeps, delta = policy.jitter.toMillis) } + + test("DEADLINE_EXCEEDED is not retryable by defaultPolicy") { + // DEADLINE_EXCEEDED must not be retried via canRetry. The reattachable execute path + // handles it separately by converting it to RetryException in the iterator. + val exception = new StatusRuntimeException(Status.DEADLINE_EXCEEDED) + val canRetry = RetryPolicy.defaultPolicy().canRetry + assert(canRetry(exception) == false) + } + + test("DEADLINE_EXCEEDED is not retried by retry handler") { + // Verify the function is called exactly once and the exception propagates immediately. + val dummyFn = new DummyFn(new StatusRuntimeException(Status.DEADLINE_EXCEEDED), numFails = 1) + val retryPolicies = RetryPolicy.defaultPolicies() + val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {}) + intercept[StatusRuntimeException] { + retryHandler.retry { dummyFn.fn() } + } + assert(dummyFn.counter == 1) + } + + test("RpcDeadlines.disabled creates an instance with all deadlines set to None") { + val disabled = RpcDeadlines.disabled + assert(disabled.reattachableExecutePlan.isEmpty) + assert(disabled.reattachExecute.isEmpty) + assert(disabled.analyzePlan.isEmpty) + assert(disabled.addArtifacts.isEmpty) + assert(disabled.config.isEmpty) + assert(disabled.interrupt.isEmpty) + assert(disabled.releaseSession.isEmpty) + assert(disabled.artifactStatus.isEmpty) + assert(disabled.cloneSession.isEmpty) + assert(disabled.getStatus.isEmpty) + assert(disabled.fetchErrorDetails.isEmpty) + } } diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index 22feaff1c77f1..1fc6c13e146f9 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.connect.client import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Base64, UUID} -import java.util.concurrent.TimeUnit +import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.collection.mutable +import scala.concurrent.duration.FiniteDuration import scala.jdk.CollectionConverters._ import com.google.protobuf.{Any => PAny, StringValue} @@ -824,10 +825,278 @@ class SparkConnectClientSuite extends ConnectFunSuite { assert(!headerInterceptor.headers.exists(_.containsKey(key))) } } + + test("analyzePlan deadline fires on slow server") { + val latch = new CountDownLatch(1) + val slowService = new DummySparkConnectService { + override def analyzePlan( + request: AnalyzePlanRequest, + responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = { + latch.await(5, TimeUnit.SECONDS) + super.analyzePlan(request, responseObserver) + } + } + server = NettyServerBuilder.forPort(0).addService(slowService).build().start() + service = slowService + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .rpcDeadlines(RpcDeadlines(analyzePlan = Some(FiniteDuration(50, TimeUnit.MILLISECONDS)))) + .retryPolicy(RetryPolicy(maxRetries = Some(0), canRetry = _ => false, name = "NoRetry")) + .build() + try { + val ex = intercept[SparkException] { + client.analyze(proto.AnalyzePlanRequest.newBuilder().setSessionId("abc123").build()) + } + assert(ex.getCause.isInstanceOf[StatusRuntimeException]) + assert( + ex.getCause.asInstanceOf[StatusRuntimeException].getStatus.getCode == + Status.Code.DEADLINE_EXCEEDED) + } finally { + latch.countDown() + } + } + + test("analyzePlan with short deadline fires, then succeeds after disabling deadlines") { + val latch = new CountDownLatch(1) + val slowService = new DummySparkConnectService { + override def analyzePlan( + request: AnalyzePlanRequest, + responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = { + latch.await(5, TimeUnit.SECONDS) + super.analyzePlan(request, responseObserver) + } + } + server = NettyServerBuilder.forPort(0).addService(slowService).build().start() + service = slowService + val noRetry = RetryPolicy(maxRetries = Some(0), canRetry = _ => false, name = "NoRetry") + + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .rpcDeadlines(RpcDeadlines(analyzePlan = Some(FiniteDuration(50, TimeUnit.MILLISECONDS)))) + .retryPolicy(noRetry) + .build() + val ex = intercept[SparkException] { + client.analyze(proto.AnalyzePlanRequest.newBuilder().setSessionId("abc123").build()) + } + assert( + ex.getCause.asInstanceOf[StatusRuntimeException].getStatus.getCode == + Status.Code.DEADLINE_EXCEEDED) + client.shutdown() + latch.countDown() + + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .rpcDeadlines(RpcDeadlines.disabled) + .retryPolicy(noRetry) + .build() + client.analyze(proto.AnalyzePlanRequest.newBuilder().setSessionId("abc123").build()) + } + + test("reattachable execute recovers from mid-stream deadline via ReattachExecute") { + val executeLatch = new CountDownLatch(1) + @volatile var reattachCalled = false + val midStreamService = new DummySparkConnectService { + override def executePlan( + request: ExecutePlanRequest, + responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + val sessionId = request.getSessionId + val operationId = + if (request.hasOperationId) request.getOperationId + else UUID.randomUUID().toString + val serverSideSessionId = "srv-deadline-test" + responseObserver.onNext( + ExecutePlanResponse + .newBuilder() + .setSessionId(sessionId) + .setServerSideSessionId(serverSideSessionId) + .setOperationId(operationId) + .setResponseId("r1") + .build()) + executeLatch.await(5, TimeUnit.SECONDS) + responseObserver.onCompleted() + } + + override def reattachExecute( + request: proto.ReattachExecuteRequest, + responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + reattachCalled = true + val sessionId = request.getSessionId + val operationId = request.getOperationId + val serverSideSessionId = "srv-deadline-test" + responseObserver.onNext( + ExecutePlanResponse + .newBuilder() + .setSessionId(sessionId) + .setServerSideSessionId(serverSideSessionId) + .setOperationId(operationId) + .setResponseId("r2") + .build()) + responseObserver.onNext( + ExecutePlanResponse + .newBuilder() + .setSessionId(sessionId) + .setServerSideSessionId(serverSideSessionId) + .setOperationId(operationId) + .setResponseId("r3") + .setResultComplete(proto.ExecutePlanResponse.ResultComplete.newBuilder().build()) + .build()) + responseObserver.onCompleted() + } + } + server = NettyServerBuilder.forPort(0).addService(midStreamService).build().start() + service = midStreamService + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .rpcDeadlines( + RpcDeadlines(reattachableExecutePlan = Some(FiniteDuration(50, TimeUnit.MILLISECONDS)))) + .enableReattachableExecute() + .build() + + val iter = client.execute(buildPlan("select 1")) + val reattachableIter = ExecutePlanResponseReattachableIterator.fromIterator(iter) + iter.foreach(_ => ()) + executeLatch.countDown() + + assert(reattachableIter.resultComplete, "iterator should complete after reattach") + assert(reattachCalled, "ReattachExecute should have been called") + } + + test("non-reattachable executePlan has no client-side deadline") { + val latch = new CountDownLatch(1) + val slowService = new DummySparkConnectService { + override def executePlan( + request: ExecutePlanRequest, + responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + latch.await(150, TimeUnit.MILLISECONDS) + super.executePlan(request, responseObserver) + } + } + server = NettyServerBuilder.forPort(0).addService(slowService).build().start() + service = slowService + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .disableReattachableExecute() + .rpcDeadlines( + RpcDeadlines(reattachableExecutePlan = Some(FiniteDuration(50, TimeUnit.MILLISECONDS)))) + .retryPolicy(RetryPolicy(maxRetries = Some(0), canRetry = _ => false, name = "NoRetry")) + .build() + val iter = client.execute(buildPlan("select 1")) + iter.foreach(_ => ()) + } + + test("config deadline fires on slow server") { + val latch = new CountDownLatch(1) + val slowService = new DummySparkConnectService { + override def config( + request: proto.ConfigRequest, + responseObserver: StreamObserver[proto.ConfigResponse]): Unit = { + latch.await(5, TimeUnit.SECONDS) + super.config(request, responseObserver) + } + } + server = NettyServerBuilder.forPort(0).addService(slowService).build().start() + service = slowService + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .rpcDeadlines(RpcDeadlines(config = Some(FiniteDuration(50, TimeUnit.MILLISECONDS)))) + .retryPolicy(RetryPolicy(maxRetries = Some(0), canRetry = _ => false, name = "NoRetry")) + .build() + val op = proto.ConfigRequest.Operation + .newBuilder() + .setGetOption( + proto.ConfigRequest.GetOption.newBuilder().addKeys("spark.sql.shuffle.partitions")) + .build() + val ex = intercept[SparkException] { + client.config(op) + } + assert(ex.getCause.isInstanceOf[StatusRuntimeException]) + assert( + ex.getCause.asInstanceOf[StatusRuntimeException].getStatus.getCode == + Status.Code.DEADLINE_EXCEEDED) + latch.countDown() + } + + test("interrupt deadline fires on slow server") { + val latch = new CountDownLatch(1) + val slowService = new DummySparkConnectService { + override def interrupt( + request: proto.InterruptRequest, + responseObserver: StreamObserver[proto.InterruptResponse]): Unit = { + latch.await(5, TimeUnit.SECONDS) + super.interrupt(request, responseObserver) + } + } + server = NettyServerBuilder.forPort(0).addService(slowService).build().start() + service = slowService + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .rpcDeadlines(RpcDeadlines(interrupt = Some(FiniteDuration(50, TimeUnit.MILLISECONDS)))) + .retryPolicy(RetryPolicy(maxRetries = Some(0), canRetry = _ => false, name = "NoRetry")) + .build() + val ex = intercept[SparkException] { + client.interruptAll() + } + assert(ex.getCause.isInstanceOf[StatusRuntimeException]) + assert( + ex.getCause.asInstanceOf[StatusRuntimeException].getStatus.getCode == + Status.Code.DEADLINE_EXCEEDED) + latch.countDown() + } + + test("releaseSession deadline fires on slow server") { + val latch = new CountDownLatch(1) + val slowService = new DummySparkConnectService { + override def releaseSession( + request: proto.ReleaseSessionRequest, + responseObserver: StreamObserver[proto.ReleaseSessionResponse]): Unit = { + latch.await(5, TimeUnit.SECONDS) + responseObserver.onNext( + proto.ReleaseSessionResponse + .newBuilder() + .setSessionId(request.getSessionId) + .build()) + responseObserver.onCompleted() + } + } + server = NettyServerBuilder.forPort(0).addService(slowService).build().start() + service = slowService + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .rpcDeadlines( + RpcDeadlines(releaseSession = Some(FiniteDuration(50, TimeUnit.MILLISECONDS)))) + .retryPolicy(RetryPolicy(maxRetries = Some(0), canRetry = _ => false, name = "NoRetry")) + .build() + val ex = intercept[SparkException] { + client.releaseSession() + } + assert(ex.getCause.isInstanceOf[StatusRuntimeException]) + assert( + ex.getCause.asInstanceOf[StatusRuntimeException].getStatus.getCode == + Status.Code.DEADLINE_EXCEEDED) + latch.countDown() + } + + // Note: artifactStatus, addArtifacts, cloneSession, getStatus, and fetchErrorDetails deadlines + // use the same withDeadline() mechanism as analyzePlan/config/interrupt above. They are + // exercised through ArtifactManager or internal flows that are harder to unit-test in isolation. + + test("SPARK-56538: RpcDeadlines.disabled has no configured deadlines in toString") { + assert(RpcDeadlines.disabled.toString === "RpcDeadlines(all disabled)") + } } class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectServiceImplBase { + @volatile var analyzePlanAwait: Option[CountDownLatch] = None + private var inputPlan: proto.Plan = _ private val inputArtifactRequests: mutable.ListBuffer[AddArtifactsRequest] = mutable.ListBuffer.empty @@ -908,6 +1177,7 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer override def analyzePlan( request: AnalyzePlanRequest, responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = { + analyzePlanAwait.foreach(_.await()) // Reply with a dummy response using the same client ID val requestSessionId = request.getSessionId synchronized { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala index ad39a3fb29f24..dde1e81701476 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala @@ -16,9 +16,12 @@ */ package org.apache.spark.sql.connect.client +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration.FiniteDuration import scala.jdk.CollectionConverters._ -import io.grpc.ManagedChannel +import io.grpc.{Deadline, ManagedChannel} import org.apache.spark.connect.proto._ import org.apache.spark.sql.util.CloseableIterator @@ -29,11 +32,19 @@ private[connect] class CustomSparkConnectBlockingStub( private val stub = SparkConnectServiceGrpc.newBlockingStub(channel) + private def withDeadline( + d: Option[FiniteDuration]): SparkConnectServiceGrpc.SparkConnectServiceBlockingStub = + d.map(dur => stub.withDeadline(Deadline.after(dur.toMillis, TimeUnit.MILLISECONDS))) + .getOrElse(stub) + private val retryHandler = stubState.retryHandler // GrpcExceptionConverter with a GRPC stub for fetching error details from server. private val grpcExceptionConverter = stubState.exceptionConverter + // Non-reattachable executePlan intentionally has no deadline: a timeout here would kill the + // server-side execution with no way to recover (there is no ReattachExecute for this path). + // Use reattachable execution for long-running queries that need deadline protection. def executePlan(request: ExecutePlanRequest): CloseableIterator[ExecutePlanResponse] = { grpcExceptionConverter.convert( request.getSessionId, @@ -63,8 +74,13 @@ private[connect] class CustomSparkConnectBlockingStub( request.getUserContext, request.getClientType, stubState.responseValidator.wrapIterator( - // ExecutePlanResponseReattachableIterator does all retries by itself, don't wrap it here - new ExecutePlanResponseReattachableIterator(request, channel, stubState.retryHandler))) + // Reattachable iterator retries internally; omit RetryIterator wrapper here. + new ExecutePlanResponseReattachableIterator( + request, + channel, + stubState.retryHandler, + stubState.rpcDeadlines.reattachableExecutePlan, + stubState.rpcDeadlines.reattachExecute))) } } @@ -75,7 +91,7 @@ private[connect] class CustomSparkConnectBlockingStub( request.getClientType) { retryHandler.retry { stubState.responseValidator.verifyResponse { - stub.analyzePlan(request) + withDeadline(stubState.rpcDeadlines.analyzePlan).analyzePlan(request) } } } @@ -88,7 +104,7 @@ private[connect] class CustomSparkConnectBlockingStub( request.getClientType) { retryHandler.retry { stubState.responseValidator.verifyResponse { - stub.config(request) + withDeadline(stubState.rpcDeadlines.config).config(request) } } } @@ -101,7 +117,7 @@ private[connect] class CustomSparkConnectBlockingStub( request.getClientType) { retryHandler.retry { stubState.responseValidator.verifyResponse { - stub.interrupt(request) + withDeadline(stubState.rpcDeadlines.interrupt).interrupt(request) } } } @@ -114,7 +130,7 @@ private[connect] class CustomSparkConnectBlockingStub( request.getClientType) { retryHandler.retry { stubState.responseValidator.verifyResponse { - stub.releaseSession(request) + withDeadline(stubState.rpcDeadlines.releaseSession).releaseSession(request) } } } @@ -127,7 +143,7 @@ private[connect] class CustomSparkConnectBlockingStub( request.getClientType) { retryHandler.retry { stubState.responseValidator.verifyResponse { - stub.artifactStatus(request) + withDeadline(stubState.rpcDeadlines.artifactStatus).artifactStatus(request) } } } @@ -140,7 +156,7 @@ private[connect] class CustomSparkConnectBlockingStub( request.getClientType) { retryHandler.retry { stubState.responseValidator.verifyResponse { - stub.cloneSession(request) + withDeadline(stubState.rpcDeadlines.cloneSession).cloneSession(request) } } } @@ -153,7 +169,7 @@ private[connect] class CustomSparkConnectBlockingStub( request.getClientType) { retryHandler.retry { stubState.responseValidator.verifyResponse { - stub.getStatus(request) + withDeadline(stubState.rpcDeadlines.getStatus).getStatus(request) } } } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala index 187c2842a0bc8..4edd84ed5de76 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala @@ -16,7 +16,9 @@ */ package org.apache.spark.sql.connect.client -import io.grpc.ManagedChannel +import java.util.concurrent.TimeUnit + +import io.grpc.{Deadline, ManagedChannel} import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, SparkConnectServiceGrpc} @@ -29,7 +31,15 @@ private[client] class CustomSparkConnectStub( def addArtifacts(responseObserver: StreamObserver[AddArtifactsResponse]) : StreamObserver[AddArtifactsRequest] = { + // Evaluated on every call (including each RetryStreamObserver retry) so each attempt gets a + // fresh Deadline object with a new absolute expiry, not one shared from the first attempt. + def freshStub: SparkConnectServiceGrpc.SparkConnectServiceStub = + stubState.rpcDeadlines.addArtifacts + .map(d => stub.withDeadline(Deadline.after(d.toMillis, TimeUnit.MILLISECONDS))) + .getOrElse(stub) stubState.responseValidator.wrapStreamObserver( - stubState.retryHandler.RetryStreamObserver(responseObserver, stub.addArtifacts)) + stubState.retryHandler.RetryStreamObserver( + responseObserver, + (obs: StreamObserver[AddArtifactsResponse]) => freshStub.addArtifacts(obs))) } } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala index 131a2e77cc431..abc7be4861a8e 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.connect.client import java.util.UUID +import java.util.concurrent.TimeUnit +import scala.concurrent.duration.FiniteDuration import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal -import io.grpc.{ManagedChannel, StatusRuntimeException} +import io.grpc.{Deadline, ManagedChannel, Status, StatusRuntimeException} import io.grpc.protobuf.StatusProto import io.grpc.stub.StreamObserver @@ -53,7 +55,9 @@ import org.apache.spark.sql.util.WrappedCloseableIterator class ExecutePlanResponseReattachableIterator( request: proto.ExecutePlanRequest, channel: ManagedChannel, - retryHandler: GrpcRetryHandler) + retryHandler: GrpcRetryHandler, + reattachableExecutePlanDeadline: Option[FiniteDuration] = None, + reattachExecuteDeadline: Option[FiniteDuration] = None) extends WrappedCloseableIterator[proto.ExecutePlanResponse] with Logging { @@ -80,6 +84,12 @@ class ExecutePlanResponseReattachableIterator( private val rawBlockingStub = proto.SparkConnectServiceGrpc.newBlockingStub(channel) private val rawAsyncStub = proto.SparkConnectServiceGrpc.newStub(channel) + private def stubWithDeadline(deadline: Option[FiniteDuration]) + : proto.SparkConnectServiceGrpc.SparkConnectServiceBlockingStub = + deadline + .map(d => rawBlockingStub.withDeadline(Deadline.after(d.toMillis, TimeUnit.MILLISECONDS))) + .getOrElse(rawBlockingStub) + private val initialRequest: proto.ExecutePlanRequest = request .toBuilder() .addRequestOptions( @@ -104,7 +114,10 @@ class ExecutePlanResponseReattachableIterator( // throw error on first iter.hasNext() or iter.next() // Visible for testing. private[connect] var iter: Option[java.util.Iterator[proto.ExecutePlanResponse]] = - Some(rawBlockingStub.executePlan(initialRequest)) + Some(stubWithDeadline(reattachableExecutePlanDeadline).executePlan(initialRequest)) + + // When true, an empty iterator triggers a fresh ExecutePlan instead of ReattachExecute. + private var restartExecutionOnNextRetry: Boolean = false // Server side session ID, used to detect if the server side session changed. This is set upon // receiving the first response from the server. @@ -230,8 +243,13 @@ class ExecutePlanResponseReattachableIterator( */ private def callIter[V](iterFun: java.util.Iterator[proto.ExecutePlanResponse] => V) = { try { - if (iter.isEmpty) { - iter = Some(rawBlockingStub.reattachExecute(createReattachExecuteRequest())) + if (iter.isEmpty && restartExecutionOnNextRetry) { + iter = Some(stubWithDeadline(reattachableExecutePlanDeadline).executePlan(initialRequest)) + restartExecutionOnNextRetry = false + } else if (iter.isEmpty) { + iter = Some( + stubWithDeadline(reattachExecuteDeadline) + .reattachExecute(createReattachExecuteRequest())) } iterFun(iter.get) } catch { @@ -248,7 +266,20 @@ class ExecutePlanResponseReattachableIterator( ex) } // Try a new ExecutePlan, and throw upstream for retry. - iter = Some(rawBlockingStub.executePlan(initialRequest)) + iter = None + restartExecutionOnNextRetry = true + val error = new RetryException() + error.addSuppressed(ex) + throw error + case ex: StatusRuntimeException if ex.getStatus.getCode == Status.Code.DEADLINE_EXCEEDED => + // The per-RPC deadline fired. The server-side operation is still alive; we clear the + // iterator and raise RetryException so the outer retry loop opens a fresh + // ReattachExecute stream (a new per-RPC deadline countdown) to resume receiving results. + logDebug( + s"Deadline exceeded on stream for operation $operationId; will reattach. " + + s"(last response: $lastReturnedResponseId)") + iter = None + restartExecutionOnNextRetry = false // defensive: deadline != operation lost val error = new RetryException() error.addSuppressed(ex) throw error diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index 6e5304b8cc772..b52261e5efde5 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.connect.client import java.time.DateTimeException +import java.util.concurrent.TimeUnit +import scala.concurrent.duration.FiniteDuration import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag import com.google.rpc.ErrorInfo -import io.grpc.{ManagedChannel, StatusRuntimeException} +import io.grpc.{Deadline, ManagedChannel, Status, StatusRuntimeException} import io.grpc.protobuf.StatusProto import org.json4s.{DefaultFormats, Formats} import org.json4s.jackson.JsonMethods @@ -49,10 +51,18 @@ import org.apache.spark.util.ArrayImplicits._ * the ErrorInfo is missing, the exception will be constructed based on the StatusRuntimeException * itself. */ -private[client] class GrpcExceptionConverter(channel: ManagedChannel) extends Logging { +private[client] class GrpcExceptionConverter( + channel: ManagedChannel, + fetchErrorDetailsDeadline: Option[FiniteDuration] = None) + extends Logging { import GrpcExceptionConverter._ - val grpcStub = SparkConnectServiceGrpc.newBlockingStub(channel) + private val grpcStub = SparkConnectServiceGrpc.newBlockingStub(channel) + + private def stubWithDeadline: SparkConnectServiceGrpc.SparkConnectServiceBlockingStub = + fetchErrorDetailsDeadline + .map(d => grpcStub.withDeadline(Deadline.after(d.toMillis, TimeUnit.MILLISECONDS))) + .getOrElse(grpcStub) def convert[T](sessionId: String, userContext: UserContext, clientType: String)(f: => T): T = { try { @@ -108,7 +118,7 @@ private[client] class GrpcExceptionConverter(channel: ManagedChannel) extends Lo } try { - val errorDetailsResponse = grpcStub.fetchErrorDetails( + val errorDetailsResponse = stubWithDeadline.fetchErrorDetails( FetchErrorDetailsRequest .newBuilder() .setSessionId(sessionId) @@ -160,11 +170,25 @@ private[client] class GrpcExceptionConverter(channel: ManagedChannel) extends Lo } // If no ErrorInfo is found, create a SparkException based on the StatusRuntimeException. + val (message, cause) = if (ex.getStatus.getCode == Status.Code.DEADLINE_EXCEEDED) { + val msg = s"${ex.toString}: RPC deadline exceeded. Deadlines can be configured via " + + "SparkConnectClient.Builder.rpcDeadlines(). To disable all deadlines: " + + "SparkConnectClient.builder().rpcDeadlines(RpcDeadlines.disabled).build()" + // For DEADLINE_EXCEEDED, we pass `ex` itself as the cause rather than `ex.getCause`. + // StatusRuntimeException.getCause() returns status.getCause(), which is always null for + // client-side deadline fires (gRPC constructs the status without a wrapped cause). Using + // ex.getCause would produce a SparkException with cause = null, losing the gRPC status + // code and description from the exception chain. Passing ex preserves full context and + // allows callers to programmatically inspect the status code via getCause().getStatus(). + (msg, ex) + } else { + (ex.toString, ex.getCause) + } new SparkException( - message = ex.toString, - cause = ex.getCause, + message = message, + cause = cause, errorClass = Some("CONNECT_CLIENT_UNEXPECTED_MISSING_SQL_STATE"), - messageParameters = Map("message" -> ex.toString), + messageParameters = Map("message" -> message), context = Array.empty) } } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala index 5b5c4b517923e..b396ee96a4e5b 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala @@ -169,6 +169,11 @@ object RetryPolicy extends Logging { return true } + // DEADLINE_EXCEEDED on the reattachable execute path is handled directly in + // ExecutePlanResponseReattachableIterator, which converts it to RetryException so the + // server-side operation continues and a fresh ReattachExecute is issued. We do not retry + // other RPCs on deadline: those are non-idempotent or the deadline signals a genuine + // timeout that retrying won't fix. false case _ => false } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RpcDeadlines.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RpcDeadlines.scala new file mode 100644 index 0000000000000..e15c68091760a --- /dev/null +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RpcDeadlines.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.client + +import scala.concurrent.duration.{DurationInt, FiniteDuration} + +/** + * Per-RPC deadline configuration. Each field controls the deadline for one gRPC call. Set a field + * to None to disable the deadline for that call. Use RpcDeadlines.disabled to create an instance + * with all deadlines disabled. + * + * Note on reattachableExecutePlan and reattachExecute: these deadlines apply to each individual + * gRPC stream segment, not to the overall query execution lifetime. When a deadline fires, the + * server-side operation continues running; the client simply opens a new ReattachExecute stream + * to resume receiving results. This avoids hanging connections while preserving execution. + * + * Non-reattachable ExecutePlan has no deadline because a timeout there would kill the execution + * with no recovery path. ReleaseExecute has no deadline (fire-and-forget cleanup). + */ +private[sql] case class RpcDeadlines( + reattachableExecutePlan: Option[FiniteDuration] = Some(10.minutes), + reattachExecute: Option[FiniteDuration] = Some(10.minutes), + analyzePlan: Option[FiniteDuration] = Some(1.hour), + addArtifacts: Option[FiniteDuration] = Some(1.hour), + config: Option[FiniteDuration] = Some(10.minutes), + interrupt: Option[FiniteDuration] = Some(10.minutes), + releaseSession: Option[FiniteDuration] = Some(10.minutes), + artifactStatus: Option[FiniteDuration] = Some(10.minutes), + cloneSession: Option[FiniteDuration] = Some(10.minutes), + getStatus: Option[FiniteDuration] = Some(10.minutes), + fetchErrorDetails: Option[FiniteDuration] = Some(10.minutes)) { + + // Validate all fields: each must be a positive duration or None. + private lazy val namedFields: Seq[(String, Option[FiniteDuration])] = + productElementNames.toSeq.zip( + productIterator.map(_.asInstanceOf[Option[FiniteDuration]]).toSeq) + + namedFields.foreach { case (name, opt) => + opt.foreach(d => + require(d.toMillis > 0, s"RpcDeadlines.$name must be a positive duration, got $d")) + } + + override def toString: String = { + val configured = namedFields.collect { case (name, Some(d)) => s"$name=$d" } + if (configured.isEmpty) "RpcDeadlines(all disabled)" + else s"RpcDeadlines(${configured.mkString(", ")})" + } +} + +private[sql] object RpcDeadlines { + + /** + * Creates an RpcDeadlines with all deadlines disabled (no per-RPC timeout on any call). Use + * this when you want to rely solely on server-side or network-layer timeouts. + */ + val disabled: RpcDeadlines = RpcDeadlines( + reattachableExecutePlan = None, + reattachExecute = None, + analyzePlan = None, + addArtifacts = None, + config = None, + interrupt = None, + releaseSession = None, + artifactStatus = None, + cloneSession = None, + getStatus = None, + fetchErrorDetails = None) +} diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index d9b9ba35b5e6c..cab3540a59b1f 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -54,7 +54,7 @@ private[sql] class SparkConnectClient( private val userContext: UserContext = configuration.userContext - private[this] val stubState = new SparkConnectStubState(channel, configuration.retryPolicies) + private[this] val stubState = new SparkConnectStubState(channel, configuration) private[this] val bstub = new CustomSparkConnectBlockingStub(channel, stubState) private[this] val stub = @@ -804,6 +804,11 @@ object SparkConnectClient { retryPolicy(List(policy)) } + def rpcDeadlines(deadlines: RpcDeadlines): Builder = { + _configuration = _configuration.copy(rpcDeadlines = deadlines) + this + } + private object URIParams { val PARAM_USER_ID = "user_id" val PARAM_USE_SSL = "use_ssl" @@ -1037,6 +1042,7 @@ object SparkConnectClient { userAgent: String = genUserAgent( sys.env.getOrElse("SPARK_CONNECT_USER_AGENT", DEFAULT_USER_AGENT)), retryPolicies: Seq[RetryPolicy] = RetryPolicy.defaultPolicies(), + rpcDeadlines: RpcDeadlines = RpcDeadlines(), useReattachableExecute: Boolean = true, interceptors: List[ClientInterceptor] = List.empty, sessionId: Option[String] = None, diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala index 2ec9ecad90309..365f21e57a900 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala @@ -26,14 +26,25 @@ import org.apache.spark.internal.Logging // that the same stub instance is used for all requests from the same client. In addition, // this class provides access to the commonly configured retry policy and exception conversion // logic. -class SparkConnectStubState(channel: ManagedChannel, retryPolicies: Seq[RetryPolicy]) +class SparkConnectStubState( + channel: ManagedChannel, + val configuration: SparkConnectClient.Configuration) extends Logging { + val rpcDeadlines: RpcDeadlines = configuration.rpcDeadlines + + { + if (log.isInfoEnabled) { + logInfo(s"Spark Connect RPC deadlines: $rpcDeadlines") + } + } + // Manages the retry handler logic used by the stubs. - lazy val retryHandler = new GrpcRetryHandler(retryPolicies) + lazy val retryHandler = new GrpcRetryHandler(configuration.retryPolicies) // Responsible to convert the GRPC Status exceptions into Spark exceptions. - lazy val exceptionConverter: GrpcExceptionConverter = new GrpcExceptionConverter(channel) + lazy val exceptionConverter: GrpcExceptionConverter = + new GrpcExceptionConverter(channel, configuration.rpcDeadlines.fetchErrorDetails) // Provides a helper for validating the responses processed by the stub. lazy val responseValidator = new ResponseValidator() diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala index 748e7623aba3a..d4ecbca28e9dc 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala @@ -280,9 +280,9 @@ trait SparkConnectServerTest extends SharedSparkSession { protected def withCustomBlockingStub( retryPolicies: Seq[RetryPolicy] = RetryPolicy.defaultPolicies())( f: CustomSparkConnectBlockingStub => Unit): Unit = { - val conf = SparkConnectClient.Configuration(port = serverPort) + val conf = SparkConnectClient.Configuration(port = serverPort, retryPolicies = retryPolicies) val channel = conf.createChannel() - val stubState = new SparkConnectStubState(channel, retryPolicies) + val stubState = new SparkConnectStubState(channel, conf) val bstub = new CustomSparkConnectBlockingStub(channel, stubState) try f(bstub) finally { From 98b59203205de47ff4beb4fc82d508cb1f82067a Mon Sep 17 00:00:00 2001 From: pranavdev022 Date: Sun, 26 Apr 2026 20:16:54 +0200 Subject: [PATCH 2/2] apply fixes --- python/pyspark/sql/connect/client/core.py | 58 ++--- .../connect/client/test_client_retries.py | 4 - .../client/SparkConnectClientSuite.scala | 201 +++++++----------- .../CustomSparkConnectBlockingStub.scala | 3 + ...cutePlanResponseReattachableIterator.scala | 12 +- .../sql/connect/client/RpcDeadlines.scala | 2 +- .../client/SparkConnectStubState.scala | 6 +- 7 files changed, 95 insertions(+), 191 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 7c62d7af63cd7..5253fffd9d4d0 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -22,7 +22,6 @@ ] import atexit -import math from dataclasses import dataclass, fields import pyspark @@ -113,7 +112,6 @@ from pyspark.errors import ( PySparkAssertionError, PySparkNotImplementedError, - PySparkTypeError, PySparkValueError, ) from pyspark.sql.connect.shell.progress import Progress, ProgressHandler, from_proto @@ -129,7 +127,7 @@ PYSPARK_ROOT = os.path.dirname(pyspark.__file__) -@dataclass +@dataclass(frozen=True) class RpcDeadlines: """Per-RPC timeout configuration for :class:`SparkConnectClient`. @@ -159,24 +157,13 @@ class RpcDeadlines: def __post_init__(self) -> None: for field in fields(self): value = getattr(self, field.name) - if value is not None: - if not isinstance(value, (int, float)): - raise PySparkTypeError( - errorClass="NOT_EXPECTED_TYPE", - messageParameters={ - "arg_name": f"RpcDeadlines.{field.name}", - "expected_type": "int, float, or None", - "arg_type": type(value).__name__, - }, - ) - fv = float(value) - if not math.isfinite(fv) or fv <= 0: - raise PySparkValueError( - message=( - f"RpcDeadlines.{field.name} must be a finite positive number or None, " - f"got {value!r}" - ), - ) + if value is not None and value <= 0: + raise PySparkValueError( + message=( + f"RpcDeadlines.{field.name} must be a positive number or None, " + f"got {value!r}" + ), + ) @classmethod def disabled(cls) -> "RpcDeadlines": @@ -774,8 +761,9 @@ def __init__( The server will attempt to use this size if it is set and within the valid range ([1KB, max batch size on server]). Otherwise, the server's maximum batch size is used. rpc_deadlines : RpcDeadlines, optional - Per-RPC gRPC call timeouts in seconds. Defaults follow SPARK-56538; use - :meth:`RpcDeadlines.disabled` to turn off all per-RPC deadlines. + Per-RPC gRPC call timeouts in seconds (10 min for most RPCs, + 1 hour for analyze/addArtifacts, none for non-reattachable execute). + Use :meth:`RpcDeadlines.disabled` to turn off all deadlines. """ self.thread_local = threading.local() @@ -814,29 +802,7 @@ def __init__( self._rpc_deadlines: RpcDeadlines = ( rpc_deadlines if rpc_deadlines is not None else RpcDeadlines() ) - d = self._rpc_deadlines - configured = [ - (name, val) - for name, val in [ - ("reattachableExecutePlan", d.reattachable_execute_plan), - ("reattachExecute", d.reattach_execute), - ("analyzePlan", d.analyze_plan), - ("addArtifacts", d.add_artifacts), - ("config", d.config), - ("interrupt", d.interrupt), - ("releaseSession", d.release_session), - ("artifactStatus", d.artifact_status), - ("cloneSession", d.clone_session), - ("getStatus", d.get_status), - ("fetchErrorDetails", d.fetch_error_details), - ] - if val is not None - ] - if configured: - logger.info( - "Spark Connect RPC deadlines: " - + ", ".join(f"{name}: {val}s" for name, val in configured) - ) + logger.info("Spark Connect RPC deadlines: %s", self._rpc_deadlines) self._artifact_manager = ArtifactManager( self._user_id, self._session_id, diff --git a/python/pyspark/sql/tests/connect/client/test_client_retries.py b/python/pyspark/sql/tests/connect/client/test_client_retries.py index 5e212f39902f9..aa8612e1e232f 100644 --- a/python/pyspark/sql/tests/connect/client/test_client_retries.py +++ b/python/pyspark/sql/tests/connect/client/test_client_retries.py @@ -315,10 +315,6 @@ def test_rpc_deadlines_rejects_non_positive_values(self): replace(RpcDeadlines(), config=0) with self.assertRaises(ValueError): replace(RpcDeadlines(), reattachable_execute_plan=-0.001) - with self.assertRaises(ValueError): - replace(RpcDeadlines(), config=float("nan")) - with self.assertRaises(ValueError): - replace(RpcDeadlines(), config=float("inf")) with self.assertRaises(ValueError): replace(RpcDeadlines(), config=float("-inf")) # None (disabled) and positive values should be accepted without error. diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index 1fc6c13e146f9..2f4d95984083c 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -826,32 +826,73 @@ class SparkConnectClientSuite extends ConnectFunSuite { } } - test("analyzePlan deadline fires on slow server") { + test("one-shot RPC deadlines fire on slow server") { val latch = new CountDownLatch(1) val slowService = new DummySparkConnectService { override def analyzePlan( request: AnalyzePlanRequest, responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = { - latch.await(5, TimeUnit.SECONDS) + latch.await(250, TimeUnit.MILLISECONDS) super.analyzePlan(request, responseObserver) } + override def config( + request: proto.ConfigRequest, + responseObserver: StreamObserver[proto.ConfigResponse]): Unit = { + latch.await(250, TimeUnit.MILLISECONDS) + super.config(request, responseObserver) + } + override def interrupt( + request: proto.InterruptRequest, + responseObserver: StreamObserver[proto.InterruptResponse]): Unit = { + latch.await(250, TimeUnit.MILLISECONDS) + super.interrupt(request, responseObserver) + } + override def releaseSession( + request: proto.ReleaseSessionRequest, + responseObserver: StreamObserver[proto.ReleaseSessionResponse]): Unit = { + latch.await(250, TimeUnit.MILLISECONDS) + responseObserver.onNext( + proto.ReleaseSessionResponse + .newBuilder() + .setSessionId(request.getSessionId) + .build()) + responseObserver.onCompleted() + } } server = NettyServerBuilder.forPort(0).addService(slowService).build().start() service = slowService + val d = FiniteDuration(50, TimeUnit.MILLISECONDS) client = SparkConnectClient .builder() .connectionString(s"sc://localhost:${server.getPort}") - .rpcDeadlines(RpcDeadlines(analyzePlan = Some(FiniteDuration(50, TimeUnit.MILLISECONDS)))) + .rpcDeadlines( + RpcDeadlines( + analyzePlan = Some(d), + config = Some(d), + interrupt = Some(d), + releaseSession = Some(d))) .retryPolicy(RetryPolicy(maxRetries = Some(0), canRetry = _ => false, name = "NoRetry")) .build() - try { - val ex = intercept[SparkException] { - client.analyze(proto.AnalyzePlanRequest.newBuilder().setSessionId("abc123").build()) - } + + def expectDeadlineExceeded(thunk: => Any): Unit = { + val ex = intercept[SparkException] { thunk } assert(ex.getCause.isInstanceOf[StatusRuntimeException]) assert( ex.getCause.asInstanceOf[StatusRuntimeException].getStatus.getCode == Status.Code.DEADLINE_EXCEEDED) + } + + try { + expectDeadlineExceeded( + client.analyze(proto.AnalyzePlanRequest.newBuilder().setSessionId("abc123").build())) + val configOp = proto.ConfigRequest.Operation + .newBuilder() + .setGetOption( + proto.ConfigRequest.GetOption.newBuilder().addKeys("spark.sql.shuffle.partitions")) + .build() + expectDeadlineExceeded(client.config(configOp)) + expectDeadlineExceeded(client.interruptAll()) + expectDeadlineExceeded(client.releaseSession()) } finally { latch.countDown() } @@ -863,7 +904,7 @@ class SparkConnectClientSuite extends ConnectFunSuite { override def analyzePlan( request: AnalyzePlanRequest, responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = { - latch.await(5, TimeUnit.SECONDS) + latch.await(250, TimeUnit.MILLISECONDS) super.analyzePlan(request, responseObserver) } } @@ -877,14 +918,17 @@ class SparkConnectClientSuite extends ConnectFunSuite { .rpcDeadlines(RpcDeadlines(analyzePlan = Some(FiniteDuration(50, TimeUnit.MILLISECONDS)))) .retryPolicy(noRetry) .build() - val ex = intercept[SparkException] { - client.analyze(proto.AnalyzePlanRequest.newBuilder().setSessionId("abc123").build()) + try { + val ex = intercept[SparkException] { + client.analyze(proto.AnalyzePlanRequest.newBuilder().setSessionId("abc123").build()) + } + assert( + ex.getCause.asInstanceOf[StatusRuntimeException].getStatus.getCode == + Status.Code.DEADLINE_EXCEEDED) + client.shutdown() + } finally { + latch.countDown() } - assert( - ex.getCause.asInstanceOf[StatusRuntimeException].getStatus.getCode == - Status.Code.DEADLINE_EXCEEDED) - client.shutdown() - latch.countDown() client = SparkConnectClient .builder() @@ -915,7 +959,7 @@ class SparkConnectClientSuite extends ConnectFunSuite { .setOperationId(operationId) .setResponseId("r1") .build()) - executeLatch.await(5, TimeUnit.SECONDS) + executeLatch.await(250, TimeUnit.MILLISECONDS) responseObserver.onCompleted() } @@ -956,13 +1000,16 @@ class SparkConnectClientSuite extends ConnectFunSuite { .enableReattachableExecute() .build() - val iter = client.execute(buildPlan("select 1")) - val reattachableIter = ExecutePlanResponseReattachableIterator.fromIterator(iter) - iter.foreach(_ => ()) - executeLatch.countDown() + try { + val iter = client.execute(buildPlan("select 1")) + val reattachableIter = ExecutePlanResponseReattachableIterator.fromIterator(iter) + iter.foreach(_ => ()) - assert(reattachableIter.resultComplete, "iterator should complete after reattach") - assert(reattachCalled, "ReattachExecute should have been called") + assert(reattachableIter.resultComplete, "iterator should complete after reattach") + assert(reattachCalled, "ReattachExecute should have been called") + } finally { + executeLatch.countDown() + } } test("non-reattachable executePlan has no client-side deadline") { @@ -971,7 +1018,7 @@ class SparkConnectClientSuite extends ConnectFunSuite { override def executePlan( request: ExecutePlanRequest, responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { - latch.await(150, TimeUnit.MILLISECONDS) + latch.await(250, TimeUnit.MILLISECONDS) super.executePlan(request, responseObserver) } } @@ -985,109 +1032,14 @@ class SparkConnectClientSuite extends ConnectFunSuite { RpcDeadlines(reattachableExecutePlan = Some(FiniteDuration(50, TimeUnit.MILLISECONDS)))) .retryPolicy(RetryPolicy(maxRetries = Some(0), canRetry = _ => false, name = "NoRetry")) .build() - val iter = client.execute(buildPlan("select 1")) - iter.foreach(_ => ()) - } - - test("config deadline fires on slow server") { - val latch = new CountDownLatch(1) - val slowService = new DummySparkConnectService { - override def config( - request: proto.ConfigRequest, - responseObserver: StreamObserver[proto.ConfigResponse]): Unit = { - latch.await(5, TimeUnit.SECONDS) - super.config(request, responseObserver) - } - } - server = NettyServerBuilder.forPort(0).addService(slowService).build().start() - service = slowService - client = SparkConnectClient - .builder() - .connectionString(s"sc://localhost:${server.getPort}") - .rpcDeadlines(RpcDeadlines(config = Some(FiniteDuration(50, TimeUnit.MILLISECONDS)))) - .retryPolicy(RetryPolicy(maxRetries = Some(0), canRetry = _ => false, name = "NoRetry")) - .build() - val op = proto.ConfigRequest.Operation - .newBuilder() - .setGetOption( - proto.ConfigRequest.GetOption.newBuilder().addKeys("spark.sql.shuffle.partitions")) - .build() - val ex = intercept[SparkException] { - client.config(op) - } - assert(ex.getCause.isInstanceOf[StatusRuntimeException]) - assert( - ex.getCause.asInstanceOf[StatusRuntimeException].getStatus.getCode == - Status.Code.DEADLINE_EXCEEDED) - latch.countDown() - } - - test("interrupt deadline fires on slow server") { - val latch = new CountDownLatch(1) - val slowService = new DummySparkConnectService { - override def interrupt( - request: proto.InterruptRequest, - responseObserver: StreamObserver[proto.InterruptResponse]): Unit = { - latch.await(5, TimeUnit.SECONDS) - super.interrupt(request, responseObserver) - } - } - server = NettyServerBuilder.forPort(0).addService(slowService).build().start() - service = slowService - client = SparkConnectClient - .builder() - .connectionString(s"sc://localhost:${server.getPort}") - .rpcDeadlines(RpcDeadlines(interrupt = Some(FiniteDuration(50, TimeUnit.MILLISECONDS)))) - .retryPolicy(RetryPolicy(maxRetries = Some(0), canRetry = _ => false, name = "NoRetry")) - .build() - val ex = intercept[SparkException] { - client.interruptAll() - } - assert(ex.getCause.isInstanceOf[StatusRuntimeException]) - assert( - ex.getCause.asInstanceOf[StatusRuntimeException].getStatus.getCode == - Status.Code.DEADLINE_EXCEEDED) - latch.countDown() - } - - test("releaseSession deadline fires on slow server") { - val latch = new CountDownLatch(1) - val slowService = new DummySparkConnectService { - override def releaseSession( - request: proto.ReleaseSessionRequest, - responseObserver: StreamObserver[proto.ReleaseSessionResponse]): Unit = { - latch.await(5, TimeUnit.SECONDS) - responseObserver.onNext( - proto.ReleaseSessionResponse - .newBuilder() - .setSessionId(request.getSessionId) - .build()) - responseObserver.onCompleted() - } - } - server = NettyServerBuilder.forPort(0).addService(slowService).build().start() - service = slowService - client = SparkConnectClient - .builder() - .connectionString(s"sc://localhost:${server.getPort}") - .rpcDeadlines( - RpcDeadlines(releaseSession = Some(FiniteDuration(50, TimeUnit.MILLISECONDS)))) - .retryPolicy(RetryPolicy(maxRetries = Some(0), canRetry = _ => false, name = "NoRetry")) - .build() - val ex = intercept[SparkException] { - client.releaseSession() + try { + val iter = client.execute(buildPlan("select 1")) + iter.foreach(_ => ()) + } finally { + latch.countDown() } - assert(ex.getCause.isInstanceOf[StatusRuntimeException]) - assert( - ex.getCause.asInstanceOf[StatusRuntimeException].getStatus.getCode == - Status.Code.DEADLINE_EXCEEDED) - latch.countDown() } - // Note: artifactStatus, addArtifacts, cloneSession, getStatus, and fetchErrorDetails deadlines - // use the same withDeadline() mechanism as analyzePlan/config/interrupt above. They are - // exercised through ArtifactManager or internal flows that are harder to unit-test in isolation. - test("SPARK-56538: RpcDeadlines.disabled has no configured deadlines in toString") { assert(RpcDeadlines.disabled.toString === "RpcDeadlines(all disabled)") } @@ -1095,8 +1047,6 @@ class SparkConnectClientSuite extends ConnectFunSuite { class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectServiceImplBase { - @volatile var analyzePlanAwait: Option[CountDownLatch] = None - private var inputPlan: proto.Plan = _ private val inputArtifactRequests: mutable.ListBuffer[AddArtifactsRequest] = mutable.ListBuffer.empty @@ -1177,7 +1127,6 @@ class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectSer override def analyzePlan( request: AnalyzePlanRequest, responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = { - analyzePlanAwait.foreach(_.await()) // Reply with a dummy response using the same client ID val requestSessionId = request.getSessionId synchronized { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala index dde1e81701476..a4406c2a68fda 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala @@ -32,6 +32,9 @@ private[connect] class CustomSparkConnectBlockingStub( private val stub = SparkConnectServiceGrpc.newBlockingStub(channel) + // Build a fresh deadline-bound stub on every call so each retry gets a full budget. + // gRPC deadlines are absolute (fixed at stub construction); a shared stub would let + // later attempts inherit an already-elapsed deadline and time out prematurely. private def withDeadline( d: Option[FiniteDuration]): SparkConnectServiceGrpc.SparkConnectServiceBlockingStub = d.map(dur => stub.withDeadline(Deadline.after(dur.toMillis, TimeUnit.MILLISECONDS))) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala index abc7be4861a8e..e6428270c20ba 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala @@ -116,9 +116,6 @@ class ExecutePlanResponseReattachableIterator( private[connect] var iter: Option[java.util.Iterator[proto.ExecutePlanResponse]] = Some(stubWithDeadline(reattachableExecutePlanDeadline).executePlan(initialRequest)) - // When true, an empty iterator triggers a fresh ExecutePlan instead of ReattachExecute. - private var restartExecutionOnNextRetry: Boolean = false - // Server side session ID, used to detect if the server side session changed. This is set upon // receiving the first response from the server. private var serverSideSessionId: Option[String] = None @@ -243,10 +240,7 @@ class ExecutePlanResponseReattachableIterator( */ private def callIter[V](iterFun: java.util.Iterator[proto.ExecutePlanResponse] => V) = { try { - if (iter.isEmpty && restartExecutionOnNextRetry) { - iter = Some(stubWithDeadline(reattachableExecutePlanDeadline).executePlan(initialRequest)) - restartExecutionOnNextRetry = false - } else if (iter.isEmpty) { + if (iter.isEmpty) { iter = Some( stubWithDeadline(reattachExecuteDeadline) .reattachExecute(createReattachExecuteRequest())) @@ -266,8 +260,7 @@ class ExecutePlanResponseReattachableIterator( ex) } // Try a new ExecutePlan, and throw upstream for retry. - iter = None - restartExecutionOnNextRetry = true + iter = Some(stubWithDeadline(reattachableExecutePlanDeadline).executePlan(initialRequest)) val error = new RetryException() error.addSuppressed(ex) throw error @@ -279,7 +272,6 @@ class ExecutePlanResponseReattachableIterator( s"Deadline exceeded on stream for operation $operationId; will reattach. " + s"(last response: $lastReturnedResponseId)") iter = None - restartExecutionOnNextRetry = false // defensive: deadline != operation lost val error = new RetryException() error.addSuppressed(ex) throw error diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RpcDeadlines.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RpcDeadlines.scala index e15c68091760a..0742a206b44c5 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RpcDeadlines.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RpcDeadlines.scala @@ -46,7 +46,7 @@ private[sql] case class RpcDeadlines( fetchErrorDetails: Option[FiniteDuration] = Some(10.minutes)) { // Validate all fields: each must be a positive duration or None. - private lazy val namedFields: Seq[(String, Option[FiniteDuration])] = + private val namedFields: Seq[(String, Option[FiniteDuration])] = productElementNames.toSeq.zip( productIterator.map(_.asInstanceOf[Option[FiniteDuration]]).toSeq) diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala index 365f21e57a900..4d2d15073f222 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala @@ -33,10 +33,8 @@ class SparkConnectStubState( val rpcDeadlines: RpcDeadlines = configuration.rpcDeadlines - { - if (log.isInfoEnabled) { - logInfo(s"Spark Connect RPC deadlines: $rpcDeadlines") - } + if (log.isInfoEnabled) { + logInfo(s"Spark Connect RPC deadlines: $rpcDeadlines") } // Manages the retry handler logic used by the stubs.