Skip to content

Commit 31c0837

Browse files
SCHJonathanhuangxiaopingRD
authored andcommitted
[SPARK-54020] Support spark.sql(...) Python API inside query functions for Spark Declarative Pipeline
### What changes were proposed in this pull request? This PR adds support for `spark.sql(...)` Python API inside query functions for Spark Declarative Pipelines. Users can now use `spark.sql(...)` to define query functions, and dependencies are correctly tracked. **Example usage:** ```python dp.materialized_view() def source(): return spark.sql("SELECT * FROM RANGE(5)") dp.materialized_view() def target(): return spark.sql("SELECT * FROM source") ``` This PR also adds restrictions on the set of SQL commands users can execute. Unsupported commands (e.g., `spark.sql("CREATE TABLE ...")`) inside query functions will raise an error. **Implementation details:** 1. Added `PipelineAnalysisContext` to Spark Connect's user context extensions, enabling the server to identify requests originating from Spark Declarative Pipelines and apply appropriate restrictions. 2. The `flow_name` field in `PipelineAnalysisContext` determines execution behavior: - **Inside query functions** (`flow_name` is set): Spark Connect server treats `spark.sql()` as a no-op and returns the raw logical plan to SDP for deferred analysis as part of the Dataflow Graph. - **Outside query functions** (`flow_name` is empty): Spark Connect server eagerly executes the command, but only SDP-allowlisted commands are permitted. ### Why are the changes needed? `spark.sql(...)` is a common and intuitive pattern for users who are more familiar with SQL to define query functions. Supporting this API improves usability and allows SQL-first developers to work more naturally with Spark Declarative Pipelines. ### Does this PR introduce _any_ user-facing change? Yes. Previously, `spark.sql(...)` inside query functions was not supported and users would see an `ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION` exception. This PR lifts that restriction. ### How was this patch tested? New test cases in `PythonPipelineSuite` unit test ### Was this patch authored or co-authored using generative AI tooling? No Closes apache#53024 from SCHJonathan/jonathan-chang_data/spark-sql. Authored-by: Yuheng Chang <jonathanyuheng@gmail.com> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent 107bc20 commit 31c0837

File tree

16 files changed

+732
-144
lines changed

16 files changed

