diff --git a/python/pyspark/sql/connect/client/artifact.py b/python/pyspark/sql/connect/client/artifact.py index 277c313eef621..94879171b41f0 100644 --- a/python/pyspark/sql/connect/client/artifact.py +++ b/python/pyspark/sql/connect/client/artifact.py @@ -169,6 +169,8 @@ def __init__( session_id: str, channel: grpc.Channel, metadata: Iterable[Tuple[str, str]], + add_artifacts_timeout: Optional[float] = None, + artifact_status_timeout: Optional[float] = None, ): self._user_context = proto.UserContext() if user_id is not None: @@ -176,6 +178,8 @@ def __init__( self._stub = grpc_lib.SparkConnectServiceStub(channel) self._session_id = session_id self._metadata = metadata + self._add_artifacts_timeout = add_artifacts_timeout + self._artifact_status_timeout = artifact_status_timeout def _parse_artifacts( self, path_or_uri: str, pyfile: bool, archive: bool, file: bool @@ -288,7 +292,11 @@ def _retrieve_responses( self, requests: Iterator[proto.AddArtifactsRequest] ) -> proto.AddArtifactsResponse: """Separated for the testing purpose.""" - return self._stub.AddArtifacts(requests, metadata=self._metadata) + return self._stub.AddArtifacts( + requests, + metadata=self._metadata, + timeout=self._add_artifacts_timeout, + ) def _request_add_artifacts(self, requests: Iterator[proto.AddArtifactsRequest]) -> None: response: proto.AddArtifactsResponse = self._retrieve_responses(requests) @@ -428,7 +436,9 @@ def is_cached_artifact(self, hash: str) -> bool: user_context=self._user_context, session_id=self._session_id, names=[artifactName] ) resp: proto.ArtifactStatusesResponse = self._stub.ArtifactStatus( - request, metadata=self._metadata + request, + metadata=self._metadata, + timeout=self._artifact_status_timeout, ) status = resp.statuses.get(artifactName) return status.exists if status is not None else False @@ -446,7 +456,9 @@ def get_cached_artifacts(self, hashes: list[str]) -> set[str]: user_context=self._user_context, session_id=self._session_id, names=artifact_names ) resp: proto.ArtifactStatusesResponse = self._stub.ArtifactStatus( - request, metadata=self._metadata + request, + metadata=self._metadata, + timeout=self._artifact_status_timeout, ) cached = set() diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 1a36699b78815..5253fffd9d4d0 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -17,10 +17,12 @@ __all__ = [ "ChannelBuilder", "DefaultChannelBuilder", + "RpcDeadlines", "SparkConnectClient", ] import atexit +from dataclasses import dataclass, fields import pyspark from pyspark.sql.connect.proto.base_pb2 import FetchErrorDetailsResponse @@ -107,7 +109,11 @@ from pyspark.sql.types import DataType, StructType from pyspark.util import PythonEvalType from pyspark.storagelevel import StorageLevel -from pyspark.errors import PySparkValueError, PySparkAssertionError, PySparkNotImplementedError +from pyspark.errors import ( + PySparkAssertionError, + PySparkNotImplementedError, + PySparkValueError, +) from pyspark.sql.connect.shell.progress import Progress, ProgressHandler, from_proto if TYPE_CHECKING: @@ -121,6 +127,65 @@ PYSPARK_ROOT = os.path.dirname(pyspark.__file__) +@dataclass(frozen=True) +class RpcDeadlines: + """Per-RPC timeout configuration for :class:`SparkConnectClient`. + + Each field controls the timeout (in seconds as a float) for one gRPC call type. Set a field to + ``None`` to disable the per-RPC timeout for that call. Use :meth:`RpcDeadlines.disabled` to + create an instance with all timeouts disabled. + + Note on ``reattachable_execute_plan`` and ``reattach_execute``: these timeouts apply to each + individual gRPC stream segment, not to the overall query execution lifetime. When a deadline + fires, the server-side operation continues running; the client opens a new ReattachExecute + stream to resume receiving results. Non-reattachable ExecutePlan has no deadline because a + timeout there would kill the execution with no recovery path. + """ + + reattachable_execute_plan: Optional[float] = 10 * 60 # 10 min + reattach_execute: Optional[float] = 10 * 60 # 10 min + analyze_plan: Optional[float] = 60 * 60 # 1 hour + add_artifacts: Optional[float] = 60 * 60 # 1 hour + config: Optional[float] = 10 * 60 # 10 min + interrupt: Optional[float] = 10 * 60 # 10 min + release_session: Optional[float] = 10 * 60 # 10 min + artifact_status: Optional[float] = 10 * 60 # 10 min + clone_session: Optional[float] = 10 * 60 # 10 min + get_status: Optional[float] = 10 * 60 # 10 min + fetch_error_details: Optional[float] = 10 * 60 # 10 min + + def __post_init__(self) -> None: + for field in fields(self): + value = getattr(self, field.name) + if value is not None and value <= 0: + raise PySparkValueError( + message=( + f"RpcDeadlines.{field.name} must be a positive number or None, " + f"got {value!r}" + ), + ) + + @classmethod + def disabled(cls) -> "RpcDeadlines": + """Create an :class:`RpcDeadlines` with all per-RPC timeouts disabled. + + Use this when you want to rely solely on server-side or network-layer timeouts. + """ + return cls( + reattachable_execute_plan=None, + reattach_execute=None, + analyze_plan=None, + add_artifacts=None, + config=None, + interrupt=None, + release_session=None, + artifact_status=None, + clone_session=None, + get_status=None, + fetch_error_details=None, + ) + + def _import_zstandard_if_available() -> Optional[Any]: """ Import zstandard if available, otherwise return None. @@ -647,6 +712,7 @@ def __init__( session_hooks: Optional[list["SparkSession.Hook"]] = None, allow_arrow_batch_chunking: bool = True, preferred_arrow_chunk_size: Optional[int] = None, + rpc_deadlines: Optional[RpcDeadlines] = None, ): """ Creates a new SparkSession for the Spark Connect interface. @@ -694,6 +760,10 @@ def __init__( results. The server will attempt to use this size if it is set and within the valid range ([1KB, max batch size on server]). Otherwise, the server's maximum batch size is used. + rpc_deadlines : RpcDeadlines, optional + Per-RPC gRPC call timeouts in seconds (10 min for most RPCs, + 1 hour for analyze/addArtifacts, none for non-reattachable execute). + Use :meth:`RpcDeadlines.disabled` to turn off all deadlines. """ self.thread_local = threading.local() @@ -729,8 +799,17 @@ def __init__( self._channel = self._builder.toChannel() self._closed = False self._internal_stub = grpc_lib.SparkConnectServiceStub(self._channel) + self._rpc_deadlines: RpcDeadlines = ( + rpc_deadlines if rpc_deadlines is not None else RpcDeadlines() + ) + logger.info("Spark Connect RPC deadlines: %s", self._rpc_deadlines) self._artifact_manager = ArtifactManager( - self._user_id, self._session_id, self._channel, self._builder.metadata() + self._user_id, + self._session_id, + self._channel, + self._builder.metadata(), + add_artifacts_timeout=self._rpc_deadlines.add_artifacts, + artifact_status_timeout=self._rpc_deadlines.artifact_status, ) self._use_reattachable_execute = use_reattachable_execute self._allow_arrow_batch_chunking = allow_arrow_batch_chunking @@ -1450,7 +1529,11 @@ def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult: try: for attempt in self._retrying(): with attempt: - resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata()) + resp = self._stub.AnalyzePlan( + req, + metadata=self._builder.metadata(), + timeout=self._rpc_deadlines.analyze_plan, + ) self._verify_response_integrity(resp) return AnalyzeResult.fromProto(resp) raise SparkConnectException("Invalid state during retry exception handling.") @@ -1479,7 +1562,12 @@ def handle_response(b: pb2.ExecutePlanResponse) -> None: if self._use_reattachable_execute: # Don't use retryHandler - own retry handling is inside. generator = ExecutePlanResponseReattachableIterator( - req, self._stub, self._retrying, self._builder.metadata() + req, + self._stub, + self._retrying, + self._builder.metadata(), + reattachable_execute_plan_timeout=self._rpc_deadlines.reattachable_execute_plan, + reattach_execute_timeout=self._rpc_deadlines.reattach_execute, ) try: for b in generator: @@ -1689,7 +1777,12 @@ def handle_response( if self._use_reattachable_execute: # Don't use retryHandler - own retry handling is inside. generator = ExecutePlanResponseReattachableIterator( - req, self._stub, self._retrying, self._builder.metadata() + req, + self._stub, + self._retrying, + self._builder.metadata(), + reattachable_execute_plan_timeout=self._rpc_deadlines.reattachable_execute_plan, + reattach_execute_timeout=self._rpc_deadlines.reattach_execute, ) try: for b in generator: @@ -1841,7 +1934,11 @@ def config(self, operation: pb2.ConfigRequest.Operation) -> ConfigResult: try: for attempt in self._retrying(): with attempt: - resp = self._stub.Config(req, metadata=self._builder.metadata()) + resp = self._stub.Config( + req, + metadata=self._builder.metadata(), + timeout=self._rpc_deadlines.config, + ) self._verify_response_integrity(resp) return ConfigResult.fromProto(resp) raise SparkConnectException("Invalid state during retry exception handling.") @@ -1883,7 +1980,11 @@ def interrupt_all(self) -> Optional[List[str]]: try: for attempt in self._retrying(): with attempt: - resp = self._stub.Interrupt(req, metadata=self._builder.metadata()) + resp = self._stub.Interrupt( + req, + metadata=self._builder.metadata(), + timeout=self._rpc_deadlines.interrupt, + ) self._verify_response_integrity(resp) return list(resp.interrupted_ids) raise SparkConnectException("Invalid state during retry exception handling.") @@ -1895,7 +1996,11 @@ def interrupt_tag(self, tag: str) -> Optional[List[str]]: try: for attempt in self._retrying(): with attempt: - resp = self._stub.Interrupt(req, metadata=self._builder.metadata()) + resp = self._stub.Interrupt( + req, + metadata=self._builder.metadata(), + timeout=self._rpc_deadlines.interrupt, + ) self._verify_response_integrity(resp) return list(resp.interrupted_ids) raise SparkConnectException("Invalid state during retry exception handling.") @@ -1907,7 +2012,11 @@ def interrupt_operation(self, op_id: str) -> Optional[List[str]]: try: for attempt in self._retrying(): with attempt: - resp = self._stub.Interrupt(req, metadata=self._builder.metadata()) + resp = self._stub.Interrupt( + req, + metadata=self._builder.metadata(), + timeout=self._rpc_deadlines.interrupt, + ) self._verify_response_integrity(resp) return list(resp.interrupted_ids) raise SparkConnectException("Invalid state during retry exception handling.") @@ -1923,7 +2032,11 @@ def release_session(self) -> None: try: for attempt in self._retrying(): with attempt: - resp = self._stub.ReleaseSession(req, metadata=self._builder.metadata()) + resp = self._stub.ReleaseSession( + req, + metadata=self._builder.metadata(), + timeout=self._rpc_deadlines.release_session, + ) self._verify_response_integrity(resp) return raise SparkConnectException("Invalid state during retry exception handling.") @@ -1975,7 +2088,11 @@ def _get_operation_statuses( try: for attempt in self._retrying(): with attempt: - resp = self._stub.GetStatus(req, metadata=self._builder.metadata()) + resp = self._stub.GetStatus( + req, + metadata=self._builder.metadata(), + timeout=self._rpc_deadlines.get_status, + ) self._verify_response_integrity(resp) return resp raise SparkConnectException("Invalid state during retry exception handling.") @@ -2100,7 +2217,11 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet req.user_context.user_id = self._user_id self._update_request_with_user_context_extensions(req) try: - return self._stub.FetchErrorDetails(req, metadata=self._builder.metadata()) + return self._stub.FetchErrorDetails( + req, + metadata=self._builder.metadata(), + timeout=self._rpc_deadlines.fetch_error_details, + ) except grpc.RpcError: return None @@ -2142,6 +2263,17 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn: # https://grpc.github.io/grpc/python/grpc.html#grpc.UnaryUnaryMultiCallable.__call__ error: grpc.Call = cast(grpc.Call, rpc_error) status_code: grpc.StatusCode = error.code() + if status_code == grpc.StatusCode.DEADLINE_EXCEEDED: + raise SparkConnectGrpcException( + message=( + f"{error}: RPC deadline exceeded. " + "The client applies per-RPC timeouts to prevent silent hangs " + "on broken connections. Deadlines can be configured via the " + "rpc_deadlines parameter of SparkConnectClient. To disable all " + "deadlines: SparkConnectClient(url, rpc_deadlines=RpcDeadlines.disabled())." + ), + grpc_status_code=status_code, + ) from None status: Optional[Status] = rpc_status.from_call(error) if status: for d in status.details: @@ -2466,7 +2598,9 @@ def clone(self, new_session_id: Optional[str] = None) -> "SparkConnectClient": for attempt in self._retrying(): with attempt: response: pb2.CloneSessionResponse = self._stub.CloneSession( - request, metadata=self._builder.metadata() + request, + metadata=self._builder.metadata(), + timeout=self._rpc_deadlines.clone_session, ) # Assert that the returned session ID matches the requested ID if one was provided @@ -2485,6 +2619,8 @@ def clone(self, new_session_id: Optional[str] = None) -> "SparkConnectClient": connection=new_connection, user_id=self._user_id, use_reattachable_execute=self._use_reattachable_execute, + session_hooks=self._session_hooks, + rpc_deadlines=self._rpc_deadlines, ) # Ensure the session ID is correctly set from the response new_client._session_id = response.new_session_id diff --git a/python/pyspark/sql/connect/client/reattach.py b/python/pyspark/sql/connect/client/reattach.py index 5750dd045a919..6babe7a8e4289 100644 --- a/python/pyspark/sql/connect/client/reattach.py +++ b/python/pyspark/sql/connect/client/reattach.py @@ -72,6 +72,8 @@ def __init__( stub: grpc_lib.SparkConnectServiceStub, retrying: Callable[[], Retrying], metadata: Iterable[Tuple[str, str]], + reattachable_execute_plan_timeout: Optional[float] = None, + reattach_execute_timeout: Optional[float] = None, ): self._request = request self._retrying = retrying @@ -86,6 +88,8 @@ def __init__( self._operation_id = str(uuid.uuid4()) self._stub = stub + self._reattachable_execute_plan_timeout = reattachable_execute_plan_timeout + self._reattach_execute_timeout = reattach_execute_timeout request.request_options.append( pb2.ExecutePlanRequest.RequestOption( reattach_options=pb2.ReattachOptions(reattachable=True) @@ -108,7 +112,11 @@ def __init__( self._metadata = metadata with disable_gc(): self._iterator: Optional[Iterator[pb2.ExecutePlanResponse]] = iter( - self._stub.ExecutePlan(self._initial_request, metadata=metadata) + self._stub.ExecutePlan( + self._initial_request, + metadata=metadata, + timeout=self._reattachable_execute_plan_timeout, + ) ) # Current item from this iterator. @@ -256,7 +264,9 @@ def _call_iter(self, iter_fun: Callable) -> Any: # we get a new iterator with ReattachExecute if it was unset. self._iterator = iter( self._stub.ReattachExecute( - self._create_reattach_execute_request(), metadata=self._metadata + self._create_reattach_execute_request(), + metadata=self._metadata, + timeout=self._reattach_execute_timeout, ) ) @@ -283,9 +293,23 @@ def _call_iter(self, iter_fun: Callable) -> Any: ) # Try a new ExecutePlan, and throw upstream for retry. self._iterator = iter( - self._stub.ExecutePlan(self._initial_request, metadata=self._metadata) + self._stub.ExecutePlan( + self._initial_request, + metadata=self._metadata, + timeout=self._reattachable_execute_plan_timeout, + ) ) raise RetryException() + elif e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + # The per-RPC deadline fired. The server-side operation is still alive; clear + # the iterator and raise RetryException so the retry loop opens a fresh + # ReattachExecute stream with a new deadline countdown to resume results. + logger.debug( + f"Deadline exceeded on stream for operation {self._operation_id}; " + f"will reattach. (last response: {self._last_returned_response_id})" + ) + self._iterator = None + raise RetryException() from e else: # Remove the iterator, so that a new one will be created after retry. self._iterator = None diff --git a/python/pyspark/sql/connect/client/retries.py b/python/pyspark/sql/connect/client/retries.py index 898d976f2628e..86ed004c40fe9 100644 --- a/python/pyspark/sql/connect/client/retries.py +++ b/python/pyspark/sql/connect/client/retries.py @@ -368,6 +368,11 @@ def can_retry(self, e: BaseException) -> bool: # All errors messages containing `RetryInfo` should be retried. return True + # DEADLINE_EXCEEDED on the reattachable execute path is handled directly in + # ExecutePlanResponseReattachableIterator, which converts it to RetryException so the + # server-side operation continues and a fresh ReattachExecute is issued. We do not retry + # other RPCs on deadline: those are non-idempotent or the deadline signals a genuine + # timeout that retrying won't fix. Keep this in sync with RetryPolicy.scala. return False diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 85fbafe227284..3c580fdf54b2b 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -32,6 +32,7 @@ import pandas as pd import pyarrow as pa from pyspark.sql.connect.client import SparkConnectClient, DefaultChannelBuilder + from pyspark.sql.connect.client.core import RpcDeadlines from pyspark.sql.connect.client.retries import ( Retrying, DefaultPolicy, @@ -125,7 +126,6 @@ def ReleaseExecute(self, req: proto.ReleaseExecuteRequest, *args, **kwargs): if req.HasField("release_all"): self.release_calls += 1 elif req.HasField("release_until"): - print("increment") self.release_until_calls += 1 class MockService: @@ -154,7 +154,7 @@ def __init__(self, session_id: str, operation_statuses=None): operation_statuses = self.DEFAULT_OPERATION_STATUSES self._operation_statuses = {s.operation_id: s for s in operation_statuses} - def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): + def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata, timeout=None): self.req = req self.client_user_context_extensions = list(req.user_context.extensions) resp = proto.ExecutePlanResponse() @@ -175,14 +175,14 @@ def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): resp.arrow_batch.row_count = 2 return [resp] - def Interrupt(self, req: proto.InterruptRequest, metadata): + def Interrupt(self, req: proto.InterruptRequest, metadata, timeout=None): self.req = req self.client_user_context_extensions = list(req.user_context.extensions) resp = proto.InterruptResponse() resp.session_id = self._session_id return resp - def Config(self, req: proto.ConfigRequest, metadata): + def Config(self, req: proto.ConfigRequest, metadata, timeout=None): self.req = req self.client_user_context_extensions = list(req.user_context.extensions) resp = proto.ConfigResponse() @@ -197,7 +197,7 @@ def Config(self, req: proto.ConfigRequest, metadata): pair.value = req.operation.get_with_default.pairs[0].value or "true" return resp - def AnalyzePlan(self, req: proto.AnalyzePlanRequest, metadata): + def AnalyzePlan(self, req: proto.AnalyzePlanRequest, metadata, timeout=None): self.req = req self.client_user_context_extensions = list(req.user_context.extensions) resp = proto.AnalyzePlanResponse() @@ -206,7 +206,7 @@ def AnalyzePlan(self, req: proto.AnalyzePlanRequest, metadata): resp.semantic_hash.result = 12345 return resp - def GetStatus(self, req: proto.GetStatusRequest, metadata): + def GetStatus(self, req: proto.GetStatusRequest, metadata, timeout=None): self.req = req self.client_user_context_extensions = list(req.user_context.extensions) self.received_custom_server_session_id = req.client_observed_server_side_session_id @@ -706,6 +706,111 @@ def test_get_operations_statuses_with_request_extensions(self): resp.extensions[0].Unpack(resp_echoed) self.assertEqual(resp_echoed.value, "request_extension") + def test_analyze_plan_short_deadline_fires_then_succeeds_after_disabling(self): + """With a short deadline the call fails; after disabling deadlines it succeeds.""" + + class CapturingMock(MockService): + """Captures the timeout passed by the client; raises DEADLINE_EXCEEDED if set.""" + + def __init__(self, session_id): + super().__init__(session_id) + self.captured_timeout = "not_called" + + def AnalyzePlan(self, req, metadata, timeout=None): + self.captured_timeout = timeout + if timeout is not None: + raise TestException("deadline exceeded", grpc.StatusCode.DEADLINE_EXCEEDED) + return super().AnalyzePlan(req, metadata, timeout=timeout) + + client_with_deadline = SparkConnectClient( + "sc://foo/", + use_reattachable_execute=False, + rpc_deadlines=RpcDeadlines(analyze_plan=0.050), + retry_policy=dict(max_retries=0), + ) + mock_with_deadline = CapturingMock(session_id=client_with_deadline._session_id) + client_with_deadline._stub = mock_with_deadline + with self.assertRaises(SparkConnectGrpcException) as cm: + client_with_deadline._analyze("schema", plan=proto.Plan()) + self.assertEqual(cm.exception.getGrpcStatusCode(), grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertEqual(mock_with_deadline.captured_timeout, 0.050) + + client_disabled = SparkConnectClient( + "sc://foo/", + use_reattachable_execute=False, + rpc_deadlines=RpcDeadlines.disabled(), + retry_policy=dict(max_retries=0), + ) + mock_disabled = CapturingMock(session_id=client_disabled._session_id) + client_disabled._stub = mock_disabled + client_disabled._analyze("schema", plan=proto.Plan()) + self.assertIsNone(mock_disabled.captured_timeout) + + def test_each_rpc_receives_configured_deadline(self): + """Every RPC that accepts a deadline should forward it as timeout to the stub.""" + + class TimeoutCapturingMock(MockService): + """Records the timeout kwarg for each RPC call.""" + + def __init__(self, session_id): + super().__init__(session_id) + self.captured_timeouts = {} + + def AnalyzePlan(self, req, metadata, timeout=None): + self.captured_timeouts["AnalyzePlan"] = timeout + return super().AnalyzePlan(req, metadata, timeout=timeout) + + def Config(self, req, metadata, timeout=None): + self.captured_timeouts["Config"] = timeout + return super().Config(req, metadata, timeout=timeout) + + def Interrupt(self, req, metadata, timeout=None): + self.captured_timeouts["Interrupt"] = timeout + return super().Interrupt(req, metadata, timeout=timeout) + + def ReleaseSession(self, req, metadata, timeout=None): + self.captured_timeouts["ReleaseSession"] = timeout + resp = proto.ReleaseSessionResponse() + resp.session_id = self._session_id + return resp + + def GetStatus(self, req, metadata, timeout=None): + self.captured_timeouts["GetStatus"] = timeout + return super().GetStatus(req, metadata, timeout=timeout) + + deadlines = RpcDeadlines( + analyze_plan=11.0, + config=22.0, + interrupt=33.0, + release_session=44.0, + get_status=55.0, + ) + client = SparkConnectClient( + "sc://foo/", + use_reattachable_execute=False, + rpc_deadlines=deadlines, + retry_policy=dict(max_retries=0), + ) + mock = TimeoutCapturingMock(session_id=client._session_id) + client._stub = mock + + client._analyze("schema", plan=proto.Plan()) + self.assertEqual(mock.captured_timeouts["AnalyzePlan"], 11.0) + + op = proto.ConfigRequest.Operation() + op.get.keys.append("spark.sql.shuffle.partitions") + client.config(op) + self.assertEqual(mock.captured_timeouts["Config"], 22.0) + + client.interrupt_all() + self.assertEqual(mock.captured_timeouts["Interrupt"], 33.0) + + client.release_session() + self.assertEqual(mock.captured_timeouts["ReleaseSession"], 44.0) + + client._get_operation_statuses() + self.assertEqual(mock.captured_timeouts["GetStatus"], 55.0) + @unittest.skipIf(not should_test_connect, connect_requirement_message) class SparkConnectClientReattachTestCase(unittest.TestCase): @@ -881,6 +986,59 @@ def test_observed_session_id(self): reattach = ite._create_reattach_execute_request() self.assertEqual(reattach.client_observed_server_side_session_id, session_id) + def test_deadline_exceeded_triggers_reattach(self): + """DEADLINE_EXCEEDED mid-stream on ExecutePlan should trigger a ReattachExecute.""" + + def deadline_exceeded(): + raise TestException("deadline", grpc.StatusCode.DEADLINE_EXCEEDED) + + stub = self._stub_with( + [self.response, deadline_exceeded], + [self.response, self.finished], + ) + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) + for _ in ite: + pass + + def check(): + self.assertEqual(1, stub.execute_calls) + self.assertEqual(1, stub.attach_calls) + self.assertEqual(1, stub.release_calls) + + eventually(timeout=1, catch_assertions=True)(check)() + + def test_deadline_exceeded_mid_stream_completes_successfully(self): + """After a mid-stream DEADLINE_EXCEEDED, reattach resumes and all responses are collected.""" + + response2 = proto.ExecutePlanResponse(response_id="2") + response3 = proto.ExecutePlanResponse(response_id="3") + + def deadline_exceeded(): + raise TestException("deadline", grpc.StatusCode.DEADLINE_EXCEEDED) + + finished = proto.ExecutePlanResponse( + result_complete=proto.ExecutePlanResponse.ResultComplete(), + response_id="final", + ) + stub = self._stub_with( + [self.response, response2, deadline_exceeded], + [response3, finished], + ) + collected = [] + ite = ExecutePlanResponseReattachableIterator(self.request, stub, self.retrying, []) + for r in ite: + if not r.HasField("result_complete"): + collected.append(r.response_id) + + self.assertEqual(collected, ["1", "2", "3"]) + + def check(): + self.assertEqual(1, stub.execute_calls) + self.assertEqual(1, stub.attach_calls) + self.assertEqual(1, stub.release_calls) + + eventually(timeout=1, catch_assertions=True)(check)() + def test_server_unreachable(self): # DNS resolution should fail for "foo". This error is a retriable UNAVAILABLE error. client = SparkConnectClient( diff --git a/python/pyspark/sql/tests/connect/client/test_client_retries.py b/python/pyspark/sql/tests/connect/client/test_client_retries.py index c509b4db26543..aa8612e1e232f 100644 --- a/python/pyspark/sql/tests/connect/client/test_client_retries.py +++ b/python/pyspark/sql/tests/connect/client/test_client_retries.py @@ -27,6 +27,7 @@ from google.rpc import status_pb2 from google.rpc import error_details_pb2 from pyspark.sql.connect.client import SparkConnectClient + from pyspark.sql.connect.client.core import RpcDeadlines from pyspark.sql.connect.client.retries import ( Retrying, DefaultPolicy, @@ -235,6 +236,91 @@ def test_return_to_exponential_backoff(self): ] self.assertListsAlmostEqual(sleep_tracker.times, expected_times, delta=policy.jitter) + def test_deadline_exceeded_not_retryable_by_default_policy(self): + client = SparkConnectClient("sc://foo/;token=bar") + policy = get_client_policies_map(client).get(DefaultPolicy) + self.assertIsNotNone(policy) + self.assertFalse( + policy.can_retry(TestException("deadline", code=grpc.StatusCode.DEADLINE_EXCEEDED)) + ) + + def test_deadline_exceeded_not_retried_by_retry_handler(self): + client = SparkConnectClient("sc://foo/;token=bar") + sleep_tracker = SleepTimeTracker() + tries = 0 + with self.assertRaises(TestException): + for attempt in Retrying(client._retry_policies, sleep=sleep_tracker.sleep): + with attempt: + tries += 1 + raise TestException("d", code=grpc.StatusCode.DEADLINE_EXCEEDED) + self.assertEqual(tries, 1) + self.assertEqual(len(sleep_tracker.times), 0) + + def test_deadline_exceeded_is_not_retried_for_non_retryable_codes(self): + # Sanity check: ABORTED is not retried by DefaultPolicy (unless matching cluster message) + policy = DefaultPolicy() + self.assertFalse( + policy.can_retry(TestException("some aborted error", code=grpc.StatusCode.ABORTED)) + ) + + def test_deadline_exceeded_exception_message_contains_configuration_hint(self): + """DEADLINE_EXCEEDED exceptions should carry a hint about how to adjust timeouts.""" + from pyspark.errors.exceptions.connect import SparkConnectGrpcException + + client = SparkConnectClient("sc://foo/;token=bar", retry_policy=dict(max_retries=0)) + err = TestException("deadline exceeded", code=grpc.StatusCode.DEADLINE_EXCEEDED) + with self.assertRaises(SparkConnectGrpcException) as cm: + client._handle_rpc_error(err) + self.assertIn("RPC deadline exceeded", str(cm.exception)) + self.assertIn("RpcDeadlines.disabled()", str(cm.exception)) + self.assertEqual(cm.exception.getGrpcStatusCode(), grpc.StatusCode.DEADLINE_EXCEEDED) + + def test_rpc_deadlines_disabled_sets_all_fields_to_none(self): + """RpcDeadlines.disabled() should create an instance with every field set to None.""" + d = RpcDeadlines.disabled() + self.assertIsNone(d.reattachable_execute_plan) + self.assertIsNone(d.reattach_execute) + self.assertIsNone(d.analyze_plan) + self.assertIsNone(d.add_artifacts) + self.assertIsNone(d.config) + self.assertIsNone(d.interrupt) + self.assertIsNone(d.release_session) + self.assertIsNone(d.artifact_status) + self.assertIsNone(d.clone_session) + self.assertIsNone(d.get_status) + self.assertIsNone(d.fetch_error_details) + + def test_rpc_deadlines_defaults_are_set(self): + """RpcDeadlines() default instance should have documented timeout values.""" + d = RpcDeadlines() + self.assertEqual(d.reattachable_execute_plan, 10 * 60) + self.assertEqual(d.reattach_execute, 10 * 60) + self.assertEqual(d.analyze_plan, 60 * 60) + self.assertEqual(d.add_artifacts, 60 * 60) + self.assertEqual(d.config, 10 * 60) + self.assertEqual(d.interrupt, 10 * 60) + self.assertEqual(d.release_session, 10 * 60) + self.assertEqual(d.artifact_status, 10 * 60) + self.assertEqual(d.clone_session, 10 * 60) + self.assertEqual(d.get_status, 10 * 60) + self.assertEqual(d.fetch_error_details, 10 * 60) + + def test_rpc_deadlines_rejects_non_positive_values(self): + """RpcDeadlines should raise ValueError for any non-None field that is <= 0.""" + from dataclasses import replace + + with self.assertRaises(ValueError): + replace(RpcDeadlines(), analyze_plan=-1.0) + with self.assertRaises(ValueError): + replace(RpcDeadlines(), config=0) + with self.assertRaises(ValueError): + replace(RpcDeadlines(), reattachable_execute_plan=-0.001) + with self.assertRaises(ValueError): + replace(RpcDeadlines(), config=float("-inf")) + # None (disabled) and positive values should be accepted without error. + replace(RpcDeadlines(), analyze_plan=None) + replace(RpcDeadlines(), analyze_plan=0.001) + if __name__ == "__main__": from pyspark.testing import main diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala index a6ccf39886e22..bbe3fd16a9693 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala @@ -57,7 +57,7 @@ class ArtifactSuite extends ConnectFunSuite { private def createArtifactManager(): Unit = { channel = InProcessChannelBuilder.forName(getClass.getName).directExecutor().build() - state = new SparkConnectStubState(channel, RetryPolicy.defaultPolicies()) + state = new SparkConnectStubState(channel, Configuration()) bstub = new CustomSparkConnectBlockingStub(channel, state) stub = new CustomSparkConnectStub(channel, state) artifactManager = new ArtifactManager(Configuration(), "", bstub, stub) diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientRetriesSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientRetriesSuite.scala index 7ea01e34ec88a..595f48b9f5485 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientRetriesSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientRetriesSuite.scala @@ -275,4 +275,38 @@ class SparkConnectClientRetriesSuite extends ConnectFunSuite with Eventually { policy.initialBackoff.toMillis * math.pow(policy.backoffMultiplier, i + 2).toLong) assertLongSequencesAlmostEqual(st.times, expectedSleeps, delta = policy.jitter.toMillis) } + + test("DEADLINE_EXCEEDED is not retryable by defaultPolicy") { + // DEADLINE_EXCEEDED must not be retried via canRetry. The reattachable execute path + // handles it separately by converting it to RetryException in the iterator. + val exception = new StatusRuntimeException(Status.DEADLINE_EXCEEDED) + val canRetry = RetryPolicy.defaultPolicy().canRetry + assert(canRetry(exception) == false) + } + + test("DEADLINE_EXCEEDED is not retried by retry handler") { + // Verify the function is called exactly once and the exception propagates immediately. + val dummyFn = new DummyFn(new StatusRuntimeException(Status.DEADLINE_EXCEEDED), numFails = 1) + val retryPolicies = RetryPolicy.defaultPolicies() + val retryHandler = new GrpcRetryHandler(retryPolicies, sleep = _ => {}) + intercept[StatusRuntimeException] { + retryHandler.retry { dummyFn.fn() } + } + assert(dummyFn.counter == 1) + } + + test("RpcDeadlines.disabled creates an instance with all deadlines set to None") { + val disabled = RpcDeadlines.disabled + assert(disabled.reattachableExecutePlan.isEmpty) + assert(disabled.reattachExecute.isEmpty) + assert(disabled.analyzePlan.isEmpty) + assert(disabled.addArtifacts.isEmpty) + assert(disabled.config.isEmpty) + assert(disabled.interrupt.isEmpty) + assert(disabled.releaseSession.isEmpty) + assert(disabled.artifactStatus.isEmpty) + assert(disabled.cloneSession.isEmpty) + assert(disabled.getStatus.isEmpty) + assert(disabled.fetchErrorDetails.isEmpty) + } } diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index 22feaff1c77f1..2f4d95984083c 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.connect.client import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Base64, UUID} -import java.util.concurrent.TimeUnit +import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.collection.mutable +import scala.concurrent.duration.FiniteDuration import scala.jdk.CollectionConverters._ import com.google.protobuf.{Any => PAny, StringValue} @@ -824,6 +825,224 @@ class SparkConnectClientSuite extends ConnectFunSuite { assert(!headerInterceptor.headers.exists(_.containsKey(key))) } } + + test("one-shot RPC deadlines fire on slow server") { + val latch = new CountDownLatch(1) + val slowService = new DummySparkConnectService { + override def analyzePlan( + request: AnalyzePlanRequest, + responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = { + latch.await(250, TimeUnit.MILLISECONDS) + super.analyzePlan(request, responseObserver) + } + override def config( + request: proto.ConfigRequest, + responseObserver: StreamObserver[proto.ConfigResponse]): Unit = { + latch.await(250, TimeUnit.MILLISECONDS) + super.config(request, responseObserver) + } + override def interrupt( + request: proto.InterruptRequest, + responseObserver: StreamObserver[proto.InterruptResponse]): Unit = { + latch.await(250, TimeUnit.MILLISECONDS) + super.interrupt(request, responseObserver) + } + override def releaseSession( + request: proto.ReleaseSessionRequest, + responseObserver: StreamObserver[proto.ReleaseSessionResponse]): Unit = { + latch.await(250, TimeUnit.MILLISECONDS) + responseObserver.onNext( + proto.ReleaseSessionResponse + .newBuilder() + .setSessionId(request.getSessionId) + .build()) + responseObserver.onCompleted() + } + } + server = NettyServerBuilder.forPort(0).addService(slowService).build().start() + service = slowService + val d = FiniteDuration(50, TimeUnit.MILLISECONDS) + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .rpcDeadlines( + RpcDeadlines( + analyzePlan = Some(d), + config = Some(d), + interrupt = Some(d), + releaseSession = Some(d))) + .retryPolicy(RetryPolicy(maxRetries = Some(0), canRetry = _ => false, name = "NoRetry")) + .build() + + def expectDeadlineExceeded(thunk: => Any): Unit = { + val ex = intercept[SparkException] { thunk } + assert(ex.getCause.isInstanceOf[StatusRuntimeException]) + assert( + ex.getCause.asInstanceOf[StatusRuntimeException].getStatus.getCode == + Status.Code.DEADLINE_EXCEEDED) + } + + try { + expectDeadlineExceeded( + client.analyze(proto.AnalyzePlanRequest.newBuilder().setSessionId("abc123").build())) + val configOp = proto.ConfigRequest.Operation + .newBuilder() + .setGetOption( + proto.ConfigRequest.GetOption.newBuilder().addKeys("spark.sql.shuffle.partitions")) + .build() + expectDeadlineExceeded(client.config(configOp)) + expectDeadlineExceeded(client.interruptAll()) + expectDeadlineExceeded(client.releaseSession()) + } finally { + latch.countDown() + } + } + + test("analyzePlan with short deadline fires, then succeeds after disabling deadlines") { + val latch = new CountDownLatch(1) + val slowService = new DummySparkConnectService { + override def analyzePlan( + request: AnalyzePlanRequest, + responseObserver: StreamObserver[AnalyzePlanResponse]): Unit = { + latch.await(250, TimeUnit.MILLISECONDS) + super.analyzePlan(request, responseObserver) + } + } + server = NettyServerBuilder.forPort(0).addService(slowService).build().start() + service = slowService + val noRetry = RetryPolicy(maxRetries = Some(0), canRetry = _ => false, name = "NoRetry") + + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .rpcDeadlines(RpcDeadlines(analyzePlan = Some(FiniteDuration(50, TimeUnit.MILLISECONDS)))) + .retryPolicy(noRetry) + .build() + try { + val ex = intercept[SparkException] { + client.analyze(proto.AnalyzePlanRequest.newBuilder().setSessionId("abc123").build()) + } + assert( + ex.getCause.asInstanceOf[StatusRuntimeException].getStatus.getCode == + Status.Code.DEADLINE_EXCEEDED) + client.shutdown() + } finally { + latch.countDown() + } + + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .rpcDeadlines(RpcDeadlines.disabled) + .retryPolicy(noRetry) + .build() + client.analyze(proto.AnalyzePlanRequest.newBuilder().setSessionId("abc123").build()) + } + + test("reattachable execute recovers from mid-stream deadline via ReattachExecute") { + val executeLatch = new CountDownLatch(1) + @volatile var reattachCalled = false + val midStreamService = new DummySparkConnectService { + override def executePlan( + request: ExecutePlanRequest, + responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + val sessionId = request.getSessionId + val operationId = + if (request.hasOperationId) request.getOperationId + else UUID.randomUUID().toString + val serverSideSessionId = "srv-deadline-test" + responseObserver.onNext( + ExecutePlanResponse + .newBuilder() + .setSessionId(sessionId) + .setServerSideSessionId(serverSideSessionId) + .setOperationId(operationId) + .setResponseId("r1") + .build()) + executeLatch.await(250, TimeUnit.MILLISECONDS) + responseObserver.onCompleted() + } + + override def reattachExecute( + request: proto.ReattachExecuteRequest, + responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + reattachCalled = true + val sessionId = request.getSessionId + val operationId = request.getOperationId + val serverSideSessionId = "srv-deadline-test" + responseObserver.onNext( + ExecutePlanResponse + .newBuilder() + .setSessionId(sessionId) + .setServerSideSessionId(serverSideSessionId) + .setOperationId(operationId) + .setResponseId("r2") + .build()) + responseObserver.onNext( + ExecutePlanResponse + .newBuilder() + .setSessionId(sessionId) + .setServerSideSessionId(serverSideSessionId) + .setOperationId(operationId) + .setResponseId("r3") + .setResultComplete(proto.ExecutePlanResponse.ResultComplete.newBuilder().build()) + .build()) + responseObserver.onCompleted() + } + } + server = NettyServerBuilder.forPort(0).addService(midStreamService).build().start() + service = midStreamService + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .rpcDeadlines( + RpcDeadlines(reattachableExecutePlan = Some(FiniteDuration(50, TimeUnit.MILLISECONDS)))) + .enableReattachableExecute() + .build() + + try { + val iter = client.execute(buildPlan("select 1")) + val reattachableIter = ExecutePlanResponseReattachableIterator.fromIterator(iter) + iter.foreach(_ => ()) + + assert(reattachableIter.resultComplete, "iterator should complete after reattach") + assert(reattachCalled, "ReattachExecute should have been called") + } finally { + executeLatch.countDown() + } + } + + test("non-reattachable executePlan has no client-side deadline") { + val latch = new CountDownLatch(1) + val slowService = new DummySparkConnectService { + override def executePlan( + request: ExecutePlanRequest, + responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { + latch.await(250, TimeUnit.MILLISECONDS) + super.executePlan(request, responseObserver) + } + } + server = NettyServerBuilder.forPort(0).addService(slowService).build().start() + service = slowService + client = SparkConnectClient + .builder() + .connectionString(s"sc://localhost:${server.getPort}") + .disableReattachableExecute() + .rpcDeadlines( + RpcDeadlines(reattachableExecutePlan = Some(FiniteDuration(50, TimeUnit.MILLISECONDS)))) + .retryPolicy(RetryPolicy(maxRetries = Some(0), canRetry = _ => false, name = "NoRetry")) + .build() + try { + val iter = client.execute(buildPlan("select 1")) + iter.foreach(_ => ()) + } finally { + latch.countDown() + } + } + + test("SPARK-56538: RpcDeadlines.disabled has no configured deadlines in toString") { + assert(RpcDeadlines.disabled.toString === "RpcDeadlines(all disabled)") + } } class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectServiceImplBase { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala index ad39a3fb29f24..a4406c2a68fda 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala @@ -16,9 +16,12 @@ */ package org.apache.spark.sql.connect.client +import java.util.concurrent.TimeUnit + +import scala.concurrent.duration.FiniteDuration import scala.jdk.CollectionConverters._ -import io.grpc.ManagedChannel +import io.grpc.{Deadline, ManagedChannel} import org.apache.spark.connect.proto._ import org.apache.spark.sql.util.CloseableIterator @@ -29,11 +32,22 @@ private[connect] class CustomSparkConnectBlockingStub( private val stub = SparkConnectServiceGrpc.newBlockingStub(channel) + // Build a fresh deadline-bound stub on every call so each retry gets a full budget. + // gRPC deadlines are absolute (fixed at stub construction); a shared stub would let + // later attempts inherit an already-elapsed deadline and time out prematurely. + private def withDeadline( + d: Option[FiniteDuration]): SparkConnectServiceGrpc.SparkConnectServiceBlockingStub = + d.map(dur => stub.withDeadline(Deadline.after(dur.toMillis, TimeUnit.MILLISECONDS))) + .getOrElse(stub) + private val retryHandler = stubState.retryHandler // GrpcExceptionConverter with a GRPC stub for fetching error details from server. private val grpcExceptionConverter = stubState.exceptionConverter + // Non-reattachable executePlan intentionally has no deadline: a timeout here would kill the + // server-side execution with no way to recover (there is no ReattachExecute for this path). + // Use reattachable execution for long-running queries that need deadline protection. def executePlan(request: ExecutePlanRequest): CloseableIterator[ExecutePlanResponse] = { grpcExceptionConverter.convert( request.getSessionId, @@ -63,8 +77,13 @@ private[connect] class CustomSparkConnectBlockingStub( request.getUserContext, request.getClientType, stubState.responseValidator.wrapIterator( - // ExecutePlanResponseReattachableIterator does all retries by itself, don't wrap it here - new ExecutePlanResponseReattachableIterator(request, channel, stubState.retryHandler))) + // Reattachable iterator retries internally; omit RetryIterator wrapper here. + new ExecutePlanResponseReattachableIterator( + request, + channel, + stubState.retryHandler, + stubState.rpcDeadlines.reattachableExecutePlan, + stubState.rpcDeadlines.reattachExecute))) } } @@ -75,7 +94,7 @@ private[connect] class CustomSparkConnectBlockingStub( request.getClientType) { retryHandler.retry { stubState.responseValidator.verifyResponse { - stub.analyzePlan(request) + withDeadline(stubState.rpcDeadlines.analyzePlan).analyzePlan(request) } } } @@ -88,7 +107,7 @@ private[connect] class CustomSparkConnectBlockingStub( request.getClientType) { retryHandler.retry { stubState.responseValidator.verifyResponse { - stub.config(request) + withDeadline(stubState.rpcDeadlines.config).config(request) } } } @@ -101,7 +120,7 @@ private[connect] class CustomSparkConnectBlockingStub( request.getClientType) { retryHandler.retry { stubState.responseValidator.verifyResponse { - stub.interrupt(request) + withDeadline(stubState.rpcDeadlines.interrupt).interrupt(request) } } } @@ -114,7 +133,7 @@ private[connect] class CustomSparkConnectBlockingStub( request.getClientType) { retryHandler.retry { stubState.responseValidator.verifyResponse { - stub.releaseSession(request) + withDeadline(stubState.rpcDeadlines.releaseSession).releaseSession(request) } } } @@ -127,7 +146,7 @@ private[connect] class CustomSparkConnectBlockingStub( request.getClientType) { retryHandler.retry { stubState.responseValidator.verifyResponse { - stub.artifactStatus(request) + withDeadline(stubState.rpcDeadlines.artifactStatus).artifactStatus(request) } } } @@ -140,7 +159,7 @@ private[connect] class CustomSparkConnectBlockingStub( request.getClientType) { retryHandler.retry { stubState.responseValidator.verifyResponse { - stub.cloneSession(request) + withDeadline(stubState.rpcDeadlines.cloneSession).cloneSession(request) } } } @@ -153,7 +172,7 @@ private[connect] class CustomSparkConnectBlockingStub( request.getClientType) { retryHandler.retry { stubState.responseValidator.verifyResponse { - stub.getStatus(request) + withDeadline(stubState.rpcDeadlines.getStatus).getStatus(request) } } } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala index 187c2842a0bc8..4edd84ed5de76 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectStub.scala @@ -16,7 +16,9 @@ */ package org.apache.spark.sql.connect.client -import io.grpc.ManagedChannel +import java.util.concurrent.TimeUnit + +import io.grpc.{Deadline, ManagedChannel} import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, SparkConnectServiceGrpc} @@ -29,7 +31,15 @@ private[client] class CustomSparkConnectStub( def addArtifacts(responseObserver: StreamObserver[AddArtifactsResponse]) : StreamObserver[AddArtifactsRequest] = { + // Evaluated on every call (including each RetryStreamObserver retry) so each attempt gets a + // fresh Deadline object with a new absolute expiry, not one shared from the first attempt. + def freshStub: SparkConnectServiceGrpc.SparkConnectServiceStub = + stubState.rpcDeadlines.addArtifacts + .map(d => stub.withDeadline(Deadline.after(d.toMillis, TimeUnit.MILLISECONDS))) + .getOrElse(stub) stubState.responseValidator.wrapStreamObserver( - stubState.retryHandler.RetryStreamObserver(responseObserver, stub.addArtifacts)) + stubState.retryHandler.RetryStreamObserver( + responseObserver, + (obs: StreamObserver[AddArtifactsResponse]) => freshStub.addArtifacts(obs))) } } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala index 131a2e77cc431..e6428270c20ba 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.connect.client import java.util.UUID +import java.util.concurrent.TimeUnit +import scala.concurrent.duration.FiniteDuration import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal -import io.grpc.{ManagedChannel, StatusRuntimeException} +import io.grpc.{Deadline, ManagedChannel, Status, StatusRuntimeException} import io.grpc.protobuf.StatusProto import io.grpc.stub.StreamObserver @@ -53,7 +55,9 @@ import org.apache.spark.sql.util.WrappedCloseableIterator class ExecutePlanResponseReattachableIterator( request: proto.ExecutePlanRequest, channel: ManagedChannel, - retryHandler: GrpcRetryHandler) + retryHandler: GrpcRetryHandler, + reattachableExecutePlanDeadline: Option[FiniteDuration] = None, + reattachExecuteDeadline: Option[FiniteDuration] = None) extends WrappedCloseableIterator[proto.ExecutePlanResponse] with Logging { @@ -80,6 +84,12 @@ class ExecutePlanResponseReattachableIterator( private val rawBlockingStub = proto.SparkConnectServiceGrpc.newBlockingStub(channel) private val rawAsyncStub = proto.SparkConnectServiceGrpc.newStub(channel) + private def stubWithDeadline(deadline: Option[FiniteDuration]) + : proto.SparkConnectServiceGrpc.SparkConnectServiceBlockingStub = + deadline + .map(d => rawBlockingStub.withDeadline(Deadline.after(d.toMillis, TimeUnit.MILLISECONDS))) + .getOrElse(rawBlockingStub) + private val initialRequest: proto.ExecutePlanRequest = request .toBuilder() .addRequestOptions( @@ -104,7 +114,7 @@ class ExecutePlanResponseReattachableIterator( // throw error on first iter.hasNext() or iter.next() // Visible for testing. private[connect] var iter: Option[java.util.Iterator[proto.ExecutePlanResponse]] = - Some(rawBlockingStub.executePlan(initialRequest)) + Some(stubWithDeadline(reattachableExecutePlanDeadline).executePlan(initialRequest)) // Server side session ID, used to detect if the server side session changed. This is set upon // receiving the first response from the server. @@ -231,7 +241,9 @@ class ExecutePlanResponseReattachableIterator( private def callIter[V](iterFun: java.util.Iterator[proto.ExecutePlanResponse] => V) = { try { if (iter.isEmpty) { - iter = Some(rawBlockingStub.reattachExecute(createReattachExecuteRequest())) + iter = Some( + stubWithDeadline(reattachExecuteDeadline) + .reattachExecute(createReattachExecuteRequest())) } iterFun(iter.get) } catch { @@ -248,7 +260,18 @@ class ExecutePlanResponseReattachableIterator( ex) } // Try a new ExecutePlan, and throw upstream for retry. - iter = Some(rawBlockingStub.executePlan(initialRequest)) + iter = Some(stubWithDeadline(reattachableExecutePlanDeadline).executePlan(initialRequest)) + val error = new RetryException() + error.addSuppressed(ex) + throw error + case ex: StatusRuntimeException if ex.getStatus.getCode == Status.Code.DEADLINE_EXCEEDED => + // The per-RPC deadline fired. The server-side operation is still alive; we clear the + // iterator and raise RetryException so the outer retry loop opens a fresh + // ReattachExecute stream (a new per-RPC deadline countdown) to resume receiving results. + logDebug( + s"Deadline exceeded on stream for operation $operationId; will reattach. " + + s"(last response: $lastReturnedResponseId)") + iter = None val error = new RetryException() error.addSuppressed(ex) throw error diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index 6e5304b8cc772..b52261e5efde5 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -17,12 +17,14 @@ package org.apache.spark.sql.connect.client import java.time.DateTimeException +import java.util.concurrent.TimeUnit +import scala.concurrent.duration.FiniteDuration import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag import com.google.rpc.ErrorInfo -import io.grpc.{ManagedChannel, StatusRuntimeException} +import io.grpc.{Deadline, ManagedChannel, Status, StatusRuntimeException} import io.grpc.protobuf.StatusProto import org.json4s.{DefaultFormats, Formats} import org.json4s.jackson.JsonMethods @@ -49,10 +51,18 @@ import org.apache.spark.util.ArrayImplicits._ * the ErrorInfo is missing, the exception will be constructed based on the StatusRuntimeException * itself. */ -private[client] class GrpcExceptionConverter(channel: ManagedChannel) extends Logging { +private[client] class GrpcExceptionConverter( + channel: ManagedChannel, + fetchErrorDetailsDeadline: Option[FiniteDuration] = None) + extends Logging { import GrpcExceptionConverter._ - val grpcStub = SparkConnectServiceGrpc.newBlockingStub(channel) + private val grpcStub = SparkConnectServiceGrpc.newBlockingStub(channel) + + private def stubWithDeadline: SparkConnectServiceGrpc.SparkConnectServiceBlockingStub = + fetchErrorDetailsDeadline + .map(d => grpcStub.withDeadline(Deadline.after(d.toMillis, TimeUnit.MILLISECONDS))) + .getOrElse(grpcStub) def convert[T](sessionId: String, userContext: UserContext, clientType: String)(f: => T): T = { try { @@ -108,7 +118,7 @@ private[client] class GrpcExceptionConverter(channel: ManagedChannel) extends Lo } try { - val errorDetailsResponse = grpcStub.fetchErrorDetails( + val errorDetailsResponse = stubWithDeadline.fetchErrorDetails( FetchErrorDetailsRequest .newBuilder() .setSessionId(sessionId) @@ -160,11 +170,25 @@ private[client] class GrpcExceptionConverter(channel: ManagedChannel) extends Lo } // If no ErrorInfo is found, create a SparkException based on the StatusRuntimeException. + val (message, cause) = if (ex.getStatus.getCode == Status.Code.DEADLINE_EXCEEDED) { + val msg = s"${ex.toString}: RPC deadline exceeded. Deadlines can be configured via " + + "SparkConnectClient.Builder.rpcDeadlines(). To disable all deadlines: " + + "SparkConnectClient.builder().rpcDeadlines(RpcDeadlines.disabled).build()" + // For DEADLINE_EXCEEDED, we pass `ex` itself as the cause rather than `ex.getCause`. + // StatusRuntimeException.getCause() returns status.getCause(), which is always null for + // client-side deadline fires (gRPC constructs the status without a wrapped cause). Using + // ex.getCause would produce a SparkException with cause = null, losing the gRPC status + // code and description from the exception chain. Passing ex preserves full context and + // allows callers to programmatically inspect the status code via getCause().getStatus(). + (msg, ex) + } else { + (ex.toString, ex.getCause) + } new SparkException( - message = ex.toString, - cause = ex.getCause, + message = message, + cause = cause, errorClass = Some("CONNECT_CLIENT_UNEXPECTED_MISSING_SQL_STATE"), - messageParameters = Map("message" -> ex.toString), + messageParameters = Map("message" -> message), context = Array.empty) } } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala index 5b5c4b517923e..b396ee96a4e5b 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RetryPolicy.scala @@ -169,6 +169,11 @@ object RetryPolicy extends Logging { return true } + // DEADLINE_EXCEEDED on the reattachable execute path is handled directly in + // ExecutePlanResponseReattachableIterator, which converts it to RetryException so the + // server-side operation continues and a fresh ReattachExecute is issued. We do not retry + // other RPCs on deadline: those are non-idempotent or the deadline signals a genuine + // timeout that retrying won't fix. false case _ => false } diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RpcDeadlines.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RpcDeadlines.scala new file mode 100644 index 0000000000000..0742a206b44c5 --- /dev/null +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/RpcDeadlines.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.client + +import scala.concurrent.duration.{DurationInt, FiniteDuration} + +/** + * Per-RPC deadline configuration. Each field controls the deadline for one gRPC call. Set a field + * to None to disable the deadline for that call. Use RpcDeadlines.disabled to create an instance + * with all deadlines disabled. + * + * Note on reattachableExecutePlan and reattachExecute: these deadlines apply to each individual + * gRPC stream segment, not to the overall query execution lifetime. When a deadline fires, the + * server-side operation continues running; the client simply opens a new ReattachExecute stream + * to resume receiving results. This avoids hanging connections while preserving execution. + * + * Non-reattachable ExecutePlan has no deadline because a timeout there would kill the execution + * with no recovery path. ReleaseExecute has no deadline (fire-and-forget cleanup). + */ +private[sql] case class RpcDeadlines( + reattachableExecutePlan: Option[FiniteDuration] = Some(10.minutes), + reattachExecute: Option[FiniteDuration] = Some(10.minutes), + analyzePlan: Option[FiniteDuration] = Some(1.hour), + addArtifacts: Option[FiniteDuration] = Some(1.hour), + config: Option[FiniteDuration] = Some(10.minutes), + interrupt: Option[FiniteDuration] = Some(10.minutes), + releaseSession: Option[FiniteDuration] = Some(10.minutes), + artifactStatus: Option[FiniteDuration] = Some(10.minutes), + cloneSession: Option[FiniteDuration] = Some(10.minutes), + getStatus: Option[FiniteDuration] = Some(10.minutes), + fetchErrorDetails: Option[FiniteDuration] = Some(10.minutes)) { + + // Validate all fields: each must be a positive duration or None. + private val namedFields: Seq[(String, Option[FiniteDuration])] = + productElementNames.toSeq.zip( + productIterator.map(_.asInstanceOf[Option[FiniteDuration]]).toSeq) + + namedFields.foreach { case (name, opt) => + opt.foreach(d => + require(d.toMillis > 0, s"RpcDeadlines.$name must be a positive duration, got $d")) + } + + override def toString: String = { + val configured = namedFields.collect { case (name, Some(d)) => s"$name=$d" } + if (configured.isEmpty) "RpcDeadlines(all disabled)" + else s"RpcDeadlines(${configured.mkString(", ")})" + } +} + +private[sql] object RpcDeadlines { + + /** + * Creates an RpcDeadlines with all deadlines disabled (no per-RPC timeout on any call). Use + * this when you want to rely solely on server-side or network-layer timeouts. + */ + val disabled: RpcDeadlines = RpcDeadlines( + reattachableExecutePlan = None, + reattachExecute = None, + analyzePlan = None, + addArtifacts = None, + config = None, + interrupt = None, + releaseSession = None, + artifactStatus = None, + cloneSession = None, + getStatus = None, + fetchErrorDetails = None) +} diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index d9b9ba35b5e6c..cab3540a59b1f 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -54,7 +54,7 @@ private[sql] class SparkConnectClient( private val userContext: UserContext = configuration.userContext - private[this] val stubState = new SparkConnectStubState(channel, configuration.retryPolicies) + private[this] val stubState = new SparkConnectStubState(channel, configuration) private[this] val bstub = new CustomSparkConnectBlockingStub(channel, stubState) private[this] val stub = @@ -804,6 +804,11 @@ object SparkConnectClient { retryPolicy(List(policy)) } + def rpcDeadlines(deadlines: RpcDeadlines): Builder = { + _configuration = _configuration.copy(rpcDeadlines = deadlines) + this + } + private object URIParams { val PARAM_USER_ID = "user_id" val PARAM_USE_SSL = "use_ssl" @@ -1037,6 +1042,7 @@ object SparkConnectClient { userAgent: String = genUserAgent( sys.env.getOrElse("SPARK_CONNECT_USER_AGENT", DEFAULT_USER_AGENT)), retryPolicies: Seq[RetryPolicy] = RetryPolicy.defaultPolicies(), + rpcDeadlines: RpcDeadlines = RpcDeadlines(), useReattachableExecute: Boolean = true, interceptors: List[ClientInterceptor] = List.empty, sessionId: Option[String] = None, diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala index 2ec9ecad90309..4d2d15073f222 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectStubState.scala @@ -26,14 +26,23 @@ import org.apache.spark.internal.Logging // that the same stub instance is used for all requests from the same client. In addition, // this class provides access to the commonly configured retry policy and exception conversion // logic. -class SparkConnectStubState(channel: ManagedChannel, retryPolicies: Seq[RetryPolicy]) +class SparkConnectStubState( + channel: ManagedChannel, + val configuration: SparkConnectClient.Configuration) extends Logging { + val rpcDeadlines: RpcDeadlines = configuration.rpcDeadlines + + if (log.isInfoEnabled) { + logInfo(s"Spark Connect RPC deadlines: $rpcDeadlines") + } + // Manages the retry handler logic used by the stubs. - lazy val retryHandler = new GrpcRetryHandler(retryPolicies) + lazy val retryHandler = new GrpcRetryHandler(configuration.retryPolicies) // Responsible to convert the GRPC Status exceptions into Spark exceptions. - lazy val exceptionConverter: GrpcExceptionConverter = new GrpcExceptionConverter(channel) + lazy val exceptionConverter: GrpcExceptionConverter = + new GrpcExceptionConverter(channel, configuration.rpcDeadlines.fetchErrorDetails) // Provides a helper for validating the responses processed by the stub. lazy val responseValidator = new ResponseValidator() diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala index 748e7623aba3a..d4ecbca28e9dc 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala @@ -280,9 +280,9 @@ trait SparkConnectServerTest extends SharedSparkSession { protected def withCustomBlockingStub( retryPolicies: Seq[RetryPolicy] = RetryPolicy.defaultPolicies())( f: CustomSparkConnectBlockingStub => Unit): Unit = { - val conf = SparkConnectClient.Configuration(port = serverPort) + val conf = SparkConnectClient.Configuration(port = serverPort, retryPolicies = retryPolicies) val channel = conf.createChannel() - val stubState = new SparkConnectStubState(channel, retryPolicies) + val stubState = new SparkConnectStubState(channel, conf) val bstub = new CustomSparkConnectBlockingStub(channel, stubState) try f(bstub) finally {