Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions python/pyspark/sql/connect/client/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,17 @@ def __init__(
session_id: str,
channel: grpc.Channel,
metadata: Iterable[Tuple[str, str]],
add_artifacts_timeout: Optional[float] = None,
artifact_status_timeout: Optional[float] = None,
):
self._user_context = proto.UserContext()
if user_id is not None:
self._user_context.user_id = user_id
self._stub = grpc_lib.SparkConnectServiceStub(channel)
self._session_id = session_id
self._metadata = metadata
self._add_artifacts_timeout = add_artifacts_timeout
self._artifact_status_timeout = artifact_status_timeout

def _parse_artifacts(
self, path_or_uri: str, pyfile: bool, archive: bool, file: bool
Expand Down Expand Up @@ -288,7 +292,11 @@ def _retrieve_responses(
self, requests: Iterator[proto.AddArtifactsRequest]
) -> proto.AddArtifactsResponse:
"""Separated for the testing purpose."""
return self._stub.AddArtifacts(requests, metadata=self._metadata)
return self._stub.AddArtifacts(
requests,
metadata=self._metadata,
timeout=self._add_artifacts_timeout,
)

def _request_add_artifacts(self, requests: Iterator[proto.AddArtifactsRequest]) -> None:
response: proto.AddArtifactsResponse = self._retrieve_responses(requests)
Expand Down Expand Up @@ -428,7 +436,9 @@ def is_cached_artifact(self, hash: str) -> bool:
user_context=self._user_context, session_id=self._session_id, names=[artifactName]
)
resp: proto.ArtifactStatusesResponse = self._stub.ArtifactStatus(
request, metadata=self._metadata
request,
metadata=self._metadata,
timeout=self._artifact_status_timeout,
)
status = resp.statuses.get(artifactName)
return status.exists if status is not None else False
Expand All @@ -446,7 +456,9 @@ def get_cached_artifacts(self, hashes: list[str]) -> set[str]:
user_context=self._user_context, session_id=self._session_id, names=artifact_names
)
resp: proto.ArtifactStatusesResponse = self._stub.ArtifactStatus(
request, metadata=self._metadata
request,
metadata=self._metadata,
timeout=self._artifact_status_timeout,
)

cached = set()
Expand Down
162 changes: 149 additions & 13 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
__all__ = [
"ChannelBuilder",
"DefaultChannelBuilder",
"RpcDeadlines",
"SparkConnectClient",
]

import atexit
from dataclasses import dataclass, fields

import pyspark
from pyspark.sql.connect.proto.base_pb2 import FetchErrorDetailsResponse
Expand Down Expand Up @@ -107,7 +109,11 @@
from pyspark.sql.types import DataType, StructType
from pyspark.util import PythonEvalType
from pyspark.storagelevel import StorageLevel
from pyspark.errors import PySparkValueError, PySparkAssertionError, PySparkNotImplementedError
from pyspark.errors import (
PySparkAssertionError,
PySparkNotImplementedError,
PySparkValueError,
)
from pyspark.sql.connect.shell.progress import Progress, ProgressHandler, from_proto

if TYPE_CHECKING:
Expand All @@ -121,6 +127,65 @@
PYSPARK_ROOT = os.path.dirname(pyspark.__file__)


@dataclass(frozen=True)
class RpcDeadlines:
"""Per-RPC timeout configuration for :class:`SparkConnectClient`.

Each field controls the timeout (in seconds as a float) for one gRPC call type. Set a field to
``None`` to disable the per-RPC timeout for that call. Use :meth:`RpcDeadlines.disabled` to
create an instance with all timeouts disabled.

Note on ``reattachable_execute_plan`` and ``reattach_execute``: these timeouts apply to each
individual gRPC stream segment, not to the overall query execution lifetime. When a deadline
fires, the server-side operation continues running; the client opens a new ReattachExecute
stream to resume receiving results. Non-reattachable ExecutePlan has no deadline because a
timeout there would kill the execution with no recovery path.
"""

reattachable_execute_plan: Optional[float] = 10 * 60 # 10 min
reattach_execute: Optional[float] = 10 * 60 # 10 min
analyze_plan: Optional[float] = 60 * 60 # 1 hour
add_artifacts: Optional[float] = 60 * 60 # 1 hour
config: Optional[float] = 10 * 60 # 10 min
interrupt: Optional[float] = 10 * 60 # 10 min
release_session: Optional[float] = 10 * 60 # 10 min
artifact_status: Optional[float] = 10 * 60 # 10 min
clone_session: Optional[float] = 10 * 60 # 10 min
get_status: Optional[float] = 10 * 60 # 10 min
fetch_error_details: Optional[float] = 10 * 60 # 10 min

def __post_init__(self) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT While this is fancy. It is also quite complex. Is also not really needed. Just check the individual values.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is okay. It's not difficult to read as a Python programmer. The issue is what would be the alternative. I can't think of any super clean ways to write this. The good thing about this method is that we don't need to change any code when we add new fields in the future. Single source of truth. The only thing I feel a bit unnecessary is the infinite check - that was a bit too much.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simplified the validation. Removed the isinstance type check and math.isfinite guard.
Keeping the fields() iteration so we don't need to update validation when fields are added.