+732
-144
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6961,6 +6961,12 @@
69616961
],
69626962
"sqlState" : "0A000"
69636963
},
6964+
"UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND" : {
6965+
"message" : [
6966+
"'<command>' is not supported in spark.sql(\"...\") API in Spark Declarative Pipeline."
6967+
],
6968+
"sqlState" : "0A000"
6969+
},
69646970
"UNSUPPORTED_SAVE_MODE" : {
69656971
"message" : [
69666972
"The save mode <saveMode> is not supported for:"
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
from contextlib import contextmanager
18+
from typing import Generator, Optional
19+
from pyspark.sql import SparkSession
20+
21+
from typing import Any, cast
22+
23+
24+
@contextmanager
25+
def add_pipeline_analysis_context(
26+
spark: SparkSession, dataflow_graph_id: str, flow_name: Optional[str]
27+
) -> Generator[None, None, None]:
28+
"""
29+
Context manager that add PipelineAnalysisContext extension to the user context
30+
used for pipeline specific analysis.
31+
"""
32+
extension_id = None
33+
# Cast because mypy seems to think `spark` is a function, not an object.
34+
# Likely related to SPARK-47544.
35+
client = cast(Any, spark).client
36+
try:
37+
import pyspark.sql.connect.proto as pb2
38+
from google.protobuf import any_pb2
39+
40+
analysis_context = pb2.PipelineAnalysisContext(
41+
dataflow_graph_id=dataflow_graph_id, flow_name=flow_name
42+
)
43+
extension = any_pb2.Any()
44+
extension.Pack(analysis_context)
45+
extension_id = client.add_threadlocal_user_context_extension(extension)
46+
yield
47+
finally:
48+
client.remove_user_context_extension(extension_id)

python/pyspark/pipelines/block_connect_access.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# limitations under the License.
1616
#
1717
from contextlib import contextmanager
18-
from typing import Callable, Generator, NoReturn
18+
from typing import Any, Callable, Generator
1919

2020
from pyspark.errors import PySparkException
2121
from pyspark.sql.connect.proto.base_pb2_grpc import SparkConnectServiceStub
@@ -24,6 +24,27 @@
2424
BLOCKED_RPC_NAMES = ["AnalyzePlan", "ExecutePlan"]
2525

2626

27+
def _is_sql_command_request(rpc_name: str, args: tuple) -> bool:
28+
"""
29+
Check if the RPC call is a spark.sql() command (ExecutePlan with sql_command).
30+
31+
:param rpc_name: Name of the RPC being called
32+
:param args: Arguments passed to the RPC
33+
:return: True if this is an ExecutePlan request with a sql_command
34+
"""
35+
if rpc_name != "ExecutePlan" or len(args) == 0:
36+
return False
37+
38+
request = args[0]
39+
if not hasattr(request, "plan"):
40+
return False
41+
plan = request.plan
42+
if not plan.HasField("command"):
43+
return False
44+
command = plan.command
45+
return command.HasField("sql_command")
46+
47+
2748
@contextmanager
2849
def block_spark_connect_execution_and_analysis() -> Generator[None, None, None]:
2950
"""
@@ -38,16 +59,23 @@ def block_spark_connect_execution_and_analysis() -> Generator[None, None, None]:
3859

3960
# Define a new __getattribute__ method that blocks RPC calls
4061
def blocked_getattr(self: SparkConnectServiceStub, name: str) -> Callable:
41-
if name not in BLOCKED_RPC_NAMES:
42-
return original_getattr(self, name)
62+
original_method = original_getattr(self, name)
4363

44-
def blocked_method(*args: object, **kwargs: object) -> NoReturn:
45-
raise PySparkException(
46-
errorClass="ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION",
47-
messageParameters={},
48-
)
64+
def intercepted_method(*args: object, **kwargs: object) -> Any:
65+
# Allow all RPCs that are not AnalyzePlan or ExecutePlan
66+
if name not in BLOCKED_RPC_NAMES:
67+
return original_method(*args, **kwargs)
68+
# Allow spark.sql() commands (ExecutePlan with sql_command)
69+
elif _is_sql_command_request(name, args):
70+
return original_method(*args, **kwargs)
71+
# Block all other AnalyzePlan and ExecutePlan calls
72+
else:
73+
raise PySparkException(
74+
errorClass="ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION",
75+
messageParameters={},
76+
)
4977

50-
return blocked_method
78+
return intercepted_method
5179

5280
try:
5381
# Apply our custom __getattribute__ method

python/pyspark/pipelines/cli.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
handle_pipeline_events,
5050
)
5151

52+
from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context
53+
5254
PIPELINE_SPEC_FILE_NAMES = ["pipeline.yaml", "pipeline.yml"]
5355

5456

@@ -216,7 +218,11 @@ def validate_str_dict(d: Mapping[str, str], field_name: str) -> Mapping[str, str
216218

217219

218220
def register_definitions(
219-
spec_path: Path, registry: GraphElementRegistry, spec: PipelineSpec
221+
spec_path: Path,
222+
registry: GraphElementRegistry,
223+
spec: PipelineSpec,
224+
spark: SparkSession,
225+
dataflow_graph_id: str,
220226
) -> None:
221227
"""Register the graph element definitions in the pipeline spec with the given registry.
222228
- Looks for Python files matching the glob patterns in the spec and imports them.
@@ -245,8 +251,11 @@ def register_definitions(
245251
assert (
246252
module_spec.loader is not None
247253
), f"Module spec has no loader for {file}"
248-
with block_session_mutations():
249-
module_spec.loader.exec_module(module)
254+
with add_pipeline_analysis_context(
255+
spark=spark, dataflow_graph_id=dataflow_graph_id, flow_name=None
256+
):
257+
with block_session_mutations():
258+
module_spec.loader.exec_module(module)
250259
elif file.suffix == ".sql":
251260
log_with_curr_timestamp(f"Registering SQL file {file}...")
252261
with file.open("r") as f:
@@ -324,7 +333,7 @@ def run(
324333

325334
log_with_curr_timestamp("Registering graph elements...")
326335
registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id)
327-
register_definitions(spec_path, registry, spec)
336+
register_definitions(spec_path, registry, spec, spark, dataflow_graph_id)
328337

329338
log_with_curr_timestamp("Starting run...")
330339
result_iter = start_run(

python/pyspark/pipelines/spark_connect_graph_element_registry.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from pyspark.sql.types import StructType
3636
from typing import Any, cast
3737
import pyspark.sql.connect.proto as pb2
38+
from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context
3839

3940

4041
class SparkConnectGraphElementRegistry(GraphElementRegistry):
@@ -43,6 +44,7 @@ class SparkConnectGraphElementRegistry(GraphElementRegistry):
4344
def __init__(self, spark: SparkSession, dataflow_graph_id: str) -> None:
4445
# Cast because mypy seems to think `spark`` is a function, not an object. Likely related to
4546
# SPARK-47544.
47+
self._spark = spark
4648
self._client = cast(Any, spark).client
4749
self._dataflow_graph_id = dataflow_graph_id
4850

@@ -110,8 +112,11 @@ def register_output(self, output: Output) -> None:
110112
self._client.execute_command(command)
111113

112114
def register_flow(self, flow: Flow) -> None:
113-
with block_spark_connect_execution_and_analysis():
114-
df = flow.func()
115+
with add_pipeline_analysis_context(
116+
spark=self._spark, dataflow_graph_id=self._dataflow_graph_id, flow_name=flow.name
117+
):
118+
with block_spark_connect_execution_and_analysis():
119+
df = flow.func()
115120
relation = cast(ConnectDataFrame, df)._plan.plan(self._client)
116121

117122
relation_flow_details = pb2.PipelineCommand.DefineFlow.WriteRelationFlowDetails(
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
import unittest
18+
19+
from pyspark.testing.connectutils import (
20+
ReusedConnectTestCase,
21+
should_test_connect,
22+
connect_requirement_message,
23+
)
24+
25+
if should_test_connect:
26+
from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context
27+
28+
29+
@unittest.skipIf(not should_test_connect, connect_requirement_message)
30+
class AddPipelineAnalysisContextTests(ReusedConnectTestCase):
31+
def test_add_pipeline_analysis_context_with_flow_name(self):
32+
with add_pipeline_analysis_context(self.spark, "test_dataflow_graph_id", "test_flow_name"):
33+
import pyspark.sql.connect.proto as pb2
34+
35+
thread_local_extensions = self.spark.client.thread_local.user_context_extensions
36+
self.assertEqual(len(thread_local_extensions), 1)
37+
# Extension is stored as (id, extension), unpack the extension
38+
_extension_id, extension = thread_local_extensions[0]
39+
context = pb2.PipelineAnalysisContext()
40+
extension.Unpack(context)
41+
self.assertEqual(context.dataflow_graph_id, "test_dataflow_graph_id")
42+
self.assertEqual(context.flow_name, "test_flow_name")
43+
thread_local_extensions_after = self.spark.client.thread_local.user_context_extensions
44+
self.assertEqual(len(thread_local_extensions_after), 0)
45+
46+
def test_add_pipeline_analysis_context_without_flow_name(self):
47+
with add_pipeline_analysis_context(self.spark, "test_dataflow_graph_id", None):
48+
import pyspark.sql.connect.proto as pb2
49+
50+
thread_local_extensions = self.spark.client.thread_local.user_context_extensions
51+
self.assertEqual(len(thread_local_extensions), 1)
52+
# Extension is stored as (id, extension), unpack the extension
53+
_extension_id, extension = thread_local_extensions[0]
54+
context = pb2.PipelineAnalysisContext()
55+
extension.Unpack(context)
56+
self.assertEqual(context.dataflow_graph_id, "test_dataflow_graph_id")
57+
# Empty string means no flow name
58+
self.assertEqual(context.flow_name, "")
59+
thread_local_extensions_after = self.spark.client.thread_local.user_context_extensions
60+
self.assertEqual(len(thread_local_extensions_after), 0)
61+
62+
def test_nested_add_pipeline_analysis_context(self):
63+
import pyspark.sql.connect.proto as pb2
64+
65+
with add_pipeline_analysis_context(self.spark, "test_dataflow_graph_id_1", flow_name=None):
66+
with add_pipeline_analysis_context(
67+
self.spark, "test_dataflow_graph_id_2", flow_name="test_flow_name"
68+
):
69+
thread_local_extensions = self.spark.client.thread_local.user_context_extensions
70+
self.assertEqual(len(thread_local_extensions), 2)
71+
# Extension is stored as (id, extension), unpack the extensions
72+
_, extension_1 = thread_local_extensions[0]
73+
context_1 = pb2.PipelineAnalysisContext()
74+
extension_1.Unpack(context_1)
75+
self.assertEqual(context_1.dataflow_graph_id, "test_dataflow_graph_id_1")
76+
self.assertEqual(context_1.flow_name, "")
77+
_, extension_2 = thread_local_extensions[1]
78+
context_2 = pb2.PipelineAnalysisContext()
79+
extension_2.Unpack(context_2)
80+
self.assertEqual(context_2.dataflow_graph_id, "test_dataflow_graph_id_2")
81+
self.assertEqual(context_2.flow_name, "test_flow_name")
82+
thread_local_extensions_after_1 = self.spark.client.thread_local.user_context_extensions
83+
self.assertEqual(len(thread_local_extensions_after_1), 1)
84+
_, extension_3 = thread_local_extensions_after_1[0]
85+
context_3 = pb2.PipelineAnalysisContext()
86+
extension_3.Unpack(context_3)
87+
self.assertEqual(context_3.dataflow_graph_id, "test_dataflow_graph_id_1")
88+
self.assertEqual(context_3.flow_name, "")
89+
thread_local_extensions_after_2 = self.spark.client.thread_local.user_context_extensions
90+
self.assertEqual(len(thread_local_extensions_after_2), 0)
91+
92+
93+
if __name__ == "__main__":
94+
try:
95+
import xmlrunner # type: ignore
96+
97+
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
98+
except ImportError:
99+
testRunner = None
100+
unittest.main(testRunner=testRunner, verbosity=2)

python/pyspark/pipelines/tests/test_cli.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from pyspark.errors import PySparkException
2424
from pyspark.testing.connectutils import (
25+
ReusedConnectTestCase,
2526
should_test_connect,
2627
connect_requirement_message,
2728
)
@@ -45,7 +46,7 @@
4546
not should_test_connect or not have_yaml,
4647
connect_requirement_message or yaml_requirement_message,
4748
)
48-
class CLIUtilityTests(unittest.TestCase):
49+
class CLIUtilityTests(ReusedConnectTestCase):
4950
def test_load_pipeline_spec(self):
5051
with tempfile.NamedTemporaryFile(mode="w") as tmpfile:
5152
tmpfile.write(
@@ -294,7 +295,9 @@ def mv2():
294295
)
295296

296297
registry = LocalGraphElementRegistry()
297-
register_definitions(outer_dir / "pipeline.yaml", registry, spec)
298+
register_definitions(
299+
outer_dir / "pipeline.yaml", registry, spec, self.spark, "test_graph_id"
300+
)
298301
self.assertEqual(len(registry.outputs), 1)
299302
self.assertEqual(registry.outputs[0].name, "mv1")
300303

@@ -315,7 +318,9 @@ def test_register_definitions_file_raises_error(self):
315318

316319
registry = LocalGraphElementRegistry()
317320
with self.assertRaises(RuntimeError) as context:
318-
register_definitions(outer_dir / "pipeline.yml", registry, spec)
321+
register_definitions(
322+
outer_dir / "pipeline.yml", registry, spec, self.spark, "test_graph_id"
323+
)
319324
self.assertIn("This is a test exception", str(context.exception))
320325

321326
def test_register_definitions_unsupported_file_extension_matches_glob(self):
@@ -334,7 +339,7 @@ def test_register_definitions_unsupported_file_extension_matches_glob(self):
334339

335340
registry = LocalGraphElementRegistry()
336341
with self.assertRaises(PySparkException) as context:
337-
register_definitions(outer_dir, registry, spec)
342+
register_definitions(outer_dir, registry, spec, self.spark, "test_graph_id")
338343
self.assertEqual(
339344
context.exception.getCondition(), "PIPELINE_UNSUPPORTED_DEFINITIONS_FILE_EXTENSION"
340345
)
@@ -382,6 +387,8 @@ def test_python_import_current_directory(self):
382387
configuration={},
383388
libraries=[LibrariesGlob(include="defs.py")],
384389
),
390+
self.spark,
391+
"test_graph_id",
385392
)
386393

387394
def test_full_refresh_all_conflicts_with_full_refresh(self):

python/pyspark/pipelines/tests/test_init_cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_init(self):
6060
self.assertTrue((Path.cwd() / "pipeline-storage").exists())
6161

6262
registry = LocalGraphElementRegistry()
63-
register_definitions(spec_path, registry, spec)
63+
register_definitions(spec_path, registry, spec, self.spark, "test_graph_id")
6464
self.assertEqual(len(registry.outputs), 1)
6565
self.assertEqual(registry.outputs[0].name, "example_python_materialized_view")
6666
self.assertEqual(len(registry.flows), 1)

0 commit comments

Comments
 (0)