for field in fields(self):
value = getattr(self, field.name)
if value is not None and value <= 0:
raise PySparkValueError(
message=(
f"RpcDeadlines.{field.name} must be a positive number or None, "
f"got {value!r}"
),
)

@classmethod
def disabled(cls) -> "RpcDeadlines":
"""Create an :class:`RpcDeadlines` with all per-RPC timeouts disabled.

Use this when you want to rely solely on server-side or network-layer timeouts.
"""
return cls(
reattachable_execute_plan=None,
reattach_execute=None,
analyze_plan=None,
add_artifacts=None,
config=None,
interrupt=None,
release_session=None,
artifact_status=None,
clone_session=None,
get_status=None,
fetch_error_details=None,
)


def _import_zstandard_if_available() -> Optional[Any]:
"""
Import zstandard if available, otherwise return None.
Expand Down Expand Up @@ -647,6 +712,7 @@ def __init__(
session_hooks: Optional[list["SparkSession.Hook"]] = None,
allow_arrow_batch_chunking: bool = True,
preferred_arrow_chunk_size: Optional[int] = None,
rpc_deadlines: Optional[RpcDeadlines] = None,
):
"""
Creates a new SparkSession for the Spark Connect interface.
Expand Down Expand Up @@ -694,6 +760,10 @@ def __init__(
results.
The server will attempt to use this size if it is set and within the valid range
([1KB, max batch size on server]). Otherwise, the server's maximum batch size is used.
rpc_deadlines : RpcDeadlines, optional
Per-RPC gRPC call timeouts in seconds (10 min for most RPCs,
1 hour for analyze/addArtifacts, none for non-reattachable execute).
Use :meth:`RpcDeadlines.disabled` to turn off all deadlines.
"""
self.thread_local = threading.local()

Expand Down Expand Up @@ -729,8 +799,17 @@ def __init__(
self._channel = self._builder.toChannel()
self._closed = False
self._internal_stub = grpc_lib.SparkConnectServiceStub(self._channel)
self._rpc_deadlines: RpcDeadlines = (
rpc_deadlines if rpc_deadlines is not None else RpcDeadlines()
)
logger.info("Spark Connect RPC deadlines: %s", self._rpc_deadlines)
self._artifact_manager = ArtifactManager(
self._user_id, self._session_id, self._channel, self._builder.metadata()
self._user_id,
self._session_id,
self._channel,
self._builder.metadata(),
add_artifacts_timeout=self._rpc_deadlines.add_artifacts,
artifact_status_timeout=self._rpc_deadlines.artifact_status,
)
self._use_reattachable_execute = use_reattachable_execute
self._allow_arrow_batch_chunking = allow_arrow_batch_chunking
Expand Down Expand Up @@ -1450,7 +1529,11 @@ def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult:
try:
for attempt in self._retrying():
with attempt:
resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata())
resp = self._stub.AnalyzePlan(
req,
metadata=self._builder.metadata(),
timeout=self._rpc_deadlines.analyze_plan,
)
self._verify_response_integrity(resp)
return AnalyzeResult.fromProto(resp)
raise SparkConnectException("Invalid state during retry exception handling.")
Expand Down Expand Up @@ -1479,7 +1562,12 @@ def handle_response(b: pb2.ExecutePlanResponse) -> None:
if self._use_reattachable_execute:
# Don't use retryHandler - own retry handling is inside.
generator = ExecutePlanResponseReattachableIterator(
req, self._stub, self._retrying, self._builder.metadata()
req,
self._stub,
self._retrying,
self._builder.metadata(),
reattachable_execute_plan_timeout=self._rpc_deadlines.reattachable_execute_plan,
reattach_execute_timeout=self._rpc_deadlines.reattach_execute,
)
try:
for b in generator:
Expand Down Expand Up @@ -1689,7 +1777,12 @@ def handle_response(
if self._use_reattachable_execute:
# Don't use retryHandler - own retry handling is inside.
generator = ExecutePlanResponseReattachableIterator(
req, self._stub, self._retrying, self._builder.metadata()
req,
self._stub,
self._retrying,
self._builder.metadata(),
reattachable_execute_plan_timeout=self._rpc_deadlines.reattachable_execute_plan,
reattach_execute_timeout=self._rpc_deadlines.reattach_execute,
)
try:
for b in generator:
Expand Down Expand Up @@ -1841,7 +1934,11 @@ def config(self, operation: pb2.ConfigRequest.Operation) -> ConfigResult:
try:
for attempt in self._retrying():
with attempt:
resp = self._stub.Config(req, metadata=self._builder.metadata())
resp = self._stub.Config(
req,
metadata=self._builder.metadata(),
timeout=self._rpc_deadlines.config,
)
self._verify_response_integrity(resp)
return ConfigResult.fromProto(resp)
raise SparkConnectException("Invalid state during retry exception handling.")
Expand Down Expand Up @@ -1883,7 +1980,11 @@ def interrupt_all(self) -> Optional[List[str]]:
try:
for attempt in self._retrying():
with attempt:
resp = self._stub.Interrupt(req, metadata=self._builder.metadata())
resp = self._stub.Interrupt(
req,
metadata=self._builder.metadata(),
timeout=self._rpc_deadlines.interrupt,
)
self._verify_response_integrity(resp)
return list(resp.interrupted_ids)
raise SparkConnectException("Invalid state during retry exception handling.")
Expand All @@ -1895,7 +1996,11 @@ def interrupt_tag(self, tag: str) -> Optional[List[str]]:
try:
for attempt in self._retrying():
with attempt:
resp = self._stub.Interrupt(req, metadata=self._builder.metadata())
resp = self._stub.Interrupt(
req,
metadata=self._builder.metadata(),
timeout=self._rpc_deadlines.interrupt,
)
self._verify_response_integrity(resp)
return list(resp.interrupted_ids)
raise SparkConnectException("Invalid state during retry exception handling.")
Expand All @@ -1907,7 +2012,11 @@ def interrupt_operation(self, op_id: str) -> Optional[List[str]]:
try:
for attempt in self._retrying():
with attempt:
resp = self._stub.Interrupt(req, metadata=self._builder.metadata())
resp = self._stub.Interrupt(
req,
metadata=self._builder.metadata(),
timeout=self._rpc_deadlines.interrupt,
)
self._verify_response_integrity(resp)
return list(resp.interrupted_ids)
raise SparkConnectException("Invalid state during retry exception handling.")
Expand All @@ -1923,7 +2032,11 @@ def release_session(self) -> None:
try:
for attempt in self._retrying():
with attempt:
resp = self._stub.ReleaseSession(req, metadata=self._builder.metadata())
resp = self._stub.ReleaseSession(
req,
metadata=self._builder.metadata(),
timeout=self._rpc_deadlines.release_session,
)
self._verify_response_integrity(resp)
return
raise SparkConnectException("Invalid state during retry exception handling.")
Expand Down Expand Up @@ -1975,7 +2088,11 @@ def _get_operation_statuses(
try:
for attempt in self._retrying():
with attempt:
resp = self._stub.GetStatus(req, metadata=self._builder.metadata())
resp = self._stub.GetStatus(
req,
metadata=self._builder.metadata(),
timeout=self._rpc_deadlines.get_status,
)
self._verify_response_integrity(resp)
return resp
raise SparkConnectException("Invalid state during retry exception handling.")
Expand Down Expand Up @@ -2100,7 +2217,11 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet
req.user_context.user_id = self._user_id
self._update_request_with_user_context_extensions(req)
try:
return self._stub.FetchErrorDetails(req, metadata=self._builder.metadata())
return self._stub.FetchErrorDetails(
req,
metadata=self._builder.metadata(),
timeout=self._rpc_deadlines.fetch_error_details,
)
except grpc.RpcError:
return None

Expand Down Expand Up @@ -2142,6 +2263,17 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
# https://grpc.github.io/grpc/python/grpc.html#grpc.UnaryUnaryMultiCallable.__call__
error: grpc.Call = cast(grpc.Call, rpc_error)
status_code: grpc.StatusCode = error.code()
if status_code == grpc.StatusCode.DEADLINE_EXCEEDED:
raise SparkConnectGrpcException(
message=(
f"{error}: RPC deadline exceeded. "
"The client applies per-RPC timeouts to prevent silent hangs "
"on broken connections. Deadlines can be configured via the "
"rpc_deadlines parameter of SparkConnectClient. To disable all "
"deadlines: SparkConnectClient(url, rpc_deadlines=RpcDeadlines.disabled())."
),
grpc_status_code=status_code,
) from None
status: Optional[Status] = rpc_status.from_call(error)
if status:
for d in status.details:
Expand Down Expand Up @@ -2466,7 +2598,9 @@ def clone(self, new_session_id: Optional[str] = None) -> "SparkConnectClient":
for attempt in self._retrying():
with attempt:
response: pb2.CloneSessionResponse = self._stub.CloneSession(
request, metadata=self._builder.metadata()
request,
metadata=self._builder.metadata(),
timeout=self._rpc_deadlines.clone_session,
)

# Assert that the returned session ID matches the requested ID if one was provided
Expand All @@ -2485,6 +2619,8 @@ def clone(self, new_session_id: Optional[str] = None) -> "SparkConnectClient":
connection=new_connection,
user_id=self._user_id,
use_reattachable_execute=self._use_reattachable_execute,
session_hooks=self._session_hooks,
rpc_deadlines=self._rpc_deadlines,
)
# Ensure the session ID is correctly set from the response
new_client._session_id = response.new_session_id
Expand Down
Loading