diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml
index 3e90bb329be56..eb42160c378de 100644
--- a/.github/workflows/benchmark.yml
+++ b/.github/workflows/benchmark.yml
@@ -136,6 +136,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
fail-fast: false
+ max-parallel: 20
matrix:
split: ${{fromJSON(needs.matrix-gen.outputs.matrix)}}
env:
diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml
index 3776a116fd785..742990622712b 100644
--- a/.github/workflows/build_and_test.yml
+++ b/.github/workflows/build_and_test.yml
@@ -30,8 +30,7 @@ on:
description: Branch to run the build against
required: false
type: string
- # Change 'master' to 'branch-4.0' in branch-4.0 branch after cutting it.
- default: master
+ default: branch-4.1
hadoop:
description: Hadoop version to run with. HADOOP_PROFILE environment variable should accept it.
required: false
@@ -232,6 +231,7 @@ jobs:
timeout-minutes: 120
strategy:
fail-fast: false
+ max-parallel: 20
matrix:
java:
- ${{ inputs.java }}
@@ -362,7 +362,7 @@ jobs:
- name: Install Python packages (Python 3.11)
if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) || contains(matrix.modules, 'connect') || contains(matrix.modules, 'yarn')
run: |
- python3.11 -m pip install 'numpy>=1.22' pyarrow pandas pyyaml scipy unittest-xml-reporting 'lxml==4.9.4' 'grpcio==1.67.0' 'grpcio-status==1.67.0' 'protobuf==5.29.5'
+ python3.11 -m pip install 'numpy>=1.22' pyarrow 'pandas==2.3.3' pyyaml scipy unittest-xml-reporting 'lxml==4.9.4' 'grpcio==1.76.0' 'grpcio-status==1.76.0' 'protobuf==6.33.0' 'zstandard==0.25.0'
python3.11 -m pip list
# Run the tests.
- name: Run tests
@@ -504,6 +504,7 @@ jobs:
image: ${{ needs.precondition.outputs.image_pyspark_url_link }}
strategy:
fail-fast: false
+ max-parallel: 20
matrix:
java:
- ${{ inputs.java }}
@@ -548,6 +549,7 @@ jobs:
HIVE_PROFILE: hive2.3
GITHUB_PREV_SHA: ${{ github.event.before }}
SPARK_LOCAL_IP: localhost
+ NOLINT_ON_COMPILE: true
SKIP_UNIDOC: true
SKIP_MIMA: true
SKIP_PACKAGING: true
@@ -766,7 +768,7 @@ jobs:
python-version: '3.11'
- name: Install dependencies for Python CodeGen check
run: |
- python3.11 -m pip install 'black==23.12.1' 'protobuf==5.29.5' 'mypy==1.8.0' 'mypy-protobuf==3.3.0'
+ python3.11 -m pip install 'black==23.12.1' 'protobuf==6.33.0' 'mypy==1.8.0' 'mypy-protobuf==3.3.0'
python3.11 -m pip list
- name: Python CodeGen check for branch-3.5
if: inputs.branch == 'branch-3.5'
@@ -1314,6 +1316,11 @@ jobs:
key: k8s-integration-coursier-${{ hashFiles('**/pom.xml', '**/plugins.sbt') }}
restore-keys: |
k8s-integration-coursier-
+ - name: Free up disk space
+ run: |
+ if [ -f ./dev/free_disk_space ]; then
+ ./dev/free_disk_space
+ fi
- name: Install Java ${{ inputs.java }}
uses: actions/setup-java@v4
with:
diff --git a/.github/workflows/build_python_3.11_macos.yml b/.github/workflows/build_python_3.11_macos.yml
deleted file mode 100644
index 9566bfd8271d1..0000000000000
--- a/.github/workflows/build_python_3.11_macos.yml
+++ /dev/null
@@ -1,33 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-
-name: "Build / Python-only (master, Python 3.11, MacOS)"
-
-on:
- schedule:
- - cron: '0 21 * * *'
- workflow_dispatch:
-
-jobs:
- run-build:
- permissions:
- packages: write
- name: Run
- uses: ./.github/workflows/python_hosted_runner_test.yml
- if: github.repository == 'apache/spark'
diff --git a/.github/workflows/build_python_connect.yml b/.github/workflows/build_python_connect.yml
index 5edb54de82b6d..b1ebb45b9cbc1 100644
--- a/.github/workflows/build_python_connect.yml
+++ b/.github/workflows/build_python_connect.yml
@@ -72,7 +72,7 @@ jobs:
python packaging/client/setup.py sdist
cd dist
pip install pyspark*client-*.tar.gz
- pip install 'grpcio==1.67.0' 'grpcio-status==1.67.0' 'protobuf==5.29.5' 'googleapis-common-protos==1.65.0' 'graphviz==0.20.3' 'six==1.16.0' 'pandas==2.3.2' scipy 'plotly<6.0.0' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' 'graphviz==0.20.3' 'torch<2.6.0' torchvision torcheval deepspeed unittest-xml-reporting
+ pip install 'grpcio==1.76.0' 'grpcio-status==1.76.0' 'protobuf==6.33.0' 'googleapis-common-protos==1.71.0' 'graphviz==0.20.3' 'six==1.16.0' 'pandas==2.3.3' scipy 'plotly<6.0.0' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' 'graphviz==0.20.3' 'torch<2.6.0' torchvision torcheval deepspeed unittest-xml-reporting
- name: List Python packages
run: python -m pip list
- name: Run tests (local)
diff --git a/.github/workflows/maven_test.yml b/.github/workflows/maven_test.yml
index 95c9aac33fc6c..580593f1cfe5e 100644
--- a/.github/workflows/maven_test.yml
+++ b/.github/workflows/maven_test.yml
@@ -56,8 +56,10 @@ jobs:
build:
name: "Build modules using Maven: ${{ matrix.modules }} ${{ matrix.comment }}"
runs-on: ${{ inputs.os }}
+ timeout-minutes: 150
strategy:
fail-fast: false
+ max-parallel: 20
matrix:
java:
- ${{ inputs.java }}
@@ -175,7 +177,7 @@ jobs:
- name: Install Python packages (Python 3.11)
if: contains(matrix.modules, 'resource-managers#yarn') || (contains(matrix.modules, 'sql#core')) || contains(matrix.modules, 'connect')
run: |
- python3.11 -m pip install 'numpy>=1.22' pyarrow pandas pyyaml scipy unittest-xml-reporting 'grpcio==1.67.0' 'grpcio-status==1.67.0' 'protobuf==5.29.5'
+ python3.11 -m pip install 'numpy>=1.22' pyarrow 'pandas==2.3.3' pyyaml scipy unittest-xml-reporting 'grpcio==1.76.0' 'grpcio-status==1.76.0' 'protobuf==6.33.0' 'zstandard==0.25.0'
python3.11 -m pip list
# Run the tests using script command.
# BSD's script command doesn't support -c option, and the usage is different from Linux's one.
diff --git a/.github/workflows/pages.yml b/.github/workflows/pages.yml
index 86ef00220b373..2bba3dcaf176d 100644
--- a/.github/workflows/pages.yml
+++ b/.github/workflows/pages.yml
@@ -61,9 +61,9 @@ jobs:
- name: Install Python dependencies
run: |
pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \
- ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.22' pyarrow 'pandas==2.3.2' 'plotly>=4.8' 'docutils<0.18.0' \
+ ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.22' pyarrow 'pandas==2.3.3' 'plotly>=4.8' 'docutils<0.18.0' \
'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.12.1' \
- 'pandas-stubs==1.2.0.53' 'grpcio==1.67.0' 'grpcio-status==1.67.0' 'protobuf==5.29.5' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \
+ 'pandas-stubs==1.2.0.53' 'grpcio==1.76.0' 'grpcio-status==1.76.0' 'protobuf==6.33.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \
'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5'
- name: Install Ruby for documentation generation
uses: ruby/setup-ruby@v1
diff --git a/.github/workflows/publish_snapshot.yml b/.github/workflows/publish_snapshot.yml
index 6e2e5709bbd18..0608ba1fd1f1b 100644
--- a/.github/workflows/publish_snapshot.yml
+++ b/.github/workflows/publish_snapshot.yml
@@ -36,6 +36,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
fail-fast: false
+ max-parallel: 20
matrix:
# keep in sync with default value of workflow_dispatch input 'branch'
branch: ${{ fromJSON( inputs.branch || '["master", "branch-4.0", "branch-3.5"]' ) }}
@@ -67,6 +68,7 @@ jobs:
env:
ASF_USERNAME: ${{ secrets.NEXUS_USER }}
ASF_PASSWORD: ${{ secrets.NEXUS_PW }}
+ ASF_NEXUS_TOKEN: ${{ secrets.NEXUS_TOKEN }}
GPG_KEY: "not_used"
GPG_PASSPHRASE: "not_used"
GIT_REF: ${{ matrix.branch }}
diff --git a/.github/workflows/python_hosted_runner_test.yml b/.github/workflows/python_hosted_runner_test.yml
index 9a6afc095063c..659171b901d3f 100644
--- a/.github/workflows/python_hosted_runner_test.yml
+++ b/.github/workflows/python_hosted_runner_test.yml
@@ -62,6 +62,7 @@ jobs:
timeout-minutes: 120
strategy:
fail-fast: false
+ max-parallel: 20
matrix:
java:
- ${{ inputs.java }}
@@ -147,8 +148,8 @@ jobs:
run: |
python${{matrix.python}} -m pip install --ignore-installed 'blinker>=1.6.2'
python${{matrix.python}} -m pip install --ignore-installed 'six==1.16.0'
- python${{matrix.python}} -m pip install numpy 'pyarrow>=21.0.0' 'six==1.16.0' 'pandas==2.3.2' scipy 'plotly<6.0.0' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' unittest-xml-reporting && \
- python${{matrix.python}} -m pip install 'grpcio==1.67.0' 'grpcio-status==1.67.0' 'protobuf==5.29.5' 'googleapis-common-protos==1.65.0' 'graphviz==0.20.3' && \
+ python${{matrix.python}} -m pip install numpy 'pyarrow>=22.0.0' 'six==1.16.0' 'pandas==2.3.3' scipy 'plotly<6.0.0' 'mlflow>=2.8.1' coverage matplotlib openpyxl 'memory-profiler>=0.61.0' 'scikit-learn>=1.3.2' unittest-xml-reporting && \
+ python${{matrix.python}} -m pip install 'grpcio==1.76.0' 'grpcio-status==1.76.0' 'protobuf==6.33.0' 'googleapis-common-protos==1.71.0' 'zstandard==0.25.0' 'graphviz==0.20.3' && \
python${{matrix.python}} -m pip cache purge
- name: List Python packages
run: python${{matrix.python}} -m pip list
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 5de61c831cbef..ab10d04d9badd 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -28,6 +28,10 @@
# - ASF_PASSWORD:
# The password associated with your ASF account.
#
+# - ASF_NEXUS_TOKEN:
+# ASF Nexus API token associated with your ASF account.
+# Can be found in https://repository.apache.org/#profile;User%20Token
+#
# - GPG_PRIVATE_KEY:
# Your GPG private key, exported using:
# gpg --armor --export-secret-keys ABCD1234 > private.key
@@ -128,6 +132,7 @@ jobs:
GIT_NAME: "${{ github.actor }}"
ASF_USERNAME: "${{ secrets.ASF_USERNAME }}"
ASF_PASSWORD: "${{ secrets.ASF_PASSWORD }}"
+ ASF_NEXUS_TOKEN: "${{ secrets.ASF_NEXUS_TOKEN }}"
GPG_PRIVATE_KEY: "${{ secrets.GPG_PRIVATE_KEY }}"
GPG_PASSPHRASE: "${{ secrets.GPG_PASSPHRASE }}"
PYPI_API_TOKEN: "${{ secrets.PYPI_API_TOKEN }}"
@@ -163,6 +168,7 @@ jobs:
GPG_PRIVATE_KEY="not_used"
GPG_PASSPHRASE="not_used"
ASF_USERNAME="gurwls223"
+ ASF_NEXUS_TOKEN="not_used"
export SKIP_TAG=1
unset RELEASE_VERSION
else
@@ -170,7 +176,7 @@ jobs:
export DRYRUN_MODE=0
fi
- export ASF_PASSWORD GPG_PRIVATE_KEY GPG_PASSPHRASE ASF_USERNAME
+ export ASF_PASSWORD GPG_PRIVATE_KEY GPG_PASSPHRASE ASF_USERNAME ASF_NEXUS_TOKEN
export GIT_BRANCH="${GIT_BRANCH:-master}"
[ -n "$RELEASE_VERSION" ] && export RELEASE_VERSION
@@ -239,7 +245,7 @@ jobs:
# Redact sensitive information in log files
shopt -s globstar nullglob
FILES=("$RELEASE_DIR/docker-build.log" "$OUTPUT_DIR/"*.log)
- PATTERNS=("$ASF_USERNAME" "$ASF_PASSWORD" "$GPG_PRIVATE_KEY" "$GPG_PASSPHRASE" "$PYPI_API_TOKEN")
+ PATTERNS=("$ASF_USERNAME" "$ASF_PASSWORD" "$GPG_PRIVATE_KEY" "$GPG_PASSPHRASE" "$PYPI_API_TOKEN" "$ASF_NEXUS_TOKEN")
for file in "${FILES[@]}"; do
[ -f "$file" ] || continue
cp "$file" "$file.bak"
diff --git a/.sbtopts b/.sbtopts
new file mode 100644
index 0000000000000..3516fc4bd7ebc
--- /dev/null
+++ b/.sbtopts
@@ -0,0 +1,20 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+-J-Xmx8g
+-J-Xms8g
+-J-XX:MaxMetaspaceSize=1g
diff --git a/LICENSE-binary b/LICENSE-binary
index 95087a0a0de23..fcc54c51bc820 100644
--- a/LICENSE-binary
+++ b/LICENSE-binary
@@ -215,10 +215,8 @@ com.github.joshelser:dropwizard-metrics-hadoop-metrics2-reporter
com.google.code.findbugs:jsr305
com.google.code.gson:gson
com.google.crypto.tink:tink
-com.google.errorprone:error_prone_annotations
com.google.flatbuffers:flatbuffers-java
com.google.guava:guava
-com.google.j2objc:j2objc-annotations
com.jamesmurty.utils:java-xmlbuilder
com.ning:compress-lzf
com.squareup.okhttp3:logging-interceptor
@@ -478,7 +476,6 @@ dev.ludovic.netlib:blas
dev.ludovic.netlib:arpack
dev.ludovic.netlib:lapack
net.razorvine:pickle
-org.checkerframework:checker-qual
org.typelevel:algebra_2.13:jar
org.typelevel:cats-kernel_2.13
org.typelevel:spire_2.13
diff --git a/NOTICE-binary b/NOTICE-binary
index a3f302b1cb04d..69ade56b36c45 100644
--- a/NOTICE-binary
+++ b/NOTICE-binary
@@ -102,7 +102,7 @@ which has the following notices:
Please visit the Netty web site for more information:
- * http://netty.io/
+ * https://netty.io/
Copyright 2014 The Netty Project
@@ -110,7 +110,7 @@ The Netty Project licenses this file to you under the Apache License,
version 2.0 (the "License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at:
- http://www.apache.org/licenses/LICENSE-2.0
+ https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
@@ -132,6 +132,14 @@ been derived from the works by JSR-166 EG, Doug Lea, and Jason T. Greene:
* http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/
* http://viewvc.jboss.org/cgi-bin/viewvc.cgi/jbosscache/experimental/jsr166/
+This product contains a modified version of Robert Harder's Public Domain
+Base64 Encoder and Decoder, which can be obtained at:
+
+ * LICENSE:
+ * license/LICENSE.base64.txt (Public Domain)
+ * HOMEPAGE:
+ * http://iharder.sourceforge.net/current/java/base64/
+
This product contains a modified portion of 'Webbit', an event based
WebSocket and HTTP server, which can be obtained at:
@@ -146,7 +154,7 @@ facade for Java, which can be obtained at:
* LICENSE:
* license/LICENSE.slf4j.txt (MIT License)
* HOMEPAGE:
- * http://www.slf4j.org/
+ * https://www.slf4j.org/
This product contains a modified portion of 'Apache Harmony', an open source
Java SE, which can be obtained at:
@@ -156,7 +164,7 @@ Java SE, which can be obtained at:
* LICENSE:
* license/LICENSE.harmony.txt (Apache License 2.0)
* HOMEPAGE:
- * http://archive.apache.org/dist/harmony/
+ * https://archive.apache.org/dist/harmony/
This product contains a modified portion of 'jbzip2', a Java bzip2 compression
and decompression library written by Matthew J. Francis. It can be obtained at:
@@ -215,6 +223,14 @@ and decompression library, which can be obtained at:
* HOMEPAGE:
* https://github.com/jponge/lzma-java
+This product optionally depends on 'zstd-jni', a zstd-jni Java compression
+and decompression library, which can be obtained at:
+
+ * LICENSE:
+ * license/LICENSE.zstd-jni.txt (BSD)
+ * HOMEPAGE:
+ * https://github.com/luben/zstd-jni
+
This product contains a modified portion of 'jfastlz', a Java port of FastLZ compression
and decompression library written by William Kinney. It can be obtained at:
@@ -238,7 +254,7 @@ equivalent functionality. It can be obtained at:
* LICENSE:
* license/LICENSE.bouncycastle.txt (MIT License)
* HOMEPAGE:
- * http://www.bouncycastle.org/
+ * https://www.bouncycastle.org/
This product optionally depends on 'Snappy', a compression library produced
by Google Inc, which can be obtained at:
@@ -252,9 +268,9 @@ This product optionally depends on 'JBoss Marshalling', an alternative Java
serialization API, which can be obtained at:
* LICENSE:
- * license/LICENSE.jboss-marshalling.txt (GNU LGPL 2.1)
+ * license/LICENSE.jboss-marshalling.txt (Apache License 2.0)
* HOMEPAGE:
- * http://www.jboss.org/jbossmarshalling
+ * https://github.com/jboss-remoting/jboss-marshalling
This product optionally depends on 'Caliper', Google's micro-
benchmarking framework, which can be obtained at:
@@ -264,13 +280,21 @@ benchmarking framework, which can be obtained at:
* HOMEPAGE:
* https://github.com/google/caliper
+This product optionally depends on 'Apache Commons Logging', a logging
+framework, which can be obtained at:
+
+ * LICENSE:
+ * license/LICENSE.commons-logging.txt (Apache License 2.0)
+ * HOMEPAGE:
+ * https://commons.apache.org/logging/
+
This product optionally depends on 'Apache Log4J', a logging framework, which
can be obtained at:
* LICENSE:
* license/LICENSE.log4j.txt (Apache License 2.0)
* HOMEPAGE:
- * http://logging.apache.org/log4j/
+ * https://logging.apache.org/log4j/
This product optionally depends on 'Aalto XML', an ultra-high performance
non-blocking XML processor, which can be obtained at:
@@ -278,7 +302,7 @@ non-blocking XML processor, which can be obtained at:
* LICENSE:
* license/LICENSE.aalto-xml.txt (Apache License 2.0)
* HOMEPAGE:
- * http://wiki.fasterxml.com/AaltoHome
+ * https://wiki.fasterxml.com/AaltoHome
This product contains a modified version of 'HPACK', a Java implementation of
the HTTP/2 HPACK algorithm written by Twitter. It can be obtained at:
@@ -288,6 +312,22 @@ the HTTP/2 HPACK algorithm written by Twitter. It can be obtained at:
* HOMEPAGE:
* https://github.com/twitter/hpack
+This product contains a modified version of 'HPACK', a Java implementation of
+the HTTP/2 HPACK algorithm written by Cory Benfield. It can be obtained at:
+
+ * LICENSE:
+ * license/LICENSE.hyper-hpack.txt (MIT License)
+ * HOMEPAGE:
+ * https://github.com/python-hyper/hpack/
+
+This product contains a modified version of 'HPACK', a Java implementation of
+the HTTP/2 HPACK algorithm written by Tatsuhiro Tsujikawa. It can be obtained at:
+
+ * LICENSE:
+ * license/LICENSE.nghttp2-hpack.txt (MIT License)
+ * HOMEPAGE:
+ * https://github.com/nghttp2/nghttp2/
+
This product contains a modified portion of 'Apache Commons Lang', a Java library
provides utilities for the java.lang API, which can be obtained at:
@@ -304,6 +344,37 @@ This product contains the Maven wrapper scripts from 'Maven Wrapper', that provi
* HOMEPAGE:
* https://github.com/takari/maven-wrapper
+This product contains the dnsinfo.h header file, that provides a way to retrieve the system DNS configuration on MacOS.
+This private header is also used by Apple's open source
+ mDNSResponder (https://opensource.apple.com/tarballs/mDNSResponder/).
+
+ * LICENSE:
+ * license/LICENSE.dnsinfo.txt (Apple Public Source License 2.0)
+ * HOMEPAGE:
+ * https://www.opensource.apple.com/source/configd/configd-453.19/dnsinfo/dnsinfo.h
+
+This product optionally depends on 'Brotli4j', Brotli compression and
+decompression for Java., which can be obtained at:
+
+ * LICENSE:
+ * license/LICENSE.brotli4j.txt (Apache License 2.0)
+ * HOMEPAGE:
+ * https://github.com/hyperxpro/Brotli4j
+
+This product is statically linked against Quiche.
+
+ * LICENSE:
+ * license/LICENSE.quiche.txt (BSD2)
+ * HOMEPAGE:
+ * https://github.com/cloudflare/quiche
+
+
+This product is statically linked against boringssl.
+
+ * LICENSE
+ * license/LICENSE.boringssl.txt (Apache License 2.0)
+ * HOMEPAGE:
+ * https://boringssl.googlesource.com/boringssl/
The binary distribution of this product bundles binaries of
Commons Codec 1.4,
@@ -755,10 +826,6 @@ project. The following notice covers the Felix files:
I. Included Software
- This product includes software developed at
- The Apache Software Foundation (http://www.apache.org/).
- Licensed under the Apache License 2.0.
-
This product includes software developed at
The OSGi Alliance (http://www.osgi.org/).
Copyright (c) OSGi Alliance (2000, 2007).
@@ -1071,199 +1138,6 @@ Copyright 2019 The Apache Software Foundation
Hive Storage API
Copyright 2018 The Apache Software Foundation
-
-
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
-
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
- 1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- Copyright 2015-2015 DataNucleus
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
Android JSON library
Copyright (C) 2010 The Android Open Source Project
@@ -1273,9 +1147,6 @@ The Android Open Source Project
Apache Yetus - Audience Annotations
Copyright 2015-2017 The Apache Software Foundation
-This product includes software developed at
-The Apache Software Foundation (http://www.apache.org/).
-
Ehcache V3
Copyright 2014-2016 Terracotta, Inc.
@@ -1285,19 +1156,9 @@ under the Apache License 2.0 (see: org.ehcache.impl.internal.classes.commonslang
Apache Geronimo JCache Spec 1.0
Copyright 2003-2014 The Apache Software Foundation
-This product includes software developed at
-The Apache Software Foundation (http://www.apache.org/).
-
-
-
-
Token provider
Copyright 2014-2017 The Apache Software Foundation
-This product includes software developed at
-The Apache Software Foundation (http://www.apache.org/).
-
-
Metrics
Copyright 2010-2013 Coda Hale and Yammer, Inc.
diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index 4393175430265..d72a6a562432f 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -1,6 +1,6 @@
Package: SparkR
Type: Package
-Version: 4.1.0
+Version: 4.1.2
Title: R Front End for 'Apache Spark'
Description: Provides an R Front end for 'Apache Spark' .
Authors@R:
diff --git a/README.md b/README.md
index f3435178c68eb..77d12031106dd 100644
--- a/README.md
+++ b/README.md
@@ -39,7 +39,6 @@ This README file only contains basic setup instructions.
| | [](https://github.com/apache/spark/actions/workflows/build_python_3.10.yml) |
| | [](https://github.com/apache/spark/actions/workflows/build_python_3.11_classic_only.yml) |
| | [](https://github.com/apache/spark/actions/workflows/build_python_3.11_arm.yml) |
-| | [](https://github.com/apache/spark/actions/workflows/build_python_3.11_macos.yml) |
| | [](https://github.com/apache/spark/actions/workflows/build_python_3.11_macos26.yml) |
| | [](https://github.com/apache/spark/actions/workflows/build_python_numpy_2.1.3.yml) |
| | [](https://github.com/apache/spark/actions/workflows/build_python_3.12.yml) |
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 0e6012062313e..f2afb76fef9ed 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../pom.xml
@@ -142,11 +142,25 @@
Because we don't shade dependencies anymore, we need to restore Guava to compile scope so
that the libraries Spark depend on have it available. We'll package the version that Spark
uses which is not the same as Hadoop dependencies, but works.
+ As mentioned in https://github.com/google/guava/wiki/UseGuavaInYourBuild
+ Guava has one dependency that is needed for linkage at runtime:
+ com.google.guava:failureaccess:
-->
com.google.guava
guava
${hadoop.deps.scope}
+
+
+ *
+ *
+
+
+
+
+ com.google.guava
+ failureaccess
+ ${hadoop.deps.scope}
diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml
index 5d923ecc69ffe..d2d9521f73834 100644
--- a/common/kvstore/pom.xml
+++ b/common/kvstore/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml
index 51b782920e6d9..e5dc64c542f5b 100644
--- a/common/network-common/pom.xml
+++ b/common/network-common/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
@@ -123,6 +123,11 @@
guava
compile
+
+ com.google.guava
+ failureaccess
+ compile
+
org.apache.commons
commons-crypto
diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
index e229e32e91717..e76e843b053b4 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java
@@ -367,7 +367,7 @@ private void assertErrorAndClosed(RpcResult result, String expectedError) {
"Connection reset",
"java.nio.channels.ClosedChannelException",
"io.netty.channel.StacklessClosedChannelException",
- "java.io.IOException: Broken pipe"
+ "Broken pipe"
);
Set containsAndClosed = new HashSet<>(Set.of(expectedError));
containsAndClosed.addAll(possibleClosedErrors);
diff --git a/common/network-common/src/test/java/org/apache/spark/network/ssl/ReloadingX509TrustManagerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ssl/ReloadingX509TrustManagerSuite.java
index 5bb47ff388671..b373d99a8e404 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/ssl/ReloadingX509TrustManagerSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/ssl/ReloadingX509TrustManagerSuite.java
@@ -189,7 +189,9 @@ public void testReload() throws Exception {
public void testReloadMissingTrustStore() throws Exception {
KeyPair kp = generateKeyPair("RSA");
X509Certificate cert1 = generateCertificate("CN=Cert1", kp, 30, "SHA1withRSA");
- File trustStore = new File("testmissing.jks");
+ File trustStore = File.createTempFile("testmissing", "jks");
+ trustStore.delete();
+ // trustStore is going to be re-created later so delete it on exit.
trustStore.deleteOnExit();
assertFalse(trustStore.exists());
createTrustStore(trustStore, "password", "cert1", cert1);
diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml
index 60ad971573997..eb8c7817c8ef0 100644
--- a/common/network-shuffle/pom.xml
+++ b/common/network-shuffle/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml
index c4451923b17a5..abc36aaa92b0c 100644
--- a/common/network-yarn/pom.xml
+++ b/common/network-yarn/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
@@ -171,18 +171,25 @@
package
+
+
+
+
+
+
+
+
diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml
index 8cf59f603a5bb..2c4460fac9637 100644
--- a/common/sketch/pom.xml
+++ b/common/sketch/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/common/tags/pom.xml b/common/tags/pom.xml
index de93a9205ac91..426ff50535b76 100644
--- a/common/tags/pom.xml
+++ b/common/tags/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml
index 47fc39abd2a67..cc73ba82b29f1 100644
--- a/common/unsafe/pom.xml
+++ b/common/unsafe/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/GeographyVal.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/GeographyVal.java
index 48dc6f896e91a..48b121ba894a5 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/GeographyVal.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/GeographyVal.java
@@ -17,9 +17,12 @@
package org.apache.spark.unsafe.types;
+import org.apache.spark.annotation.Unstable;
+
import java.io.Serializable;
// This class represents the physical type for the GEOGRAPHY data type.
+@Unstable
public final class GeographyVal implements Comparable, Serializable {
// The GEOGRAPHY type is implemented as a byte array. We provide `getBytes` and `fromBytes`
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/GeometryVal.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/GeometryVal.java
index 2bb7f194c940d..381d3e25c68af 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/GeometryVal.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/GeometryVal.java
@@ -17,9 +17,12 @@
package org.apache.spark.unsafe.types;
+import org.apache.spark.annotation.Unstable;
+
import java.io.Serializable;
// This class represents the physical type for the GEOMETRY data type.
+@Unstable
public final class GeometryVal implements Comparable, Serializable {
// The GEOMETRY type is implemented as a byte array. We provide `getBytes` and `fromBytes`
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 87d004040c3a0..96b103ae33881 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -1160,9 +1160,10 @@ public UTF8String reverse() {
int i = 0; // position in byte
while (i < numBytes) {
- int len = numBytesForFirstByte(getByte(i));
+ int len = Math.min(numBytesForFirstByte(getByte(i)), numBytes);
+ int targetOffset = Math.max(result.length - i - len, 0);
copyMemory(this.base, this.offset + i, result,
- BYTE_ARRAY_OFFSET + result.length - i - len, len);
+ BYTE_ARRAY_OFFSET + targetOffset, len);
i += len;
}
diff --git a/common/utils-java/pom.xml b/common/utils-java/pom.xml
index ba3603f810856..f0486ebe8f215 100644
--- a/common/utils-java/pom.xml
+++ b/common/utils-java/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
index 8b6d3614b86de..48bc6f201bc7f 100644
--- a/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
+++ b/common/utils-java/src/main/java/org/apache/spark/internal/LogKeys.java
@@ -614,6 +614,7 @@ public enum LogKeys implements LogKey {
PYTHON_WORKER_CHANNEL_IS_BLOCKING_MODE,
PYTHON_WORKER_CHANNEL_IS_CONNECTED,
PYTHON_WORKER_HAS_INPUTS,
+ PYTHON_WORKER_ID,
PYTHON_WORKER_IDLE_TIMEOUT,
PYTHON_WORKER_IS_ALIVE,
PYTHON_WORKER_MODULE,
diff --git a/common/utils/pom.xml b/common/utils/pom.xml
index df3bc5adb10bd..45f640a406784 100644
--- a/common/utils/pom.xml
+++ b/common/utils/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json
index 34b72975cc077..aa0f0a89f97c4 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -901,6 +901,28 @@
},
"sqlState" : "56K00"
},
+ "CONNECT_INVALID_PLAN" : {
+ "message" : [
+ "The Spark Connect plan is invalid."
+ ],
+ "subClass" : {
+ "CANNOT_PARSE" : {
+ "message" : [
+ "Cannot decompress or parse the input plan ()",
+ "This may be caused by a corrupted compressed plan.",
+ "To disable plan compression, set 'spark.connect.session.planCompression.threshold' to -1."
+ ]
+ },
+ "PLAN_SIZE_LARGER_THAN_MAX" : {
+ "message" : [
+ "The plan size is larger than max ( vs. )",
+ "This typically occurs when building very complex queries with many operations, large literals, or deeply nested expressions.",
+ "Consider splitting the query into smaller parts using temporary views for intermediate results or reducing the number of operations."
+ ]
+ }
+ },
+ "sqlState" : "56K00"
+ },
"CONNECT_ML" : {
"message" : [
"Generic Spark Connect ML error."
@@ -1407,6 +1429,12 @@
],
"sqlState" : "42623"
},
+ "DEFINE_FLOW_ONCE_OPTION_NOT_SUPPORTED" : {
+ "message" : [
+ "Defining a one-time flow with the 'once' option is not supported."
+ ],
+ "sqlState" : "0A000"
+ },
"DESCRIBE_JSON_NOT_EXTENDED" : {
"message" : [
"DESCRIBE TABLE ... AS JSON only supported when [EXTENDED|FORMATTED] is specified.",
@@ -1660,6 +1688,12 @@
],
"sqlState" : "42846"
},
+ "EXPRESSION_TRANSLATION_TO_V2_IS_NOT_SUPPORTED" : {
+ "message" : [
+ "Expression cannot be translated to v2 expression."
+ ],
+ "sqlState" : "0A000"
+ },
"EXPRESSION_TYPE_IS_NOT_ORDERABLE" : {
"message" : [
"Column expression cannot be sorted because its type is not orderable."
@@ -1827,6 +1861,12 @@
],
"sqlState" : "2203G"
},
+ "FAILED_TO_CREATE_PLAN_FOR_DIRECT_QUERY" : {
+ "message" : [
+ "Failed to create plan for direct query on files: "
+ ],
+ "sqlState" : "58030"
+ },
"FAILED_TO_LOAD_ROUTINE" : {
"message" : [
"Failed to load routine ."
@@ -1888,6 +1928,12 @@
],
"sqlState" : "42623"
},
+ "GEO_ENCODER_SRID_MISMATCH_ERROR" : {
+ "message" : [
+ "Failed to encode value because provided SRID of a value to encode does not match type SRID: ."
+ ],
+ "sqlState" : "42K09"
+ },
"GET_TABLES_BY_TYPE_UNSUPPORTED_BY_HIVE_VERSION" : {
"message" : [
"Hive 2.2 and lower versions don't support getTablesByType. Please use Hive 2.3 or higher version."
@@ -1966,6 +2012,12 @@
],
"sqlState" : "22546"
},
+ "HLL_K_MUST_BE_CONSTANT" : {
+ "message" : [
+ "Invalid call to ; the `K` value must be a constant value, but got a non-constant expression."
+ ],
+ "sqlState" : "42K0E"
+ },
"HLL_UNION_DIFFERENT_LG_K" : {
"message" : [
"Sketches have different `lgConfigK` values: and . Set the `allowDifferentLgConfigK` parameter to true to call with different `lgConfigK` values."
@@ -2007,7 +2059,7 @@
},
"IDENTIFIER_TOO_MANY_NAME_PARTS" : {
"message" : [
- " is not a valid identifier as it has more than 2 name parts."
+ " is not a valid identifier as it has more than name parts."
],
"sqlState" : "42601"
},
@@ -2071,6 +2123,15 @@
],
"sqlState" : "42000"
},
+ "INCOMPATIBLE_COLUMN_CHANGES_AFTER_VIEW_WITH_PLAN_CREATION" : {
+ "message" : [
+ "View plan references table whose columns changed since the view plan was initially captured.",
+ "Column changes:",
+ "",
+ "This indicates the table has evolved and the view based on the plan must be recreated."
+ ],
+ "sqlState" : "51024"
+ },
"INCOMPATIBLE_COLUMN_TYPE" : {
"message" : [
" can only be performed on tables with compatible column types. The column of the table is type which is not compatible with at the same column of the first table.."
@@ -2153,6 +2214,31 @@
],
"sqlState" : "42000"
},
+ "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS" : {
+ "message" : [
+ "Detected incompatible changes to table after DataFrame/Dataset has been resolved and analyzed, meaning the underlying plan is out of sync. Please, re-create DataFrame/Dataset before attempting to execute the query again."
+ ],
+ "subClass" : {
+ "COLUMNS_MISMATCH" : {
+ "message" : [
+ "Data columns have changed:",
+ ""
+ ]
+ },
+ "METADATA_COLUMNS_MISMATCH" : {
+ "message" : [
+ "Metadata columns have changed:",
+ ""
+ ]
+ },
+ "TABLE_ID_MISMATCH" : {
+ "message" : [
+ "Table ID has changed from to ."
+ ]
+ }
+ },
+ "sqlState" : "51024"
+ },
"INCOMPATIBLE_VIEW_SCHEMA_CHANGE" : {
"message" : [
"The SQL query of view has an incompatible schema change and column cannot be resolved. Expected columns named but got .",
@@ -2478,6 +2564,12 @@
],
"sqlState" : "22003"
},
+ "INVALID_ARTIFACT_PATH" : {
+ "message" : [
+ "Artifact with name is invalid. The name must be a relative path and cannot reference parent/sibling/nephew directories."
+ ],
+ "sqlState" : "22023"
+ },
"INVALID_ATTRIBUTE_NAME_SYNTAX" : {
"message" : [
"Syntax error in the attribute name: . Check that backticks appear in pairs, a quoted string is a complete name part and use a backtick only inside quoted name parts."
@@ -3466,6 +3558,16 @@
"expects an integer literal, but got ."
]
},
+ "INTERRUPT_TYPE_OPERATION_ID_REQUIRES_ID" : {
+ "message" : [
+ "INTERRUPT_TYPE_OPERATION_ID requested, but no operation_id provided."
+ ]
+ },
+ "INTERRUPT_TYPE_TAG_REQUIRES_TAG" : {
+ "message" : [
+ "INTERRUPT_TYPE_TAG requested, but no operation_tag provided."
+ ]
+ },
"LENGTH" : {
"message" : [
"Expects `length` greater than or equal to 0, but got ."
@@ -3496,6 +3598,11 @@
"Expects a positive or a negative value for `start`, but got 0."
]
},
+ "STREAMING_LISTENER_COMMAND_MISSING" : {
+ "message" : [
+ "Missing command in StreamingQueryListenerBusCommand."
+ ]
+ },
"STRING" : {
"message" : [
"expects a string literal, but got ."
@@ -4121,6 +4228,30 @@
],
"sqlState" : "42K0E"
},
+ "KLL_INVALID_INPUT_SKETCH_BUFFER" : {
+ "message" : [
+ "Invalid call to ; only valid KLL sketch buffers are supported as inputs (such as those produced by the `kll_sketch_agg` function)."
+ ],
+ "sqlState" : "22000"
+ },
+ "KLL_SKETCH_INVALID_QUANTILE_RANGE" : {
+ "message" : [
+ "For function , the quantile value must be between 0.0 and 1.0 (inclusive)."
+ ],
+ "sqlState" : "22003"
+ },
+ "KLL_SKETCH_K_MUST_BE_CONSTANT" : {
+ "message" : [
+ "For function , the k parameter must be a constant value, but got a non-constant expression."
+ ],
+ "sqlState" : "42K0E"
+ },
+ "KLL_SKETCH_K_OUT_OF_RANGE" : {
+ "message" : [
+ "For function , the k parameter must be between 8 and 65535 (inclusive), but got ."
+ ],
+ "sqlState" : "22003"
+ },
"KRYO_BUFFER_OVERFLOW" : {
"message" : [
"Kryo serialization failed: . To avoid this, increase \"\" value."
@@ -4282,6 +4413,59 @@
},
"sqlState" : "XX000"
},
+ "MISSING_CATALOG_ABILITY" : {
+ "message" : [
+ "Catalog does not support"
+ ],
+ "subClass" : {
+ "CREATE_FUNCTION" : {
+ "message" : [
+ "CREATE FUNCTION."
+ ]
+ },
+ "DROP_FUNCTION" : {
+ "message" : [
+ "DROP FUNCTION."
+ ]
+ },
+ "FUNCTIONS" : {
+ "message" : [
+ "functions."
+ ]
+ },
+ "NAMESPACES" : {
+ "message" : [
+ "namespaces."
+ ]
+ },
+ "PROCEDURES" : {
+ "message" : [
+ "procedures."
+ ]
+ },
+ "REFRESH_FUNCTION" : {
+ "message" : [
+ "REFRESH FUNCTION."
+ ]
+ },
+ "TABLES" : {
+ "message" : [
+ "tables."
+ ]
+ },
+ "TABLE_VALUED_FUNCTIONS" : {
+ "message" : [
+ "table-valued functions."
+ ]
+ },
+ "VIEWS" : {
+ "message" : [
+ "views."
+ ]
+ }
+ },
+ "sqlState" : "0A000"
+ },
"MISSING_DATABASE_FOR_V1_SESSION_CATALOG" : {
"message" : [
"Database name is not specified in the v1 session catalog. Please ensure to provide a valid database name when interacting with the v1 catalog."
@@ -4841,6 +5025,12 @@
],
"sqlState" : "42000"
},
+ "PIPELINE_STORAGE_ROOT_INVALID" : {
+ "message" : [
+ "Pipeline storage root must be an absolute path with a URI scheme (e.g., file://, s3a://, hdfs://). Got: ``."
+ ],
+ "sqlState" : "42K03"
+ },
"PIPE_OPERATOR_AGGREGATE_EXPRESSION_CONTAINS_NO_AGGREGATE_FUNCTION" : {
"message" : [
"Non-grouping expression is provided as an argument to the |> AGGREGATE pipe operator but does not contain any aggregate function; please update it to include an aggregate function and then retry the query again."
@@ -5052,7 +5242,7 @@
"RUN_EMPTY_PIPELINE" : {
"message" : [
"Pipelines are expected to have at least one non-temporary dataset defined (tables, persisted views) but no non-temporary datasets were found in your pipeline.",
- "Please verify that you have included the expected source files, and that your source code includes table definitions (e.g., CREATE MATERIALIZED VIEW in SQL code, @sdp.table in python code)."
+ "Please verify that you have included the expected source files, and that your source code includes table definitions (e.g., CREATE MATERIALIZED VIEW in SQL code, @dp.table in python code)."
],
"sqlState" : "42617"
},
@@ -5590,6 +5780,21 @@
"message" : [
"The input stream is not supported in Real-time Mode."
]
+ },
+ "OPERATOR_OR_SINK_NOT_IN_ALLOWLIST" : {
+ "message" : [
+ "The (s): not in the allowlist for Real-Time Mode. To bypass this check, set spark.sql.streaming.realTimeMode.allowlistCheck to false. By changing this, you agree to run the query at your own risk."
+ ]
+ },
+ "OUTPUT_MODE_NOT_SUPPORTED" : {
+ "message" : [
+ "The output mode is not supported. To work around this limitation, set the output mode to Update. In the future, may be supported."
+ ]
+ },
+ "SINK_NOT_SUPPORTED" : {
+ "message" : [
+ "The sink is currently not supported. See the Real-Time Mode User Guide for a list of supported sinks."
+ ]
}
},
"sqlState" : "0A000"
@@ -5708,7 +5913,7 @@
"TEMP_TABLE_OR_VIEW_ALREADY_EXISTS" : {
"message" : [
"Cannot create the temporary view because it already exists.",
- "Choose a different name, drop or replace the existing view, or add the IF NOT EXISTS clause to tolerate pre-existing views."
+ "Choose a different name, drop or replace the existing view."
],
"sqlState" : "42P07"
},
@@ -5730,6 +5935,12 @@
],
"sqlState" : "22546"
},
+ "THETA_LG_NOM_ENTRIES_MUST_BE_CONSTANT" : {
+ "message" : [
+ "Invalid call to ; the `lgNomEntries` value must be a constant value, but got a non-constant expression."
+ ],
+ "sqlState" : "42K0E"
+ },
"TRAILING_COMMA_IN_SELECT" : {
"message" : [
"Trailing comma detected in SELECT clause. Remove the trailing comma before the FROM clause."
@@ -6337,6 +6548,11 @@
"Drop the namespace ."
]
},
+ "GEOSPATIAL_DISABLED" : {
+ "message" : [
+ "Geospatial feature is disabled."
+ ]
+ },
"HIVE_TABLE_TYPE" : {
"message" : [
"The is hive ."
@@ -6352,6 +6568,11 @@
"INSERT INTO with IF NOT EXISTS in the PARTITION spec."
]
},
+ "INTERRUPT_TYPE" : {
+ "message" : [
+ "Unsupported interrupt type: ."
+ ]
+ },
"LAMBDA_FUNCTION_WITH_PYTHON_UDF" : {
"message" : [
"Lambda function with Python UDF in a higher order function."
@@ -6722,6 +6943,12 @@
],
"sqlState" : "0A000"
},
+ "UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND" : {
+ "message" : [
+ "'' is not supported in spark.sql(\"...\") API in Spark Declarative Pipeline."
+ ],
+ "sqlState" : "0A000"
+ },
"UNSUPPORTED_SAVE_MODE" : {
"message" : [
"The save mode is not supported for:"
@@ -6908,6 +7135,12 @@
],
"sqlState" : "0A001"
},
+ "UNSUPPORTED_TIME_TYPE" : {
+ "message" : [
+ "The data type TIME is not supported."
+ ],
+ "sqlState" : "0A000"
+ },
"UNSUPPORTED_TYPED_LITERAL" : {
"message" : [
"Literals of the type are not supported. Supported types are ."
@@ -7742,11 +7975,6 @@
"Cannot use \"INTERVAL\" type in the table schema."
]
},
- "_LEGACY_ERROR_TEMP_1184" : {
- "message" : [
- "Catalog does not support ."
- ]
- },
"_LEGACY_ERROR_TEMP_1186" : {
"message" : [
"Multi-part identifier cannot be empty."
@@ -8447,11 +8675,6 @@
"Failed to merge incompatible schemas and ."
]
},
- "_LEGACY_ERROR_TEMP_2096" : {
- "message" : [
- " is not supported temporarily."
- ]
- },
"_LEGACY_ERROR_TEMP_2097" : {
"message" : [
"Could not execute broadcast in secs. You can increase the timeout for broadcasts via or disable broadcast join by setting to -1 or remove the broadcast hint if it exists in your code."
diff --git a/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala b/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala
index af76056dfe928..0d958e3f71604 100644
--- a/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala
+++ b/common/utils/src/main/scala/org/apache/spark/ErrorClassesJSONReader.scala
@@ -152,7 +152,7 @@ private object ErrorClassesJsonReader {
.addModule(DefaultScalaModule)
.build()
private def readAsMap(url: URL): Map[String, ErrorInfo] = {
- val map = mapper.readValue(url, new TypeReference[Map[String, ErrorInfo]]() {})
+ val map = mapper.readValue(url.openStream(), new TypeReference[Map[String, ErrorInfo]]() {})
val errorClassWithDots = map.collectFirst {
case (errorClass, _) if errorClass.contains('.') => errorClass
case (_, ErrorInfo(_, Some(map), _, _)) if map.keys.exists(_.contains('.')) =>
diff --git a/common/utils/src/main/scala/org/apache/spark/SparkBuildInfo.scala b/common/utils/src/main/scala/org/apache/spark/SparkBuildInfo.scala
index ebc62460d2318..7618105bd72ea 100644
--- a/common/utils/src/main/scala/org/apache/spark/SparkBuildInfo.scala
+++ b/common/utils/src/main/scala/org/apache/spark/SparkBuildInfo.scala
@@ -29,8 +29,8 @@ private[spark] object SparkBuildInfo {
spark_build_date: String,
spark_doc_root: String) = {
- val resourceStream = Thread.currentThread().getContextClassLoader.
- getResourceAsStream("spark-version-info.properties")
+ val resourceStream = getClass.getClassLoader
+ .getResourceAsStream("spark-version-info.properties")
if (resourceStream == null) {
throw new SparkException("Could not find spark-version-info.properties")
}
diff --git a/common/utils/src/test/scala/org/apache/spark/util/MaybeNull.scala b/common/utils/src/test/scala/org/apache/spark/util/MaybeNull.scala
new file mode 100644
index 0000000000000..44bdffeacfde6
--- /dev/null
+++ b/common/utils/src/test/scala/org/apache/spark/util/MaybeNull.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+/* The MaybeNull class is a utility that introduces controlled nullability into a sequence
+ * of invocations. It is designed to return a ~null~ value at a specified interval while returning
+ * the provided value for all other invocations.
+ */
+case class MaybeNull(interval: Int) {
+ assert(interval > 1)
+ private var invocations = 0
+ def apply[T](value: T): T = {
+ val result = if (invocations % interval == 0) {
+ null.asInstanceOf[T]
+ } else {
+ value
+ }
+ invocations += 1
+ result
+ }
+}
diff --git a/common/variant/pom.xml b/common/variant/pom.xml
index 0fe977b8eadd2..bf200867a41d2 100644
--- a/common/variant/pom.xml
+++ b/common/variant/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/connector/avro/pom.xml b/connector/avro/pom.xml
index 9c3f2249e2f19..da4366424525a 100644
--- a/connector/avro/pom.xml
+++ b/connector/avro/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
index 0b3823ca16160..b0f510f3257ef 100644
--- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
+++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala
@@ -1423,6 +1423,61 @@ abstract class AvroSuite
}
}
+ test("to_avro with reordered fields and nullable target succeeds") {
+ // Test that when Catalyst and Avro field orders differ, null values
+ // are correctly validated against the mapped Avro field's nullability
+ val avroSchema = """{
+ "type": "record",
+ "name": "ReorderedRecord",
+ "fields": [
+ {"name": "a", "type": ["null", "string"]},
+ {"name": "b", "type": "string"}
+ ]
+ }"""
+
+ // Catalyst has fields in order [b, a], Avro has [a, b]
+ // Pass null for 'a' which is nullable in Avro - should succeed
+ val df = Seq(("B", null.asInstanceOf[String])).toDF("b", "a")
+ .select(struct($"b", $"a").as("s"))
+ val result = df.select(avro.functions.to_avro($"s", avroSchema).as("avro"))
+
+ // Should succeed without throwing AVRO_CANNOT_WRITE_NULL_FIELD
+ val collected = result.collect()
+ assert(collected.length == 1)
+
+ // Verify data correctness by round-tripping through from_avro
+ val roundTrip = result.select(avro.functions.from_avro($"avro", avroSchema).as("s"))
+ // final field order should be [a, b] as per avro schema
+ checkAnswer(roundTrip, Row(Row(null, "B")))
+ }
+
+ test("to_avro with reordered fields fails with correct field name") {
+ // Test that when Catalyst and Avro field orders differ and we try to write
+ // null to a non-nullable field, the error message references the correct field name
+ val avroSchema = """{
+ "type": "record",
+ "name": "ReorderedRecord",
+ "fields": [
+ {"name": "a", "type": ["null", "string"]},
+ {"name": "b", "type": "string"}
+ ]
+ }"""
+
+ // Catalyst has fields in order [b, a], Avro has [a, b]
+ // Pass null for 'b' which is non-nullable in Avro - should fail with correct field name 'b'
+ val df = Seq((null.asInstanceOf[String], "A")).toDF("b", "a")
+ .select(struct($"b", $"a").as("s"))
+
+ checkError(
+ exception = intercept[SparkRuntimeException] {
+ df.select(avro.functions.to_avro($"s", avroSchema)).collect()
+ },
+ condition = "AVRO_CANNOT_WRITE_NULL_FIELD",
+ parameters = Map(
+ "name" -> "`b`",
+ "dataType" -> "\"string\""))
+ }
+
test("support user provided avro schema for writing nullable fixed type") {
withTempPath { tempDir =>
val avroSchema =
diff --git a/connector/docker-integration-tests/pom.xml b/connector/docker-integration-tests/pom.xml
index 09fabdf64dba8..d89a9f50f4626 100644
--- a/connector/docker-integration-tests/pom.xml
+++ b/connector/docker-integration-tests/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/connector/docker-integration-tests/src/test/resources/postgres-krb-setup.sh b/connector/docker-integration-tests/src/test/resources/postgres-krb-setup.sh
index dd9fd8cb51adf..3c6aff0ee3227 100755
--- a/connector/docker-integration-tests/src/test/resources/postgres-krb-setup.sh
+++ b/connector/docker-integration-tests/src/test/resources/postgres-krb-setup.sh
@@ -16,6 +16,6 @@
# limitations under the License.
#
-sed -i 's/host all all all .*/host all all all gss/g' /var/lib/postgresql/data/pg_hba.conf
-echo "krb_server_keyfile='/docker-entrypoint-initdb.d/postgres.keytab'" >> /var/lib/postgresql/data/postgresql.conf
+sed -i 's/host all all all .*/host all all all gss/g' $(find /var/lib/postgresql -name pg_hba.conf)
+echo "krb_server_keyfile='/docker-entrypoint-initdb.d/postgres.keytab'" >> $(find /var/lib/postgresql -name postgresql.conf)
psql -U postgres -c "CREATE ROLE \"postgres/__IP_ADDRESS_REPLACE_ME__@EXAMPLE.COM\" LOGIN SUPERUSER"
diff --git a/connector/kafka-0-10-assembly/pom.xml b/connector/kafka-0-10-assembly/pom.xml
index 0305e8895ba3c..600a9c34769cf 100644
--- a/connector/kafka-0-10-assembly/pom.xml
+++ b/connector/kafka-0-10-assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/connector/kafka-0-10-sql/pom.xml b/connector/kafka-0-10-sql/pom.xml
index 4ab99f7929591..c5b85f4027bc2 100644
--- a/connector/kafka-0-10-sql/pom.xml
+++ b/connector/kafka-0-10-sql/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/InternalKafkaConsumerPool.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/InternalKafkaConsumerPool.scala
index edd5121cfbeee..06ccd7548a040 100644
--- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/InternalKafkaConsumerPool.scala
+++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/InternalKafkaConsumerPool.scala
@@ -129,6 +129,14 @@ private[consumer] class InternalKafkaConsumerPool(
def size(key: CacheKey): Int = numIdle(key) + numActive(key)
+ private[kafka010] def numActiveInGroupIdPrefix(groupIdPrefix: String): Int = {
+ import scala.jdk.CollectionConverters._
+
+ pool.getNumActivePerKey().asScala.filter { case (key, _) =>
+ key.startsWith(groupIdPrefix + "-")
+ }.values.map(_.toInt).sum
+ }
+
// TODO: revisit the relation between CacheKey and kafkaParams - for now it looks a bit weird
// as we force all consumers having same (groupId, topicPartition) to have same kafkaParams
// which might be viable in performance perspective (kafkaParams might be too huge to use
diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala
index af4e5bab2947d..126434625a8d5 100644
--- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala
+++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/consumer/KafkaDataConsumer.scala
@@ -848,4 +848,8 @@ private[kafka010] object KafkaDataConsumer extends Logging {
new KafkaDataConsumer(topicPartition, kafkaParams, consumerPool, fetchedDataPool)
}
+
+ private[kafka010] def getActiveSizeInConsumerPool(groupIdPrefix: String): Int = {
+ consumerPool.numActiveInGroupIdPrefix(groupIdPrefix)
+ }
}
diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeModeSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeModeSuite.scala
index 83aae64d84f7e..468d1da7f467f 100644
--- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeModeSuite.scala
+++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRealTimeModeSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.kafka010
+import java.util.UUID
+
import org.scalatest.matchers.should.Matchers
import org.scalatest.time.SpanSugar._
@@ -26,6 +28,7 @@ import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.ContinuousMemorySink
import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.kafka010.consumer.KafkaDataConsumer
import org.apache.spark.sql.streaming.{StreamingQuery, Trigger}
import org.apache.spark.sql.streaming.OutputMode.Update
import org.apache.spark.sql.streaming.util.GlobalSingletonManualClock
@@ -39,9 +42,7 @@ class KafkaRealTimeModeSuite
override protected val defaultTrigger = RealTimeTrigger.apply("3 seconds")
override protected def sparkConf: SparkConf = {
- // Should turn to use StreamingShuffleManager when it is ready.
super.sparkConf
- .set("spark.databricks.streaming.realTimeMode.enabled", "true")
.set(
SQLConf.STATE_STORE_PROVIDER_CLASS,
classOf[RocksDBStateStoreProvider].getName)
@@ -679,3 +680,97 @@ class KafkaRealTimeModeSuite
)
}
}
+
+class KafkaConsumerPoolRealTimeModeSuite
+ extends KafkaSourceTest
+ with Matchers {
+ override protected val defaultTrigger = RealTimeTrigger.apply("3 seconds")
+
+ override protected def sparkConf: SparkConf = {
+ super.sparkConf
+ .set(
+ SQLConf.STATE_STORE_PROVIDER_CLASS,
+ classOf[RocksDBStateStoreProvider].getName)
+ }
+
+ import testImplicits._
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ spark.conf.set(
+ SQLConf.STREAMING_REAL_TIME_MODE_MIN_BATCH_DURATION,
+ defaultTrigger.batchDurationMs
+ )
+ }
+
+ test("SPARK-54200: Kafka consumers in consumer pool should be properly reused") {
+ val topic = newTopic()
+ testUtils.createTopic(topic, partitions = 2)
+
+ testUtils.sendMessages(topic, Array("1", "2"), Some(0))
+ testUtils.sendMessages(topic, Array("3"), Some(1))
+
+ val groupIdPrefix = UUID.randomUUID().toString
+
+ val reader = spark
+ .readStream
+ .format("kafka")
+ .option("kafka.bootstrap.servers", testUtils.brokerAddress)
+ .option("subscribe", topic)
+ .option("startingOffsets", "earliest")
+ .option("groupIdPrefix", groupIdPrefix)
+ .load()
+ .selectExpr("CAST(value AS STRING)")
+ .as[String]
+ .map(_.toInt)
+ .map(_ + 1)
+
+ // At any point of time, Kafka consumer pool should only contain at most 2 active instances.
+ testStream(reader, Update, sink = new ContinuousMemorySink())(
+ StartStream(),
+ CheckAnswerWithTimeout(60000, 2, 3, 4),
+ WaitUntilCurrentBatchProcessed,
+ // After completion of batch 0
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ assertActiveSizeOnConsumerPool(groupIdPrefix, 2)
+
+ testUtils.sendMessages(topic, Array("4", "5"), Some(0))
+ testUtils.sendMessages(topic, Array("6"), Some(1))
+ }
+ },
+ CheckAnswerWithTimeout(5000, 2, 3, 4, 5, 6, 7),
+ WaitUntilCurrentBatchProcessed,
+ // After completion of batch 1
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ assertActiveSizeOnConsumerPool(groupIdPrefix, 2)
+
+ testUtils.sendMessages(topic, Array("7"), Some(1))
+ }
+ },
+ CheckAnswerWithTimeout(5000, 2, 3, 4, 5, 6, 7, 8),
+ WaitUntilCurrentBatchProcessed,
+ // After completion of batch 2
+ new ExternalAction() {
+ override def runAction(): Unit = {
+ assertActiveSizeOnConsumerPool(groupIdPrefix, 2)
+ }
+ },
+ StopStream
+ )
+ }
+
+ /**
+ * NOTE: This method leverages that we run test code, driver and executor in a same process in
+ * a normal unit test setup (say, local[] in spark master). With that setup, we
+ * can access singleton object directly.
+ */
+ private def assertActiveSizeOnConsumerPool(
+ groupIdPrefix: String,
+ maxAllowedActiveSize: Int): Unit = {
+ val activeSize = KafkaDataConsumer.getActiveSizeInConsumerPool(groupIdPrefix)
+ assert(activeSize <= maxAllowedActiveSize, s"Consumer pool size is expected to be less " +
+ s"than $maxAllowedActiveSize, but $activeSize.")
+ }
+}
diff --git a/connector/kafka-0-10-token-provider/pom.xml b/connector/kafka-0-10-token-provider/pom.xml
index 57840a8a8d6b3..cf092cb94a9e9 100644
--- a/connector/kafka-0-10-token-provider/pom.xml
+++ b/connector/kafka-0-10-token-provider/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/connector/kafka-0-10/pom.xml b/connector/kafka-0-10/pom.xml
index 62b0807e959d8..b47c502cadf1f 100644
--- a/connector/kafka-0-10/pom.xml
+++ b/connector/kafka-0-10/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/connector/kinesis-asl-assembly/pom.xml b/connector/kinesis-asl-assembly/pom.xml
index 4903adfe420fd..2f43af3eff918 100644
--- a/connector/kinesis-asl-assembly/pom.xml
+++ b/connector/kinesis-asl-assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/connector/kinesis-asl/pom.xml b/connector/kinesis-asl/pom.xml
index aa90e3c8a94f5..d7eaf46ea110c 100644
--- a/connector/kinesis-asl/pom.xml
+++ b/connector/kinesis-asl/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/connector/profiler/pom.xml b/connector/profiler/pom.xml
index 0eee0123ec071..dcf6efff18a9d 100644
--- a/connector/profiler/pom.xml
+++ b/connector/profiler/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/connector/protobuf/pom.xml b/connector/protobuf/pom.xml
index 845a7bb64cd35..d65da3de461c2 100644
--- a/connector/protobuf/pom.xml
+++ b/connector/protobuf/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/connector/spark-ganglia-lgpl/pom.xml b/connector/spark-ganglia-lgpl/pom.xml
index 4485b1ea414df..606ab7f01c964 100644
--- a/connector/spark-ganglia-lgpl/pom.xml
+++ b/connector/spark-ganglia-lgpl/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/core/pom.xml b/core/pom.xml
index ef408a763323a..55cd208f70772 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../pom.xml
diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
index ab9e470e0c2c0..8b41df6b269f9 100644
--- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
+++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
@@ -339,9 +339,12 @@ public void showMemoryUsage() {
MDC.of(LogKeys.MEMORY_SIZE, memoryNotAccountedFor),
MDC.of(LogKeys.TASK_ATTEMPT_ID, taskAttemptId));
logger.info(
- "{} bytes of memory are used for execution and {} bytes of memory are used for storage",
+ "{} bytes of memory are used for execution " +
+ "and {} bytes of memory are used for storage " +
+ "and {} bytes of unmanaged memory are used",
MDC.of(LogKeys.EXECUTION_MEMORY_SIZE, memoryManager.executionMemoryUsed()),
- MDC.of(LogKeys.STORAGE_MEMORY_SIZE, memoryManager.storageMemoryUsed()));
+ MDC.of(LogKeys.STORAGE_MEMORY_SIZE, memoryManager.storageMemoryUsed()),
+ MDC.of(LogKeys.MEMORY_SIZE, UnifiedMemoryManager$.MODULE$.getUnmanagedMemoryUsed()));
}
}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js
index 8961140a40190..853dfa708ef48 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js
@@ -27,6 +27,16 @@ var appLimit = -1;
function setAppLimit(val) {
appLimit = val;
}
+/* escape XSS */
+function escapeHtml(text) {
+ if (typeof text !== 'string') return text;
+ return text
+ .replace(/&/g, "&")
+ .replace(//g, ">")
+ .replace(/"/g, """)
+ .replace(/'/g, "'");
+}
/* eslint-enable no-unused-vars*/
function makeIdNumeric(id) {
@@ -151,7 +161,7 @@ $(document).ready(function() {
attempt["durationMillisec"] = attempt["duration"];
attempt["duration"] = formatDuration(attempt["duration"]);
attempt["id"] = id;
- attempt["name"] = name;
+ attempt["name"] = escapeHtml(name);
attempt["version"] = version;
attempt["attemptUrl"] = uiRoot + "/history/" + id + "/" +
(attempt.hasOwnProperty("attemptId") ? attempt["attemptId"] + "/" : "") + "jobs/";
@@ -206,7 +216,11 @@ $(document).ready(function() {
data: 'duration',
render: (id, type, row) => `${row.duration}`
},
- {name: 'user', data: 'sparkUser' },
+ {
+ name: 'user',
+ data: 'sparkUser',
+ render: (name) => escapeHtml(name)
+ },
{name: 'lastUpdated', data: 'lastUpdated' },
{
name: 'eventLog',
diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
index fd0baec8af6c7..230c2059e6e3f 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js
@@ -295,6 +295,20 @@ function renderDagVizForJob(svgContainer) {
.append("g")
}
+ // Now we need to shift the container for this stage so it doesn't overlap with
+ // existing ones, taking into account the position and width of the last stage's
+ // container. We do not need to do this for the first stage of this job.
+ if (i > 0) {
+ const lastStage = svgContainer.selectAll("g.cluster.stage")
+ .filter((d, i, nodes) => i === nodes.length - 1);
+ if (lastStage) {
+ const lastStageWidth = toFloat(lastStage.select("rect").attr("width"));
+ const lastStagePosition = getAbsolutePosition(lastStage);
+ const offset = lastStagePosition.x + lastStageWidth + VizConstants.stageSep;
+ container.attr("transform", `translate(${offset}, 0)`);
+ }
+ }
+
var g = graphlibDot.read(dot);
// Actually render the stage
renderDot(g, container, true);
@@ -312,20 +326,6 @@ function renderDagVizForJob(svgContainer) {
.attr("rx", "4")
.attr("ry", "4");
- // Now we need to shift the container for this stage so it doesn't overlap with
- // existing ones, taking into account the position and width of the last stage's
- // container. We do not need to do this for the first stage of this job.
- if (i > 0) {
- var existingStages = svgContainer.selectAll("g.cluster.stage").nodes();
- if (existingStages.length > 0) {
- var lastStage = d3.select(existingStages.pop());
- var lastStageWidth = toFloat(lastStage.select("rect").attr("width"));
- var lastStagePosition = getAbsolutePosition(lastStage);
- var offset = lastStagePosition.x + lastStageWidth + VizConstants.stageSep;
- container.attr("transform", "translate(" + offset + ", 0)");
- }
- }
-
// If there are any incoming edges into this graph, keep track of them to render
// them separately later. Note that we cannot draw them now because we need to
// put these edges in a separate container that is on top of all stage graphs.
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 334eb832c4c2b..41a1b51a43154 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -857,37 +857,34 @@ private[spark] class MapOutputTrackerMaster(
}
}
+ private def getShuffleStatusOrError(shuffleId: Int, caller: String): ShuffleStatus = {
+ shuffleStatuses.get(shuffleId) match {
+ case Some(shuffleStatus) => shuffleStatus
+ case None => throw new ShuffleStatusNotFoundException(shuffleId, caller)
+ }
+ }
+
def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Boolean = {
- shuffleStatuses(shuffleId).addMapOutput(mapIndex, status)
+ getShuffleStatusOrError(shuffleId, "registerMapOutput").addMapOutput(mapIndex, status)
}
/** Unregister map output information of the given shuffle, mapper and block manager */
def unregisterMapOutput(shuffleId: Int, mapIndex: Int, bmAddress: BlockManagerId): Unit = {
- shuffleStatuses.get(shuffleId) match {
- case Some(shuffleStatus) =>
- shuffleStatus.removeMapOutput(mapIndex, bmAddress)
- incrementEpoch()
- case None =>
- throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
- }
+ getShuffleStatusOrError(shuffleId, "unregisterMapOutput").removeMapOutput(mapIndex, bmAddress)
+ incrementEpoch()
}
/** Unregister all map and merge output information of the given shuffle. */
def unregisterAllMapAndMergeOutput(shuffleId: Int): Unit = {
- shuffleStatuses.get(shuffleId) match {
- case Some(shuffleStatus) =>
- shuffleStatus.removeOutputsByFilter(x => true)
- shuffleStatus.removeMergeResultsByFilter(x => true)
- shuffleStatus.removeShuffleMergerLocations()
- incrementEpoch()
- case None =>
- throw new SparkException(
- s"unregisterAllMapAndMergeOutput called for nonexistent shuffle ID $shuffleId.")
- }
+ val shuffleStatus = getShuffleStatusOrError(shuffleId, "unregisterAllMapAndMergeOutput")
+ shuffleStatus.removeOutputsByFilter(x => true)
+ shuffleStatus.removeMergeResultsByFilter(x => true)
+ shuffleStatus.removeShuffleMergerLocations()
+ incrementEpoch()
}
def registerMergeResult(shuffleId: Int, reduceId: Int, status: MergeStatus): Unit = {
- shuffleStatuses(shuffleId).addMergeResult(reduceId, status)
+ getShuffleStatusOrError(shuffleId, "registerMergeResult").addMergeResult(reduceId, status)
}
def registerMergeResults(shuffleId: Int, statuses: Seq[(Int, MergeStatus)]): Unit = {
@@ -899,7 +896,8 @@ private[spark] class MapOutputTrackerMaster(
def registerShufflePushMergerLocations(
shuffleId: Int,
shuffleMergers: Seq[BlockManagerId]): Unit = {
- shuffleStatuses(shuffleId).registerShuffleMergerLocations(shuffleMergers)
+ getShuffleStatusOrError(shuffleId, "registerShufflePushMergerLocations")
+ .registerShuffleMergerLocations(shuffleMergers)
}
/**
@@ -918,28 +916,19 @@ private[spark] class MapOutputTrackerMaster(
reduceId: Int,
bmAddress: BlockManagerId,
mapIndex: Option[Int] = None): Unit = {
- shuffleStatuses.get(shuffleId) match {
- case Some(shuffleStatus) =>
- val mergeStatus = shuffleStatus.mergeStatuses(reduceId)
- if (mergeStatus != null &&
- (mapIndex.isEmpty || mergeStatus.tracker.contains(mapIndex.get))) {
- shuffleStatus.removeMergeResult(reduceId, bmAddress)
- incrementEpoch()
- }
- case None =>
- throw new SparkException("unregisterMergeResult called for nonexistent shuffle ID")
+ val shuffleStatus = getShuffleStatusOrError(shuffleId, "unregisterMergeResult")
+ val mergeStatus = shuffleStatus.mergeStatuses(reduceId)
+ if (mergeStatus != null &&
+ (mapIndex.isEmpty || mergeStatus.tracker.contains(mapIndex.get))) {
+ shuffleStatus.removeMergeResult(reduceId, bmAddress)
+ incrementEpoch()
}
}
def unregisterAllMergeResult(shuffleId: Int): Unit = {
- shuffleStatuses.get(shuffleId) match {
- case Some(shuffleStatus) =>
- shuffleStatus.removeMergeResultsByFilter(x => true)
- incrementEpoch()
- case None =>
- throw new SparkException(
- s"unregisterAllMergeResult called for nonexistent shuffle ID $shuffleId.")
- }
+ getShuffleStatusOrError(shuffleId, "unregisterAllMergeResult")
+ .removeMergeResultsByFilter(x => true)
+ incrementEpoch()
}
/** Unregister shuffle data */
@@ -1022,7 +1011,7 @@ private[spark] class MapOutputTrackerMaster(
* Return statistics about all of the outputs for a given shuffle.
*/
def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = {
- shuffleStatuses(dep.shuffleId).withMapStatuses { statuses =>
+ getShuffleStatusOrError(dep.shuffleId, "getStatistics").withMapStatuses { statuses =>
val totalSizes = new Array[Long](dep.partitioner.numPartitions)
val parallelAggThreshold = conf.get(
SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD)
@@ -1285,6 +1274,9 @@ private[spark] class MapOutputTrackerMaster(
}
}
+case class ShuffleStatusNotFoundException(shuffleId: Int, methodName: String)
+ extends SparkException(s"$methodName called for nonexistent shuffle ID $shuffleId.")
+
/**
* Executor-side client for fetching map output info from the driver's MapOutputTrackerMaster.
* Note that this is not used in local-mode; instead, local-mode Executors access the
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 898bbad26b7e3..6f8be49e3959b 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -454,7 +454,14 @@ class SparkContext(config: SparkConf) extends Logging {
// Set Spark driver host and port system properties. This explicitly sets the configuration
// instead of relying on the default value of the config constant.
- _conf.set(DRIVER_HOST_ADDRESS, _conf.get(DRIVER_HOST_ADDRESS))
+ if (SparkMasterRegex.isK8s(master) &&
+ _conf.getBoolean("spark.kubernetes.executor.useDriverPodIP", false)) {
+ logInfo("Use DRIVER_BIND_ADDRESS instead of DRIVER_HOST_ADDRESS as driver address " +
+ "because spark.kubernetes.executor.useDriverPodIP is true in K8s mode.")
+ _conf.set(DRIVER_HOST_ADDRESS, _conf.get(DRIVER_BIND_ADDRESS))
+ } else {
+ _conf.set(DRIVER_HOST_ADDRESS, _conf.get(DRIVER_HOST_ADDRESS))
+ }
_conf.setIfMissing(DRIVER_PORT, 0)
_conf.set(EXECUTOR_ID, SparkContext.DRIVER_IDENTIFIER)
@@ -3452,6 +3459,15 @@ private object SparkMasterRegex {
val SPARK_REGEX = """spark://(.*)""".r
// Regular expression for connecting to kubernetes clusters
val KUBERNETES_REGEX = """k8s://(.*)""".r
+
+ def isK8s(master: String) : Boolean = isK8s(Option(master))
+
+ def isK8s(master: Option[String]) : Boolean = {
+ master match {
+ case Some(KUBERNETES_REGEX(_)) => true
+ case _ => false
+ }
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonErrorUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonErrorUtils.scala
new file mode 100644
index 0000000000000..73c2a29ea4095
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonErrorUtils.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.python
+
+import java.util
+
+import org.apache.spark.{BreakingChangeInfo, QueryContext, SparkThrowable}
+
+/**
+ * Utility object that provides convenient accessors for extracting
+ * detailed information from a [[SparkThrowable]] instance.
+ *
+ * This object is primarily used in PySpark
+ * to retrieve structured error metadata because Py4J does not work
+ * with default methods.
+ */
+private[spark] object PythonErrorUtils {
+ def getCondition(e: SparkThrowable): String = e.getCondition
+ def getErrorClass(e: SparkThrowable): String = e.getCondition
+ def getSqlState(e: SparkThrowable): String = e.getSqlState
+ def isInternalError(e: SparkThrowable): Boolean = e.isInternalError
+ def getBreakingChangeInfo(e: SparkThrowable): BreakingChangeInfo = e.getBreakingChangeInfo
+ def getMessageParameters(e: SparkThrowable): util.Map[String, String] = e.getMessageParameters
+ def getDefaultMessageTemplate(e: SparkThrowable): String = e.getDefaultMessageTemplate
+ def getQueryContext(e: SparkThrowable): Array[QueryContext] = e.getQueryContext
+}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 66e204fee44b9..7f1dc7fc86fcd 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -188,12 +188,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
protected val timelyFlushEnabled: Boolean = false
protected val timelyFlushTimeoutNanos: Long = 0
protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
+ private val useDaemon = conf.get(PYTHON_USE_DAEMON)
private val reuseWorker = conf.get(PYTHON_WORKER_REUSE)
protected val faultHandlerEnabled: Boolean = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED)
protected val idleTimeoutSeconds: Long = conf.get(PYTHON_WORKER_IDLE_TIMEOUT_SECONDS)
protected val killOnIdleTimeout: Boolean = conf.get(PYTHON_WORKER_KILL_ON_IDLE_TIMEOUT)
protected val tracebackDumpIntervalSeconds: Long =
conf.get(PYTHON_WORKER_TRACEBACK_DUMP_INTERVAL_SECONDS)
+ protected val killWorkerOnFlushFailure: Boolean =
+ conf.get(PYTHON_DAEMON_KILL_WORKER_ON_FLUSH_FAILURE)
protected val hideTraceback: Boolean = false
protected val simplifiedTraceback: Boolean = false
@@ -294,13 +297,16 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
if (tracebackDumpIntervalSeconds > 0L) {
envVars.put("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", tracebackDumpIntervalSeconds.toString)
}
+ if (useDaemon && killWorkerOnFlushFailure) {
+ envVars.put("PYTHON_DAEMON_KILL_WORKER_ON_FLUSH_FAILURE", "1")
+ }
// allow the user to set the batch size for the BatchedSerializer on UDFs
envVars.put("PYTHON_UDF_BATCH_SIZE", batchSizeForPythonUDF.toString)
envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default"))
val (worker: PythonWorker, handle: Option[ProcessHandle]) = env.createPythonWorker(
- pythonExec, workerModule, daemonModule, envVars.asScala.toMap)
+ pythonExec, workerModule, daemonModule, envVars.asScala.toMap, useDaemon)
// Whether is the worker released into idle pool or closed. When any codes try to release or
// close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make
// sure there is only one winner that is going to release or close the worker.
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerLogCapture.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerLogCapture.scala
index 71fc00546ef6e..a2a7c5ea14513 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerLogCapture.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerLogCapture.scala
@@ -26,7 +26,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.jdk.CollectionConverters._
import org.apache.spark.SparkEnv
-import org.apache.spark.internal.Logging
+import org.apache.spark.internal.{Logging, LogKeys}
import org.apache.spark.storage.{PythonWorkerLogBlockIdGenerator, PythonWorkerLogLine, RollingLogWriter}
/**
@@ -64,7 +64,9 @@ private[python] class PythonWorkerLogCapture(
writer.close()
} catch {
case e: Exception =>
- logWarning(s"Failed to close log writer for worker $workerId", e)
+ logWarning(
+ log"Failed to close log writer for worker ${MDC(LogKeys.PYTHON_WORKER_ID, workerId)}",
+ e)
}
}
}
@@ -73,12 +75,14 @@ private[python] class PythonWorkerLogCapture(
* Closes all active worker log writers.
*/
def closeAllWriters(): Unit = {
- workerLogWriters.values().asScala.foreach { case (writer, _) =>
+ workerLogWriters.asScala.foreach { case (workerId, (writer, _)) =>
try {
writer.close()
} catch {
case e: Exception =>
- logWarning("Failed to close log writer", e)
+ logWarning(
+ log"Failed to close log writer for worker ${MDC(LogKeys.PYTHON_WORKER_ID, workerId)}",
+ e)
}
}
workerLogWriters.clear()
@@ -128,7 +132,8 @@ private[python] class PythonWorkerLogCapture(
}
} catch {
case e: Exception =>
- logWarning(s"Failed to write log for worker $workerId", e)
+ logWarning(
+ log"Failed to write log for worker ${MDC(LogKeys.PYTHON_WORKER_ID, workerId)}", e)
}
}
prefix
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkPipelines.scala b/core/src/main/scala/org/apache/spark/deploy/SparkPipelines.scala
index 713937cadabfb..ee3bbd88646ff 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkPipelines.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkPipelines.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy
-import java.util
+import java.util.{Arrays => JArrays, List => JList}
import java.util.Locale
import scala.collection.mutable.ArrayBuffer
@@ -46,7 +46,7 @@ object SparkPipelines extends Logging {
pipelinesCliFile: String,
args: Array[String]): Seq[String] = {
val (sparkSubmitArgs, pipelinesArgs) = splitArgs(args)
- (sparkSubmitArgs ++ Seq(pipelinesCliFile) ++ pipelinesArgs)
+ sparkSubmitArgs ++ Seq(pipelinesCliFile) ++ pipelinesArgs
}
/**
@@ -59,7 +59,7 @@ object SparkPipelines extends Logging {
var remote = "local"
new SparkSubmitArgumentsParser() {
- parse(util.Arrays.asList(args: _*))
+ parse(JArrays.asList(args: _*))
override protected def handle(opt: String, value: String): Boolean = {
if (opt == "--remote") {
@@ -91,7 +91,7 @@ object SparkPipelines extends Logging {
true
}
- override protected def handleExtraArgs(extra: util.List[String]): Unit = {
+ override protected def handleExtraArgs(extra: JList[String]): Unit = {
pipelinesArgs.appendAll(extra.asScala)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index b5d026e39a906..6872c7c3bd717 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -170,7 +170,7 @@ private[spark] class SparkSubmit extends Logging {
// Here we are checking for client mode because when job is sumbitted in cluster
// deploy mode with k8s resource manager, the spark submit in the driver container
// is done in client mode.
- val isKubernetesClusterModeDriver = args.master.startsWith("k8s") &&
+ val isKubernetesClusterModeDriver = SparkMasterRegex.isK8s(args.master) &&
"client".equals(args.deployMode) &&
sparkConf.getBoolean("spark.kubernetes.submitInDriver", false)
if (isKubernetesClusterModeDriver) {
@@ -257,7 +257,7 @@ private[spark] class SparkSubmit extends Logging {
v match {
case "yarn" => YARN
case m if m.startsWith("spark") => STANDALONE
- case m if m.startsWith("k8s") => KUBERNETES
+ case m if SparkMasterRegex.isK8s(m) => KUBERNETES
case m if m.startsWith("local") => LOCAL
case _ =>
error("Master must either be yarn or start with spark, k8s, or local")
@@ -448,14 +448,16 @@ private[spark] class SparkSubmit extends Logging {
log" from ${MDC(LogKeys.SOURCE_PATH, source)}" +
log" to ${MDC(LogKeys.DESTINATION_PATH, dest)}")
Utils.deleteRecursively(dest)
- if (isArchive) {
+ val resourceUri = if (isArchive) {
Utils.unpack(source, dest)
+ localResources
} else {
Files.copy(source.toPath, dest.toPath)
+ dest.toURI
}
// Keep the URIs of local files with the given fragments.
Utils.getUriBuilder(
- localResources).fragment(resolvedUri.getFragment).build().toString
+ resourceUri).fragment(resolvedUri.getFragment).build().toString
} ++ avoidDownloads.map(_.toString)).mkString(",")
}
@@ -1041,7 +1043,7 @@ private[spark] class SparkSubmit extends Logging {
}
throw cause
} finally {
- if (args.master.startsWith("k8s") && !isShell(args.primaryResource) &&
+ if (SparkMasterRegex.isK8s(args.master) && !isShell(args.primaryResource) &&
!isSqlShell(args.mainClass) && !isThriftServer(args.mainClass) &&
!isConnectServer(args.mainClass)) {
try {
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileWriters.scala b/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileWriters.scala
index 4e3bee1015ff3..7c022c283db41 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileWriters.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/EventLogFileWriters.scala
@@ -131,7 +131,20 @@ abstract class EventLogFileWriter(
}
protected def closeWriter(): Unit = {
+ // 1. Flush first to check the errors
+ writer.foreach(_.flush())
+ if (writer.exists(_.checkError())) {
+ logError("Spark detects errors while flushing event logs.")
+ }
+ hadoopDataStream.foreach(_.hflush())
+
+ // 2. Try to close and check the errors
writer.foreach(_.close())
+ if (writer.exists(_.checkError())) {
+ logError("Spark detects errors while closing event logs.")
+ // 3. Ensuring the underlying stream is closed at least (best-effort).
+ hadoopDataStream.foreach(_.close())
+ }
}
protected def renameFile(src: Path, dest: Path, overwrite: Boolean): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 4863291b529bb..c723a8de8c442 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -502,7 +502,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
var count: Int = 0
try {
val newLastScanTime = clock.getTimeMillis()
- logInfo(log"Scanning ${MDC(HISTORY_DIR, logDir)} with " +
+ logDebug(log"Scanning ${MDC(HISTORY_DIR, logDir)} with " +
log"lastScanTime=${MDC(LAST_SCAN_TIME, lastScanTime)}")
// Mark entries that are processing as not stale. Such entries do not have a chance to be
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 98da33a429eca..93fb64f485f62 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -211,6 +211,8 @@ private[deploy] class Worker(
private var registerMasterFutures: Array[JFuture[_]] = null
private var registrationRetryTimer: Option[JScheduledFuture[_]] = None
+ private var heartbeatTask: Option[JScheduledFuture[_]] = None
+ private var workDirCleanupTask: Option[JScheduledFuture[_]] = None
// A thread pool for registering with masters. Because registering with a master is a blocking
// action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same
@@ -492,16 +494,25 @@ private[deploy] class Worker(
logInfo(log"Successfully registered with master ${MDC(MASTER_URL, preferredMasterAddress)}")
registered = true
changeMaster(masterRef, masterWebUiUrl, masterAddress)
- forwardMessageScheduler.scheduleAtFixedRate(
- () => Utils.tryLogNonFatalError { self.send(SendHeartbeat) },
- 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS)
- if (CLEANUP_ENABLED) {
+
+ // Only schedule heartbeat task if not already scheduled. The existing task will
+ // continue running through reconnections, and the SendHeartbeat handler already
+ // checks the 'connected' flag before sending heartbeats to master.
+ if (heartbeatTask.isEmpty) {
+ heartbeatTask = Some(forwardMessageScheduler.scheduleAtFixedRate(
+ () => Utils.tryLogNonFatalError {
+ self.send(SendHeartbeat)
+ },
+ 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS))
+ }
+ // Only schedule work directory cleanup task if not already scheduled
+ if (CLEANUP_ENABLED && workDirCleanupTask.isEmpty) {
logInfo(
log"Worker cleanup enabled; old application directories will be deleted in: " +
log"${MDC(PATH, workDir)}")
- forwardMessageScheduler.scheduleAtFixedRate(
+ workDirCleanupTask = Some(forwardMessageScheduler.scheduleAtFixedRate(
() => Utils.tryLogNonFatalError { self.send(WorkDirCleanup) },
- CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS)
+ CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS))
}
val execs = executors.values.map { e =>
@@ -852,6 +863,10 @@ private[deploy] class Worker(
cleanupThreadExecutor.shutdownNow()
metricsSystem.report()
cancelLastRegistrationRetry()
+ heartbeatTask.foreach(_.cancel(true))
+ heartbeatTask = None
+ workDirCleanupTask.foreach(_.cancel(true))
+ workDirCleanupTask = None
forwardMessageScheduler.shutdownNow()
registerMasterThreadPool.shutdownNow()
executors.values.foreach(_.kill())
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index a14ba21a0c186..c8843ac3427e4 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -20,7 +20,7 @@ package org.apache.spark.executor
import java.io.{File, NotSerializableException}
import java.lang.Thread.UncaughtExceptionHandler
import java.lang.management.ManagementFactory
-import java.net.{URI, URL}
+import java.net.{URI, URL, URLClassLoader}
import java.nio.ByteBuffer
import java.util.{Locale, Properties}
import java.util.concurrent._
@@ -40,6 +40,7 @@ import org.slf4j.{MDC => SLF4JMDC}
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.executor.Executor.TASK_THREAD_NAME_PREFIX
import org.apache.spark.internal.{Logging, LogKeys}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.internal.config._
@@ -57,14 +58,186 @@ import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util._
import org.apache.spark.util.ArrayImplicits._
+private[spark] object IsolatedSessionState {
+ // Authoritative store for all isolated sessions. Sessions are put here when created
+ // and removed when cleanup runs. The Guava cache just tracks which sessions are
+ // "active" for LRU eviction policy, but this map is the source of truth.
+ // This ensures there's only ONE IsolatedSessionState per UUID at any time.
+ val sessions = new ConcurrentHashMap[String, IsolatedSessionState]()
+}
+
+/**
+ * Represents an isolated session state on the executor side, containing session-specific
+ * classloaders, files, jars, and archives. This class manages the lifecycle of these resources
+ * and prevents race conditions between concurrent task execution and cache eviction.
+ *
+ * == Architecture ==
+ *
+ * Sessions are managed through two mechanisms:
+ * 1. A Guava LRU cache (`isolatedSessionCache`) for active session lookup with size limits
+ * 2. An authoritative map (`IsolatedSessionState.sessions`) tracking all sessions until cleanup
+ *
+ * The Guava cache handles LRU eviction, while the authoritative map ensures there's only one
+ * IsolatedSessionState instance per UUID at any time and tracks sessions that are evicted but
+ * still in use.
+ *
+ * == State Machine ==
+ *
+ * Each session has two state variables protected by a synchronized lock:
+ * - `refCount`: Number of tasks currently using this session
+ * - `evicted`: Whether the session has been evicted from the Guava cache
+ *
+ * Valid state transitions:
+ * {{{
+ * Common workflow (no contention):
+ * [Created] --> acquire() --> [Active: refCount > 0]
+ * |
+ * release() (all tasks done)
+ * |
+ * v
+ * [Idle: refCount = 0]
+ * |
+ * markEvicted() (cache eviction)
+ * |
+ * v
+ * [Cleanup]
+ *
+ * Contention case (eviction while tasks running):
+ * [Active: refCount > 0] --> markEvicted() --> [Deferred: evicted = true]
+ * |
+ * +---------------------------------------------+
+ * | |
+ * v v
+ * release() (last task) tryUnEvict()
+ * | |
+ * v v
+ * [Cleanup] [Active] (back in cache)
+ * }}}
+ *
+ * == Cleanup ==
+ *
+ * Cleanup happens when both conditions are met: `refCount == 0` AND `evicted == true`.
+ * This can occur either:
+ * - Immediately when `markEvicted()` is called and no tasks are using the session
+ * - Deferred when the last task calls `release()` after the session was evicted
+ *
+ * Cleanup closes the classloader, deletes session files, and removes the session
+ * from the authoritative map.
+ *
+ * == Concurrency Guarantees ==
+ *
+ * The key insight is that as long as a session is still in use (refCount > 0), it
+ * remains in the authoritative map and we can get its instance. When a new task needs
+ * a session that was evicted from the LRU cache but is still in use:
+ *
+ * - If cleanup has NOT started (refCount > 0): we can cancel the pending cleanup
+ * via `tryUnEvict()`, put the instance back into the LRU cache, and safely reuse it.
+ *
+ * - If cleanup HAS started (refCount became 0): cleanup runs synchronously under the
+ * lock, so it must complete before any new task can proceed. Once cleanup finishes,
+ * the session is removed from the authoritative map, and a fresh instance is created.
+ *
+ * This design ensures there is never a race where a task uses a session that is being
+ * or has been cleaned up. The `acquire()` and `tryUnEvict()` methods are intentionally
+ * separate: `tryUnEvict()` is only called from the cache loader to guarantee the
+ * session is put back into the LRU cache, maintaining the invariant that a non-evicted
+ * session is always in the cache.
+ */
private[spark] class IsolatedSessionState(
- val sessionUUID: String,
- var urlClassLoader: MutableURLClassLoader,
- var replClassLoader: ClassLoader,
- val currentFiles: HashMap[String, Long],
- val currentJars: HashMap[String, Long],
- val currentArchives: HashMap[String, Long],
- val replClassDirUri: Option[String])
+ val sessionUUID: String,
+ var urlClassLoader: MutableURLClassLoader,
+ var replClassLoader: ClassLoader,
+ val currentFiles: HashMap[String, Long],
+ val currentJars: HashMap[String, Long],
+ val currentArchives: HashMap[String, Long],
+ val replClassDirUri: Option[String]) extends Logging {
+
+ // Reference count for the number of running tasks using this session.
+ // Access is synchronized via `lock`.
+ private var refCount: Int = 0
+
+ // Whether this session has been evicted from the cache.
+ // Access is synchronized via `lock`.
+ private var evicted: Boolean = false
+
+ // Lock to synchronize all state changes.
+ private val lock = new Object
+
+ /**
+ * Increment the reference count, indicating a task is using this session.
+ * @return true if the session was successfully acquired, false if it was already evicted
+ */
+ def acquire(): Boolean = lock.synchronized {
+ if (evicted) {
+ false
+ } else {
+ refCount += 1
+ true
+ }
+ }
+
+ /**
+ * Try to un-evict this session so it can be reused.
+ * This is called from the cache loader to reuse a deferred session.
+ * The caller should call acquire() separately after the session is in cache.
+ * @return true if successfully un-evicted, false if already cleaned up or refCount is 0
+ */
+ def tryUnEvict(): Boolean = lock.synchronized {
+ if (evicted && refCount > 0) {
+ evicted = false
+ logInfo(log"Session ${MDC(SESSION_ID, sessionUUID)} un-evicted, " +
+ log"still in use by ${MDC(LogKeys.COUNT, refCount)} task(s)")
+ true
+ } else {
+ false
+ }
+ }
+
+ /** Decrement the reference count. If evicted and no more tasks, clean up. */
+ def release(): Unit = lock.synchronized {
+ refCount -= 1
+ if (refCount == 0 && evicted) {
+ cleanup()
+ }
+ }
+
+ /** Mark this session as evicted. Cleans up immediately if refCount is 0. */
+ def markEvicted(): Unit = lock.synchronized {
+ evicted = true
+ if (refCount == 0) {
+ cleanup()
+ } else {
+ logInfo(log"Session ${MDC(SESSION_ID, sessionUUID)} evicted but still in use by " +
+ log"${MDC(LogKeys.COUNT, refCount)} task(s), deferring cleanup")
+ }
+ }
+
+ private def cleanup(): Unit = {
+ // Close the urlClassLoader to release resources.
+ try {
+ urlClassLoader match {
+ case cl: URLClassLoader =>
+ cl.close()
+ logInfo(log"Closed urlClassLoader for session ${MDC(SESSION_ID, sessionUUID)}")
+ case _ =>
+ }
+ } catch {
+ case NonFatal(e) =>
+ logWarning(log"Failed to close urlClassLoader for session " +
+ log"${MDC(SESSION_ID, sessionUUID)}", e)
+ }
+
+ // Delete session files.
+ val sessionBasedRoot = new File(SparkFiles.getRootDirectory(), sessionUUID)
+ if (sessionBasedRoot.isDirectory && sessionBasedRoot.exists()) {
+ Utils.deleteRecursively(sessionBasedRoot)
+ }
+
+ // Remove from authoritative sessions map after cleanup
+ IsolatedSessionState.sessions.remove(sessionUUID)
+ logInfo(log"Session cleaned up: ${MDC(SESSION_ID, sessionUUID)}")
+ }
+}
/**
* Spark executor, backed by a threadpool to run tasks.
@@ -132,7 +305,7 @@ private[spark] class Executor(
private[executor] val threadPool = {
val threadFactory = new ThreadFactoryBuilder()
.setDaemon(true)
- .setNameFormat("Executor task launch worker-%d")
+ .setNameFormat(s"$TASK_THREAD_NAME_PREFIX-%d")
.setThreadFactory((r: Runnable) => new UninterruptibleThread(r, "unused"))
.build()
Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor]
@@ -190,13 +363,17 @@ private[spark] class Executor(
isDefaultState(jobArtifactState.uuid))
val replClassLoader = addReplClassLoaderIfNeeded(
urlClassLoader, jobArtifactState.replClassDirUri, jobArtifactState.uuid)
- new IsolatedSessionState(
+ val state = new IsolatedSessionState(
jobArtifactState.uuid, urlClassLoader, replClassLoader,
currentFiles,
currentJars,
currentArchives,
jobArtifactState.replClassDirUri
)
+ // Store in the authoritative sessions map immediately.
+ // This ensures there's only one session per UUID at any time.
+ IsolatedSessionState.sessions.put(jobArtifactState.uuid, state)
+ state
}
private def isStubbingEnabledForState(name: String) = {
@@ -207,11 +384,11 @@ private[spark] class Executor(
private def isDefaultState(name: String) = name == "default"
// Classloader isolation
- // The default isolation group
+ // The default isolation group. Not in the cache, never evicted.
val defaultSessionState: IsolatedSessionState = newSessionState(JobArtifactState("default", None))
val isolatedSessionCache: Cache[String, IsolatedSessionState] = CacheBuilder.newBuilder()
- .maximumSize(100)
+ .maximumSize(conf.get(EXECUTOR_ISOLATED_SESSION_CACHE_SIZE))
.expireAfterAccess(30, TimeUnit.MINUTES)
.removalListener(new RemovalListener[String, IsolatedSessionState]() {
override def onRemoval(
@@ -219,11 +396,9 @@ private[spark] class Executor(
val state = notification.getValue
// Cache is always used for isolated sessions.
assert(!isDefaultState(state.sessionUUID))
- val sessionBasedRoot = new File(SparkFiles.getRootDirectory(), state.sessionUUID)
- if (sessionBasedRoot.isDirectory && sessionBasedRoot.exists()) {
- Utils.deleteRecursively(sessionBasedRoot)
- }
- logInfo(log"Session evicted: ${MDC(SESSION_ID, state.sessionUUID)}")
+ // Mark evicted. The session stays in the authoritative sessions map until cleanup.
+ // If refCount > 0, cleanup is deferred until all tasks release.
+ state.markEvicted()
}
})
.build[String, IsolatedSessionState]
@@ -380,7 +555,24 @@ private[spark] class Executor(
tr.kill(killMark._1, killMark._2)
killMarks.remove(taskId)
}
- threadPool.execute(tr)
+ try {
+ threadPool.execute(tr)
+ } catch {
+ case t: Throwable =>
+ try {
+ logError(log"Executor launch task ${MDC(TASK_NAME, taskDescription.name)} failed," +
+ log" reason: ${MDC(REASON, t.getMessage)}")
+ context.statusUpdate(
+ taskDescription.taskId,
+ TaskState.FAILED,
+ env.closureSerializer.newInstance().serialize(new ExceptionFailure(t, Seq.empty)))
+ } catch {
+ case t: Throwable =>
+ logError(log"Executor update launching task ${MDC(TASK_NAME, taskDescription.name)} " +
+ log"failed status failed, reason: ${MDC(REASON, t.getMessage)}")
+ System.exit(-1)
+ }
+ }
if (decommissioned) {
log.error(s"Launching a task while in decommissioned state.")
}
@@ -478,7 +670,7 @@ private[spark] class Executor(
val taskId = taskDescription.taskId
val taskName = taskDescription.name
- val threadName = s"Executor task launch worker for $taskName"
+ val threadName = s"$TASK_THREAD_NAME_PREFIX for $taskName"
val mdcProperties = taskDescription.properties.asScala
.filter(_._1.startsWith("mdc.")).toSeq
@@ -559,13 +751,46 @@ private[spark] class Executor(
(accums, accUpdates)
}
+ /**
+ * Obtains an IsolatedSessionState for the given job artifact state.
+ * Gets or creates a session from the cache, then acquires it. We need to retry the cache
+ * lookup if the session was evicted between get() and acquire(). This can happen when the
+ * cache is full and another task triggers eviction.
+ */
+ private def obtainSession(jobArtifactState: JobArtifactState): IsolatedSessionState = {
+ var session: IsolatedSessionState = null
+ var acquired = false
+ while (!acquired) {
+ // Get or create session. The loader uses sessions map as the authoritative store.
+ // This ensures there's only one IsolatedSessionState per UUID at any time.
+ session = isolatedSessionCache.get(jobArtifactState.uuid, () => {
+ // Check the authoritative sessions map first. tryUnEvict() will block if
+ // cleanup is in progress, so when it returns false, the session is already
+ // removed from the map and it's safe to create a new one.
+ val existingSession = IsolatedSessionState.sessions.get(jobArtifactState.uuid)
+ if (existingSession != null && existingSession.tryUnEvict()) {
+ existingSession
+ } else {
+ newSessionState(jobArtifactState)
+ }
+ })
+ // acquire() can return false if session was evicted between get() and now.
+ // In that case, retry - the session is already removed from cache.
+ acquired = session.acquire()
+ }
+ session
+ }
+
override def run(): Unit = {
// Classloader isolation
val isolatedSession = taskDescription.artifacts.state match {
case Some(jobArtifactState) =>
- isolatedSessionCache.get(jobArtifactState.uuid, () => newSessionState(jobArtifactState))
- case _ => defaultSessionState
+ obtainSession(jobArtifactState)
+ case _ =>
+ // The default session is never in the cache and never evicted,
+ // so no need to acquire/release.
+ defaultSessionState
}
setMDCForTask(taskName, mdcProperties)
@@ -728,7 +953,6 @@ private[spark] class Executor(
.inc(task.metrics.outputMetrics.bytesWritten)
executorSource.METRIC_OUTPUT_RECORDS_WRITTEN
.inc(task.metrics.outputMetrics.recordsWritten)
- executorSource.METRIC_RESULT_SIZE.inc(task.metrics.resultSize)
executorSource.METRIC_DISK_BYTES_SPILLED.inc(task.metrics.diskBytesSpilled)
executorSource.METRIC_MEMORY_BYTES_SPILLED.inc(task.metrics.memoryBytesSpilled)
incrementShuffleMetrics(executorSource, task.metrics)
@@ -742,6 +966,7 @@ private[spark] class Executor(
val serializedDirectResult = SerializerHelper.serializeToChunkedBuffer(ser, directResult,
valueByteBuffer.size + accumUpdates.size * 32 + metricPeaks.length * 8)
val resultSize = serializedDirectResult.size
+ executorSource.METRIC_RESULT_SIZE.inc(resultSize)
// directSend = sending directly back to the driver
val serializedResult: ByteBuffer = {
@@ -873,6 +1098,11 @@ private[spark] class Executor(
// are known, and metricsPoller.onTaskStart was called.
metricsPoller.onTaskCompletion(taskId, task.stageId, task.stageAttemptId)
}
+ // Release the session reference. If evicted and this was the last task, cleanup happens.
+ // Skip for defaultSessionState since it's never evicted.
+ if (isolatedSession ne defaultSessionState) {
+ isolatedSession.release()
+ }
}
}
@@ -1316,6 +1546,8 @@ private[spark] class Executor(
}
private[spark] object Executor extends Logging {
+ val TASK_THREAD_NAME_PREFIX = "Executor task launch worker"
+
// This is reserved for internal use by components that need to read task properties before a
// task is fully deserialized. When possible, the TaskContext.getLocalProperty call should be
// used instead.
diff --git a/core/src/main/scala/org/apache/spark/internal/config/Python.scala b/core/src/main/scala/org/apache/spark/internal/config/Python.scala
index de95e2fa1f7a2..dc16d1ff255db 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/Python.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/Python.scala
@@ -138,4 +138,16 @@ private[spark] object Python {
.intConf
.checkValue(_ > 0, "If set, the idle worker max size must be > 0.")
.createOptional
+
+ val PYTHON_DAEMON_KILL_WORKER_ON_FLUSH_FAILURE =
+ ConfigBuilder("spark.python.daemon.killWorkerOnFlushFailure")
+ .doc("When enabled, exceptions raised during output flush operations in the Python " +
+ "worker managed under Python daemon are not caught, causing the worker to terminate " +
+ "with the exception. This allows Spark to detect the failure and launch a new worker " +
+ "and retry the task. " +
+ "When disabled, flush exceptions are caught and logged but the worker continues, " +
+ "which could cause the worker to get stuck due to protocol mismatch.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(true)
}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 9876848f654a4..8b32c18aa3b61 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -358,6 +358,15 @@ package object config {
.intConf
.createWithDefault(60)
+ private[spark] val EXECUTOR_ISOLATED_SESSION_CACHE_SIZE =
+ ConfigBuilder("spark.executor.isolatedSessionCache.size")
+ .doc("Maximum number of isolated sessions to cache in the executor. Each cached session " +
+ "maintains its own classloader for artifact isolation.")
+ .version("4.1.0")
+ .intConf
+ .checkValue(_ > 0, "The cache size must be positive.")
+ .createWithDefault(100)
+
private[spark] val EXECUTOR_PROCESS_TREE_METRICS_ENABLED =
ConfigBuilder("spark.executor.processTreeMetrics.enabled")
.doc("Whether to collect process tree metrics (from the /proc filesystem) when collecting " +
diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
index db51f14415e1a..6b278c47f32f1 100644
--- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
@@ -273,6 +273,14 @@ object UnifiedMemoryManager extends Logging {
// Atomic flag to ensure polling is only started once per JVM
private val pollingStarted = new AtomicBoolean(false)
+ /**
+ * Returns the total unmanaged memory in bytes, including both
+ * on-heap unmanaged memory and off-heap unmanaged memory.
+ */
+ private[spark] def getUnmanagedMemoryUsed: Long = {
+ UnifiedMemoryManager.unmanagedOnHeapUsed.get() + UnifiedMemoryManager.unmanagedOffHeapUsed.get()
+ }
+
/**
* Register an unmanaged memory consumer to track its memory usage.
*
diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala b/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala
index 1b4b4f61016a4..971b14265979e 100644
--- a/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala
+++ b/core/src/main/scala/org/apache/spark/resource/ResourceProfile.scala
@@ -25,6 +25,7 @@ import scala.collection.mutable
import scala.jdk.CollectionConverters._
import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkException}
+import org.apache.spark.SparkMasterRegex._
import org.apache.spark.annotation.{Evolving, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys._
@@ -178,8 +179,8 @@ class ResourceProfile(
// only applies to yarn/k8s
private def shouldCheckExecutorCores(sparkConf: SparkConf): Boolean = {
val master = sparkConf.getOption("spark.master")
- sparkConf.contains(EXECUTOR_CORES) ||
- (master.isDefined && (master.get.equalsIgnoreCase("yarn") || master.get.startsWith("k8s")))
+ sparkConf.contains(EXECUTOR_CORES) || isK8s(master) ||
+ (master.isDefined && master.get.equalsIgnoreCase("yarn"))
}
/**
diff --git a/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala b/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala
index 10121f6ef2667..a3d76a92ddd8b 100644
--- a/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala
+++ b/core/src/main/scala/org/apache/spark/resource/ResourceProfileManager.scala
@@ -21,7 +21,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock
import scala.collection.mutable.HashMap
-import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.{SparkConf, SparkException, SparkMasterRegex}
import org.apache.spark.annotation.Evolving
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.internal.LogKeys
@@ -49,7 +49,7 @@ private[spark] class ResourceProfileManager(sparkConf: SparkConf,
private val dynamicEnabled = Utils.isDynamicAllocationEnabled(sparkConf)
private val master = sparkConf.getOption("spark.master")
private val isYarn = master.isDefined && master.get.equals("yarn")
- private val isK8s = master.isDefined && master.get.startsWith("k8s://")
+ private val isK8s = SparkMasterRegex.isK8s(master)
private val isStandaloneOrLocalCluster = master.isDefined && (
master.get.startsWith("spark://") || master.get.startsWith("local-cluster")
)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 7d77628c3f088..69e766ebcef25 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1560,42 +1560,27 @@ private[spark] class DAGScheduler(
// `findMissingPartitions()` returns all partitions every time.
stage match {
case sms: ShuffleMapStage if !sms.isAvailable =>
- val needFullStageRetry = if (sms.shuffleDep.checksumMismatchFullRetryEnabled) {
- // When the parents of this stage are indeterminate (e.g., some parents are not
- // checkpointed and checksum mismatches are detected), the output data of the parents
- // may have changed due to task retries. For correctness reason, we need to
- // retry all tasks of the current stage. The legacy way of using current stage's
- // deterministic level to trigger full stage retry is not accurate.
- stage.isParentIndeterminate
- } else {
- if (stage.isIndeterminate) {
- // already executed at least once
- if (sms.getNextAttemptId > 0) {
- // While we previously validated possible rollbacks during the handling of a FetchFailure,
- // where we were fetching from an indeterminate source map stages, this later check
- // covers additional cases like recalculating an indeterminate stage after an executor
- // loss. Moreover, because this check occurs later in the process, if a result stage task
- // has successfully completed, we can detect this and abort the job, as rolling back a
- // result stage is not possible.
- val stagesToRollback = collectSucceedingStages(sms)
- abortStageWithInvalidRollBack(stagesToRollback)
- // stages which cannot be rolled back were aborted which leads to removing the
- // the dependant job(s) from the active jobs set
- val numActiveJobsWithStageAfterRollback =
- activeJobs.count(job => stagesToRollback.contains(job.finalStage))
- if (numActiveJobsWithStageAfterRollback == 0) {
- logInfo(log"All jobs depending on the indeterminate stage " +
- log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is not needed anymore.")
- return
- }
+ if (!sms.shuffleDep.checksumMismatchFullRetryEnabled && stage.isIndeterminate) {
+ // already executed at least once
+ if (sms.getNextAttemptId > 0) {
+ // While we previously validated possible rollbacks during the handling of a FetchFailure,
+ // where we were fetching from an indeterminate source map stages, this later check
+ // covers additional cases like recalculating an indeterminate stage after an executor
+ // loss. Moreover, because this check occurs later in the process, if a result stage task
+ // has successfully completed, we can detect this and abort the job, as rolling back a
+ // result stage is not possible.
+ val stagesToRollback = collectSucceedingStages(sms)
+ filterAndAbortUnrollbackableStages(stagesToRollback)
+ // stages which cannot be rolled back were aborted which leads to removing the
+ // the dependant job(s) from the active jobs set
+ val numActiveJobsWithStageAfterRollback =
+ activeJobs.count(job => stagesToRollback.contains(job.finalStage))
+ if (numActiveJobsWithStageAfterRollback == 0) {
+ logInfo(log"All jobs depending on the indeterminate stage " +
+ log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is not needed anymore.")
+ return
}
- true
- } else {
- false
}
- }
-
- if (needFullStageRetry) {
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
sms.shuffleDep.newShuffleMergeState()
}
@@ -1913,16 +1898,127 @@ private[spark] class DAGScheduler(
/**
* If a map stage is non-deterministic, the map tasks of the stage may return different result
- * when re-try. To make sure data correctness, we need to re-try all the tasks of its succeeding
- * stages, as the input data may be changed after the map tasks are re-tried. For stages where
- * rollback and retry all tasks are not possible, we will need to abort the stages.
+ * when re-try. To make sure data correctness, we need to clean up shuffles to make sure succeeding
+ * stages will be resubmitted and re-try all the tasks, as the input data may be changed after
+ * the map tasks are re-tried. For stages where rollback and retry all tasks are not possible,
+ * we will need to abort the stages.
+ */
+ private[scheduler] def rollbackSucceedingStages(mapStage: ShuffleMapStage): Unit = {
+ val stagesToRollback = collectSucceedingStages(mapStage).filterNot(_ == mapStage)
+ val stagesCanRollback = filterAndAbortUnrollbackableStages(stagesToRollback)
+ // stages which cannot be rolled back were aborted which leads to removing the
+ // the dependant job(s) from the active jobs set, there could be no active jobs
+ // left depending on the indeterminate stage and hence no need to roll back any stages.
+ val numActiveJobsWithStageAfterRollback =
+ activeJobs.count(job => stagesToRollback.contains(job.finalStage))
+ if (numActiveJobsWithStageAfterRollback == 0) {
+ logInfo(log"All jobs depending on the indeterminate stage " +
+ log"(${MDC(STAGE_ID, mapStage.id)}) were aborted.")
+ } else {
+ // Mark rollback attempt to identify elder attempts which could consume inconsistent data,
+ // the results from these attempts should be ignored.
+ // Rollback the running stages first to avoid triggering more fetch failures.
+ stagesToRollback.toSeq.sortBy(!runningStages.contains(_)).foreach {
+ case sms: ShuffleMapStage =>
+ rollbackShuffleMapStage(sms, "rolling back due to indeterminate " +
+ s"output of shuffle map stage $mapStage")
+ sms.markAsRollingBack()
+
+ case rs: ResultStage =>
+ rs.markAsRollingBack()
+ }
+
+ logInfo(log"The shuffle map stage ${MDC(STAGE, mapStage)} with indeterminate output " +
+ log"was retried, we will roll back and rerun its succeeding " +
+ log"stages: ${MDC(STAGES, stagesCanRollback)}")
+ }
+ }
+
+ /**
+ * Roll back the given shuffle map stage:
+ * 1. If the stage is running, cancel the stage and kill all running tasks. Clean up the shuffle
+ * output resubmit it if it's not exceeded max retries.
+ * 2. If the stage is not running but having output generated, clean up the shuffle output to
+ * ensure the stage will be re-executed with fully retry.
+ *
+ * @param sms the shuffle map stage to roll back
+ * @param reason the reason for rolling back
+ */
+ private def rollbackShuffleMapStage(sms: ShuffleMapStage, reason: String): Unit = {
+ logInfo(log"Rolling back ${MDC(STAGE, sms)} due to indeterminate rollback")
+ val clearShuffle = if (runningStages.contains(sms)) {
+ logInfo(log"Stage ${MDC(STAGE, sms)} is running, marking it as failed and " +
+ log"resubmit if allowed")
+ cancelStageAndTryResubmit(sms, reason)
+ } else {
+ true
+ }
+
+ // Clean up shuffle outputs in case the stage is not aborted to ensure the stage
+ // will be re-executed.
+ if (clearShuffle) {
+ logInfo(log"Cleaning up shuffle for stage ${MDC(STAGE, sms)} to ensure re-execution")
+ mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
+ sms.shuffleDep.newShuffleMergeState()
+ }
+ }
+
+ /**
+ * Cancel the give running shuffle map stage, killing all running tasks, resubmit if it doesn't
+ * exceed max retries.
+ *
+ * @param stage the stage to cancel and resubmit
+ * @param reason the reason for the operation
+ * @return true if the stage is successfully cancelled and resubmitted, otherwise false
*/
- private[scheduler] def abortUnrollbackableStages(mapStage: ShuffleMapStage): Unit = {
- val stagesToRollback = collectSucceedingStages(mapStage)
- val rollingBackStages = abortStageWithInvalidRollBack(stagesToRollback)
- logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with indeterminate output " +
- log"was failed, we will roll back and rerun below stages which include itself and all its " +
- log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}")
+ private def cancelStageAndTryResubmit(stage: ShuffleMapStage, reason: String): Boolean = {
+ assert(runningStages.contains(stage), "stage must be running to be cancelled and resubmitted")
+ try {
+ // killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask.
+ val job = jobIdToActiveJob.get(stage.firstJobId)
+ val shouldInterrupt = job.exists(j => shouldInterruptTaskThread(j))
+ taskScheduler.killAllTaskAttempts(stage.id, shouldInterrupt, reason)
+ } catch {
+ case e: UnsupportedOperationException =>
+ logWarning(log"Could not kill all tasks for stage ${MDC(STAGE_ID, stage.id)}", e)
+ abortStage(stage, "Rollback failed due to: Not able to kill running tasks for stage " +
+ s"$stage (${stage.name})", Some(e))
+ return false
+ }
+
+ stage.failedAttemptIds.add(stage.latestInfo.attemptNumber())
+ val shouldAbortStage = stage.failedAttemptIds.size >= maxConsecutiveStageAttempts ||
+ disallowStageRetryForTest
+ markStageAsFinished(stage, Some(reason), willRetry = !shouldAbortStage)
+
+ if (shouldAbortStage) {
+ val abortMessage = if (disallowStageRetryForTest) {
+ "Stage will not retry stage due to testing config. Most recent failure " +
+ s"reason: $reason"
+ } else {
+ s"$stage (${stage.name}) has failed the maximum allowable number of " +
+ s"times: $maxConsecutiveStageAttempts. Most recent failure reason: $reason"
+ }
+ abortStage(stage, s"rollback failed due to: $abortMessage", None)
+ } else {
+ // In case multiple task failures triggered for a single stage attempt, ensure we only
+ // resubmit the failed stage once.
+ val noResubmitEnqueued = !failedStages.contains(stage)
+ failedStages += stage
+ if (noResubmitEnqueued) {
+ logInfo(log"Resubmitting ${MDC(FAILED_STAGE, stage)} " +
+ log"(${MDC(FAILED_STAGE_NAME, stage.name)}) due to rollback.")
+ messageScheduler.schedule(
+ new Runnable {
+ override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
+ },
+ DAGScheduler.RESUBMIT_TIMEOUT,
+ TimeUnit.MILLISECONDS
+ )
+ }
+ }
+
+ !shouldAbortStage
}
/**
@@ -1990,7 +2086,21 @@ private[spark] class DAGScheduler(
// tasks complete, they still count and we can mark the corresponding partitions as
// finished if the stage is determinate. Here we notify the task scheduler to skip running
// tasks for the same partition to save resource.
- if (!stage.isIndeterminate && task.stageAttemptId < stage.latestInfo.attemptNumber()) {
+ def stageWithChecksumMismatchFullRetryEnabled(stage: Stage): Boolean = {
+ stage match {
+ case s: ShuffleMapStage => s.shuffleDep.checksumMismatchFullRetryEnabled
+ case _ => stage.parents.exists(stageWithChecksumMismatchFullRetryEnabled)
+ }
+ }
+
+ // Ignore task completion for old attempt of indeterminate stage
+ val ignoreOldTaskAttempts = if (stageWithChecksumMismatchFullRetryEnabled(stage)) {
+ stage.maxAttemptIdToIgnore.exists(_ >= task.stageAttemptId)
+ } else {
+ stage.isIndeterminate && task.stageAttemptId < stage.latestInfo.attemptNumber()
+ }
+
+ if (!ignoreOldTaskAttempts && task.stageAttemptId < stage.latestInfo.attemptNumber()) {
taskScheduler.notifyPartitionCompletion(stageId, task.partitionId)
}
@@ -2002,6 +2112,13 @@ private[spark] class DAGScheduler(
resultStage.activeJob match {
case Some(job) =>
if (!job.finished(rt.outputId)) {
+ if (ignoreOldTaskAttempts) {
+ val reason = "Task with indeterminate results from old attempt succeeded, " +
+ s"aborting the stage $resultStage to ensure data correctness."
+ abortStage(resultStage, reason, None)
+ return
+ }
+
job.finished(rt.outputId) = true
job.numFinished += 1
// If the whole job has finished, remove it
@@ -2045,10 +2162,7 @@ private[spark] class DAGScheduler(
case smt: ShuffleMapTask =>
val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
- // Ignore task completion for old attempt of indeterminate stage
- val ignoreIndeterminate = stage.isIndeterminate &&
- task.stageAttemptId < stage.latestInfo.attemptNumber()
- if (!ignoreIndeterminate) {
+ if (!ignoreOldTaskAttempts) {
shuffleStage.pendingPartitions -= task.partitionId
val status = event.result.asInstanceOf[MapStatus]
val execId = status.location.executorId
@@ -2077,7 +2191,7 @@ private[spark] class DAGScheduler(
shuffleStage.maxChecksumMismatchedId = smt.stageAttemptId
if (shuffleStage.shuffleDep.checksumMismatchFullRetryEnabled
&& shuffleStage.isStageIndeterminate) {
- abortUnrollbackableStages(shuffleStage)
+ rollbackSucceedingStages(shuffleStage)
}
}
}
@@ -2206,7 +2320,11 @@ private[spark] class DAGScheduler(
// guaranteed to be determinate, so the input data of the reducers will not change
// even if the map tasks are re-tried.
if (mapStage.isIndeterminate && !mapStage.shuffleDep.checksumMismatchFullRetryEnabled) {
- abortUnrollbackableStages(mapStage)
+ val stagesToRollback = collectSucceedingStages(mapStage)
+ val stagesCanRollback = filterAndAbortUnrollbackableStages(stagesToRollback)
+ logInfo(log"The shuffle map stage ${MDC(STAGE, mapStage)} with indeterminate output " +
+ log"was failed, we will roll back and rerun below stages which include itself and all " +
+ log"its indeterminate child stages: ${MDC(STAGES, stagesCanRollback)}")
}
// We expect one executor failure to trigger many FetchFailures in rapid succession,
@@ -2396,7 +2514,8 @@ private[spark] class DAGScheduler(
* @param stagesToRollback stages to roll back
* @return Shuffle map stages which need and can be rolled back
*/
- private def abortStageWithInvalidRollBack(stagesToRollback: HashSet[Stage]): HashSet[Stage] = {
+ private def filterAndAbortUnrollbackableStages(
+ stagesToRollback: HashSet[Stage]): HashSet[Stage] = {
def generateErrorMessage(stage: Stage): String = {
"A shuffle map stage with indeterminate output was failed and retried. " +
@@ -2920,6 +3039,21 @@ private[spark] class DAGScheduler(
}
}
+ private[scheduler] def handleShuffleStatusNotFoundException(
+ ex: ShuffleStatusNotFoundException): Unit = {
+ val stage = shuffleIdToMapStage.get(ex.shuffleId)
+ val reason = "exceptions encountered while invoking " +
+ s"MapOutputTracker.${ex.methodName} with shuffleId=${ex.shuffleId}"
+ if (stage.isDefined) {
+ abortStage(stage.get, reason, Some(ex))
+ logWarning(s"Aborting stage because of $reason. It is possible that the stage is " +
+ "being cancelled.")
+ } else {
+ logWarning(s"Tried aborting stage because of $reason, but the stage was not found. " +
+ "It is possible that the stage has been cancelled earlier.")
+ }
+ }
+
/**
* Marks a stage as finished and removes it from the list of running stages.
*/
@@ -3192,6 +3326,9 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
val timerContext = timer.time()
try {
doOnReceive(event)
+ } catch {
+ case ex: ShuffleStatusNotFoundException =>
+ dagScheduler.handleShuffleStatusNotFoundException(ex)
} finally {
timerContext.stop()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 9bf604e9a83cf..d8aaea013ee65 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -84,6 +84,14 @@ private[scheduler] abstract class Stage(
*/
private[scheduler] var maxChecksumMismatchedId: Int = nextAttemptId
+ /**
+ * The max attempt id we should ignore results for this stage, indicating there are ancestor
+ * stages having been detected with checksum mismatches. This stage is probably also
+ * indeterminate, so we need to avoid completing the stage and the job with incorrect result
+ * by ignoring the task output from previous attempts which might consume inconsistent data
+ */
+ private[scheduler] var maxAttemptIdToIgnore: Option[Int] = None
+
val name: String = callSite.shortForm
val details: String = callSite.longForm
@@ -108,6 +116,14 @@ private[scheduler] abstract class Stage(
failedAttemptIds.clear()
}
+ /** Mark the latest attempt as rollback */
+ private[scheduler] def markAsRollingBack(): Unit = {
+ // Only if the stage has been submitted
+ if (getNextAttemptId > 0) {
+ maxAttemptIdToIgnore = Some(latestInfo.attemptNumber())
+ }
+ }
+
/** Creates a new attempt for this stage by creating a new StageInfo with a new attempt ID. */
def makeNewStageAttempt(
numPartitionsToCompute: Int,
diff --git a/core/src/main/scala/org/apache/spark/status/KVUtils.scala b/core/src/main/scala/org/apache/spark/status/KVUtils.scala
index 76fb654f8da2d..f9ea732f95140 100644
--- a/core/src/main/scala/org/apache/spark/status/KVUtils.scala
+++ b/core/src/main/scala/org/apache/spark/status/KVUtils.scala
@@ -25,6 +25,7 @@ import scala.jdk.CollectionConverters._
import scala.reflect.{classTag, ClassTag}
import com.fasterxml.jackson.annotation.JsonInclude
+import com.fasterxml.jackson.core.StreamReadConstraints
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import org.fusesource.leveldbjni.internal.NativeDB
import org.rocksdb.RocksDBException
@@ -74,8 +75,12 @@ private[spark] object KVUtils extends Logging {
private[spark] class KVStoreScalaSerializer extends KVStoreSerializer {
mapper.registerModule(DefaultScalaModule)
- mapper.setSerializationInclusion(JsonInclude.Include.NON_ABSENT)
+ mapper.setDefaultPropertyInclusion(JsonInclude.Include.NON_ABSENT)
+ // SPARK-49872: Remove jackson JSON string length limitation.
+ mapper.getFactory.setStreamReadConstraints(
+ StreamReadConstraints.builder().maxStringLength(Int.MaxValue).build()
+ )
}
/**
diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala
index efc670440bc64..3c4efd8a5ead1 100644
--- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala
+++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala
@@ -859,40 +859,49 @@ private[spark] object LiveEntityHelpers {
}
createMetrics(
- updateMetricValue(m.executorDeserializeTime),
- updateMetricValue(m.executorDeserializeCpuTime),
- updateMetricValue(m.executorRunTime),
- updateMetricValue(m.executorCpuTime),
- updateMetricValue(m.resultSize),
- updateMetricValue(m.jvmGcTime),
- updateMetricValue(m.resultSerializationTime),
- updateMetricValue(m.memoryBytesSpilled),
- updateMetricValue(m.diskBytesSpilled),
- updateMetricValue(m.peakExecutionMemory),
- updateMetricValue(m.inputMetrics.bytesRead),
- updateMetricValue(m.inputMetrics.recordsRead),
- updateMetricValue(m.shuffleReadMetrics.shufflePushReadMetrics.corruptMergedBlockChunks),
- updateMetricValue(m.shuffleReadMetrics.shufflePushReadMetrics.mergedFetchFallbackCount),
- updateMetricValue(m.shuffleReadMetrics.shufflePushReadMetrics.remoteMergedBlocksFetched),
- updateMetricValue(m.shuffleReadMetrics.shufflePushReadMetrics.localMergedBlocksFetched),
- updateMetricValue(m.shuffleReadMetrics.shufflePushReadMetrics.remoteMergedChunksFetched),
- updateMetricValue(m.shuffleReadMetrics.shufflePushReadMetrics.localMergedChunksFetched),
- updateMetricValue(m.shuffleReadMetrics.shufflePushReadMetrics.remoteMergedBytesRead),
- updateMetricValue(m.shuffleReadMetrics.shufflePushReadMetrics.localMergedBytesRead),
- updateMetricValue(m.shuffleReadMetrics.remoteReqsDuration),
- updateMetricValue(m.shuffleReadMetrics.shufflePushReadMetrics.remoteMergedReqsDuration),
- updateMetricValue(m.outputMetrics.bytesWritten),
- updateMetricValue(m.outputMetrics.recordsWritten),
- updateMetricValue(m.shuffleReadMetrics.remoteBlocksFetched),
- updateMetricValue(m.shuffleReadMetrics.localBlocksFetched),
- updateMetricValue(m.shuffleReadMetrics.fetchWaitTime),
- updateMetricValue(m.shuffleReadMetrics.remoteBytesRead),
- updateMetricValue(m.shuffleReadMetrics.remoteBytesReadToDisk),
- updateMetricValue(m.shuffleReadMetrics.localBytesRead),
- updateMetricValue(m.shuffleReadMetrics.recordsRead),
- updateMetricValue(m.shuffleWriteMetrics.bytesWritten),
- updateMetricValue(m.shuffleWriteMetrics.writeTime),
- updateMetricValue(m.shuffleWriteMetrics.recordsWritten))
+ executorDeserializeTime = updateMetricValue(m.executorDeserializeTime),
+ executorDeserializeCpuTime = updateMetricValue(m.executorDeserializeCpuTime),
+ executorRunTime = updateMetricValue(m.executorRunTime),
+ executorCpuTime = updateMetricValue(m.executorCpuTime),
+ resultSize = updateMetricValue(m.resultSize),
+ jvmGcTime = updateMetricValue(m.jvmGcTime),
+ resultSerializationTime = updateMetricValue(m.resultSerializationTime),
+ memoryBytesSpilled = updateMetricValue(m.memoryBytesSpilled),
+ diskBytesSpilled = updateMetricValue(m.diskBytesSpilled),
+ peakExecutionMemory = updateMetricValue(m.peakExecutionMemory),
+ inputBytesRead = updateMetricValue(m.inputMetrics.bytesRead),
+ inputRecordsRead = updateMetricValue(m.inputMetrics.recordsRead),
+ outputBytesWritten = updateMetricValue(m.outputMetrics.bytesWritten),
+ outputRecordsWritten = updateMetricValue(m.outputMetrics.recordsWritten),
+ shuffleRemoteBlocksFetched = updateMetricValue(m.shuffleReadMetrics.remoteBlocksFetched),
+ shuffleLocalBlocksFetched = updateMetricValue(m.shuffleReadMetrics.localBlocksFetched),
+ shuffleFetchWaitTime = updateMetricValue(m.shuffleReadMetrics.fetchWaitTime),
+ shuffleRemoteBytesRead = updateMetricValue(m.shuffleReadMetrics.remoteBytesRead),
+ shuffleRemoteBytesReadToDisk = updateMetricValue(m.shuffleReadMetrics.remoteBytesReadToDisk),
+ shuffleLocalBytesRead = updateMetricValue(m.shuffleReadMetrics.localBytesRead),
+ shuffleRecordsRead = updateMetricValue(m.shuffleReadMetrics.recordsRead),
+ shuffleCorruptMergedBlockChunks =
+ updateMetricValue(m.shuffleReadMetrics.shufflePushReadMetrics.corruptMergedBlockChunks),
+ shuffleMergedFetchFallbackCount =
+ updateMetricValue(m.shuffleReadMetrics.shufflePushReadMetrics.mergedFetchFallbackCount),
+ shuffleMergedRemoteBlocksFetched =
+ updateMetricValue(m.shuffleReadMetrics.shufflePushReadMetrics.remoteMergedBlocksFetched),
+ shuffleMergedLocalBlocksFetched =
+ updateMetricValue(m.shuffleReadMetrics.shufflePushReadMetrics.localMergedBlocksFetched),
+ shuffleMergedRemoteChunksFetched =
+ updateMetricValue(m.shuffleReadMetrics.shufflePushReadMetrics.remoteMergedChunksFetched),
+ shuffleMergedLocalChunksFetched =
+ updateMetricValue(m.shuffleReadMetrics.shufflePushReadMetrics.localMergedChunksFetched),
+ shuffleMergedRemoteBytesRead =
+ updateMetricValue(m.shuffleReadMetrics.shufflePushReadMetrics.remoteMergedBytesRead),
+ shuffleMergedLocalBytesRead =
+ updateMetricValue(m.shuffleReadMetrics.shufflePushReadMetrics.localMergedBytesRead),
+ shuffleRemoteReqsDuration = updateMetricValue(m.shuffleReadMetrics.remoteReqsDuration),
+ shuffleMergedRemoteReqsDuration =
+ updateMetricValue(m.shuffleReadMetrics.shufflePushReadMetrics.remoteMergedReqsDuration),
+ shuffleBytesWritten = updateMetricValue(m.shuffleWriteMetrics.bytesWritten),
+ shuffleWriteTime = updateMetricValue(m.shuffleWriteMetrics.writeTime),
+ shuffleRecordsWritten = updateMetricValue(m.shuffleWriteMetrics.recordsWritten))
}
private def addMetrics(m1: v1.TaskMetrics, m2: v1.TaskMetrics, mult: Int): v1.TaskMetrics = {
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala
index 259d0aacc5755..1103f55297821 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/JacksonMessageWriter.scala
@@ -48,7 +48,7 @@ private[v1] class JacksonMessageWriter extends MessageBodyWriter[Object]{
}
mapper.registerModule(com.fasterxml.jackson.module.scala.DefaultScalaModule)
mapper.enable(SerializationFeature.INDENT_OUTPUT)
- mapper.setSerializationInclusion(JsonInclude.Include.NON_ABSENT)
+ mapper.setDefaultPropertyInclusion(JsonInclude.Include.NON_ABSENT)
mapper.setDateFormat(JacksonMessageWriter.makeISODateFormat)
override def isWriteable(
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index b2f185bc590fd..cc552a2985f7e 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -193,6 +193,14 @@ final class ShuffleBlockFetcherIterator(
initialize()
+ private def withFetchWaitTimeTracked[T](f: => T): T = {
+ val startFetchWait = System.nanoTime()
+ val res = f
+ val fetchWaitTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait)
+ shuffleMetrics.incFetchWaitTime(fetchWaitTime)
+ res
+ }
+
// Decrements the buffer reference count.
// The currentResult is set to null to prevent releasing the buffer again on cleanup()
private[storage] def releaseCurrentResultBuffer(): Unit = {
@@ -718,7 +726,7 @@ final class ShuffleBlockFetcherIterator(
", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)
// Send out initial requests for blocks, up to our maxBytesInFlight
- fetchUpToMaxBytes()
+ withFetchWaitTimeTracked(fetchUpToMaxBytes())
val numDeferredRequest = deferredFetchRequests.values.map(_.size).sum
val numFetches = remoteRequests.size - fetchRequests.size - numDeferredRequest
@@ -731,7 +739,7 @@ final class ShuffleBlockFetcherIterator(
fetchLocalBlocks(localBlocks)
logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}")
// Get host local blocks if any
- fetchAllHostLocalBlocks(hostLocalBlocksByExecutor)
+ withFetchWaitTimeTracked(fetchAllHostLocalBlocks(hostLocalBlocksByExecutor))
pushBasedFetchHelper.fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks)
}
@@ -813,10 +821,7 @@ final class ShuffleBlockFetcherIterator(
// is also corrupt, so the previous stage could be retried.
// For local shuffle block, throw FailureFetchResult for the first IOException.
while (result == null) {
- val startFetchWait = System.nanoTime()
- result = results.take()
- val fetchWaitTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait)
- shuffleMetrics.incFetchWaitTime(fetchWaitTime)
+ result = withFetchWaitTimeTracked[FetchResult](results.take())
result match {
case SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) =>
@@ -1076,7 +1081,7 @@ final class ShuffleBlockFetcherIterator(
}
// Send fetch requests up to maxBytesInFlight
- fetchUpToMaxBytes()
+ withFetchWaitTimeTracked(fetchUpToMaxBytes())
}
currentResult = result.asInstanceOf[SuccessFetchResult]
diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala
index cd057ed08c3c3..4aa4954b84a91 100644
--- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala
+++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala
@@ -269,7 +269,7 @@ private[spark] object RDDOperationGraph extends Logging {
val label = StringEscapeUtils.escapeJava(
s"${node.name} [${node.id}]$isCached$isBarrier$outputDeterministicLevel" +
s"
$escapedCallsite")
- s"""${node.id} [id="node_${node.id}" labelType="html" label="$label}"]"""
+ s"""${node.id} [id="node_${node.id}" labelType="html" label="$label"]"""
}
/** Update the dot representation of the RDDOperationGraph in cluster to subgraph.
diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala
index af93f781343d2..6a6b0299a0bcd 100644
--- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala
+++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala
@@ -61,12 +61,12 @@ private[spark] object ShutdownHookManager extends Logging {
// Add a shutdown hook to delete the temp dirs when the JVM exits
logDebug("Adding shutdown hook") // force eager creation of logger
addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () =>
- logInfo("Shutdown hook called")
+ logDebug("Shutdown hook called")
// we need to materialize the paths to delete because deleteRecursively removes items from
// shutdownDeletePaths as we are traversing through it.
shutdownDeletePaths.toArray.foreach { dirPath =>
try {
- logInfo(log"Deleting directory ${MDC(LogKeys.PATH, dirPath)}")
+ logDebug(log"Deleting directory ${MDC(LogKeys.PATH, dirPath)}")
Utils.deleteRecursively(new File(dirPath))
} catch {
case e: Exception =>
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 8b1ea4d25592f..0907d6d049bf5 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -65,6 +65,7 @@ import org.slf4j.Logger
import org.apache.spark.{SPARK_VERSION, _}
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.executor.Executor.TASK_THREAD_NAME_PREFIX
import org.apache.spark.internal.{Logging, MessageWithContext}
import org.apache.spark.internal.LogKeys
import org.apache.spark.internal.LogKeys._
@@ -249,6 +250,22 @@ private[spark] object Utils
dir
}
+ /**
+ * Create a temporary directy that will always be cleaned up when the executor stops,
+ * even in the case of a hard shutdown when the shutdown hooks don't get run.
+ *
+ * Currently this only provides special behavior on YARN, where the local dirs are not
+ * guaranteed to be cleaned up on executors hard shutdown.
+ */
+ def createExecutorLocalTempDir(conf: SparkConf, namePrefix: String): File = {
+ if (Utils.isRunningInYarnContainer(conf)) {
+ // Just use the default Java tmp dir which is set to inside the container directory on YARN
+ createTempDir(namePrefix = namePrefix)
+ } else {
+ createTempDir(getLocalDir(conf), namePrefix)
+ }
+ }
+
/**
* Copy the first `maxSize` bytes of data from the InputStream to an in-memory
* buffer, primarily to check for corruption.
@@ -2086,27 +2103,39 @@ private[spark] object Utils
}
}
+ val CONNECT_EXECUTE_THREAD_PREFIX = "SparkConnectExecuteThread"
+
+ private[spark] val threadInfoOrdering = Ordering.fromLessThan {
+ (threadTrace1: ThreadInfo, threadTrace2: ThreadInfo) => {
+ def priority(ti: ThreadInfo): Int = ti.getThreadName match {
+ case name if name.startsWith(TASK_THREAD_NAME_PREFIX) => 100
+ case name if name.startsWith(CONNECT_EXECUTE_THREAD_PREFIX) => 80
+ case _ => 0
+ }
+
+ val v1 = priority(threadTrace1)
+ val v2 = priority(threadTrace2)
+ if (v1 == v2) {
+ val name1 = threadTrace1.getThreadName.toLowerCase(Locale.ROOT)
+ val name2 = threadTrace2.getThreadName.toLowerCase(Locale.ROOT)
+ val nameCmpRes = name1.compareTo(name2)
+ if (nameCmpRes == 0) {
+ threadTrace1.getThreadId < threadTrace2.getThreadId
+ } else {
+ nameCmpRes < 0
+ }
+ } else {
+ v1 > v2
+ }
+ }
+ }
+
/** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */
def getThreadDump(): Array[ThreadStackTrace] = {
// We need to filter out null values here because dumpAllThreads() may return null array
// elements for threads that are dead / don't exist.
- val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null)
- threadInfos.sortWith { case (threadTrace1, threadTrace2) =>
- val v1 = if (threadTrace1.getThreadName.contains("Executor task launch")) 1 else 0
- val v2 = if (threadTrace2.getThreadName.contains("Executor task launch")) 1 else 0
- if (v1 == v2) {
- val name1 = threadTrace1.getThreadName().toLowerCase(Locale.ROOT)
- val name2 = threadTrace2.getThreadName().toLowerCase(Locale.ROOT)
- val nameCmpRes = name1.compareTo(name2)
- if (nameCmpRes == 0) {
- threadTrace1.getThreadId < threadTrace2.getThreadId
- } else {
- nameCmpRes < 0
- }
- } else {
- v1 > v2
- }
- }.map(threadInfoToThreadStackTrace)
+ ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null)
+ .sorted(threadInfoOrdering).map(threadInfoToThreadStackTrace)
}
/** Return a heap dump. Used to capture dumps for the web UI */
@@ -2865,7 +2894,7 @@ private[spark] object Utils
* in canCreate to determine if the KubernetesClusterManager should be used.
*/
def checkAndGetK8sMasterUrl(rawMasterURL: String): String = {
- require(rawMasterURL.startsWith("k8s://"),
+ require(SparkMasterRegex.isK8s(rawMasterURL),
"Kubernetes master URL must start with k8s://.")
val masterWithoutK8sPrefix = rawMasterURL.substring("k8s://".length)
diff --git a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
index d182bd165f1f7..24a1fa7401752 100644
--- a/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkThrowableSuite.scala
@@ -75,7 +75,8 @@ class SparkThrowableSuite extends SparkFunSuite {
.addModule(DefaultScalaModule)
.enable(STRICT_DUPLICATE_DETECTION)
.build()
- mapper.readValue(errorJsonFilePath.toUri.toURL, new TypeReference[Map[String, ErrorInfo]]() {})
+ mapper.readValue(
+ errorJsonFilePath.toUri.toURL.openStream(), new TypeReference[Map[String, ErrorInfo]]() {})
}
test("Error conditions are correctly formatted") {
@@ -88,7 +89,7 @@ class SparkThrowableSuite extends SparkFunSuite {
val prettyPrinter = new DefaultPrettyPrinter()
.withArrayIndenter(DefaultIndenter.SYSTEM_LINEFEED_INSTANCE)
val rewrittenString = mapper.configure(SerializationFeature.ORDER_MAP_ENTRIES_BY_KEYS, true)
- .setSerializationInclusion(Include.NON_ABSENT)
+ .setDefaultPropertyInclusion(Include.NON_ABSENT)
.writer(prettyPrinter)
.writeValueAsString(errorReader.errorInfoMap)
@@ -124,9 +125,9 @@ class SparkThrowableSuite extends SparkFunSuite {
.enable(STRICT_DUPLICATE_DETECTION)
.build()
val errorClasses = mapper.readValue(
- errorClassesJson, new TypeReference[Map[String, String]]() {})
+ errorClassesJson.openStream(), new TypeReference[Map[String, String]]() {})
val errorStates = mapper.readValue(
- errorStatesJson, new TypeReference[Map[String, ErrorStateInfo]]() {})
+ errorStatesJson.openStream(), new TypeReference[Map[String, ErrorStateInfo]]() {})
val errorConditionStates = errorReader.errorInfoMap.values.toSeq.flatMap(_.sqlState).toSet
assert(Set("22012", "22003", "42601").subsetOf(errorStates.keySet))
assert(errorClasses.keySet.filter(!_.matches("[A-Z0-9]{2}")).isEmpty)
diff --git a/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala
index 473a2d7b2a258..2cce3d306e60c 100644
--- a/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/PythonRunnerSuite.scala
@@ -17,7 +17,8 @@
package org.apache.spark.deploy
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkFunSuite, SparkThrowable}
+import org.apache.spark.api.python.PythonErrorUtils
import org.apache.spark.util.Utils
class PythonRunnerSuite extends SparkFunSuite {
@@ -64,4 +65,31 @@ class PythonRunnerSuite extends SparkFunSuite {
intercept[IllegalArgumentException] { PythonRunner.formatPaths("hdfs:/some.py,foo.py") }
intercept[IllegalArgumentException] { PythonRunner.formatPaths("foo.py,hdfs:/some.py") }
}
+
+ test("SPARK-54052: PythonErrorUtils should have corresponding methods in SparkThrowable") {
+ // Find default methods in SparkThrowable
+ val defaultMethods = classOf[SparkThrowable]
+ .getMethods
+ .filter(m => m.getDeclaringClass == classOf[SparkThrowable])
+ .map(_.getName)
+ .toSet
+
+ // Find methods defined in PythonErrorUtils object
+ val utilsMethods = PythonErrorUtils.getClass
+ .getDeclaredMethods
+ .filterNot(_.isSynthetic)
+ .map(_.getName)
+ .filterNot(_.contains("$"))
+ .toSet
+
+ // Compare
+ assert(
+ utilsMethods == defaultMethods,
+ s"""
+ |PythonErrorUtils methods and SparkThrowable default methods differ!
+ |Missing in PythonErrorUtils: ${defaultMethods.diff(utilsMethods).mkString(", ")}
+ |Extra in PythonErrorUtils: ${utilsMethods.diff(defaultMethods).mkString(", ")}
+ |""".stripMargin
+ )
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkPipelinesSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkPipelinesSuite.scala
index a97aabfd5a371..60e279ba2ddc5 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkPipelinesSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkPipelinesSuite.scala
@@ -58,7 +58,7 @@ class SparkPipelinesSuite extends SparkSubmitTestUtils with BeforeAndAfterEach {
val args = Array(
"run",
"--spec",
- "pipeline.yml"
+ "spark-pipeline.yml"
)
assert(
SparkPipelines.constructSparkSubmitArgs(
@@ -71,7 +71,7 @@ class SparkPipelinesSuite extends SparkSubmitTestUtils with BeforeAndAfterEach {
"abc/python/pyspark/pipelines/cli.py",
"run",
"--spec",
- "pipeline.yml"
+ "spark-pipeline.yml"
)
)
}
@@ -83,7 +83,7 @@ class SparkPipelinesSuite extends SparkSubmitTestUtils with BeforeAndAfterEach {
"run",
"--supervise",
"--spec",
- "pipeline.yml",
+ "spark-pipeline.yml",
"--conf",
"spark.conf2=3"
)
@@ -101,7 +101,7 @@ class SparkPipelinesSuite extends SparkSubmitTestUtils with BeforeAndAfterEach {
"abc/python/pyspark/pipelines/cli.py",
"run",
"--spec",
- "pipeline.yml"
+ "spark-pipeline.yml"
)
)
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 0a44045742ffe..18d3c35ea94f4 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -1845,6 +1845,48 @@ class SparkSubmitSuite
assert(classpath.contains("."))
}
+ test("SPARK-52334: Update all files, jars, and pyFiles to" +
+ "reference the working directory after they are downloaded") {
+ withTempDir { dir =>
+ val text1 = File.createTempFile("test1_", ".txt", dir)
+ val zipFile1 = File.createTempFile("test1_", ".zip", dir)
+ TestUtils.createJar(Seq(text1), zipFile1)
+ val testFile = "test_metrics_config.properties"
+ val testPyFile = "test_metrics_system.properties"
+ val testJar = "TestUDTF.jar"
+ val clArgs = Seq(
+ "--deploy-mode", "client",
+ "--proxy-user", "test.user",
+ "--master", "k8s://host:port",
+ "--executor-memory", "5g",
+ "--class", "org.SomeClass",
+ "--driver-memory", "4g",
+ "--conf", "spark.kubernetes.namespace=spark",
+ "--conf", "spark.kubernetes.driver.container.image=bar",
+ "--conf", "spark.kubernetes.submitInDriver=true",
+ "--files", s"src/test/resources/$testFile",
+ "--py-files", s"src/test/resources/$testPyFile",
+ "--jars", s"src/test/resources/$testJar",
+ "--archives", s"${zipFile1.getAbsolutePath}#test_archives",
+ "/home/thejar.jar",
+ "arg1")
+ val appArgs = new SparkSubmitArguments(clArgs)
+ val _ = submit.prepareSubmitEnvironment(appArgs)
+
+ appArgs.files should be (Utils.resolveURIs(s"$testFile,$testPyFile"))
+ appArgs.pyFiles should be (Utils.resolveURIs(testPyFile))
+ appArgs.jars should be (Utils.resolveURIs(testJar))
+ appArgs.archives should be (Utils.resolveURIs(s"${zipFile1.getAbsolutePath}#test_archives"))
+
+ Files.isDirectory(Paths.get("test_archives")) should be(true)
+ Files.delete(Paths.get(testFile))
+ Files.delete(Paths.get(testPyFile))
+ Files.delete(Paths.get(testJar))
+ Files.delete(Paths.get(s"test_archives/${text1.getName}"))
+ Files.delete(Paths.get("test_archives/META-INF/MANIFEST.MF"))
+ }
+ }
+
// Requires Python dependencies for Spark Connect. Should be enabled by default.
ignore("Spark Connect application submission (Python)") {
val pyFile = File.createTempFile("remote_test", ".py")
diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
index 5ecc551c16b8c..487a90f157a9b 100644
--- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
@@ -632,6 +632,9 @@ class StandaloneDynamicAllocationSuite
Map.empty, ResourceProfile.DEFAULT_RESOURCE_PROFILE_ID)
backend.driverEndpoint.askSync[Boolean](message)
backend.driverEndpoint.send(LaunchedExecutor(id))
+ eventually(timeout(10.seconds), interval(100.millis)) {
+ assert(backend.getExecutorAvailableCpus(id).exists(_ > 0))
+ }
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala
index d9d6a4f8d35df..00a92c503be4e 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/EventLogFileWritersSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy.history
-import java.io.{File, FileOutputStream, IOException}
+import java.io.{File, FileOutputStream, IOException, OutputStream, PrintWriter}
import java.net.URI
import scala.collection.mutable
@@ -160,8 +160,152 @@ abstract class EventLogFileWritersSuite extends SparkFunSuite with LocalSparkCon
expectedLines: Seq[String] = Seq.empty): Unit
}
+/**
+ * A test OutputStream that simulates IO errors.
+ */
+class ErrorThrowingOutputStream extends OutputStream {
+ var throwOnWrite: Boolean = false
+ var throwOnFlush: Boolean = false
+ var throwOnClose: Boolean = false
+
+ override def write(b: Int): Unit = {
+ if (throwOnWrite) {
+ throw new IOException("Simulated write error")
+ }
+ }
+
+ override def write(b: Array[Byte], off: Int, len: Int): Unit = {
+ if (throwOnWrite) {
+ throw new IOException("Simulated write error")
+ }
+ }
+
+ override def flush(): Unit = {
+ if (throwOnFlush) {
+ throw new IOException("Simulated flush error")
+ }
+ }
+
+ override def close(): Unit = {
+ if (throwOnClose) {
+ throw new IOException("Simulated close error")
+ }
+ }
+}
+
+/**
+ * A testable subclass of SingleEventLogFileWriter that exposes the writer field
+ * and closeWriter method for testing.
+ */
+class TestableSingleEventLogFileWriter(
+ appId: String,
+ appAttemptId: Option[String],
+ logBaseDir: URI,
+ sparkConf: SparkConf,
+ hadoopConf: Configuration)
+ extends SingleEventLogFileWriter(appId, appAttemptId, logBaseDir, sparkConf, hadoopConf) {
+
+ def setWriterForTest(pw: PrintWriter): Unit = {
+ writer = Some(pw)
+ }
+
+ def callCloseWriter(): Unit = {
+ closeWriter()
+ }
+}
+
class SingleEventLogFileWriterSuite extends EventLogFileWritersSuite {
+ test("SPARK-55495: closeWriter should log warning when flush error occurs") {
+ val appId = getUniqueApplicationId
+ val attemptId = None
+ val conf = getLoggingConf(testDirPath, None)
+ val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
+
+ val writer = new TestableSingleEventLogFileWriter(
+ appId, attemptId, testDirPath.toUri, conf, hadoopConf)
+
+ // Create a PrintWriter with an ErrorThrowingOutputStream
+ val errorStream = new ErrorThrowingOutputStream()
+ val printWriter = new PrintWriter(errorStream)
+
+ // Simulate an error by writing to a closed stream that causes checkError to return true
+ errorStream.throwOnWrite = true
+ // scalastyle:off println
+ printWriter.println("test") // This will set the error flag
+ // scalastyle:on println
+
+ writer.setWriterForTest(printWriter)
+
+ val logAppender = new LogAppender("closeWriter flush error test")
+ withLogAppender(logAppender, level = Some(org.apache.logging.log4j.Level.WARN)) {
+ writer.callCloseWriter()
+ }
+
+ val warningMessages = logAppender.loggingEvents.map(_.getMessage.getFormattedMessage)
+ assert(warningMessages.exists(_.contains("Spark detects errors while flushing")),
+ s"Expected warning message not found. Messages: $warningMessages")
+ }
+
+ test("SPARK-55495: closeWriter should log warning when close error occurs") {
+ val appId = getUniqueApplicationId
+ val attemptId = None
+ val conf = getLoggingConf(testDirPath, None)
+ val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
+
+ val writer = new TestableSingleEventLogFileWriter(
+ appId, attemptId, testDirPath.toUri, conf, hadoopConf)
+
+ // Create a PrintWriter with an ErrorThrowingOutputStream that errors on close
+ val errorStream = new ErrorThrowingOutputStream()
+ val printWriter = new PrintWriter(errorStream)
+
+ // First write something successfully
+ // scalastyle:off println
+ printWriter.println("test")
+ // scalastyle:on println
+ printWriter.flush()
+
+ // Now set up to error on close
+ errorStream.throwOnClose = true
+
+ writer.setWriterForTest(printWriter)
+
+ val logAppender = new LogAppender("closeWriter close error test")
+ withLogAppender(logAppender, level = Some(org.apache.logging.log4j.Level.WARN)) {
+ writer.callCloseWriter()
+ }
+
+ val warningMessages = logAppender.loggingEvents.map(_.getMessage.getFormattedMessage)
+ assert(warningMessages.exists(_.contains("Spark detects errors while closing")),
+ s"Expected warning message not found. Messages: $warningMessages")
+ }
+
+ test("SPARK-55495: closeWriter should complete without warnings when no errors") {
+ val appId = getUniqueApplicationId
+ val attemptId = None
+ val conf = getLoggingConf(testDirPath, None)
+ val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
+
+ val writer = new TestableSingleEventLogFileWriter(
+ appId, attemptId, testDirPath.toUri, conf, hadoopConf)
+
+ // Create a normal PrintWriter with no errors
+ val normalStream = new ErrorThrowingOutputStream()
+ val printWriter = new PrintWriter(normalStream)
+
+ writer.setWriterForTest(printWriter)
+
+ val logAppender = new LogAppender("closeWriter no error test")
+ withLogAppender(logAppender, level = Some(org.apache.logging.log4j.Level.WARN)) {
+ writer.callCloseWriter()
+ }
+
+ val warningMessages = logAppender.loggingEvents.map(_.getMessage.getFormattedMessage)
+ assert(!warningMessages.exists(_.contains("Spark detects errors")),
+ s"Unexpected warning message found. Messages: $warningMessages")
+ }
+
test("Log overwriting") {
val appId = "test"
val appAttemptId = None
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala
index ff5d314d1688a..f9a0efce88708 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.deploy.worker
import java.io.{File, IOException}
+import java.util.concurrent.{ScheduledFuture => JScheduledFuture}
import java.util.concurrent.atomic.AtomicBoolean
import java.util.function.Supplier
@@ -37,7 +38,7 @@ import org.scalatest.matchers.should.Matchers._
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.TestUtils.{createTempJsonFile, createTempScriptWithExpectedOutput}
import org.apache.spark.deploy.{Command, ExecutorState, ExternalShuffleService}
-import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorStateChanged, WorkDirCleanup}
+import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorStateChanged, RegisteredWorker, WorkDirCleanup}
import org.apache.spark.deploy.master.DriverState
import org.apache.spark.internal.config
import org.apache.spark.internal.config.SHUFFLE_SERVICE_DB_BACKEND
@@ -46,7 +47,7 @@ import org.apache.spark.network.shuffledb.DBBackend
import org.apache.spark.resource.{ResourceAllocation, ResourceInformation}
import org.apache.spark.resource.ResourceUtils._
import org.apache.spark.resource.TestResourceIDs.{WORKER_FPGA_ID, WORKER_GPU_ID}
-import org.apache.spark.rpc.{RpcAddress, RpcEnv}
+import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv}
import org.apache.spark.util.Utils
class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter with PrivateMethodTester {
@@ -405,4 +406,41 @@ class WorkerSuite extends SparkFunSuite with Matchers with BeforeAndAfter with P
}.getMessage
assert(m.contains("Whitespace is not allowed"))
}
+
+ test("SPARK-54312: heartbeat task and workdir cleanup task should only be scheduled once " +
+ "across multiple registrations") {
+ val worker = spy(makeWorker())
+ val masterWebUiUrl = "https://1.2.3.4:8080"
+ val masterAddress = RpcAddress("1.2.3.4", 1234)
+ val masterRef = mock(classOf[RpcEndpointRef])
+ when(masterRef.address).thenReturn(masterAddress)
+
+ def getHeartbeatTask(worker: Worker): Option[JScheduledFuture[_]] = {
+ val _heartbeatTask =
+ PrivateMethod[Option[JScheduledFuture[_]]](Symbol("heartbeatTask"))
+ worker.invokePrivate(_heartbeatTask())
+ }
+
+ def getWorkDirCleanupTask(worker: Worker): Option[JScheduledFuture[_]] = {
+ val _workDirCleanupTask =
+ PrivateMethod[Option[JScheduledFuture[_]]](Symbol("workDirCleanupTask"))
+ worker.invokePrivate(_workDirCleanupTask())
+ }
+
+ // Tasks should not be scheduled yet before registration
+ assert(getHeartbeatTask(worker).isEmpty && getWorkDirCleanupTask(worker).isEmpty)
+
+ val msg = RegisteredWorker(masterRef, masterWebUiUrl, masterAddress, duplicate = false)
+ // Simulate first registration - this should schedule both tasks
+ worker.receive(msg)
+ val heartbeatTask = getHeartbeatTask(worker)
+ val workDirCleanupTask = getWorkDirCleanupTask(worker)
+ assert(heartbeatTask.isDefined && workDirCleanupTask.isDefined)
+
+ // Simulate disconnection and re-registration
+ worker.receive(msg)
+ // After re-registration, the task references should be the same (not rescheduled)
+ assert(getHeartbeatTask(worker) == heartbeatTask)
+ assert(getWorkDirCleanupTask(worker) == workDirCleanupTask)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSideSessionManagementSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSideSessionManagementSuite.scala
new file mode 100644
index 0000000000000..f127951054e78
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSideSessionManagementSuite.scala
@@ -0,0 +1,233 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import java.io.File
+
+import scala.collection.mutable.HashMap
+
+import org.mockito.Mockito.when
+import org.scalatest.BeforeAndAfterEach
+import org.scalatestplus.mockito.MockitoSugar
+
+import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite}
+import org.apache.spark.util.{MutableURLClassLoader, Utils}
+
+/**
+ * Unit tests for IsolatedSessionState lifecycle management.
+ * These tests verify the fix for race conditions in session acquire/release/eviction.
+ */
+class ExecutorSideSessionManagementSuite
+ extends SparkFunSuite
+ with BeforeAndAfterEach
+ with MockitoSugar {
+
+ private var testSessionCounter = 0
+ private var tempDir: File = _
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ // Clear the sessions map before each test
+ IsolatedSessionState.sessions.clear()
+ testSessionCounter = 0
+
+ // Set up a mock SparkEnv so that cleanup() can access SparkFiles.getRootDirectory()
+ tempDir = Utils.createTempDir()
+ val mockEnv = mock[SparkEnv]
+ val conf = new SparkConf(false)
+ when(mockEnv.conf).thenReturn(conf)
+ when(mockEnv.driverTmpDir).thenReturn(Some(tempDir.getAbsolutePath))
+ SparkEnv.set(mockEnv)
+ }
+
+ override def afterEach(): Unit = {
+ // Clear the sessions map after each test
+ IsolatedSessionState.sessions.clear()
+ SparkEnv.set(null)
+ if (tempDir != null && tempDir.exists()) {
+ Utils.deleteRecursively(tempDir)
+ tempDir = null
+ }
+ super.afterEach()
+ }
+
+ /**
+ * Creates a test IsolatedSessionState with a mock classloader and unique UUID.
+ */
+ private def createTestSession(uuid: String): IsolatedSessionState = {
+ val classLoader = new MutableURLClassLoader(
+ Array.empty,
+ Thread.currentThread().getContextClassLoader
+ )
+ val session = new IsolatedSessionState(
+ sessionUUID = uuid,
+ urlClassLoader = classLoader,
+ replClassLoader = classLoader,
+ currentFiles = new HashMap[String, Long](),
+ currentJars = new HashMap[String, Long](),
+ currentArchives = new HashMap[String, Long](),
+ replClassDirUri = None
+ )
+ // Register in authoritative sessions map as would happen in production
+ IsolatedSessionState.sessions.put(uuid, session)
+ session
+ }
+
+ private def nextUniqueUuid(): String = {
+ testSessionCounter += 1
+ s"test-uuid-$testSessionCounter"
+ }
+
+ test("acquire returns true for new session") {
+ val session = createTestSession(nextUniqueUuid())
+ assert(session.acquire())
+ }
+
+ test("acquire returns true for session acquired multiple times") {
+ val session = createTestSession(nextUniqueUuid())
+ assert(session.acquire())
+ assert(session.acquire())
+ assert(session.acquire())
+ }
+
+ test("acquire returns false after session is evicted with no references") {
+ val session = createTestSession(nextUniqueUuid())
+ session.markEvicted()
+ // Session should be cleaned up immediately since refCount is 0
+ assert(!IsolatedSessionState.sessions.containsKey(session.sessionUUID))
+ // Cannot acquire an evicted session.
+ assert(!session.acquire())
+ }
+
+ test("acquire returns false after session is evicted even with existing references") {
+ val uuid = nextUniqueUuid()
+ val session = createTestSession(uuid)
+
+ // First task acquires the session
+ assert(session.acquire())
+
+ // Session gets evicted (e.g., due to cache pressure)
+ session.markEvicted()
+ // Session should still be in the map because refCount > 0 (deferred cleanup)
+ assert(IsolatedSessionState.sessions.containsKey(uuid))
+
+ // A new task tries to acquire the same session - should fail because it's evicted
+ assert(!session.acquire())
+
+ // The original task releases - now cleanup happens
+ session.release()
+ assert(!IsolatedSessionState.sessions.containsKey(uuid))
+ }
+
+ test("deferred cleanup with multiple references") {
+ val uuid = nextUniqueUuid()
+ val session = createTestSession(uuid)
+
+ // Acquire the session multiple times (simulating multiple tasks)
+ assert(session.acquire())
+ assert(session.acquire())
+ assert(session.acquire())
+
+ // Evict the session - cleanup should be deferred
+ session.markEvicted()
+ assert(IsolatedSessionState.sessions.containsKey(uuid))
+
+ // Release twice - cleanup should still be deferred
+ session.release()
+ assert(IsolatedSessionState.sessions.containsKey(uuid))
+ session.release()
+ assert(IsolatedSessionState.sessions.containsKey(uuid))
+
+ // Release the last reference - cleanup should happen now
+ session.release()
+ assert(!IsolatedSessionState.sessions.containsKey(uuid))
+ }
+
+ test("tryUnEvict succeeds when session is evicted but still has references") {
+ val session = createTestSession(nextUniqueUuid())
+
+ // Acquire the session
+ assert(session.acquire())
+
+ // Evict the session
+ session.markEvicted()
+
+ // Try to un-evict - should succeed because refCount > 0
+ assert(session.tryUnEvict())
+
+ // Now acquire should succeed again
+ assert(session.acquire())
+ }
+
+ test("tryUnEvict fails when session is not evicted") {
+ val session = createTestSession(nextUniqueUuid())
+
+ // Acquire without eviction
+ assert(session.acquire())
+
+ // Try to un-evict - should fail because session is not evicted
+ assert(!session.tryUnEvict())
+ }
+
+ test("tryUnEvict and acquire fail when session has no references") {
+ val uuid = nextUniqueUuid()
+ val session = createTestSession(uuid)
+
+ // Evict with no references - triggers immediate cleanup
+ session.markEvicted()
+ assert(!IsolatedSessionState.sessions.containsKey(uuid))
+
+ // tryUnEvict should fail because refCount is 0 and session is already cleaned up
+ assert(!session.tryUnEvict())
+
+ // acquire should also fail
+ assert(!session.acquire())
+ }
+
+ test("session reuse via tryUnEvict keeps session in map when not evicted") {
+ // Note: This test verifies `IsolatedSessionState.sessions` behavior in isolation.
+ // In production, after tryUnEvict(), the session is put back into the Guava cache.
+ // When the cache eventually evicts it again (due to LRU policy), markEvicted() will be called,
+ // and cleanup will happen if refCount is 0. So there's no resource leak in practice.
+ val uuid = nextUniqueUuid()
+ val session = createTestSession(uuid)
+
+ // Simulate task 1 acquiring the session
+ assert(session.acquire())
+
+ // Session gets evicted (e.g., due to cache pressure)
+ session.markEvicted()
+ assert(IsolatedSessionState.sessions.containsKey(uuid)) // Deferred cleanup
+
+ // Simulate cache loader trying to reuse the session via tryUnEvict
+ assert(session.tryUnEvict())
+
+ // Now a new task can acquire the session
+ assert(session.acquire())
+
+ // Task 1 releases
+ session.release()
+ assert(IsolatedSessionState.sessions.containsKey(uuid)) // Still has 1 reference
+
+ // Task 2 releases - session stays in map because it's not evicted
+ session.release()
+ // Session stays in map because it's not evicted anymore (was un-evicted).
+ // In production, the Guava cache would eventually evict it again, triggering cleanup.
+ assert(IsolatedSessionState.sessions.containsKey(uuid))
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index 6f525cf8b898a..dd9884bffb285 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -22,7 +22,7 @@ import java.lang.Thread.UncaughtExceptionHandler
import java.net.URL
import java.nio.ByteBuffer
import java.util.{HashMap, Properties}
-import java.util.concurrent.{CountDownLatch, TimeUnit}
+import java.util.concurrent.{CountDownLatch, ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.immutable
@@ -603,6 +603,54 @@ class ExecutorSuite extends SparkFunSuite
}
}
+ test("SPARK-54087: launchTask should return task killed message when threadPool.execute fails") {
+ val conf = new SparkConf
+ val serializer = new JavaSerializer(conf)
+ val env = createMockEnv(conf, serializer)
+ val serializedTask = serializer.newInstance().serialize(new FakeTask(0, 0))
+ val taskDescription = createFakeTaskDescription(serializedTask)
+
+ val mockExecutorBackend = mock[ExecutorBackend]
+ val statusCaptor = ArgumentCaptor.forClass(classOf[ByteBuffer])
+
+ withExecutor("id", "localhost", env) { executor =>
+ // Use reflection to replace threadPool with a mock that throws an exception
+ val executorClass = classOf[Executor]
+ val threadPoolField = executorClass.getDeclaredField("threadPool")
+ threadPoolField.setAccessible(true)
+ val originalThreadPool = threadPoolField.get(executor).asInstanceOf[ThreadPoolExecutor]
+
+ // Create a mock ThreadPoolExecutor that throws an exception when execute is called
+ val mockThreadPool = mock[ThreadPoolExecutor]
+ val testException = new OutOfMemoryError("unable to create new native thread")
+ when(mockThreadPool.execute(any[Runnable])).thenThrow(testException)
+ threadPoolField.set(executor, mockThreadPool)
+
+ try {
+ // Launch the task - this should catch the exception and send statusUpdate
+ executor.launchTask(mockExecutorBackend, taskDescription)
+
+ // Verify that statusUpdate was called with FAILED state
+ verify(mockExecutorBackend).statusUpdate(
+ meq(taskDescription.taskId),
+ meq(TaskState.FAILED),
+ statusCaptor.capture()
+ )
+
+ // Verify that the exception was correctly serialized
+ val failureData = statusCaptor.getValue
+ val failReason = serializer.newInstance()
+ .deserialize[ExceptionFailure](failureData)
+ assert(failReason.exception.isDefined)
+ assert(failReason.exception.get.isInstanceOf[OutOfMemoryError])
+ assert(failReason.exception.get.getMessage === "unable to create new native thread")
+ } finally {
+ // Restore the original threadPool
+ threadPoolField.set(executor, originalThreadPool)
+ }
+ }
+ }
+
private def createMockEnv(conf: SparkConf, serializer: JavaSerializer): SparkEnv = {
val mockEnv = mock[SparkEnv]
val mockRpcEnv = mock[RpcEnv]
diff --git a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala
index 9c74f2fdd459b..9f0e622b1d515 100644
--- a/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/memory/UnifiedMemoryManagerSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.internal.config._
import org.apache.spark.internal.config.Tests._
import org.apache.spark.storage.TestBlockId
import org.apache.spark.storage.memory.MemoryStore
+import org.apache.spark.util.Utils
class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTester {
private val dummyBlock = TestBlockId("--")
@@ -554,6 +555,7 @@ class UnifiedMemoryManagerSuite extends MemoryManagerSuite with PrivateMethodTes
}
test("unmanaged memory tracking with off-heap memory enabled") {
+ assume(!Utils.isMacOnAppleSilicon)
val maxOnHeapMemory = 1000L
val maxOffHeapMemory = 1500L
val taskAttemptId = 0L
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 6ec0ea320eaa0..48f1c49e7af23 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -3421,11 +3421,12 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
stageId: Int,
shuffleId: Int,
numTasks: Int = 2,
- checksumVal: Long = 0): Unit = {
+ checksumVal: Long = 0,
+ stageAttemptId: Int = 1): Unit = {
assert(taskSets(taskSetIndex).stageId == stageId)
- assert(taskSets(taskSetIndex).stageAttemptId == 1)
+ assert(taskSets(taskSetIndex).stageAttemptId == stageAttemptId)
assert(taskSets(taskSetIndex).tasks.length == numTasks)
- completeShuffleMapStageSuccessfully(stageId, 1, 2, checksumVal = checksumVal)
+ completeShuffleMapStageSuccessfully(stageId, stageAttemptId, 2, checksumVal = checksumVal)
assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty))
}
@@ -3835,6 +3836,129 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
}
}
+ test("SPARK-54556: ensure rollback all the succeeding stages and ignore stale task results " +
+ "when shuffle checksum mismatch detected") {
+ /**
+ * Construct the following RDD graph:
+ *
+ * ShuffleMapRdd1 (Indeterminate)
+ * / \
+ * ShuffleMapRdd2 \
+ * / |
+ * ShuffleMapRdd3 |
+ * \ |
+ * FinalRd
+ *
+ * While executing the result stage, shuffle fetch failed on shuffle1 and leading to executor
+ * loss and some map output of shuffle2 lost.
+ * Both stage 0 and stage 2 will be submitted.
+ * Checksum mismatch is detected when retrying stage 0.
+ * Retry task of stage 2 completed and should be ignored.
+ */
+ val shuffleMapRdd1 = new MyRDD(sc, 2, Nil)
+ val shuffleDep1 = new ShuffleDependency(
+ shuffleMapRdd1,
+ new HashPartitioner(2),
+ checksumMismatchFullRetryEnabled = true)
+ val shuffleId1 = shuffleDep1.shuffleId
+
+ val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+ val shuffleDep2 = new ShuffleDependency(
+ shuffleMapRdd2,
+ new HashPartitioner(2),
+ checksumMismatchFullRetryEnabled = true)
+ val shuffleId2 = shuffleDep2.shuffleId
+
+ val shuffleMapRdd3 = new MyRDD(sc, 2, List(shuffleDep2), tracker = mapOutputTracker)
+ val shuffleDep3 = new ShuffleDependency(
+ shuffleMapRdd3,
+ new HashPartitioner(2),
+ checksumMismatchFullRetryEnabled = true)
+ val shuffleId3 = shuffleDep3.shuffleId
+
+ val finalRdd = new MyRDD(sc, 2, List(shuffleDep1, shuffleDep3), tracker = mapOutputTracker)
+
+ // Submit the job and complete the shuffle stages
+ submit(finalRdd, Array(0, 1))
+ completeShuffleMapStageSuccessfully(
+ 0, 0, 2, Seq("hostA", "hostB"), checksumVal = 100)
+ completeShuffleMapStageSuccessfully(
+ 1, 0, 2, Seq("hostC", "hostD"), checksumVal = 200)
+ completeShuffleMapStageSuccessfully(
+ 2, 0, 2, Seq("hostB", "hostC"), checksumVal = 300)
+ assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty))
+ assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty))
+ assert(mapOutputTracker.findMissingPartitions(shuffleId3) === Some(Seq.empty))
+
+ // The first task of result stage 3 failed with FetchFailed.
+ runEvent(makeCompletionEvent(
+ taskSets(3).tasks(0),
+ FetchFailed(makeBlockManagerId("hostB"), shuffleId1, 0L, 0, 0, "ignored"),
+ null))
+ assert(mapOutputTracker.findMissingPartitions(shuffleId3).nonEmpty)
+
+ // Check status for all failedStages.
+ val failedStages = scheduler.failedStages.toSeq
+ assert(failedStages.map(_.id) === Seq(0, 3))
+ scheduler.resubmitFailedStages()
+ // Check status for runningStages.
+ assert(scheduler.runningStages.map(_.id) === Set(0, 2))
+
+ // Complete the re-attempt of shuffle map stage 0(shuffleId1) with a different checksum.
+ completeShuffleMapStageSuccessfully(0, 1, 2, checksumVal = 101)
+ completeShuffleMapStageSuccessfully(2, 1, 2, checksumVal = 300)
+ // The result of stage 2 should be ignored
+ assert(mapOutputTracker.getNumAvailableOutputs(shuffleId3) === 0)
+ scheduler.resubmitFailedStages()
+ assert(scheduler.runningStages.map(_.id) === Set(1))
+
+ checkAndCompleteRetryStage(6, 1, shuffleId2, 2, checksumVal = 201)
+ checkAndCompleteRetryStage(7, 2, shuffleId3, 2, checksumVal = 301, stageAttemptId = 2)
+ completeAndCheckAnswer(taskSets(8), Seq((Success, 11), (Success, 12)), Map(0 -> 11, 1 -> 12))
+ }
+
+ test("SPARK-54556: abort stage if result task from old attempt with indeterminate " +
+ "result succeeded") {
+ val shuffleMapRdd1 = new MyRDD(sc, 2, Nil)
+ val shuffleDep1 = new ShuffleDependency(
+ shuffleMapRdd1,
+ new HashPartitioner(2),
+ checksumMismatchFullRetryEnabled = true)
+ val shuffleId1 = shuffleDep1.shuffleId
+
+ // Submit a job depending on shuffleDep1
+ val finalRdd1 = new MyRDD(
+ sc, 2, List(shuffleDep1), tracker = mapOutputTracker)
+ submit(finalRdd1, Array(0, 1))
+
+ // Finish stage 0.
+ completeShuffleMapStageSuccessfully(
+ 0, 0, 2, Seq("hostA", "hostB"), checksumVal = 100)
+ assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty))
+
+ // The first task of result stage failed with FetchFailed.
+ runEvent(makeCompletionEvent(
+ taskSets(1).tasks(0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0, "ignored"),
+ null))
+
+ // Check status for all failedStages.
+ val failedStages = scheduler.failedStages.toSeq
+ assert(failedStages.map(_.id) == Seq(0, 1))
+ scheduler.resubmitFailedStages()
+
+ // Complete the shuffle map stage with a different checksum
+ completeShuffleMapStageSuccessfully(0, 1, 2, checksumVal = 101)
+
+ // Complete the second task of 1st attempt of result stage.
+ runEvent(makeCompletionEvent(
+ taskSets(1).tasks(1),
+ Success,
+ 42))
+ assert(failure != null && failure.getMessage.contains(
+ "Task with indeterminate results from old attempt succeeded"))
+ }
+
test("SPARK-27164: RDD.countApprox on empty RDDs schedules jobs which never complete") {
val latch = new CountDownLatch(1)
val jobListener = new SparkListener {
diff --git a/core/src/test/scala/org/apache/spark/status/LiveEntitySuite.scala b/core/src/test/scala/org/apache/spark/status/LiveEntitySuite.scala
index 35e8a62c93c99..bed822f0b457b 100644
--- a/core/src/test/scala/org/apache/spark/status/LiveEntitySuite.scala
+++ b/core/src/test/scala/org/apache/spark/status/LiveEntitySuite.scala
@@ -66,6 +66,135 @@ class LiveEntitySuite extends SparkFunSuite {
assert(accuInfo.value == "[1,2,3,4,5,... 5 more items]")
}
+ test("makeNegative correctly negates all metrics with proper argument order") {
+ import LiveEntityHelpers._
+
+ val originalMetrics = createMetrics(
+ executorDeserializeTime = 1L,
+ executorDeserializeCpuTime = 2L,
+ executorRunTime = 3L,
+ executorCpuTime = 4L,
+ resultSize = 5L,
+ jvmGcTime = 6L,
+ resultSerializationTime = 7L,
+ memoryBytesSpilled = 8L,
+ diskBytesSpilled = 9L,
+ peakExecutionMemory = 10L,
+ inputBytesRead = 11L,
+ inputRecordsRead = 12L,
+ outputBytesWritten = 13L,
+ outputRecordsWritten = 14L,
+ shuffleRemoteBlocksFetched = 15L,
+ shuffleLocalBlocksFetched = 16L,
+ shuffleFetchWaitTime = 17L,
+ shuffleRemoteBytesRead = 18L,
+ shuffleRemoteBytesReadToDisk = 19L,
+ shuffleLocalBytesRead = 20L,
+ shuffleRecordsRead = 21L,
+ shuffleCorruptMergedBlockChunks = 22L,
+ shuffleMergedFetchFallbackCount = 23L,
+ shuffleMergedRemoteBlocksFetched = 24L,
+ shuffleMergedLocalBlocksFetched = 25L,
+ shuffleMergedRemoteChunksFetched = 26L,
+ shuffleMergedLocalChunksFetched = 27L,
+ shuffleMergedRemoteBytesRead = 28L,
+ shuffleMergedLocalBytesRead = 29L,
+ shuffleRemoteReqsDuration = 30L,
+ shuffleMergedRemoteReqsDuration = 31L,
+ shuffleBytesWritten = 32L,
+ shuffleWriteTime = 33L,
+ shuffleRecordsWritten = 34L
+ )
+
+ val negatedMetrics = makeNegative(originalMetrics)
+
+ def expectedNegated(v: Long): Long = v * -1L - 1L
+
+ // Verify all fields are correctly negated
+ assert(negatedMetrics.executorDeserializeTime === expectedNegated(1L))
+ assert(negatedMetrics.executorDeserializeCpuTime === expectedNegated(2L))
+ assert(negatedMetrics.executorRunTime === expectedNegated(3L))
+ assert(negatedMetrics.executorCpuTime === expectedNegated(4L))
+ assert(negatedMetrics.resultSize === expectedNegated(5L))
+ assert(negatedMetrics.jvmGcTime === expectedNegated(6L))
+ assert(negatedMetrics.resultSerializationTime === expectedNegated(7L))
+ assert(negatedMetrics.memoryBytesSpilled === expectedNegated(8L))
+ assert(negatedMetrics.diskBytesSpilled === expectedNegated(9L))
+ assert(negatedMetrics.peakExecutionMemory === expectedNegated(10L))
+
+ // Verify input metrics
+ assert(negatedMetrics.inputMetrics.bytesRead === expectedNegated(11L))
+ assert(negatedMetrics.inputMetrics.recordsRead === expectedNegated(12L))
+
+ // Verify output metrics (these were in wrong position in current master)
+ assert(negatedMetrics.outputMetrics.bytesWritten === expectedNegated(13L),
+ "outputMetrics.bytesWritten should be correctly negated")
+ assert(negatedMetrics.outputMetrics.recordsWritten === expectedNegated(14L),
+ "outputMetrics.recordsWritten should be correctly negated")
+
+ // Verify shuffle read metrics (these were in wrong position in current master)
+ assert(negatedMetrics.shuffleReadMetrics.remoteBlocksFetched === expectedNegated(15L),
+ "shuffleReadMetrics.remoteBlocksFetched should be correctly negated")
+ assert(negatedMetrics.shuffleReadMetrics.localBlocksFetched === expectedNegated(16L),
+ "shuffleReadMetrics.localBlocksFetched should be correctly negated")
+ assert(negatedMetrics.shuffleReadMetrics.fetchWaitTime === expectedNegated(17L),
+ "shuffleReadMetrics.fetchWaitTime should be correctly negated")
+ assert(negatedMetrics.shuffleReadMetrics.remoteBytesRead === expectedNegated(18L),
+ "shuffleReadMetrics.remoteBytesRead should be correctly negated")
+ assert(negatedMetrics.shuffleReadMetrics.remoteBytesReadToDisk === expectedNegated(19L),
+ "shuffleReadMetrics.remoteBytesReadToDisk should be correctly negated")
+ assert(negatedMetrics.shuffleReadMetrics.localBytesRead === expectedNegated(20L),
+ "shuffleReadMetrics.localBytesRead should be correctly negated")
+ assert(negatedMetrics.shuffleReadMetrics.recordsRead === expectedNegated(21L),
+ "shuffleReadMetrics.recordsRead should be correctly negated")
+
+ // Verify shuffle push read metrics (these were in wrong position in current master)
+ assert(negatedMetrics.shuffleReadMetrics.shufflePushReadMetrics.corruptMergedBlockChunks ===
+ expectedNegated(22L),
+ "shufflePushReadMetrics.corruptMergedBlockChunks should be correctly negated")
+ assert(negatedMetrics.shuffleReadMetrics.shufflePushReadMetrics.mergedFetchFallbackCount ===
+ expectedNegated(23L),
+ "shufflePushReadMetrics.mergedFetchFallbackCount should be correctly negated")
+ assert(negatedMetrics.shuffleReadMetrics.shufflePushReadMetrics.remoteMergedBlocksFetched ===
+ expectedNegated(24L),
+ "shufflePushReadMetrics.remoteMergedBlocksFetched should be correctly negated")
+ assert(negatedMetrics.shuffleReadMetrics.shufflePushReadMetrics.localMergedBlocksFetched ===
+ expectedNegated(25L),
+ "shufflePushReadMetrics.localMergedBlocksFetched should be correctly negated")
+ assert(negatedMetrics.shuffleReadMetrics.shufflePushReadMetrics.remoteMergedChunksFetched ===
+ expectedNegated(26L),
+ "shufflePushReadMetrics.remoteMergedChunksFetched should be correctly negated")
+ assert(negatedMetrics.shuffleReadMetrics.shufflePushReadMetrics.localMergedChunksFetched ===
+ expectedNegated(27L),
+ "shufflePushReadMetrics.localMergedChunksFetched should be correctly negated")
+ assert(negatedMetrics.shuffleReadMetrics.shufflePushReadMetrics.remoteMergedBytesRead ===
+ expectedNegated(28L),
+ "shufflePushReadMetrics.remoteMergedBytesRead should be correctly negated")
+ assert(negatedMetrics.shuffleReadMetrics.shufflePushReadMetrics.localMergedBytesRead ===
+ expectedNegated(29L),
+ "shufflePushReadMetrics.localMergedBytesRead should be correctly negated")
+ assert(negatedMetrics.shuffleReadMetrics.remoteReqsDuration === expectedNegated(30L),
+ "shuffleReadMetrics.remoteReqsDuration should be correctly negated")
+ assert(negatedMetrics.shuffleReadMetrics.shufflePushReadMetrics.remoteMergedReqsDuration ===
+ expectedNegated(31L),
+ "shufflePushReadMetrics.remoteMergedReqsDuration should be correctly negated")
+
+ // Verify shuffle write metrics
+ assert(negatedMetrics.shuffleWriteMetrics.bytesWritten === expectedNegated(32L))
+ assert(negatedMetrics.shuffleWriteMetrics.writeTime === expectedNegated(33L))
+ assert(negatedMetrics.shuffleWriteMetrics.recordsWritten === expectedNegated(34L))
+
+ // Verify zero handling: 0 should become -1
+ val zeroMetrics = createMetrics(default = 0L)
+ val negatedZeroMetrics = makeNegative(zeroMetrics)
+ assert(negatedZeroMetrics.executorDeserializeTime === -1L,
+ "Zero value should be converted to -1")
+ assert(negatedZeroMetrics.inputMetrics.bytesRead === -1L,
+ "Zero input metric should be converted to -1")
+ assert(negatedZeroMetrics.outputMetrics.bytesWritten === -1L,
+ "Zero output metric should be converted to -1")
+ }
+
private def checkSize(seq: Seq[_], expected: Int): Unit = {
assert(seq.length === expected)
var count = 0
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala
index 2b2a67c3c00ad..8c6b9cc288ec1 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerDecommissionIntegrationSuite.scala
@@ -418,8 +418,11 @@ class BlockManagerDecommissionIntegrationSuite extends SparkFunSuite with LocalS
.map(_.asInstanceOf[ShuffleIndexBlockId].shuffleId)
.get
+ eventually(timeout(1.minute), interval(10.milliseconds)) {
+ val newShuffleFiles = shuffleFiles.diff(existingShuffleFiles)
+ assert(newShuffleFiles.size >= shuffleBlockUpdates.size)
+ }
val newShuffleFiles = shuffleFiles.diff(existingShuffleFiles)
- assert(newShuffleFiles.size >= shuffleBlockUpdates.size)
// Remove the shuffle data
sc.shuffleDriverComponents.removeShuffle(shuffleId, true)
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index 933b6fc39e913..61952c4018534 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.util
import java.io._
+import java.lang.management.ThreadInfo
import java.lang.reflect.Field
import java.net.{BindException, ServerSocket, URI}
import java.nio.{ByteBuffer, ByteOrder}
@@ -37,6 +38,8 @@ import org.apache.hadoop.fs.Path
import org.apache.hadoop.fs.audit.CommonAuditContext.currentAuditContext
import org.apache.hadoop.ipc.{CallerContext => HadoopCallerContext}
import org.apache.logging.log4j.Level
+import org.mockito.Mockito.when
+import org.scalatestplus.mockito.MockitoSugar.mock
import org.apache.spark.{SparkConf, SparkException, SparkFunSuite, TaskContext}
import org.apache.spark.internal.config._
@@ -1126,6 +1129,40 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties {
assert(pValue > threshold)
}
+ test("ThreadInfoOrdering") {
+ val task1T = mock[ThreadInfo]
+ when(task1T.getThreadId).thenReturn(11L)
+ when(task1T.getThreadName)
+ .thenReturn("Executor task launch worker for task 1.0 in stage 1.0 (TID 11)")
+ when(task1T.toString)
+ .thenReturn("Executor task launch worker for task 1.0 in stage 1.0 (TID 11)")
+
+ val task2T = mock[ThreadInfo]
+ when(task2T.getThreadId).thenReturn(12L)
+ when(task2T.getThreadName)
+ .thenReturn("Executor task launch worker for task 2.0 in stage 1.0 (TID 22)")
+ when(task2T.toString)
+ .thenReturn("Executor task launch worker for task 2.0 in stage 1.0 (TID 22)")
+
+ val connectExecuteOp1T = mock[ThreadInfo]
+ when(connectExecuteOp1T.getThreadId).thenReturn(21L)
+ when(connectExecuteOp1T.getThreadName)
+ .thenReturn("SparkConnectExecuteThread_opId=16148fb4-4189-43c3-b8d4-8b3b6ddd41c7")
+ when(connectExecuteOp1T.toString)
+ .thenReturn("SparkConnectExecuteThread_opId=16148fb4-4189-43c3-b8d4-8b3b6ddd41c7")
+
+ val connectExecuteOp2T = mock[ThreadInfo]
+ when(connectExecuteOp2T.getThreadId).thenReturn(22L)
+ when(connectExecuteOp2T.getThreadName)
+ .thenReturn("SparkConnectExecuteThread_opId=4e4d1cac-ffde-46c1-b7c2-808b726cb47e")
+ when(connectExecuteOp2T.toString)
+ .thenReturn("SparkConnectExecuteThread_opId=4e4d1cac-ffde-46c1-b7c2-808b726cb47e")
+
+ val sorted = Seq(connectExecuteOp1T, connectExecuteOp2T, task1T, task2T)
+ .sorted(Utils.threadInfoOrdering)
+ assert(sorted === Seq(task1T, task2T, connectExecuteOp1T, connectExecuteOp2T))
+ }
+
test("redact sensitive information") {
val sparkConf = new SparkConf
diff --git a/dev/create-release/do-release-docker.sh b/dev/create-release/do-release-docker.sh
index eaa8073fbca6e..e231d7a48eec0 100755
--- a/dev/create-release/do-release-docker.sh
+++ b/dev/create-release/do-release-docker.sh
@@ -120,6 +120,11 @@ GPG_KEY_FILE="$WORKDIR/gpg.key"
fcreate_secure "$GPG_KEY_FILE"
$GPG --export-secret-key --armor --pinentry-mode loopback --passphrase "$GPG_PASSPHRASE" "$GPG_KEY" > "$GPG_KEY_FILE"
+# Build base image first (contains common tools shared across all branches)
+run_silent "Building spark-rm-base image..." "docker-build-base.log" \
+ docker build -t "spark-rm-base:latest" -f "$SELF/spark-rm/Dockerfile.base" "$SELF/spark-rm"
+
+# Build branch-specific image (extends base with Java/Python versions for this branch)
run_silent "Building spark-rm image with tag $IMGTAG..." "docker-build.log" \
docker build -t "spark-rm:$IMGTAG" --build-arg UID=$UID "$SELF/spark-rm"
@@ -146,6 +151,7 @@ RELEASE_TAG=$RELEASE_TAG
GIT_REF=$GIT_REF
SPARK_PACKAGE_VERSION=$SPARK_PACKAGE_VERSION
ASF_USERNAME=$ASF_USERNAME
+ASF_NEXUS_TOKEN=$ASF_NEXUS_TOKEN
GIT_NAME=$GIT_NAME
GIT_EMAIL=$GIT_EMAIL
GPG_KEY=$GPG_KEY
diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh
index b984876f41643..1a80191cba4bb 100755
--- a/dev/create-release/release-build.sh
+++ b/dev/create-release/release-build.sh
@@ -40,6 +40,7 @@ SPARK_VERSION - (optional) Version of Spark being built (e.g. 2.1.2)
ASF_USERNAME - Username of ASF committer account
ASF_PASSWORD - Password of ASF committer account
+ASF_NEXUS_TOKEN - API token in ASF Nexus reposiotry
GPG_KEY - GPG key used to sign release artifacts
GPG_PASSPHRASE - Passphrase for GPG key
@@ -343,7 +344,7 @@ meta:
_edit_last: '4'
_wpas_done_all: '1'
---
-We are happy to announce the availability of Apache Spark ${RELEASE_VERSION}! Visit the release notes to read about the new features, or download the release today.
+We are happy to announce the availability of Apache Spark ${RELEASE_VERSION}! Visit the release notes to read about the new features, or download the release today.
EOF
fi
@@ -401,7 +402,7 @@ You can find the list of resolved issues and detailed changes in the [JIRA relea
We would like to acknowledge all community members for contributing ${ACKNOWLEDGE}"
- FILENAME="releases/_posts/${RELEASE_DATE}-spark-release-${RELEASE_VERSION}.md"
+ FILENAME="releases/_posts/${RELEASE_DATE}-spark-release-${RELEASE_VERSION//./-}.md"
mkdir -p releases/_posts
cat > "$FILENAME" <orgapachespark-" | \
- awk '// { id = $0 } // && $0 ~ /Apache Spark '"$RELEASE_VERSION"'/ { print id }' | \
- grep -oP '(?<=)orgapachespark-[0-9]+(?=)' | \
- sort -V | tail -n 1)
+ REPO_ID=$(
+ curl --retry 10 --retry-all-errors -s -u "$ASF_USERNAME:$ASF_NEXUS_TOKEN" \
+ https://repository.apache.org/service/local/staging/profile_repositories |
+ grep -A 13 "orgapachespark-" |
+ awk '// { id = $0 }
+ // && $0 ~ /Apache Spark '"$RELEASE_VERSION"'/ { print id }' |
+ sed -n 's/.*\(orgapachespark-[0-9][0-9]*\)<\/repositoryId>.*/\1/p' |
+ sort -V |
+ tail -n 1
+ )
if [[ -z "$REPO_ID" ]]; then
echo "No matching staging repository found for Apache Spark $RELEASE_VERSION"
@@ -511,7 +515,7 @@ EOF
echo "Using repository ID: $REPO_ID"
# Release the repository
- curl --retry 10 --retry-all-errors -s -u "$APACHE_USERNAME:$APACHE_PASSWORD" \
+ curl --retry 10 --retry-all-errors -s -u "$ASF_USERNAME:$ASF_NEXUS_TOKEN" \
-H "Content-Type: application/json" \
-X POST https://repository.apache.org/service/local/staging/bulk/promote \
-d "{\"data\": {\"stagedRepositoryIds\": [\"$REPO_ID\"], \"description\": \"Apache Spark $RELEASE_VERSION\"}}"
@@ -519,9 +523,13 @@ EOF
# Wait for release to complete
echo "Waiting for release to complete..."
while true; do
- STATUS=$(curl --retry 10 --retry-all-errors -s -u "$APACHE_USERNAME:$APACHE_PASSWORD" \
- https://repository.apache.org/service/local/staging/repository/$REPO_ID | \
- grep -oPm1 "(?<=)[^<]+")
+ STATUS=$(
+ curl --retry 10 --retry-all-errors -s -u "$ASF_USERNAME:$ASF_NEXUS_TOKEN" \
+ https://repository.apache.org/service/local/staging/repository/$REPO_ID |
+ sed -n 's:.*\([^<]*\).*:\1:p' |
+ head -n 1
+ )
+
echo "Current state: $STATUS"
if [[ "$STATUS" == "released" ]]; then
echo "Release complete."
@@ -537,7 +545,7 @@ EOF
done
# Drop the repository after release
- curl --retry 10 --retry-all-errors -s -u "$APACHE_USERNAME:$APACHE_PASSWORD" \
+ curl --retry 10 --retry-all-errors -s -u "$ASF_USERNAME:$ASF_NEXUS_TOKEN" \
-H "Content-Type: application/json" \
-X POST https://repository.apache.org/service/local/staging/bulk/drop \
-d "{\"data\": {\"stagedRepositoryIds\": [\"$REPO_ID\"], \"description\": \"Dropped after release\"}}"
@@ -547,7 +555,7 @@ EOF
# Remove old releases from the mirror
# Extract major.minor prefix
RELEASE_SERIES=$(echo "$RELEASE_VERSION" | cut -d. -f1-2)
-
+
# Fetch existing dist URLs
OLD_VERSION=$(svn ls https://dist.apache.org/repos/dist/release/spark/ | \
grep "^spark-$RELEASE_SERIES" | \
@@ -557,7 +565,7 @@ EOF
if [[ -n "$OLD_VERSION" ]]; then
echo "Removing old version: spark-$OLD_VERSION"
- svn rm "https://dist.apache.org/repos/dist/release/spark/spark-$OLD_VERSION" -m "Remove older $RELEASE_SERIES release after $RELEASE_VERSION"
+ svn rm "https://dist.apache.org/repos/dist/release/spark/spark-$OLD_VERSION" --username "$ASF_USERNAME" --password "$ASF_PASSWORD" --non-interactive -m "Remove older $RELEASE_SERIES release after $RELEASE_VERSION"
else
echo "No previous $RELEASE_SERIES version found to remove. Manually remove it if there is."
fi
diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile
index e53ac6b439c87..32a2879e229d5 100644
--- a/dev/create-release/spark-rm/Dockerfile
+++ b/dev/create-release/spark-rm/Dockerfile
@@ -15,112 +15,63 @@
# limitations under the License.
#
-# Image for building Spark releases. Based on Ubuntu 22.04.
-FROM ubuntu:jammy-20250819
-LABEL org.opencontainers.image.authors="Apache Spark project "
-LABEL org.opencontainers.image.licenses="Apache-2.0"
-LABEL org.opencontainers.image.ref.name="Apache Spark Release Manager Image"
-# Overwrite this label to avoid exposing the underlying Ubuntu OS version label
-LABEL org.opencontainers.image.version=""
+# Spark 4.1 release image
+# Extends the base image with:
+# - Java 17
+# - Python 3.10 with required packages
-ENV FULL_REFRESH_DATE=20250819
-
-ENV DEBIAN_FRONTEND=noninteractive
-ENV DEBCONF_NONINTERACTIVE_SEEN=true
+FROM spark-rm-base:latest
+# Install Java 17 for Spark 4.x
RUN apt-get update && apt-get install -y \
- build-essential \
- ca-certificates \
- curl \
- gfortran \
- git \
- subversion \
- gnupg \
- libcurl4-openssl-dev \
- libfontconfig1-dev \
- libfreetype6-dev \
- libfribidi-dev \
- libgit2-dev \
- libharfbuzz-dev \
- libjpeg-dev \
- liblapack-dev \
- libopenblas-dev \
- libpng-dev \
- libpython3-dev \
- libssl-dev \
- libtiff5-dev \
- libwebp-dev \
- libxml2-dev \
- msmtp \
- nodejs \
- npm \
openjdk-17-jdk-headless \
- pandoc \
- pkg-config \
+ && rm -rf /var/lib/apt/lists/*
+
+# Install Python 3.10
+RUN apt-get update && apt-get install -y \
python3.10 \
+ python3.10-dev \
python3-psutil \
- texlive-latex-base \
- texlive \
- texlive-fonts-extra \
- texinfo \
- texlive-latex-extra \
- qpdf \
- jq \
- r-base \
- ruby \
- ruby-dev \
- software-properties-common \
- wget \
- zlib1g-dev \
+ libpython3-dev \
&& rm -rf /var/lib/apt/lists/*
+# Install pip for Python 3.10
+RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10
-RUN echo 'deb https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/' >> /etc/apt/sources.list
-RUN gpg --keyserver hkps://keyserver.ubuntu.com --recv-key E298A3A825C0D65DFD57CBB651716619E084DAB9
-RUN gpg -a --export E084DAB9 | apt-key add -
-RUN add-apt-repository 'deb https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/'
-
-# See more in SPARK-39959, roxygen2 < 7.2.1
-RUN Rscript -e "install.packages(c('devtools', 'knitr', 'markdown', \
- 'rmarkdown', 'testthat', 'devtools', 'e1071', 'survival', 'arrow', \
- 'ggplot2', 'mvtnorm', 'statmod', 'xml2'), repos='https://cloud.r-project.org/')" && \
- Rscript -e "devtools::install_version('roxygen2', version='7.2.0', repos='https://cloud.r-project.org')" && \
- Rscript -e "devtools::install_version('lintr', version='2.0.1', repos='https://cloud.r-project.org')" && \
- Rscript -e "devtools::install_version('pkgdown', version='2.0.1', repos='https://cloud.r-project.org')" && \
- Rscript -e "devtools::install_version('preferably', version='0.4', repos='https://cloud.r-project.org')"
-
-# See more in SPARK-39735
-ENV R_LIBS_SITE="/usr/local/lib/R/site-library:${R_LIBS_SITE}:/usr/lib/R/library"
+# Basic Python packages for Spark 4.1
+ARG BASIC_PIP_PKGS="numpy pyarrow>=18.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 \
+ mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2 twine==3.4.1"
-ARG BASIC_PIP_PKGS="numpy pyarrow>=18.0.0 six==1.16.0 pandas==2.3.2 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2 twine==3.4.1"
# Python deps for Spark Connect
-ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.5 googleapis-common-protos==1.65.0 graphviz==0.20.3"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 \
+ googleapis-common-protos==1.71.0 graphviz==0.20.3"
# Install Python 3.10 packages
-RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10
-RUN python3.10 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
-RUN python3.10 -m pip install --ignore-installed 'six==1.16.0' # Avoid `python3-six` installation
-RUN python3.10 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS && \
+RUN python3.10 -m pip install --ignore-installed 'blinker>=1.6.2' && \
+ python3.10 -m pip install --ignore-installed 'six==1.16.0' && \
+ python3.10 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS && \
python3.10 -m pip install 'torch<2.6.0' torchvision --index-url https://download.pytorch.org/whl/cpu && \
python3.10 -m pip install deepspeed torcheval && \
python3.10 -m pip cache purge
+# Sphinx and documentation packages
# Should unpin 'sphinxcontrib-*' after upgrading sphinx>5
-# See 'ipython_genutils' in SPARK-38517
-# See 'docutils<0.18.0' in SPARK-39421
-RUN python3.10 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \
-ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.22' pyarrow pandas 'plotly>=4.8' 'docutils<0.18.0' \
-'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.12.1' \
-'pandas-stubs==1.2.0.53' 'grpcio==1.67.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \
-'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5'
-RUN python3.10 -m pip list
-
-RUN gem install --no-document "bundler:2.4.22"
-RUN ln -s "$(which python3.10)" "/usr/local/bin/python"
-RUN ln -s "$(which python3.10)" "/usr/local/bin/python3"
+# See 'ipython_genutils' in SPARK-38517, 'docutils<0.18.0' in SPARK-39421
+RUN python3.10 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' \
+ sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \
+ ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.22' pyarrow pandas \
+ 'plotly>=4.8' 'docutils<0.18.0' 'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' \
+ 'pytest-mypy-plugins==1.9.3' 'black==23.12.1' 'pandas-stubs==1.2.0.53' \
+ 'grpcio==1.76.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \
+ 'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' \
+ 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' \
+ 'sphinxcontrib-serializinghtml==1.1.5'
-WORKDIR /opt/spark-rm/output
+# Set Python 3.10 as the default
+RUN ln -sf "$(which python3.10)" "/usr/local/bin/python" && \
+ ln -sf "$(which python3.10)" "/usr/local/bin/python3"
+# Create user for release manager
ARG UID
RUN useradd -m -s /bin/bash -p spark-rm -u $UID spark-rm
USER spark-rm:spark-rm
diff --git a/dev/create-release/spark-rm/Dockerfile.base b/dev/create-release/spark-rm/Dockerfile.base
new file mode 100644
index 0000000000000..56e85256d52da
--- /dev/null
+++ b/dev/create-release/spark-rm/Dockerfile.base
@@ -0,0 +1,110 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Base image for building Spark releases. Based on Ubuntu 22.04.
+# This image contains common tools shared across all Spark versions:
+# - Build tools (gcc, make, etc.)
+# - R with pinned package versions
+# - Ruby with bundler
+# - TeX for documentation
+# - Node.js for documentation
+#
+# Branch-specific Dockerfiles should use "FROM spark-rm-base:latest" and add:
+# - Java version (8 or 17)
+# - Python version and pip packages
+
+FROM ubuntu:jammy-20250819
+LABEL org.opencontainers.image.authors="Apache Spark project "
+LABEL org.opencontainers.image.licenses="Apache-2.0"
+LABEL org.opencontainers.image.ref.name="Apache Spark Release Manager Base Image"
+LABEL org.opencontainers.image.version=""
+
+ENV FULL_REFRESH_DATE=20250819
+
+ENV DEBIAN_FRONTEND=noninteractive
+ENV DEBCONF_NONINTERACTIVE_SEEN=true
+
+# Install common system packages and build tools
+# Note: Java and Python are installed in branch-specific Dockerfiles
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ ca-certificates \
+ curl \
+ gfortran \
+ git \
+ subversion \
+ gnupg \
+ libcurl4-openssl-dev \
+ libfontconfig1-dev \
+ libfreetype6-dev \
+ libfribidi-dev \
+ libgit2-dev \
+ libharfbuzz-dev \
+ libjpeg-dev \
+ liblapack-dev \
+ libopenblas-dev \
+ libpng-dev \
+ libssl-dev \
+ libtiff5-dev \
+ libwebp-dev \
+ libxml2-dev \
+ msmtp \
+ nodejs \
+ npm \
+ pandoc \
+ pkg-config \
+ texlive-latex-base \
+ texlive \
+ texlive-fonts-extra \
+ texinfo \
+ texlive-latex-extra \
+ qpdf \
+ jq \
+ r-base \
+ ruby \
+ ruby-dev \
+ software-properties-common \
+ wget \
+ zlib1g-dev \
+ && rm -rf /var/lib/apt/lists/*
+
+# Set up R CRAN repository for latest R packages
+RUN echo 'deb https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/' >> /etc/apt/sources.list && \
+ gpg --keyserver hkps://keyserver.ubuntu.com --recv-key E298A3A825C0D65DFD57CBB651716619E084DAB9 && \
+ gpg -a --export E084DAB9 | apt-key add - && \
+ add-apt-repository 'deb https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/'
+
+# Install R packages (same versions across all branches)
+# See more in SPARK-39959, roxygen2 < 7.2.1
+RUN Rscript -e "install.packages(c('devtools', 'knitr', 'markdown', \
+ 'rmarkdown', 'testthat', 'e1071', 'survival', 'arrow', \
+ 'ggplot2', 'mvtnorm', 'statmod', 'xml2'), repos='https://cloud.r-project.org/')" && \
+ Rscript -e "devtools::install_version('roxygen2', version='7.2.0', repos='https://cloud.r-project.org')" && \
+ Rscript -e "devtools::install_version('lintr', version='2.0.1', repos='https://cloud.r-project.org')" && \
+ Rscript -e "devtools::install_version('preferably', version='0.4', repos='https://cloud.r-project.org')" && \
+ Rscript -e "devtools::install_version('pkgdown', version='2.0.1', repos='https://cloud.r-project.org')"
+
+# See more in SPARK-39735
+ENV R_LIBS_SITE="/usr/local/lib/R/site-library:${R_LIBS_SITE}:/usr/lib/R/library"
+
+# Install Ruby bundler (same version across all branches)
+RUN gem install --no-document "bundler:2.4.22"
+
+# Create workspace directory
+WORKDIR /opt/spark-rm/output
+
+# Note: Java, Python, and user creation are done in branch-specific Dockerfiles
diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3
index 5c4c053293e02..3720185bab318 100644
--- a/dev/deps/spark-deps-hadoop-3-hive-2.3
+++ b/dev/deps/spark-deps-hadoop-3-hive-2.3
@@ -3,7 +3,7 @@ JLargeArrays/1.5//JLargeArrays-1.5.jar
JTransforms/3.1//JTransforms-3.1.jar
RoaringBitmap/1.3.0//RoaringBitmap-1.3.0.jar
ST4/4.0.4//ST4-4.0.4.jar
-aircompressor/2.0.2//aircompressor-2.0.2.jar
+aircompressor/2.0.3//aircompressor-2.0.3.jar
algebra_2.13/2.8.0//algebra_2.13-2.8.0.jar
aliyun-java-sdk-core/4.5.10//aliyun-java-sdk-core-4.5.10.jar
aliyun-java-sdk-kms/2.11.0//aliyun-java-sdk-kms-2.11.0.jar
@@ -15,6 +15,7 @@ antlr4-runtime/4.13.1//antlr4-runtime-4.13.1.jar
aopalliance-repackaged/3.0.6//aopalliance-repackaged-3.0.6.jar
arpack/3.0.4//arpack-3.0.4.jar
arpack_combined_all/0.1//arpack_combined_all-0.1.jar
+arrow-compression/18.3.0//arrow-compression-18.3.0.jar
arrow-format/18.3.0//arrow-format-18.3.0.jar
arrow-memory-core/18.3.0//arrow-memory-core-18.3.0.jar
arrow-memory-netty-buffer-patch/18.3.0//arrow-memory-netty-buffer-patch-18.3.0.jar
@@ -32,7 +33,6 @@ breeze-macros_2.13/2.1.0//breeze-macros_2.13-2.1.0.jar
breeze_2.13/2.1.0//breeze_2.13-2.1.0.jar
bundle/2.29.52//bundle-2.29.52.jar
cats-kernel_2.13/2.8.0//cats-kernel_2.13-2.8.0.jar
-checker-qual/3.43.0//checker-qual-3.43.0.jar
chill-java/0.10.0//chill-java-0.10.0.jar
chill_2.13/0.10.0//chill_2.13-0.10.0.jar
commons-cli/1.10.0//commons-cli-1.10.0.jar
@@ -42,7 +42,7 @@ commons-compiler/3.1.9//commons-compiler-3.1.9.jar
commons-compress/1.28.0//commons-compress-1.28.0.jar
commons-crypto/1.1.0//commons-crypto-1.1.0.jar
commons-dbcp/1.4//commons-dbcp-1.4.jar
-commons-io/2.20.0//commons-io-2.20.0.jar
+commons-io/2.21.0//commons-io-2.21.0.jar
commons-lang/2.6//commons-lang-2.6.jar
commons-lang3/3.19.0//commons-lang3-3.19.0.jar
commons-math3/3.6.1//commons-math3-3.6.1.jar
@@ -61,14 +61,13 @@ derby/10.16.1.1//derby-10.16.1.1.jar
derbyshared/10.16.1.1//derbyshared-10.16.1.1.jar
derbytools/10.16.1.1//derbytools-10.16.1.1.jar
dropwizard-metrics-hadoop-metrics2-reporter/0.1.2//dropwizard-metrics-hadoop-metrics2-reporter-0.1.2.jar
-error_prone_annotations/2.36.0//error_prone_annotations-2.36.0.jar
esdk-obs-java/3.20.4.2//esdk-obs-java-3.20.4.2.jar
-failureaccess/1.0.2//failureaccess-1.0.2.jar
+failureaccess/1.0.3//failureaccess-1.0.3.jar
flatbuffers-java/25.2.10//flatbuffers-java-25.2.10.jar
gcs-connector/hadoop3-2.2.28/shaded/gcs-connector-hadoop3-2.2.28-shaded.jar
gmetric4j/1.0.10//gmetric4j-1.0.10.jar
gson/2.11.0//gson-2.11.0.jar
-guava/33.4.0-jre//guava-33.4.0-jre.jar
+guava/33.4.8-jre//guava-33.4.8-jre.jar
hadoop-aliyun/3.4.2//hadoop-aliyun-3.4.2.jar
hadoop-annotations/3.4.2//hadoop-annotations-3.4.2.jar
hadoop-aws/3.4.2//hadoop-aws-3.4.2.jar
@@ -101,7 +100,6 @@ icu4j/77.1//icu4j-77.1.jar
ini4j/0.5.4//ini4j-0.5.4.jar
istack-commons-runtime/4.1.2//istack-commons-runtime-4.1.2.jar
ivy/2.5.3//ivy-2.5.3.jar
-j2objc-annotations/3.0.0//j2objc-annotations-3.0.0.jar
jackson-annotations/2.20//jackson-annotations-2.20.jar
jackson-core/2.20.0//jackson-core-2.20.0.jar
jackson-databind/2.20.0//jackson-databind-2.20.0.jar
@@ -184,18 +182,17 @@ lapack/3.0.4//lapack-3.0.4.jar
leveldbjni-all/1.8//leveldbjni-all-1.8.jar
libfb303/0.9.3//libfb303-0.9.3.jar
libthrift/0.16.0//libthrift-0.16.0.jar
-listenablefuture/9999.0-empty-to-avoid-conflict-with-guava//listenablefuture-9999.0-empty-to-avoid-conflict-with-guava.jar
log4j-1.2-api/2.24.3//log4j-1.2-api-2.24.3.jar
log4j-api/2.24.3//log4j-api-2.24.3.jar
log4j-core/2.24.3//log4j-core-2.24.3.jar
log4j-layout-template-json/2.24.3//log4j-layout-template-json-2.24.3.jar
log4j-slf4j2-impl/2.24.3//log4j-slf4j2-impl-2.24.3.jar
lz4-java/1.8.0//lz4-java-1.8.0.jar
-metrics-core/4.2.33//metrics-core-4.2.33.jar
-metrics-graphite/4.2.33//metrics-graphite-4.2.33.jar
-metrics-jmx/4.2.33//metrics-jmx-4.2.33.jar
-metrics-json/4.2.33//metrics-json-4.2.33.jar
-metrics-jvm/4.2.33//metrics-jvm-4.2.33.jar
+metrics-core/4.2.37//metrics-core-4.2.37.jar
+metrics-graphite/4.2.37//metrics-graphite-4.2.37.jar
+metrics-jmx/4.2.37//metrics-jmx-4.2.37.jar
+metrics-json/4.2.37//metrics-json-4.2.37.jar
+metrics-jvm/4.2.37//metrics-jvm-4.2.37.jar
minlog/1.3.0//minlog-1.3.0.jar
netty-all/4.2.7.Final//netty-all-4.2.7.Final.jar
netty-buffer/4.2.7.Final//netty-buffer-4.2.7.Final.jar
@@ -246,10 +243,10 @@ opencsv/2.3//opencsv-2.3.jar
opentracing-api/0.33.0//opentracing-api-0.33.0.jar
opentracing-noop/0.33.0//opentracing-noop-0.33.0.jar
opentracing-util/0.33.0//opentracing-util-0.33.0.jar
-orc-core/2.2.1/shaded-protobuf/orc-core-2.2.1-shaded-protobuf.jar
+orc-core/2.2.2/shaded-protobuf/orc-core-2.2.2-shaded-protobuf.jar
orc-format/1.1.1/shaded-protobuf/orc-format-1.1.1-shaded-protobuf.jar
-orc-mapreduce/2.2.1/shaded-protobuf/orc-mapreduce-2.2.1-shaded-protobuf.jar
-orc-shims/2.2.1//orc-shims-2.2.1.jar
+orc-mapreduce/2.2.2/shaded-protobuf/orc-mapreduce-2.2.2-shaded-protobuf.jar
+orc-shims/2.2.2//orc-shims-2.2.2.jar
oro/2.0.8//oro-2.0.8.jar
osgi-resource-locator/1.0.3//osgi-resource-locator-1.0.3.jar
paranamer/2.8.3//paranamer-2.8.3.jar
diff --git a/dev/free_disk_space b/dev/free_disk_space
index 2e6621045901b..d0916a32f301a 100755
--- a/dev/free_disk_space
+++ b/dev/free_disk_space
@@ -44,6 +44,7 @@ sudo rm -rf /opt/hostedtoolcache/go
sudo rm -rf /opt/hostedtoolcache/node
du -sh /opt/*
+sudo apt-get update --fix-missing
sudo apt-get remove --purge -y '^aspnet.*'
sudo apt-get remove --purge -y '^dotnet-.*'
sudo apt-get remove --purge -y '^llvm-.*'
diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile
index 873d572063118..655e93d9eecc0 100644
--- a/dev/infra/Dockerfile
+++ b/dev/infra/Dockerfile
@@ -49,6 +49,7 @@ RUN apt-get update && apt-get install -y \
libpython3-dev \
libssl-dev \
libtiff5-dev \
+ libwebp-dev \
libxml2-dev \
nodejs \
npm \
@@ -91,16 +92,16 @@ RUN mkdir -p /usr/local/pypy/pypy3.10 && \
ln -sf /usr/local/pypy/pypy3.10/bin/pypy /usr/local/bin/pypy3.10 && \
ln -sf /usr/local/pypy/pypy3.10/bin/pypy /usr/local/bin/pypy3
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | pypy3
-RUN pypy3 -m pip install numpy 'six==1.16.0' 'pandas==2.3.0' scipy coverage matplotlib lxml
+RUN pypy3 -m pip install numpy 'six==1.16.0' 'pandas==2.3.3' scipy coverage matplotlib lxml
-ARG BASIC_PIP_PKGS="numpy pyarrow>=18.0.0 six==1.16.0 pandas==2.3.0 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
+ARG BASIC_PIP_PKGS="numpy pyarrow>=18.0.0 six==1.16.0 pandas==2.3.3 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
# Python deps for Spark Connect
-ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.5 googleapis-common-protos==1.65.0 graphviz==0.20.3"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 graphviz==0.20.3"
# Install Python 3.10 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10
-RUN python3.10 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
+RUN python3.10 -m pip install --ignore-installed 'blinker>=1.6.2' # mlflow needs this
RUN python3.10 -m pip install --ignore-installed 'six==1.16.0' # Avoid `python3-six` installation
RUN python3.10 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS && \
python3.10 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
@@ -113,7 +114,7 @@ RUN apt-get update && apt-get install -y \
python3.9 python3.9-distutils \
&& rm -rf /var/lib/apt/lists/*
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.9
-RUN python3.9 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
+RUN python3.9 -m pip install --ignore-installed 'blinker>=1.6.2' # mlflow needs this
RUN python3.9 -m pip install --force $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS && \
python3.9 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
python3.9 -m pip install torcheval && \
@@ -124,7 +125,7 @@ RUN apt-get update && apt-get install -y \
python3.11 \
&& rm -rf /var/lib/apt/lists/*
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
-RUN python3.11 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
+RUN python3.11 -m pip install --ignore-installed 'blinker>=1.6.2' # mlflow needs this
RUN python3.11 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS && \
python3.11 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
python3.11 -m pip install deepspeed torcheval && \
@@ -135,7 +136,7 @@ RUN apt-get update && apt-get install -y \
python3.12 \
&& rm -rf /var/lib/apt/lists/*
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.12
-RUN python3.12 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
+RUN python3.12 -m pip install --ignore-installed 'blinker>=1.6.2' # mlflow needs this
RUN python3.12 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS lxml && \
python3.12 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
python3.12 -m pip install torcheval && \
@@ -146,9 +147,10 @@ RUN apt-get update && apt-get install -y \
python3.13 \
&& rm -rf /var/lib/apt/lists/*
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.13
-# TODO(SPARK-49862) Add BASIC_PIP_PKGS and CONNECT_PIP_PKGS to Python 3.13 image when it supports Python 3.13
-RUN python3.13 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
-RUN python3.13 -m pip install numpy>=2.1 pyarrow>=18.0.0 six==1.16.0 pandas==2.3.0 scipy coverage matplotlib openpyxl grpcio==1.67.0 grpcio-status==1.67.0 lxml jinja2 && \
+RUN python3.13 -m pip install --ignore-installed 'blinker>=1.6.2' # mlflow needs this
+RUN python3.13 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS lxml && \
+ python3.13 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
+ python3.13 -m pip install torcheval && \
python3.13 -m pip cache purge
# Remove unused installation packages to free up disk space
diff --git a/dev/requirements.txt b/dev/requirements.txt
index 76652df744815..cde0957715bfe 100644
--- a/dev/requirements.txt
+++ b/dev/requirements.txt
@@ -61,10 +61,11 @@ black==23.12.1
py
# Spark Connect (required)
-grpcio>=1.67.0
-grpcio-status>=1.67.0
-googleapis-common-protos>=1.65.0
-protobuf==5.29.5
+grpcio>=1.76.0
+grpcio-status>=1.76.0
+googleapis-common-protos>=1.71.0
+protobuf==6.33.0
+zstandard>=0.25.0
# Spark Connect python proto generation plugin (optional)
mypy-protobuf==3.3.0
diff --git a/dev/run-tests b/dev/run-tests
index 91a1532a338b9..6067caf210ebb 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -20,9 +20,9 @@
FWDIR="$(cd "`dirname $0`"/..; pwd)"
cd "$FWDIR"
-PYTHON_VERSION_CHECK=$(python3 -c 'import sys; print(sys.version_info < (3, 8, 0))')
+PYTHON_VERSION_CHECK=$(python3 -c 'import sys; print(sys.version_info < (3, 10, 0))')
if [[ "$PYTHON_VERSION_CHECK" == "True" ]]; then
- echo "Python versions prior to 3.8 are not supported."
+ echo "Python versions prior to 3.10 are not supported."
exit -1
fi
diff --git a/dev/spark-test-image-util/docs/run-in-container b/dev/spark-test-image-util/docs/run-in-container
index 1d43c602f7c72..3bfb3c5f651dd 100644
--- a/dev/spark-test-image-util/docs/run-in-container
+++ b/dev/spark-test-image-util/docs/run-in-container
@@ -28,8 +28,8 @@ cd /__w/spark/spark/docs
bundle install
# 3.Build docs, includes: `error docs`, `scala doc`, `python doc`, `sql doc`, excludes: `r doc`.
-# We need this link to make sure `python3` points to `python3.9` which contains the prerequisite packages.
-ln -s "$(which python3.9)" "/usr/local/bin/python3"
+# We need this link to make sure `python3` points to `python3.11` which contains the prerequisite packages.
+ln -s "$(which python3.11)" "/usr/local/bin/python3"
# Build docs first with SKIP_API to ensure they are buildable without requiring any
# language docs to be built beforehand.
diff --git a/dev/spark-test-image/docs/Dockerfile b/dev/spark-test-image/docs/Dockerfile
index c4cd43b9eb3ba..e268ea7a8351b 100644
--- a/dev/spark-test-image/docs/Dockerfile
+++ b/dev/spark-test-image/docs/Dockerfile
@@ -49,6 +49,7 @@ RUN apt-get update && apt-get install -y \
libpython3-dev \
libssl-dev \
libtiff5-dev \
+ libwebp-dev \
libxml2-dev \
nodejs \
npm \
@@ -88,8 +89,8 @@ RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
# See 'ipython_genutils' in SPARK-38517
# See 'docutils<0.18.0' in SPARK-39421
RUN python3.11 -m pip install 'sphinx==4.5.0' mkdocs 'pydata_sphinx_theme>=0.13' sphinx-copybutton nbsphinx numpydoc jinja2 markupsafe 'pyzmq<24.0.0' \
- ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.22' pyarrow 'pandas==2.3.2' 'plotly>=4.8' 'docutils<0.18.0' \
+ ipython ipython_genutils sphinx_plotly_directive 'numpy>=1.22' pyarrow 'pandas==2.3.3' 'plotly>=4.8' 'docutils<0.18.0' \
'flake8==3.9.0' 'mypy==1.8.0' 'pytest==7.1.3' 'pytest-mypy-plugins==1.9.3' 'black==23.12.1' \
- 'pandas-stubs==1.2.0.53' 'grpcio==1.67.0' 'grpcio-status==1.67.0' 'protobuf==5.29.5' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \
+ 'pandas-stubs==1.2.0.53' 'grpcio==1.76.0' 'grpcio-status==1.76.0' 'protobuf==6.33.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' \
'sphinxcontrib-applehelp==1.0.4' 'sphinxcontrib-devhelp==1.0.2' 'sphinxcontrib-htmlhelp==2.0.1' 'sphinxcontrib-qthelp==1.0.3' 'sphinxcontrib-serializinghtml==1.1.5' \
&& python3.11 -m pip cache purge
diff --git a/dev/spark-test-image/lint/Dockerfile b/dev/spark-test-image/lint/Dockerfile
index 3b603d4ab4a68..6ab571bf35d6e 100644
--- a/dev/spark-test-image/lint/Dockerfile
+++ b/dev/spark-test-image/lint/Dockerfile
@@ -46,6 +46,7 @@ RUN apt-get update && apt-get install -y \
libpng-dev \
libssl-dev \
libtiff5-dev \
+ libwebp-dev \
libxml2-dev \
nodejs \
npm \
@@ -81,8 +82,9 @@ RUN python3.11 -m pip install \
'flake8==3.9.0' \
'googleapis-common-protos-stubs==2.2.0' \
'grpc-stubs==1.24.11' \
- 'grpcio-status==1.67.0' \
- 'grpcio==1.67.0' \
+ 'grpcio-status==1.76.0' \
+ 'grpcio==1.76.0' \
+ 'zstandard==0.25.0' \
'ipython' \
'ipython_genutils' \
'jinja2' \
@@ -93,7 +95,7 @@ RUN python3.11 -m pip install \
'pandas' \
'pandas-stubs==1.2.0.53' \
'plotly>=4.8' \
- 'pyarrow>=21.0.0' \
+ 'pyarrow>=22.0.0' \
'pytest-mypy-plugins==1.9.3' \
'pytest==7.1.3' \
&& python3.11 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu \
diff --git a/dev/spark-test-image/numpy-213/Dockerfile b/dev/spark-test-image/numpy-213/Dockerfile
index 116154b663b07..713e9e7d7ef4d 100644
--- a/dev/spark-test-image/numpy-213/Dockerfile
+++ b/dev/spark-test-image/numpy-213/Dockerfile
@@ -49,6 +49,7 @@ RUN apt-get update && apt-get install -y \
libpython3-dev \
libssl-dev \
libtiff5-dev \
+ libwebp-dev \
libxml2-dev \
openjdk-17-jdk-headless \
pkg-config \
@@ -68,12 +69,12 @@ RUN apt-get update && apt-get install -y \
# Pin numpy==2.1.3
-ARG BASIC_PIP_PKGS="numpy==2.1.3 pyarrow>=21.0.0 six==1.16.0 pandas==2.2.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
+ARG BASIC_PIP_PKGS="numpy==2.1.3 pyarrow>=22.0.0 six==1.16.0 pandas==2.2.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
# Python deps for Spark Connect
-ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.5 googleapis-common-protos==1.65.0 graphviz==0.20.3"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"
# Install Python 3.11 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
-RUN python3.11 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
+RUN python3.11 -m pip install --ignore-installed 'blinker>=1.6.2' # mlflow needs this
RUN python3.11 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS && \
python3.11 -m pip cache purge
diff --git a/dev/spark-test-image/pypy-310/Dockerfile b/dev/spark-test-image/pypy-310/Dockerfile
index cddf0f8ea10a3..c8672fc0ec068 100644
--- a/dev/spark-test-image/pypy-310/Dockerfile
+++ b/dev/spark-test-image/pypy-310/Dockerfile
@@ -49,6 +49,7 @@ RUN apt-get update && apt-get install -y \
libpython3-dev \
libssl-dev \
libtiff5-dev \
+ libwebp-dev \
libxml2-dev \
openjdk-17-jdk-headless \
pkg-config \
@@ -68,4 +69,4 @@ RUN mkdir -p /usr/local/pypy/pypy3.10 && \
ln -sf /usr/local/pypy/pypy3.10/bin/pypy /usr/local/bin/pypy3.10 && \
ln -sf /usr/local/pypy/pypy3.10/bin/pypy /usr/local/bin/pypy3
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | pypy3
-RUN pypy3 -m pip install numpy 'six==1.16.0' 'pandas==2.3.2' scipy coverage matplotlib lxml
+RUN pypy3 -m pip install numpy 'six==1.16.0' 'pandas==2.3.3' scipy coverage matplotlib lxml
diff --git a/dev/spark-test-image/python-310/Dockerfile b/dev/spark-test-image/python-310/Dockerfile
index cfc03bccdf7c1..9b5b18d061c2e 100644
--- a/dev/spark-test-image/python-310/Dockerfile
+++ b/dev/spark-test-image/python-310/Dockerfile
@@ -49,6 +49,7 @@ RUN apt-get update && apt-get install -y \
libpython3-dev \
libssl-dev \
libtiff5-dev \
+ libwebp-dev \
libxml2-dev \
openjdk-17-jdk-headless \
pkg-config \
@@ -63,13 +64,13 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
-ARG BASIC_PIP_PKGS="numpy pyarrow>=21.0.0 six==1.16.0 pandas==2.3.2 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
+ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
# Python deps for Spark Connect
-ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.5 googleapis-common-protos==1.65.0 graphviz==0.20.3"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"
# Install Python 3.10 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10
-RUN python3.10 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
+RUN python3.10 -m pip install --ignore-installed 'blinker>=1.6.2' # mlflow needs this
RUN python3.10 -m pip install --ignore-installed 'six==1.16.0' # Avoid `python3-six` installation
RUN python3.10 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS && \
python3.10 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
diff --git a/dev/spark-test-image/python-311-classic-only/Dockerfile b/dev/spark-test-image/python-311-classic-only/Dockerfile
index 6a71317a5fe44..1c5f9a2335787 100644
--- a/dev/spark-test-image/python-311-classic-only/Dockerfile
+++ b/dev/spark-test-image/python-311-classic-only/Dockerfile
@@ -49,6 +49,7 @@ RUN apt-get update && apt-get install -y \
libpython3-dev \
libssl-dev \
libtiff5-dev \
+ libwebp-dev \
libxml2-dev \
openjdk-17-jdk-headless \
pkg-config \
@@ -67,12 +68,12 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
-ARG BASIC_PIP_PKGS="numpy pyarrow>=21.0.0 pandas==2.3.2 plotly<6.0.0 matplotlib openpyxl memory-profiler>=0.61.0 mlflow>=2.8.1 scipy scikit-learn>=1.3.2"
+ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 pandas==2.3.3 plotly<6.0.0 matplotlib openpyxl memory-profiler>=0.61.0 mlflow>=2.8.1 scipy scikit-learn>=1.3.2"
ARG TEST_PIP_PKGS="coverage unittest-xml-reporting"
# Install Python 3.11 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
-RUN python3.11 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
+RUN python3.11 -m pip install --ignore-installed 'blinker>=1.6.2' # mlflow needs this
RUN python3.11 -m pip install $BASIC_PIP_PKGS $TEST_PIP_PKGS && \
python3.11 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
python3.11 -m pip install deepspeed torcheval && \
diff --git a/dev/spark-test-image/python-311/Dockerfile b/dev/spark-test-image/python-311/Dockerfile
index 962f6427de6a8..f8a9df5842ce0 100644
--- a/dev/spark-test-image/python-311/Dockerfile
+++ b/dev/spark-test-image/python-311/Dockerfile
@@ -49,6 +49,7 @@ RUN apt-get update && apt-get install -y \
libpython3-dev \
libssl-dev \
libtiff5-dev \
+ libwebp-dev \
libxml2-dev \
openjdk-17-jdk-headless \
pkg-config \
@@ -67,13 +68,13 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
-ARG BASIC_PIP_PKGS="numpy pyarrow>=21.0.0 six==1.16.0 pandas==2.3.2 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
+ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
# Python deps for Spark Connect
-ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.5 googleapis-common-protos==1.65.0 graphviz==0.20.3"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"
# Install Python 3.11 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.11
-RUN python3.11 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
+RUN python3.11 -m pip install --ignore-installed 'blinker>=1.6.2' # mlflow needs this
RUN python3.11 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS && \
python3.11 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
python3.11 -m pip install deepspeed torcheval && \
diff --git a/dev/spark-test-image/python-312/Dockerfile b/dev/spark-test-image/python-312/Dockerfile
index afa24025c46c7..ca62bc5ebc611 100644
--- a/dev/spark-test-image/python-312/Dockerfile
+++ b/dev/spark-test-image/python-312/Dockerfile
@@ -49,6 +49,7 @@ RUN apt-get update && apt-get install -y \
libpython3-dev \
libssl-dev \
libtiff5-dev \
+ libwebp-dev \
libxml2-dev \
openjdk-17-jdk-headless \
pkg-config \
@@ -67,13 +68,13 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
-ARG BASIC_PIP_PKGS="numpy pyarrow>=21.0.0 six==1.16.0 pandas==2.3.2 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
+ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
# Python deps for Spark Connect
-ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.5 googleapis-common-protos==1.65.0 graphviz==0.20.3"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"
# Install Python 3.12 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.12
-RUN python3.12 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
+RUN python3.12 -m pip install --ignore-installed 'blinker>=1.6.2' # mlflow needs this
RUN python3.12 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS lxml && \
python3.12 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
python3.12 -m pip install torcheval && \
diff --git a/dev/spark-test-image/python-313-nogil/Dockerfile b/dev/spark-test-image/python-313-nogil/Dockerfile
index c7d2faed010f1..b6e2dd7c80a97 100644
--- a/dev/spark-test-image/python-313-nogil/Dockerfile
+++ b/dev/spark-test-image/python-313-nogil/Dockerfile
@@ -49,6 +49,7 @@ RUN apt-get update && apt-get install -y \
libpython3-dev \
libssl-dev \
libtiff5-dev \
+ libwebp-dev \
libxml2-dev \
openjdk-17-jdk-headless \
pkg-config \
@@ -67,14 +68,14 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
-ARG BASIC_PIP_PKGS="numpy pyarrow>=21.0.0 six==1.16.0 pandas==2.3.2 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
-ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.5 googleapis-common-protos==1.65.0 graphviz==0.20.3"
+ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"
# Install Python 3.13 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.13t
# TODO: Add BASIC_PIP_PKGS and CONNECT_PIP_PKGS when it supports Python 3.13 free threaded
# TODO: Add lxml, grpcio, grpcio-status back when they support Python 3.13 free threaded
-RUN python3.13t -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
-RUN python3.13t -m pip install numpy>=2.1 pyarrow>=19.0.0 six==1.16.0 pandas==2.3.2 scipy coverage matplotlib openpyxl jinja2 && \
+RUN python3.13t -m pip install --ignore-installed 'blinker>=1.6.2' # mlflow needs this
+RUN python3.13t -m pip install 'numpy>=2.1' 'pyarrow>=19.0.0' 'six==1.16.0' 'pandas==2.3.3' scipy coverage matplotlib openpyxl jinja2 && \
python3.13t -m pip cache purge
diff --git a/dev/spark-test-image/python-313/Dockerfile b/dev/spark-test-image/python-313/Dockerfile
index dcc68575c496e..bd64ecb31087d 100644
--- a/dev/spark-test-image/python-313/Dockerfile
+++ b/dev/spark-test-image/python-313/Dockerfile
@@ -49,6 +49,7 @@ RUN apt-get update && apt-get install -y \
libpython3-dev \
libssl-dev \
libtiff5-dev \
+ libwebp-dev \
libxml2-dev \
openjdk-17-jdk-headless \
pkg-config \
@@ -67,13 +68,13 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
-ARG BASIC_PIP_PKGS="numpy pyarrow>=21.0.0 six==1.16.0 pandas==2.3.2 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
+ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
# Python deps for Spark Connect
-ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 protobuf==5.29.5 googleapis-common-protos==1.65.0 graphviz==0.20.3"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"
# Install Python 3.13 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.13
-RUN python3.13 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
+RUN python3.13 -m pip install --ignore-installed 'blinker>=1.6.2' # mlflow needs this
RUN python3.13 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS lxml && \
python3.13 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
python3.13 -m pip install torcheval && \
diff --git a/dev/spark-test-image/python-314/Dockerfile b/dev/spark-test-image/python-314/Dockerfile
index 5ab4154dd0f71..f3da21e005b30 100644
--- a/dev/spark-test-image/python-314/Dockerfile
+++ b/dev/spark-test-image/python-314/Dockerfile
@@ -49,6 +49,7 @@ RUN apt-get update && apt-get install -y \
libpython3-dev \
libssl-dev \
libtiff5-dev \
+ libwebp-dev \
libxml2-dev \
openjdk-17-jdk-headless \
pkg-config \
@@ -67,13 +68,13 @@ RUN apt-get update && apt-get install -y \
&& rm -rf /var/lib/apt/lists/*
-ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
+ARG BASIC_PIP_PKGS="numpy pyarrow>=22.0.0 six==1.16.0 pandas==2.3.3 scipy plotly<6.0.0 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2"
# Python deps for Spark Connect
-ARG CONNECT_PIP_PKGS="grpcio==1.75.1 grpcio-status==1.71.2 protobuf==5.29.5 googleapis-common-protos==1.65.0 graphviz==0.20.3"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 protobuf==6.33.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20.3"
# Install Python 3.14 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.14
-RUN python3.14 -m pip install --ignore-installed blinker>=1.6.2 # mlflow needs this
+RUN python3.14 -m pip install --ignore-installed 'blinker>=1.6.2' # mlflow needs this
RUN python3.14 -m pip install $BASIC_PIP_PKGS unittest-xml-reporting $CONNECT_PIP_PKGS lxml && \
python3.14 -m pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu && \
python3.14 -m pip install torcheval && \
diff --git a/dev/spark-test-image/python-minimum/Dockerfile b/dev/spark-test-image/python-minimum/Dockerfile
index 8f42d02023e50..575b4afdd02c0 100644
--- a/dev/spark-test-image/python-minimum/Dockerfile
+++ b/dev/spark-test-image/python-minimum/Dockerfile
@@ -50,6 +50,7 @@ RUN apt-get update && apt-get install -y \
libpython3-dev \
libssl-dev \
libtiff5-dev \
+ libwebp-dev \
libxml2-dev \
openjdk-17-jdk-headless \
pkg-config \
@@ -63,9 +64,9 @@ RUN apt-get update && apt-get install -y \
ARG BASIC_PIP_PKGS="numpy==1.22.4 pyarrow==15.0.0 pandas==2.2.0 six==1.16.0 scipy scikit-learn coverage unittest-xml-reporting"
# Python deps for Spark Connect
-ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 googleapis-common-protos==1.65.0 graphviz==0.20 protobuf"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20 protobuf"
-# Install Python 3.9 packages
+# Install Python 3.10 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10
RUN python3.10 -m pip install --force $BASIC_PIP_PKGS $CONNECT_PIP_PKGS && \
python3.10 -m pip cache purge
diff --git a/dev/spark-test-image/python-ps-minimum/Dockerfile b/dev/spark-test-image/python-ps-minimum/Dockerfile
index 440fda96f0fc8..5142d46cc3eb0 100644
--- a/dev/spark-test-image/python-ps-minimum/Dockerfile
+++ b/dev/spark-test-image/python-ps-minimum/Dockerfile
@@ -50,6 +50,7 @@ RUN apt-get update && apt-get install -y \
libpython3-dev \
libssl-dev \
libtiff5-dev \
+ libwebp-dev \
libxml2-dev \
openjdk-17-jdk-headless \
pkg-config \
@@ -64,7 +65,7 @@ RUN apt-get update && apt-get install -y \
ARG BASIC_PIP_PKGS="pyarrow==15.0.0 pandas==2.2.0 six==1.16.0 numpy scipy coverage unittest-xml-reporting"
# Python deps for Spark Connect
-ARG CONNECT_PIP_PKGS="grpcio==1.67.0 grpcio-status==1.67.0 googleapis-common-protos==1.65.0 graphviz==0.20 protobuf"
+ARG CONNECT_PIP_PKGS="grpcio==1.76.0 grpcio-status==1.76.0 googleapis-common-protos==1.71.0 zstandard==0.25.0 graphviz==0.20 protobuf"
# Install Python 3.10 packages
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10
diff --git a/dev/spark-test-image/sparkr/Dockerfile b/dev/spark-test-image/sparkr/Dockerfile
index 3312c0852bd77..6c0314c051d17 100644
--- a/dev/spark-test-image/sparkr/Dockerfile
+++ b/dev/spark-test-image/sparkr/Dockerfile
@@ -49,6 +49,7 @@ RUN apt-get update && apt-get install -y \
libpython3-dev \
libssl-dev \
libtiff5-dev \
+ libwebp-dev \
libxml2-dev \
pandoc \
pkg-config \
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 07ac4c76b91a6..aa8ca58a5a75f 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -1114,6 +1114,8 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_connect_retry",
"pyspark.sql.tests.connect.test_connect_session",
"pyspark.sql.tests.connect.test_connect_stat",
+ "pyspark.sql.tests.connect.test_parity_geographytype",
+ "pyspark.sql.tests.connect.test_parity_geometrytype",
"pyspark.sql.tests.connect.test_parity_datasources",
"pyspark.sql.tests.connect.test_parity_errors",
"pyspark.sql.tests.connect.test_parity_catalog",
diff --git a/docs/_config.yml b/docs/_config.yml
index 2cb5b0704bae6..bd5c50b8ace1c 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -19,8 +19,8 @@ include:
# These allow the documentation to be updated with newer releases
# of Spark, Scala.
-SPARK_VERSION: 4.1.0-SNAPSHOT
-SPARK_VERSION_SHORT: 4.1.0
+SPARK_VERSION: 4.1.2-SNAPSHOT
+SPARK_VERSION_SHORT: 4.1.2
SCALA_BINARY_VERSION: "2.13"
SCALA_VERSION: "2.13.17"
SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK
@@ -39,7 +39,7 @@ DOCSEARCH_SCRIPT: |
inputSelector: '#docsearch-input',
enhancedSearchInput: true,
algoliaOptions: {
- 'facetFilters': ["version:latest"]
+ 'facetFilters': ["version:4.1.2"]
},
debug: false // Set debug to true if you want to inspect the dropdown
});
diff --git a/docs/_data/menu-sql.yaml b/docs/_data/menu-sql.yaml
index b1688aec57f01..f530d78dae453 100644
--- a/docs/_data/menu-sql.yaml
+++ b/docs/_data/menu-sql.yaml
@@ -99,6 +99,10 @@
url: sql-ref-literals.html
- text: Null Semantics
url: sql-ref-null-semantics.html
+ - text: Name Resolution
+ url: sql-ref-name-resolution.html
+ - text: SQL Scripting
+ url: sql-ref-scripting.html
- text: SQL Syntax
url: sql-ref-syntax.html
subitems:
@@ -108,6 +112,8 @@
url: sql-ref-syntax.html#dml-statements
- text: Data Retrieval(Queries)
url: sql-ref-syntax.html#data-retrieval-statements
+ - text: SQL Scripting Statements
+ url: sql-ref-syntax.html#sql-scripting-statements
- text: Auxiliary Statements
url: sql-ref-syntax.html#auxiliary-statements
- text: Pipe Syntax
diff --git a/docs/configuration.md b/docs/configuration.md
index dc9ca63d24d97..c1d3f082e87b1 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1548,7 +1548,7 @@ Apart from these, the following properties are also available, and may be useful
spark.eventLog.rolling.enabled |
- false |
+ true |
Whether rolling over event log files is enabled. If set to true, it cuts down each event
log file to the configured size.
diff --git a/docs/control-flow/case-stmt.md b/docs/control-flow/case-stmt.md
new file mode 100644
index 0000000000000..c92663905b061
--- /dev/null
+++ b/docs/control-flow/case-stmt.md
@@ -0,0 +1,102 @@
+---
+layout: global
+title: CASE statement
+displayTitle: CASE statement
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+Executes `thenStmtN` for the first `optN` that equals `expr` or `elseStmt` if no `optN` matches `expr`.
+This is called a _simple case statement_.
+
+Executes `thenStmtN` for the first `condN` evaluating to `true`, or `elseStmt` if no `condN` evaluates to `true`.
+This is called a _searched case statement_.
+
+For case expressions that yield result values, see `CASE expression`)
+
+This statement may only be used within a [compound statement](compound-stmt.html).
+
+## Syntax
+
+```
+CASE expr
+ { WHEN opt THEN { thenStmt ; } [...] } [...]
+ [ ELSE { elseStmt ; } [...] ]
+END CASE
+
+CASE
+ { WHEN cond THEN { thenStmt ; } [...] } [...]
+ [ ELSE { elseStmt ; } [...] ]
+END CASE
+```
+
+## Parameters
+
+- **`expr`**: Any expression for which a comparison is defined.
+- **`opt`**: An expression with a least common type with `expr` and all other `optN`.
+- **`thenStmt`**: A SQL Statement to execute if preceding condition is `true`.
+- **`elseStmt`**: A SQL Statement to execute if no condition is `true`.
+- **`cond`**: A `BOOLEAN` expression.
+
+Conditions are evaluated in order, and only the first set of `stmt` for which `opt` or `cond` evaluate to true will be executed.
+
+## Examples
+
+```SQL
+-- a simple case statement
+> BEGIN
+ DECLARE choice INT DEFAULT 3;
+ DECLARE result STRING;
+ CASE choice
+ WHEN 1 THEN
+ VALUES ('one fish');
+ WHEN 2 THEN
+ VALUES ('two fish');
+ WHEN 3 THEN
+ VALUES ('red fish');
+ WHEN 4 THEN
+ VALUES ('blue fish');
+ ELSE
+ VALUES ('no fish');
+ END CASE;
+ END;
+ red fish
+
+-- A searched case statement
+> BEGIN
+ DECLARE choice DOUBLE DEFAULT 3.9;
+ DECLARE result STRING;
+ CASE
+ WHEN choice < 2 THEN
+ VALUES ('one fish');
+ WHEN choice < 3 THEN
+ VALUES ('two fish');
+ WHEN choice < 4 THEN
+ VALUES ('red fish');
+ WHEN choice < 5 OR choice IS NULL THEN
+ VALUES ('blue fish');
+ ELSE
+ VALUES ('no fish');
+ END CASE;
+ END;
+ red fish
+```
+
+## Related articles
+
+- [SQL Scripting](../sql-ref-scripting.html)
+- [compound statement](compound-stmt.html)
+- [IF statement](if-stmt.html)
diff --git a/docs/control-flow/compound-stmt.md b/docs/control-flow/compound-stmt.md
new file mode 100644
index 0000000000000..d34e70648de43
--- /dev/null
+++ b/docs/control-flow/compound-stmt.md
@@ -0,0 +1,164 @@
+---
+layout: global
+title: compound statement
+displayTitle: compound statement
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+Implements a SQL Script block that can contain a sequence of SQL statements, control-of-flow statements, local variable declarations, and exception handlers.
+
+## Syntax
+
+```
+[ label : ]
+ BEGIN
+ [ { declare_variable | declare_condition } ; [...] ]
+ [ declare_handler ; [...] ]
+ [ SQL_statement ; [...] ]
+ END [ label ]
+
+declare_variable
+ DECLARE variable_name [, ...] datatype [ DEFAULT default_expr ]
+
+declare_condition
+ DECLARE condition_name CONDITION [ FOR SQLSTATE [ VALUE ] sqlstate ]
+
+declare_handler
+ DECLARE handler_type HANDLER FOR condition_values handler_action
+
+handler_type
+ EXIT
+
+condition_values
+ { { SQLSTATE [ VALUE ] sqlstate | condition_name } [, ...] |
+ { SQLEXCEPTION | NOT FOUND } [, ...] }
+```
+
+## Parameters
+
+- **`label`**
+
+ An optional identifier is used to qualify variables defined within the compound and to leave the compound.
+ Both label occurrences must match, and the `END` label can only be specified if `label:` is specified.
+
+ `label` must not be specified for a top level compound statement.
+
+- **`NOT ATOMIC`**
+
+ Specifies that, if an SQL statement within the compound fails, previous SQL statements will not be rolled back.
+ This is the default and only behavior.
+
+- **`declare_variable`**
+
+ A local variable declaration for one or more variables
+
+ - **`variable_name`**
+
+ A name for the variable.
+ The name must not be qualified, and be unique within the compound statement.
+
+ - **`data_type`**
+
+ Any supported data type. If data_type is omitted, you must specify DEFAULT, and the type is derived from the default_expression.
+
+ - **`{ DEFAULT | = } default_expression`**
+
+ Defines the variable's initial value after declaration. default_expression must be castable to data_type. If no default is specified, the variable is initialized with NULL.
+
+- **`Declare_condition`**
+
+ A local condition declaration
+
+ - **`condition_name`**
+
+ The unqualified name of the condition is scoped to the compound statement.
+
+ - **`sqlstate`**
+
+ A `STRING` literal of 5 alphanumeric characters (case insensitive) consisting of A-Z and 0..9. The SQLSTATE must not start with ‘00’, ‘01’, or ‘XX’. Any SQLSTATE starting with ‘02’ will be caught by the predefined NOT FOUND exception as well. If not specified, the SQLSTATE is ‘45000’.
+
+- **`declare_handler`**
+
+ A declaration for an error handler.
+
+ - **`handler_type`**
+
+ - **`EXIT`**
+
+ Classifies the handler to exit the compound statement after the condition is handled.
+
+ - **`condition_values`**
+
+ Specifies to which sqlstates or conditions the handler applies.
+ Condition values must be unique within all handlers within the compound statement.
+ Specific condition values take precedence over `SQLEXCEPTION`.
+
+ - **`sqlstate`**
+
+ A `STRING` literal of 5 characters `'A'-'Z'` and `'0'-'9'` (case insensitive).
+
+ - **`condition_name`**
+
+ A condition defined within this compound, an outer compound statement, or a system-defined error class.
+
+ - **`SQLEXCEPTION`**
+
+ Applies to any user-facing error condition.
+
+ - **`NOT FOUND`**
+
+ Applies to any error condition with a SQLSTATE ‘02’ class.
+
+ - **`handler_action`**
+
+ A SQL statement to execute when any of the condition values occur.
+ To add multiple statements, use a nested compound statement.
+
+- **`SQL_statement`**
+
+ A SQL statement such as a DDL, DML, control statement, or compound statement.
+ Any `SELECT` or `VALUES` statement produces a result set that the invoker can consume.
+
+## Examples
+
+```SQL
+-- A compound statement with local variables, and exit hanlder and a nested compound.
+> BEGIN
+ DECLARE a INT DEFAULT 1;
+ DECLARE b INT DEFAULT 5;
+ DECLARE EXIT HANDLER FOR DIVIDE_BY_ZERO
+ div0: BEGIN
+ VALUES (15);
+ END div0;
+ SET a = 10;
+ SET a = b / 0;
+ VALUES (a);
+END;
+15
+```
+
+## Related articles
+
+- [SQL Scripting](../sql-ref-scripting.html)
+- [CASE Statement](../control-flow/case-stmt.html)
+- [IF Statement](../control-flow/if-stmt.html)
+- [LOOP Statement](../control-flow/loop-stmt.html)
+- [WHILE Statement](../control-flow/while-stmt.html)
+- [REPEAT Statement](../control-flow/repeat-stmt.html)
+- [FOR Statement](../control-flow/for-stmt.html)
+- [ITERATE Statement](../control-flow/iterate-stmt.html)
+- [LEAVE Statement](../control-flow/leave-stmt.html)
diff --git a/docs/control-flow/for-stmt.md b/docs/control-flow/for-stmt.md
new file mode 100644
index 0000000000000..25a1cfa7218ec
--- /dev/null
+++ b/docs/control-flow/for-stmt.md
@@ -0,0 +1,92 @@
+---
+layout: global
+title: FOR statement
+displayTitle: FOR statement
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+Repeat the execution of a list of statements for each row returned by query.
+
+This statement may only be used within a [compound statement](compound-stmt.html).
+
+## Syntax
+
+```
+[ label : ] FOR [ variable_name AS ] query
+ DO
+ { stmt ; } [...]
+ END FOR [ label ]
+```
+
+## Parameters
+
+- **label**
+
+ An optional label for the loop which is unique amongst all labels for statements within which the `FOR` statement is contained.
+ If an end label is specified, it must match the beginning label.
+ The label can be used to [LEAVE](leave-stmt.html) or [ITERATE](iterate-stmt.html) the loop.
+ To qualify loop column references, use the `variable_name`, not the `label`.
+
+- **variable_name**
+
+ An optional name you can use as a qualifier when referencing the columns in the cursor.
+
+- **stmt**
+
+ A SQL statement
+
+## Notes
+
+If the query operates on a table that is also modified within the loop's body, the semantics depend on the data source.
+For Delta tables, the query will remain unaffected.
+Spark does not guarantee the full execution of the query if the `FOR` loop completes prematurely due to a `LEAVE` statement or an error condition.
+When exceptions or side-effects occur during the execution of the query, Spark does not guarantee at which point in time within the loop these occur.
+Often `FOR` loops can be replaced with relational queries, which are typically more efficient.
+
+## Examples
+
+```SQL
+-- sum up all odd numbers from 1 through 10
+> BEGIN
+ DECLARE sum INT DEFAULT 0;
+ sumNumbers: FOR row AS SELECT num FROM range(1, 20) AS t(num) DO
+ IF num > 10 THEN
+ LEAVE sumNumbers;
+ ELSEIF num % 2 = 0 THEN
+ ITERATE sumNumbers;
+ END IF;
+ SET sum = sum + row.num;
+ END FOR sumNumbers;
+ VALUES (sum);
+ END;
+ 25
+
+-- Compare with the much more efficient relational computation:
+> SELECT sum(num) FROM range(1, 10) AS t(num) WHERE num % 2 = 1;
+ 25
+```
+
+## Related articles
+
+- [SQL Scripting](../sql-ref-scripting.html)
+- [CASE Statement](../control-flow/case-stmt.html)
+- [Compound Statement](../control-flow/compound-stmt.html)
+- [LOOP Statement](../control-flow/loop-stmt.html)
+- [WHILE Statement](../control-flow/while-stmt.html)
+- [REPEAT Statement](../control-flow/repeat-stmt.html)
+- [LEAVE Statement](../control-flow/leave-stmt.html)
+- [ITERATE Statement](../control-flow/iterate-stmt.html)
diff --git a/docs/control-flow/if-stmt.md b/docs/control-flow/if-stmt.md
new file mode 100644
index 0000000000000..a687fc5c1c643
--- /dev/null
+++ b/docs/control-flow/if-stmt.md
@@ -0,0 +1,70 @@
+---
+layout: global
+title: IF statement
+displayTitle: IF statement
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+Executes lists of statements based on the first condition that evaluates to true.
+
+This statement may only be used within a [compound statement](compound-stmt.html).
+
+## Syntax
+
+```
+IF condition THEN { stmt ; } [...]
+ [ { ELSEIF condition THEN { stmt ; } [...] } [...] ]
+ [ ELSE { stmt ; } [...] ]
+ END IF
+```
+
+## Parameters
+
+- **condition**
+
+ Any expression evaluating to a BOOLEAN.
+
+- **stmt**
+
+ A SQL statement to execute if the `condition` is `true`.
+
+## Examples
+
+```SQL
+> BEGIN
+ DECLARE choice DOUBLE DEFAULT 3.9;
+ IF choice < 2 THEN
+ VALUES ('one fish');
+ ELSEIF choice < 3 THEN
+ VALUES ('two fish');
+ ELSEIF choice < 4 THEN
+ VALUES ('red fish');
+ ELSEIF choice < 5 OR choice IS NULL THEN
+ VALUES ('blue fish');
+ ELSE
+ VALUES ('no fish');
+ END IF;
+ END;
+ red fish
+```
+
+## Related articles
+
+- [SQL Scripting](../sql-ref-scripting.html)
+- [CASE Statement](../control-flow/case-stmt.html)
+- [Compound Statement](../control-flow/compound-stmt.html)
+- [FOR Statement](../control-flow/for-stmt.html)
diff --git a/docs/control-flow/iterate-stmt.md b/docs/control-flow/iterate-stmt.md
new file mode 100644
index 0000000000000..d73f33a26bf95
--- /dev/null
+++ b/docs/control-flow/iterate-stmt.md
@@ -0,0 +1,70 @@
+---
+layout: global
+title: ITERATE statement
+displayTitle: ITERATE statement
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+Terminates the execution of an iteration of a looping statement and continues with the next iteration if the looping condition is met.
+
+This statement may only be used within a [compound statement](compound-stmt.html).
+
+## Syntax
+
+```
+ITERATE label
+```
+
+## Parameters
+
+- **label**
+
+ The label identifies a looping statement that contains the `ITERATE` statement directly or indirectly.
+
+## Examples
+
+```SQL
+-- sum up all odd numbers from 1 through 10
+> BEGIN
+ DECLARE sum INT DEFAULT 0;
+ DECLARE num INT DEFAULT 0;
+ sumNumbers: LOOP
+ SET num = num + 1;
+ IF num > 10 THEN
+ LEAVE sumNumbers;
+ END IF;
+ IF num % 2 = 0 THEN
+ ITERATE sumNumbers;
+ END IF;
+ SET sum = sum + num;
+ END LOOP sumNumbers;
+ VALUES (sum);
+ END;
+25
+```
+
+## Related articles
+
+- [SQL Scripting](../sql-ref-scripting.html)
+- [CASE Statement](../control-flow/case-stmt.html)
+- [Compound Statement](../control-flow/compound-stmt.html)
+- [FOR Statement](../control-flow/for-stmt.html)
+- [LOOP Statement](../control-flow/loop-stmt.html)
+- [WHILE Statement](../control-flow/while-stmt.html)
+- [REPEAT Statement](../control-flow/repeat-stmt.html)
+- [IF Statement](../control-flow/if-stmt.html)
+- [LEAVE Statement](../control-flow/leave-stmt.html)
diff --git a/docs/control-flow/leave-stmt.md b/docs/control-flow/leave-stmt.md
new file mode 100644
index 0000000000000..a705c48b239a3
--- /dev/null
+++ b/docs/control-flow/leave-stmt.md
@@ -0,0 +1,71 @@
+---
+layout: global
+title: LEAVE statement
+displayTitle: LEAVE statement
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+Terminates the execution of an iteration of a looping statement and continues with the next iteration if the looping condition is met.
+
+This statement may only be used within a [compound statement](compound-stmt.html).
+
+## Syntax
+
+```
+ITERATE label
+```
+
+## Parameters
+
+- **label**
+
+ The label identifies a statement to leave that directly or indirectly contains the `LEAVE` statement.
+
+## Examples
+
+```SQL
+-- sum up all odd numbers from 1 through 10
+-- Iterate over even numbers and leave the loop after 10 has been reached.
+> BEGIN
+ DECLARE sum INT DEFAULT 0;
+ DECLARE num INT DEFAULT 0;
+ sumNumbers: LOOP
+ SET num = num + 1;
+ IF num > 10 THEN
+ LEAVE sumNumbers;
+ END IF;
+ IF num % 2 = 0 THEN
+ ITERATE sumNumbers;
+ END IF;
+ SET sum = sum + num;
+ END LOOP sumNumbers;
+ VALUES (sum);
+ END;
+25
+```
+
+## Related articles
+
+- [SQL Scripting](../sql-ref-scripting.html)
+- [CASE Statement](../control-flow/case-stmt.html)
+- [Compound Statement](../control-flow/compound-stmt.html)
+- [FOR Statement](../control-flow/for-stmt.html)
+- [LOOP Statement](../control-flow/loop-stmt.html)
+- [WHILE Statement](../control-flow/while-stmt.html)
+- [IF Statement](../control-flow/if-stmt.html)
+- [ITERATE Statement](../control-flow/iterate-stmt.html)
+
diff --git a/docs/control-flow/loop-stmt.md b/docs/control-flow/loop-stmt.md
new file mode 100644
index 0000000000000..7ca3b3b5bbf96
--- /dev/null
+++ b/docs/control-flow/loop-stmt.md
@@ -0,0 +1,83 @@
+---
+layout: global
+title: LOOP statement
+displayTitle: LOOP statement
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+Repeat the execution of a list of statements.
+
+This statement may only be used within a [compound statement](compound-stmt.html).
+
+## Syntax
+
+```
+[ label : ] LOOP
+ { stmt ; } [...]
+ END LOOP [ label ]
+```
+
+## Parameters
+
+- **label**
+
+ An optional label for the loop, which is unique amongst all labels for statements within which the `LOOP` statement is contained.
+ If an end label is specified, it must match the beginning label.
+ The label can be used to [LEAVE](leave-stmt.html) or [ITERATE](iterate-stmt.html) the loop.
+
+- **stmt**
+
+ A SQL statement
+
+## Examples
+
+```SQL
+-- sum up all odd numbers from 1 through 10
+> BEGIN
+ DECLARE sum INT DEFAULT 0;
+ DECLARE num INT DEFAULT 0;
+ sumNumbers: LOOP
+ SET num = num + 1;
+ IF num > 10 THEN
+ LEAVE sumNumbers;
+ END IF;
+ IF num % 2 = 0 THEN
+ ITERATE sumNumbers;
+ END IF;
+ SET sum = sum + num;
+ END LOOP sumNumbers;
+ VALUES (sum);
+ END;
+ 25
+
+-- Compare with the much more efficient relational computation:
+> SELECT sum(num) FROM range(1, 10) AS t(num) WHERE num % 2 = 1;
+ 25
+```
+
+## Related articles
+
+- [SQL Scripting](../sql-ref-scripting.html)
+- [CASE Statement](../control-flow/case-stmt.html)
+- [Compound Statement](../control-flow/compound-stmt.html)
+- [FOR Statement](../control-flow/for-stmt.html)
+- [WHILE Statement](../control-flow/while-stmt.html)
+- [REPEAT Statement](../control-flow/repeat-stmt.html)
+- [IF Statement](../control-flow/if-stmt.html)
+- [ITERATE Statement](../control-flow/iterate-stmt.html)
+- [LEAVE Statement](../control-flow/leave-stmt.html)
+
diff --git a/docs/control-flow/repeat-stmt.md b/docs/control-flow/repeat-stmt.md
new file mode 100644
index 0000000000000..4d28a6b05e0e4
--- /dev/null
+++ b/docs/control-flow/repeat-stmt.md
@@ -0,0 +1,84 @@
+---
+layout: global
+title: REPEAT statement
+displayTitle: REPEAT statement
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+Repeat the execution of a list of statements until a condition is true.
+
+This statement may only be used within a [compound statement](compound-stmt.html).
+
+## Syntax
+
+```
+[ label : ] REPEAT
+ { stmt ; } [...]
+ UNTIL cond
+ END REPEAT [ label ]
+```
+
+## Parameters
+
+- **label**
+
+ An optional label for the loop, which is unique amongst all labels for statements within which the `REPEAT` statement is contained.
+ The label can be used to [LEAVE](leave-stmt.html) or [ITERATE](iterate-stmt.html) the loop.
+
+- **cond**
+
+ Any expression evaluating to a BOOLEAN
+
+- **stmt**
+
+ A SQL statement
+
+## Examples
+
+```SQL
+-- sum up all odd numbers from 1 through 10
+> BEGIN
+ DECLARE sum INT DEFAULT 0;
+ DECLARE num INT DEFAULT 0;
+ sumNumbers: REPEAT
+ SET num = num + 1;
+ IF num % 2 = 0 THEN
+ ITERATE sumNumbers;
+ END IF;
+ SET sum = sum + num;
+ UNTIL num = 10
+ END REPEAT sumNumbers;
+ VALUES (sum);
+ END;
+ 25
+
+-- Compare with the much more efficient relational computation:
+> SELECT sum(num) FROM range(1, 10) AS t(num) WHERE num % 2 = 1;
+ 25
+```
+
+## Related articles
+
+- [SQL Scripting](../sql-ref-scripting.html)
+- [CASE Statement](../control-flow/case-stmt.html)
+- [Compound Statement](../control-flow/compound-stmt.html)
+- [FOR Statement](../control-flow/for-stmt.html)
+- [IF Statement](../control-flow/if-stmt.html)
+- [ITERATE Statement](../control-flow/iterate-stmt.html)
+- [WHILE Statement](../control-flow/while-stmt.html)
+- [LEAVE Statement](../control-flow/leave-stmt.html)
+- [LOOP Statement](../control-flow/loop-stmt.html)
diff --git a/docs/control-flow/while-stmt.md b/docs/control-flow/while-stmt.md
new file mode 100644
index 0000000000000..0edf77f0cba0d
--- /dev/null
+++ b/docs/control-flow/while-stmt.md
@@ -0,0 +1,83 @@
+---
+layout: global
+title: WHILE statement
+displayTitle: WHILE statement
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+Repeat the execution of a list of statements while a condition is true.
+
+This statement may only be used within a [compound statement](compound-stmt.html).
+
+## Syntax
+
+```
+[ label : ] WHILE cond DO
+ { stmt ; } [...]
+ END WHILE [ label ]
+```
+
+## Parameters
+
+- **label**
+
+ An optional label for the loop, which is unique amongst all labels for statements within which the `WHILE` statement is contained.
+ The label can be used to [LEAVE](leave-stmt.html) or [ITERATE](iterate-stmt.html) the loop.
+
+- **cond**
+
+ Any expression evaluating to a `BOOLEAN`.
+
+- **stmt**
+
+ A SQL statement.
+
+## Examples
+
+```SQL
+-- sum up all odd numbers from 1 through 10
+> BEGIN
+ DECLARE sum INT DEFAULT 0;
+ DECLARE num INT DEFAULT 0;
+ sumNumbers: WHILE num < 10 DO
+ SET num = num + 1;
+ IF num % 2 = 0 THEN
+ ITERATE sumNumbers;
+ END IF;
+ SET sum = sum + num;
+ END WHILE sumNumbers;
+ VALUES (sum);
+ END;
+ 25
+
+-- Compare with the much more efficient relational computation:
+> SELECT sum(num) FROM range(1, 10) AS t(num) WHERE num % 2 = 1;
+ 25
+```
+
+## Related articles
+
+- [SQL Scripting](../sql-ref-scripting.html)
+- [CASE Statement](../control-flow/case-stmt.html)
+- [Compound Statement](../control-flow/compound-stmt.html)
+- [FOR Statement](../control-flow/for-stmt.html)
+- [REPEAT Statement](../control-flow/repeat-stmt.html)
+- [IF Statement](../control-flow/if-stmt.html)
+- [ITERATE Statement](../control-flow/iterate-stmt.html)
+- [LEAVE Statement](../control-flow/leave-stmt.html)
+- [LOOP Statement](../control-flow/loop-stmt.html)
+
diff --git a/docs/core-migration-guide.md b/docs/core-migration-guide.md
index 1d55c4c3e66d2..4154043226412 100644
--- a/docs/core-migration-guide.md
+++ b/docs/core-migration-guide.md
@@ -24,7 +24,7 @@ license: |
## Upgrading from Core 4.0 to 4.1
-- Since Spark 4.1, Spark Master deamon provides REST API by default. To restore the behavior before Spark 4.1, you can set `spark.master.rest.enabled` to `false`.
+- Since Spark 4.1, Spark Master daemon provides REST API by default. To restore the behavior before Spark 4.1, you can set `spark.master.rest.enabled` to `false`.
- Since Spark 4.1, Spark will compress RDD checkpoints by default. To restore the behavior before Spark 4.1, you can set `spark.checkpoint.compress` to `false`.
- Since Spark 4.1, Spark uses Apache Hadoop Magic Committer for all S3 buckets by default. To restore the behavior before Spark 4.0, you can set `spark.hadoop.fs.s3a.committer.magic.enabled=false`.
- Since Spark 4.1, `java.lang.InternalError` encountered during file reading will no longer fail the task if the configuration `spark.sql.files.ignoreCorruptFiles` or the data source option `ignoreCorruptFiles` is set to `true`.
diff --git a/docs/declarative-pipelines-programming-guide.md b/docs/declarative-pipelines-programming-guide.md
index 3e33153e3d252..c5d18a7cb71be 100644
--- a/docs/declarative-pipelines-programming-guide.md
+++ b/docs/declarative-pipelines-programming-guide.md
@@ -24,7 +24,7 @@ license: |
## What is Spark Declarative Pipelines (SDP)?
-Spark Declarative Pipelines (SDP) is a declarative framework for building reliable, maintainable, and testable data pipelines on Spark. SDP simplifies ETL development by allowing you to focus on the transformations you want to apply to your data, rather than the mechanics of pipeline execution.
+Spark Declarative Pipelines (SDP) is a declarative framework for building reliable, maintainable, and testable data pipelines on Apache Spark. SDP simplifies ETL development by allowing you to focus on the transformations you want to apply to your data, rather than the mechanics of pipeline execution.
SDP is designed for both batch and streaming data processing, supporting common use cases such as:
- Data ingestion from cloud storage (Amazon S3, Azure ADLS Gen2, Google Cloud Storage)
@@ -35,6 +35,16 @@ The key advantage of SDP is its declarative approach - you define what tables sh

+### Quick install
+
+A quick way to install SDP is with pip:
+
+```
+pip install pyspark[pipelines]
+```
+
+See the [downloads page](//spark.apache.org/downloads.html) for more installation options.
+
## Key Concepts
### Flows
@@ -52,10 +62,10 @@ SDP creates the table named `target_table` along with a flow that reads new data
### Datasets
-A dataset is queryable object that's the output of one of more flows within a pipeline. Flows in the pipeline can also read from datasets produced in the pipeline.
+A dataset is a queryable object that's the output of one of more flows within a pipeline. Flows in the pipeline can also read from datasets produced in the pipeline.
- **Streaming Table** – a definition of a table and one or more streaming flows written into it. Streaming tables support incremental processing of data, allowing you to process only new data as it arrives.
-- **Materialized View** – is a view that is precomputed into a table. A materialized view always has exactly one batch flow writing to it.
+- **Materialized View** – a view that is precomputed into a table. A materialized view always has exactly one batch flow writing to it.
- **Temporary View** – a view that is scoped to an execution of the pipeline. It can be referenced from flows within the pipeline. It's useful for encapsulating transformations and intermediate logical entities that multiple other elements of the pipeline depend on.
### Pipelines
@@ -64,11 +74,16 @@ A pipeline is the primary unit of development and execution in SDP. A pipeline c
### Pipeline Projects
-A pipeline project is a set of source files that contain code that define the datasets and flows that make up a pipeline. These source files can be `.py` or `.sql` files.
+A pipeline project is a set of source files that contain code definitions of the datasets and flows that make up a pipeline. The source files can be `.py` or `.sql` files.
+
+It's conventional to name pipeline spec files `spark-pipeline.yml` or `spark-pipeline.yaml`.
+
+A YAML-formatted pipeline spec file contains the top-level configuration for the pipeline project with the following fields:
-A YAML-formatted pipeline spec file contains the top-level configuration for the pipeline project. It supports the following fields:
-- **definitions** (Required) - Paths where definition files can be found.
-- **database** (Optional) - The default target database for pipeline outputs.
+- **name** (Required) - The name of the pipeline project.
+- **libraries** (Required) - The paths with the transformation source files in SQL or Python.
+- **storage** (Required) – A directory where checkpoints can be stored for streaming tables within the pipeline.
+- **database** (Optional) - The default target database for pipeline outputs. **schema** can alternatively be used as an alias.
- **catalog** (Optional) - The default target catalog for pipeline outputs.
- **configuration** (Optional) - Map of Spark configuration properties.
@@ -76,180 +91,253 @@ An example pipeline spec file:
```yaml
name: my_pipeline
-definitions:
+libraries:
- glob:
- include: transformations/**/*.py
- - glob:
- include: transformations/**/*.sql
+ include: transformations/**
+storage: file:///absolute/path/to/storage/dir
catalog: my_catalog
database: my_db
configuration:
spark.sql.shuffle.partitions: "1000"
```
-It's conventional to name pipeline spec files `pipeline.yml`.
-
The `spark-pipelines init` command, described below, makes it easy to generate a pipeline project with default configuration and directory structure.
-
## The `spark-pipelines` Command Line Interface
-The `spark-pipelines` command line interface (CLI) is the primary way to execute a pipeline. It also contains an `init` subcommand for generating a pipeline project and a `dry-run` subcommand for validating a pipeline.
+The `spark-pipelines` command line interface (CLI) is the primary way to manage a pipeline.
`spark-pipelines` is built on top of `spark-submit`, meaning that it supports all cluster managers supported by `spark-submit`. It supports all `spark-submit` arguments except for `--class`.
### `spark-pipelines init`
-`spark-pipelines init --name my_pipeline` generates a simple pipeline project, inside a directory named "my_pipeline", including a spec file and example definitions.
+`spark-pipelines init --name my_pipeline` generates a simple pipeline project, inside a directory named `my_pipeline`, including a spec file and example transformation definitions.
### `spark-pipelines run`
-`spark-pipelines run` launches an execution of a pipeline and monitors its progress until it completes. The `--spec` parameter allows selecting the pipeline spec file. If not provided, the CLI will look in the current directory and parent directories for a file named `pipeline.yml` or `pipeline.yaml`.
+`spark-pipelines run` launches an execution of a pipeline and monitors its progress until it completes.
+
+Since `spark-pipelines` is built on top of `spark-submit`, it supports all `spark-submit` arguments except for `--class`. For the complete list of available parameters, see the [Spark Submit documentation](https://spark.apache.org/docs/latest/submitting-applications.html#launching-applications-with-spark-submit).
+
+It also supports several pipeline-specific parameters:
+
+* `--spec PATH` - Path to the pipeline specification file. If not provided, the CLI will look in the current directory and parent directories for one of the files:
+ * `spark-pipeline.yml`
+ * `spark-pipeline.yaml`
+
+* `--full-refresh DATASETS` - List of datasets to reset and recompute (comma-separated). This clears all existing data and checkpoints for the specified datasets and recomputes them from scratch.
+
+* `--full-refresh-all` - Perform a full graph reset and recompute. This is equivalent to `--full-refresh` for all datasets in the pipeline.
+
+* `--refresh DATASETS` - List of datasets to update (comma-separated). This triggers an update for the specified datasets without clearing existing data.
+
+#### Refresh Selection Behavior
+
+If no refresh options are specified, a default incremental update is performed. The refresh parameters are mutually exclusive:
+- `--full-refresh-all` cannot be combined with `--full-refresh` or `--refresh`
+- `--full-refresh` and `--refresh` can be used together to specify different behaviors for different datasets
+
+#### Examples
+
+```bash
+# Basic run with default incremental update
+spark-pipelines run
+
+# Run with specific spec file
+spark-pipelines run --spec /path/to/my-pipeline.yaml
+
+# Full refresh of specific datasets
+spark-pipelines run --full-refresh orders,customers
+
+# Full refresh of entire pipeline
+spark-pipelines run --full-refresh-all
+
+# Run with custom Spark configuration
+spark-pipelines run --conf spark.sql.shuffle.partitions=200 --driver-memory 4g
+
+# Run on remote Spark Connect server
+spark-pipelines run --remote sc://my-cluster:15002
+```
### `spark-pipelines dry-run`
`spark-pipelines dry-run` launches an execution of a pipeline that doesn't write or read any data, but catches many kinds of errors that would be caught if the pipeline were to actually run. E.g.
- Syntax errors – e.g. invalid Python or SQL code
-- Analysis errors – e.g. selecting from a table that doesn't exist or selecting a column that doesn't exist
+- Analysis errors – e.g. selecting from a table or a column that doesn't exist
- Graph validation errors - e.g. cyclic dependencies
+Since `spark-pipelines` is built on top of `spark-submit`, it supports all `spark-submit` arguments except for `--class`. For the complete list of available parameters, see the [Spark Submit documentation](https://spark.apache.org/docs/latest/submitting-applications.html#launching-applications-with-spark-submit).
+
+It also supports the pipeline-specific `--spec` parameter (see description above in the `run` section).
+
## Programming with SDP in Python
-SDP Python functions are defined in the `pyspark.pipelines` module. Your pipelines implemented with the Python API must import this module. It's common to alias the module to `dp` to limit the number of characters you need to type when using its APIs.
+SDP Python definitions are defined in the `pyspark.pipelines` module.
+
+Your pipelines implemented with the Python API must import this module. It's recommended to alias the module to `dp`.
```python
from pyspark import pipelines as dp
```
-### Creating a Materialized View with Python
+### Creating a Materialized View in Python
-The `@dp.materialized_view` decorator tells SDP to create a materialized view based on the results returned by a function that performs a batch read:
+The `@dp.materialized_view` decorator tells SDP to create a materialized view based on the results of a function that performs a batch read:
```python
from pyspark import pipelines as dp
+from pyspark.sql import DataFrame
@dp.materialized_view
-def basic_mv():
+def basic_mv() -> DataFrame:
return spark.table("samples.nyctaxi.trips")
```
-Optionally, you can specify the table name using the `name` argument:
+The name of the materialized view is derived from the name of the function.
+
+You can specify the name of the materialized view using the `name` argument:
```python
from pyspark import pipelines as dp
+from pyspark.sql import DataFrame
@dp.materialized_view(name="trips_mv")
-def basic_mv():
+def basic_mv() -> DataFrame:
return spark.table("samples.nyctaxi.trips")
```
-### Creating a Temporary View with Python
+### Creating a Temporary View in Python
-The `@dp.temporary_view` decorator tells SDP to create a temporary view based on the results returned by a function that performs a batch read:
+The `@dp.temporary_view` decorator tells SDP to create a temporary view based on the results of a function that performs a batch read:
```python
from pyspark import pipelines as dp
+from pyspark.sql import DataFrame
@dp.temporary_view
-def basic_tv():
+def basic_tv() -> DataFrame:
return spark.table("samples.nyctaxi.trips")
```
This temporary view can be read by other queries within the pipeline, but can't be read outside the scope of the pipeline.
-### Creating a Streaming Table with Python
+### Creating a Streaming Table in Python
-Similarly, you can create a streaming table by using the `@dp.table` decorator with a function that performs a streaming read:
+You can create a streaming table using the `@dp.table` decorator with a function that performs a streaming read:
```python
from pyspark import pipelines as dp
+from pyspark.sql import DataFrame
@dp.table
-def basic_st():
+def basic_st() -> DataFrame:
return spark.readStream.table("samples.nyctaxi.trips")
```
-### Loading Data from a Streaming Source
+### Loading Data from Streaming Sources in Python
+
+SDP supports loading data from all the formats supported by Spark Structured Streaming (`spark.readStream`).
-SDP supports loading data from all formats supported by Spark. For example, you can create a streaming table whose query reads from a Kafka topic:
+For example, you can create a streaming table whose query reads from a Kafka topic:
```python
from pyspark import pipelines as dp
+from pyspark.sql import DataFrame
@dp.table
-def ingestion_st():
+def ingestion_st() -> DataFrame:
return (
- spark.readStream.format("kafka")
+ spark.readStream
+ .format("kafka")
.option("kafka.bootstrap.servers", "localhost:9092")
.option("subscribe", "orders")
.load()
)
```
-For batch reads:
+### Loading Data from Batch Sources in Python
+
+SDP supports loading data from all the formats supported by Spark SQL (`spark.read`).
```python
from pyspark import pipelines as dp
+from pyspark.sql import DataFrame
@dp.materialized_view
-def batch_mv():
+def batch_mv() -> DataFrame:
return spark.read.format("json").load("/datasets/retail-org/sales_orders")
```
-### Querying Tables Defined in Your Pipeline
+### Querying Tables Defined in a Pipeline in Python
You can reference other tables defined in your pipeline in the same way you'd reference tables defined outside your pipeline:
```python
from pyspark import pipelines as dp
+from pyspark.sql import DataFrame
from pyspark.sql.functions import col
@dp.table
-def orders():
+def orders() -> DataFrame:
return (
- spark.readStream.format("kafka")
+ spark.readStream
+ .format("kafka")
.option("kafka.bootstrap.servers", "localhost:9092")
.option("subscribe", "orders")
.load()
)
@dp.materialized_view
-def customers():
- return spark.read.format("csv").option("header", True).load("/datasets/retail-org/customers")
+def customers() -> DataFrame:
+ return (
+ spark.read
+ .format("csv")
+ .option("header", True)
+ .load("/datasets/retail-org/customers")
+ )
@dp.materialized_view
-def customer_orders():
- return (spark.table("orders")
- .join(spark.table("customers"), "customer_id")
- .select("customer_id",
- "order_number",
- "state",
- col("order_datetime").cast("int").cast("timestamp").cast("date").alias("order_date"),
+def customer_orders() -> DataFrame:
+ return (
+ spark.table("orders")
+ .join(
+ spark.table("customers"), "customer_id")
+ .select(
+ "customer_id",
+ "order_number",
+ "state",
+ col("order_datetime").cast("date").alias("order_date"),
+ )
)
)
@dp.materialized_view
-def daily_orders_by_state():
- return (spark.table("customer_orders")
+def daily_orders_by_state() -> DataFrame:
+ return (
+ spark.table("customer_orders")
.groupBy("state", "order_date")
- .count().withColumnRenamed("count", "order_count")
+ .count()
+ .withColumnRenamed("count", "order_count")
)
```
-### Creating Tables in a For Loop
+### Creating Tables in For Loop in Python
You can use Python `for` loops to create multiple tables programmatically:
```python
from pyspark import pipelines as dp
+from pyspark.sql import DataFrame
from pyspark.sql.functions import collect_list, col
@dp.temporary_view()
-def customer_orders():
+def customer_orders() -> DataFrame:
orders = spark.table("samples.tpch.orders")
customer = spark.table("samples.tpch.customer")
- return (orders.join(customer, orders.o_custkey == customer.c_custkey)
+ return (
+ orders
+ .join(customer, orders.o_custkey == customer.c_custkey)
.select(
col("c_custkey").alias("custkey"),
col("c_name").alias("name"),
@@ -258,19 +346,22 @@ def customer_orders():
col("o_orderkey").alias("orderkey"),
col("o_orderstatus").alias("orderstatus"),
col("o_totalprice").alias("totalprice"),
- col("o_orderdate").alias("orderdate"))
+ col("o_orderdate").alias("orderdate"),
+ )
)
@dp.temporary_view()
-def nation_region():
+def nation_region() -> DataFrame:
nation = spark.table("samples.tpch.nation")
region = spark.table("samples.tpch.region")
- return (nation.join(region, nation.n_regionkey == region.r_regionkey)
+ return (
+ nation
+ .join(region, nation.n_regionkey == region.r_regionkey)
.select(
col("n_name").alias("nation"),
col("r_name").alias("region"),
- col("n_nationkey").alias("nationkey")
+ col("n_nationkey").alias("nationkey"),
)
)
@@ -280,11 +371,13 @@ region_list = spark.table("samples.tpch.region").select(collect_list("r_name")).
# Iterate through region names to create new region-specific materialized views
for region in region_list:
@dp.table(name=f"{region.lower().replace(' ', '_')}_customer_orders")
- def regional_customer_orders(region_filter=region):
+ def regional_customer_orders(region_filter=region) -> DataFrame:
customer_orders = spark.table("customer_orders")
nation_region = spark.table("nation_region")
- return (customer_orders.join(nation_region, customer_orders.nationkey == nation_region.nationkey)
+ return (
+ customer_orders
+ .join(nation_region, customer_orders.nationkey == nation_region.nationkey)
.select(
col("custkey"),
col("name"),
@@ -294,35 +387,37 @@ for region in region_list:
col("orderkey"),
col("orderstatus"),
col("totalprice"),
- col("orderdate")
- ).filter(f"region = '{region_filter}'")
+ col("orderdate"),
+ )
+ .filter(f"region = '{region_filter}'")
)
```
-### Using Multiple Flows to Write to a Single Target
+### Using Multiple Flows to Write to a Single Target in Python
-You can create multiple flows that append data to the same target:
+You can create multiple flows that append data to the same dataset:
```python
from pyspark import pipelines as dp
+from pyspark.sql import DataFrame
# create a streaming table
dp.create_streaming_table("customers_us")
-# add the first append flow
+# define the first append flow
@dp.append_flow(target = "customers_us")
-def append1():
+def append_customers_us_west() -> DataFrame:
return spark.readStream.table("customers_us_west")
-# add the second append flow
+# define the second append flow
@dp.append_flow(target = "customers_us")
-def append2():
+def append_customers_us_east() -> DataFrame:
return spark.readStream.table("customers_us_east")
```
## Programming with SDP in SQL
-### Creating a Materialized View with SQL
+### Creating a Materialized View in SQL
The basic syntax for creating a materialized view with SQL is:
@@ -331,7 +426,7 @@ CREATE MATERIALIZED VIEW basic_mv
AS SELECT * FROM samples.nyctaxi.trips;
```
-### Creating a Temporary View with SQL
+### Creating a Temporary View in SQL
The basic syntax for creating a temporary view with SQL is:
@@ -340,7 +435,7 @@ CREATE TEMPORARY VIEW basic_tv
AS SELECT * FROM samples.nyctaxi.trips;
```
-### Creating a Streaming Table with SQL
+### Creating a Streaming Table in SQL
When creating a streaming table, use the `STREAM` keyword to indicate streaming semantics for the source:
@@ -349,7 +444,7 @@ CREATE STREAMING TABLE basic_st
AS SELECT * FROM STREAM samples.nyctaxi.trips;
```
-### Querying Tables Defined in Your Pipeline
+### Querying Tables Defined in a Pipeline in SQL
You can reference other tables defined in your pipeline:
@@ -376,7 +471,7 @@ FROM customer_orders
GROUP BY state, order_date;
```
-### Using Multiple Flows to Write to a Single Target
+### Using Multiple Flows to Write to a Single Target in SQL
You can create multiple flows that append data to the same target:
@@ -384,28 +479,77 @@ You can create multiple flows that append data to the same target:
-- create a streaming table
CREATE STREAMING TABLE customers_us;
--- add the first append flow
-CREATE FLOW append1
+-- define the first append flow
+CREATE FLOW append_customers_us_west
AS INSERT INTO customers_us
SELECT * FROM STREAM(customers_us_west);
--- add the second append flow
-CREATE FLOW append2
+-- define the second append flow
+CREATE FLOW append_customers_us_east
AS INSERT INTO customers_us
SELECT * FROM STREAM(customers_us_east);
```
+## Writing Data to External Targets with Sinks
+
+Sinks in SDP provide a way to write transformed data to external destinations beyond the default streaming tables and materialized views. Sinks are particularly useful for operational use cases that require low-latency data processing, reverse ETL operations, or writing to external systems.
+
+Sinks enable a pipeline to write to any destination that a Spark Structured Streaming query can be written to, including, but not limited to, **Apache Kafka** and **Azure Event Hubs**.
+
+### Creating and Using Sinks in Python
+
+Working with sinks involves two main steps: creating the sink definition and implementing an append flow to write data.
+
+#### Creating a Kafka Sink
+
+You can create a sink that streams data to a Kafka topic:
+
+```python
+from pyspark import pipelines as dp
+from pyspark.sql.functions import to_json, struct
+
+dp.create_sink(
+ name="kafka_sink",
+ format="kafka",
+ options={
+ "kafka.bootstrap.servers": "localhost:9092",
+ "topic": "processed_orders"
+ }
+)
+
+@dp.append_flow(target="kafka_sink")
+def kafka_orders_flow() -> DataFrame:
+ return (
+ spark.readStream.table("customer_orders")
+ .select(
+ col("order_id").cast("string").alias("key"),
+ to_json(struct("*")).alias("value")
+ )
+ )
+```
+
+### Sink Considerations
+
+When working with sinks, keep the following considerations in mind:
+
+- **Streaming-only**: Sinks currently support only streaming queries through `append_flow` decorators
+- **Python API**: Sink functionality is available only through the Python API, not SQL
+- **Append-only**: Only append operations are supported; full refresh updates reset checkpoints but do not clean previously computed results
+
## Important Considerations
### Python Considerations
- SDP evaluates the code that defines a pipeline multiple times during planning and pipeline runs. Python functions that define datasets should include only the code required to define the table or view.
-- The function used to define a dataset must return a Spark DataFrame.
+- The function used to define a dataset must return a `pyspark.sql.DataFrame`.
- Never use methods that save or write to files or tables as part of your SDP dataset code.
+- When using the `for` loop pattern to define datasets in Python, ensure that the list of values passed to the `for` loop is always additive.
+
+Examples of Spark SQL operations that should never be used in SDP code:
-Examples of Apache Spark operations that should never be used in SDP code:
- `collect()`
- `count()`
+- `pivot()`
- `toPandas()`
- `save()`
- `saveAsTable()`
@@ -415,4 +559,3 @@ Examples of Apache Spark operations that should never be used in SDP code:
### SQL Considerations
- The `PIVOT` clause is not supported in SDP SQL.
-- When using the `for` loop pattern to define datasets in Python, ensure that the list of values passed to the `for` loop is always additive.
diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md
index 68cd9a78d0f3a..e253ee4225407 100644
--- a/docs/running-on-kubernetes.md
+++ b/docs/running-on-kubernetes.md
@@ -34,13 +34,13 @@ Please see [Spark Security](security.html) and the specific security sections in
Images built from the project provided Dockerfiles contain a default [`USER`](https://docs.docker.com/engine/reference/builder/#user) directive with a default UID of `185`. This means that the resulting images will be running the Spark processes as this UID inside the container. Security conscious deployments should consider providing custom images with `USER` directives specifying their desired unprivileged UID and GID. The resulting UID should include the root group in its supplementary groups in order to be able to run the Spark executables. Users building their own images with the provided `docker-image-tool.sh` script can use the `-u ` option to specify the desired UID.
-Alternatively the [Pod Template](#pod-template) feature can be used to add a [Security Context](https://kubernetes.io/docs/tasks/configure-pod-container/security-context/#volumes-and-file-systems) with a `runAsUser` to the pods that Spark submits. This can be used to override the `USER` directives in the images themselves. Please bear in mind that this requires cooperation from your users and as such may not be a suitable solution for shared environments. Cluster administrators should use [Pod Security Policies](https://kubernetes.io/docs/concepts/policy/pod-security-policy/#users-and-groups) if they wish to limit the users that pods may run as.
+Alternatively the [Pod Template](#pod-template) feature can be used to add a [Security Context](https://kubernetes.io/docs/tasks/configure-pod-container/security-context/#volumes-and-file-systems) with a `runAsUser` to the pods that Spark submits. This can be used to override the `USER` directives in the images themselves. Please bear in mind that this requires cooperation from your users and as such may not be a suitable solution for shared environments. Cluster administrators should use the [Pod Security Admission Controller](https://kubernetes.io/docs/concepts/security/pod-security-admission/) if they wish to limit the users that pods may run as.
## Volume Mounts
As described later in this document under [Using Kubernetes Volumes](#using-kubernetes-volumes) Spark on K8S provides configuration options that allow for mounting certain volume types into the driver and executor pods. In particular it allows for [`hostPath`](https://kubernetes.io/docs/concepts/storage/volumes/#hostpath) volumes which as described in the Kubernetes documentation have known security vulnerabilities.
-Cluster administrators should use [Pod Security Policies](https://kubernetes.io/docs/concepts/policy/pod-security-policy/) to limit the ability to mount `hostPath` volumes appropriately for their environments.
+Cluster administrators should use the [Pod Security Admission Controller](https://kubernetes.io/docs/concepts/security/pod-security-admission/) to limit the ability to mount `hostPath` volumes appropriately for their environments.
# Prerequisites
@@ -1284,17 +1284,6 @@ See the [configuration page](configuration.html) for information on Spark config
|
2.4.0 |
-
- spark.kubernetes.pyspark.pythonVersion |
- "3" |
-
- This sets the major Python version of the docker image used to run the driver and executor containers.
- It can be only "3". This configuration was deprecated from Spark 3.1.0, and is effectively no-op.
- Users should set 'spark.pyspark.python' and 'spark.pyspark.driver.python' configurations or
- 'PYSPARK_PYTHON' and 'PYSPARK_DRIVER_PYTHON' environment variables.
- |
- 2.4.0 |
-
spark.kubernetes.kerberos.krb5.path |
(none) |
diff --git a/docs/spark-connect-overview.md b/docs/spark-connect-overview.md
index f01ebf1b54f76..3c15153e03053 100644
--- a/docs/spark-connect-overview.md
+++ b/docs/spark-connect-overview.md
@@ -284,11 +284,11 @@ The connection may also be programmatically created using _SparkSession#builder_
-First, install PySpark with `pip install pyspark[connect]=={{site.SPARK_VERSION_SHORT}}` or if building a packaged PySpark application/library,
+First, install PySpark with `pip install pyspark-client=={{site.SPARK_VERSION_SHORT}}` or if building a packaged PySpark application/library,
add it your setup.py file as:
{% highlight python %}
install_requires=[
-'pyspark[connect]=={{site.SPARK_VERSION_SHORT}}'
+'pyspark-client=={{site.SPARK_VERSION_SHORT}}'
]
{% endhighlight %}
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md
index 5a8eb3f1e0602..ec1656b0348c8 100644
--- a/docs/spark-standalone.md
+++ b/docs/spark-standalone.md
@@ -389,7 +389,7 @@ SPARK_MASTER_OPTS supports the following system properties:
spark.dead.worker.persistence |
15 |
- Number of iterations to keep the deae worker information in UI. By default, the dead worker is visible for (15 + 1) * spark.worker.timeout since its last heartbeat.
+ Number of iterations to keep the dead worker information in UI. By default, the dead worker is visible for (15 + 1) * spark.worker.timeout since its last heartbeat.
|
0.8.0 |
diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md
index bf53ffa65d618..d927973e96ac3 100644
--- a/docs/sql-data-sources-jdbc.md
+++ b/docs/sql-data-sources-jdbc.md
@@ -389,12 +389,17 @@ Before using keytab and principal configuration option
* The included JDBC driver version supports kerberos authentication with keytab.
* There is a built-in connection provider which supports the used database.
-There is a built-in connection providers for the following databases:
+There are built-in connection providers for the following databases:
+* Databricks
* DB2
-* MariaDB
-* MS Sql
+* Derby
+* H2
+* MariaDB and MySQL
+* Microsoft SQL Server
* Oracle
* PostgreSQL
+* Snowflake
+* Teradata
If the requirements are not met, please consider using the JdbcConnectionProvider developer API to handle custom authentication.
diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md
index e5becac540328..0a2533d28f0b8 100644
--- a/docs/sql-migration-guide.md
+++ b/docs/sql-migration-guide.md
@@ -25,6 +25,7 @@ license: |
## Upgrading from Spark SQL 4.0 to 4.1
- Since Spark 4.1, the Parquet reader no longer assumes all struct values to be null, if all the requested fields are missing in the parquet file. The new default behavior is to read an additional struct field that is present in the file to determine nullness. To restore the previous behavior, set `spark.sql.legacy.parquet.returnNullStructIfAllFieldsMissing` to `true`.
+- Since Spark 4.1, the Spark Thrift Server returns the corrected 1-based ORDINAL_POSITION in the result of GetColumns operation, instead of the wrongly 0-based. To restore the previous behavior, set `spark.sql.legacy.hive.thriftServer.useZeroBasedColumnOrdinalPosition` to `true`.
## Upgrading from Spark SQL 3.5 to 4.0
diff --git a/docs/sql-ref-name-resolution.md b/docs/sql-ref-name-resolution.md
new file mode 100644
index 0000000000000..2532f05e164b3
--- /dev/null
+++ b/docs/sql-ref-name-resolution.md
@@ -0,0 +1,423 @@
+---
+layout: global
+title: Name Resolution
+displayTitle: Name Resolution
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+Name resolution is the process by which [identifiers](sql-ref-identifier.html) are resolved to specific column-, field-, parameter-, or table-references.
+
+## Column, field, parameter, and variable resolution
+
+Identifiers in expressions can be references to any one of the following:
+
+- Column name based on a view, table, common table expression (CTE), or a column_alias.
+- Field name or map key within a struct or map.
+ Fields and keys can never be unqualified.
+- Parameter name of a SQL User Defined Function.
+- Session or SQL script local variable name.
+- A special function such as `current_user` or `current_date` which does not require the usage of `()`.
+- The `DEFAULT` keyword which is used in the context of `INSERT`, or `SET VARIABLE` to set a column or variable value to its default.
+
+Name resolution applies the following principles:
+
+- The _closest_ matching reference wins, and
+- Columns and parameter win over fields and keys.
+
+In detail, resolution of identifiers to a specific reference follows these rules in order:
+
+1. **Local references**
+
+ 1. **Column reference**
+
+ Match the identifier, which may be qualified, to a column name in a table reference of the `FROM clause`.
+
+ If there is more than one such match, raise an AMBIGUOUS_COLUMN_OR_FIELD error.
+
+ 1. **Parameterless function reference**
+
+ If the identifier is unqualified and matches `current_user`, `current_date`, or `current_timestamp`: Resolve it as one of these functions.
+
+ 1. **Column DEFAULT specification**
+
+ If the identifier is unqualified, matches `default` and makes up the entire expression in the context of an `UPDATE SET`, `INSERT VALUES`, or `MERGE WHEN [NOT] MATCHED`: Resolve as the respective `DEFAULT` value of the target table of the `INSERT`.
+
+ 1. **Struct field or map key reference**
+
+ If the identifier is qualified, then endeavor to match it to a field or map key according to the following steps:
+
+ A. Remove the last identifier and treat it as a field or key.
+
+ B. Match the remainder to a column in table reference of the `FROM clause`.
+
+ - If there is more than one such match, raise an AMBIGUOUS_COLUMN_OR_FIELD error.
+
+ - If there is a match and the column is a:
+
+ - **`STRUCT`**: Match the field.
+
+ If the field cannot be matched, raise a FIELD_NOT_FOUND error.
+
+ If there is more than one field, raise a AMBIGUOUS_COLUMN_OR_FIELD error.
+
+ - **`MAP`**: Raise an error if the key is qualified.
+
+ A runtime error may occur if the key is not actually present in the map.
+
+ - Any other type: Raise an error.
+
+ C. Repeat the preceding step to remove the trailing identifier as a field. Apply rules (A) and (B) while there is an identifier left to interpret as a column.
+
+1. **Lateral column aliasing**
+
+ If the expression is within a `SELECT` list, match the leading identifier to a preceding column alias in that `SELECT` list.
+
+ If there is more than one such match, raise an AMBIGUOUS_LATERAL_COLUMN_ALIAS error.
+
+ Match each remaining identifier as a field or a map key, and raise FIELD_NOT_FOUND or AMBIGUOUS_COLUMN_OR_FIELD error if they cannot be matched.
+
+1. **Correlation**
+
+ - **LATERAL**
+
+ If the query is preceded by a `LATERAL` keyword, apply rules 1.a and 1.d considering the table references in the `FROM` containing the query and preceding the `LATERAL`.
+
+ - **Regular**
+
+ If the query is a scalar subquery, `IN`, or `EXISTS` subquery apply rules 1.a, 1.d, and 2 considering the table references in the containing query’s `FROM` clause.
+
+1. **Nested correlation**
+
+ Re-apply rule 3 iterating over the nesting levels of the query.
+
+1. **[FOR loop](control-flow/for-stmt.md]**
+
+ If the statement is contained in a `FOR` loop:
+
+ A. Match the identifier to a column in a `FOR` loop statement query.
+
+ If the identifier is qualified, the qualifier must match the name of the FOR loop variable if defined.
+
+ B. If the identifier is qualified, match to a field or map key of a parameter following rule 1.c
+
+1. **[Compound statement](control-flow/compound-stmt.html)**
+
+ If the statement is contained in a compound statement:
+
+ A. Match the identifier to a variable declared in that compound statement.
+
+ If the identifier is qualified, the qualifier must match the label of the compound statement if one was defined.
+
+ B. If the identifier is qualified, match to a field or map key of a variable following rule 1.c
+
+1. **Nested compound statement or `FOR` loop**
+
+ Re-apply rules 5 and 6, iterating over the nesting levels of the compound statement.
+
+1. **Routine parameters**
+
+ If the expression is part of a CREATE FUNCTION statement:
+
+ 1. Match the identifier to a parameter name. If the identifier is qualified, the qualifier must match the name of the routine.
+ 1. If the identifier is qualified, match to a field or map key of a parameter following rule 1.c
+
+1. **Session Variables**
+
+ 1. Match the identifier to a variable name. If the identifier is qualified, the qualifier must be `session` or `system.session`.
+ 1. If the identifier is qualified, match to a field or map key of a variable following rule 1.c
+
+### Limitations
+
+To prevent execution of potentially expensive correlated queries, Spark limits supported correlation to one level.
+This restriction also applies to parameter references in SQL functions.
+
+### Examples
+
+```sql
+-- Differentiating columns and fields
+> SELECT a FROM VALUES(1) AS t(a);
+ 1
+
+> SELECT t.a FROM VALUES(1) AS t(a);
+ 1
+
+> SELECT t.a FROM VALUES(named_struct('a', 1)) AS t(t);
+ 1
+
+-- A column takes precendece over a field
+> SELECT t.a FROM VALUES(named_struct('a', 1), 2) AS t(t, a);
+ 2
+
+-- Implict lateral column alias
+> SELECT c1 AS a, a + c1 FROM VALUES(2) AS T(c1);
+ 2 4
+
+-- A local column reference takes precedence, over a lateral column alias
+> SELECT c1 AS a, a + c1 FROM VALUES(2, 3) AS T(c1, a);
+ 2 5
+
+-- A scalar subquery correlation to S.c3
+> SELECT (SELECT c1 FROM VALUES(1, 2) AS t(c1, c2)
+ WHERE t.c2 * 2 = c3)
+ FROM VALUES(4) AS s(c3);
+ 1
+
+-- A local reference takes precedence over correlation
+> SELECT (SELECT c1 FROM VALUES(1, 2, 2) AS t(c1, c2, c3)
+ WHERE t.c2 * 2 = c3)
+ FROM VALUES(4) AS s(c3);
+ NULL
+
+-- An explicit scalar subquery correlation to s.c3
+> SELECT (SELECT c1 FROM VALUES(1, 2, 2) AS t(c1, c2, c3)
+ WHERE t.c2 * 2 = s.c3)
+ FROM VALUES(4) AS s(c3);
+ 1
+
+-- Correlation from an EXISTS predicate to t.c2
+> SELECT c1 FROM VALUES(1, 2) AS T(c1, c2)
+ WHERE EXISTS(SELECT 1 FROM VALUES(2) AS S(c2)
+ WHERE S.c2 = T.c2);
+ 1
+
+-- Attempt a lateral correlation to t.c2
+> SELECT c1, c2, c3
+ FROM VALUES(1, 2) AS t(c1, c2),
+ (SELECT c3 FROM VALUES(3, 4) AS s(c3, c4)
+ WHERE c4 = c2 * 2);
+ [UNRESOLVED_COLUMN] `c2`
+
+-- Successsful usage of lateral correlation with keyword LATERAL
+> SELECT c1, c2, c3
+ FROM VALUES(1, 2) AS t(c1, c2),
+ LATERAL(SELECT c3 FROM VALUES(3, 4) AS s(c3, c4)
+ WHERE c4 = c2 * 2);
+ 1 2 3
+
+-- Referencing a parameter of a SQL function
+> CREATE OR REPLACE TEMPORARY FUNCTION func(a INT) RETURNS INT
+ RETURN (SELECT c1 FROM VALUES(1) AS T(c1) WHERE c1 = a);
+> SELECT func(1), func(2);
+ 1 NULL
+
+-- A column takes precedence over a parameter
+> CREATE OR REPLACE TEMPORARY FUNCTION func(a INT) RETURNS INT
+ RETURN (SELECT a FROM VALUES(1) AS T(a) WHERE t.a = a);
+> SELECT func(1), func(2);
+ 1 1
+
+-- Qualify the parameter with the function name
+> CREATE OR REPLACE TEMPORARY FUNCTION func(a INT) RETURNS INT
+ RETURN (SELECT a FROM VALUES(1) AS T(a) WHERE t.a = func.a);
+> SELECT func(1), func(2);
+ 1 NULL
+
+-- Lateral alias takes precedence over correlated reference
+> SELECT (SELECT c2 FROM (SELECT 1 AS c1, c1 AS c2) WHERE c2 > 5)
+ FROM VALUES(6) AS t(c1)
+ NULL
+
+-- Lateral alias takes precedence over function parameters
+> CREATE OR REPLACE TEMPORARY FUNCTION func(x INT)
+ RETURNS TABLE (a INT, b INT, c DOUBLE)
+ RETURN SELECT x + 1 AS x, x
+> SELECT * FROM func(1)
+ 2 2
+
+-- All together now
+> CREATE OR REPLACE TEMPORARY VIEW lat(a, b) AS VALUES('lat.a', 'lat.b');
+
+> CREATE OR REPLACE TEMPORARY VIEW frm(a) AS VALUES('frm.a');
+
+> CREATE OR REPLACE TEMPORARY FUNCTION func(a INT, b int, c int)
+ RETURNS TABLE
+ RETURN SELECT t.*
+ FROM lat,
+ LATERAL(SELECT a, b, c
+ FROM frm) AS t;
+
+> VALUES func('func.a', 'func.b', 'func.c');
+ a b c
+ ----- ----- ------
+ frm.a lat.b func.c
+```
+
+## Table and view resolution
+
+An identifier in table-reference can be any one of the following:
+
+- Persistent table or view
+- Common table expression (CTE)
+- [Temporary view](sql-ref-syntax-ddl-create-view.html)
+
+Resolution of an identifier depends on whether it is qualified:
+
+- **Qualified**
+
+ If the identifier is fully qualified with three parts: `catalog.schema.relation`, it is unique.
+
+ If the identifier consists of two parts: `schema.relation`, it is further qualified with the result of `SELECT current_catalog()` to make it unique.
+
+- **Unqualified**
+
+ 1. **Common table expression**
+
+ If the reference is within the scope of a `WITH` clause, match the identifier to a CTE starting with the immediately containing `WITH` clause and moving outwards from there.
+
+ 1. **Temporary view**
+
+ Match the identifier to any temporary view defined within the current session.
+
+ 1. **Persisted table**
+
+ Fully qualify the identifier by pre-pending the result of `SELECT current_catalog()` and `SELECT current_schema()` and look it up as a persistent relation.
+
+If the relation cannot be resolved to any table, view, or CTE, Databricks raises a TABLE_OR_VIEW_NOT_FOUND error.
+
+### Examples
+
+```sql
+-- Setting up a scenario
+> USE CATALOG spark_catalog;
+> USE SCHEMA default;
+
+> CREATE TABLE rel(c1 int);
+> INSERT INTO rel VALUES(1);
+
+-- An fully qualified reference to rel:
+> SELECT c1 FROM spark_catalog.default.rel;
+ 1
+
+-- A partially qualified reference to rel:
+> SELECT c1 FROM default.rel;
+ 1
+
+-- An unqualified reference to rel:
+> SELECT c1 FROM rel;
+ 1
+
+-- Add a temporary view with a conflicting name:
+> CREATE TEMPORARY VIEW rel(c1) AS VALUES(2);
+
+-- For unqualified references the temporary view takes precedence over the persisted table:
+> SELECT c1 FROM rel;
+ 2
+
+-- Temporary views cannot be qualified, so qualifiecation resolved to the table:
+> SELECT c1 FROM default.rel;
+ 1
+
+-- An unqualified reference to a common table expression wins even over a temporary view:
+> WITH rel(c1) AS (VALUES(3))
+ SELECT * FROM rel;
+ 3
+
+-- If CTEs are nested, the match nearest to the table reference takes precedence.
+> WITH rel(c1) AS (VALUES(3))
+ (WITH rel(c1) AS (VALUES(4))
+ SELECT * FROM rel);
+ 4
+
+-- To resolve the table instead of the CTE, qualify it:
+> WITH rel(c1) AS (VALUES(3))
+ (WITH rel(c1) AS (VALUES(4))
+ SELECT * FROM default.rel);
+ 1
+
+-- For a CTE to be visible it must contain the query
+> SELECT * FROM (WITH cte(c1) AS (VALUES(1))
+ SELECT 1),
+ cte;
+ [TABLE_OR_VIEW_NOT_FOUND] The table or view `cte` cannot be found.
+```
+
+## Function resolution
+
+A function reference is recognized by the mandatory trailing set of parentheses.
+
+It can resolve to:
+
+- A builtin function provided by Spark,
+- A temporary user defined function scoped to the current session, or
+- A persistent user defined function.
+
+Resolution of a function name depends on whether it is qualified:
+
+- **Qualified**
+
+ If the name is fully qualified with three parts: `catalog.schema.function`, it is unique.
+
+ If the name consists of two parts: `schema.function`, it is further qualified with the result of `SELECT current_catalog()` to make it unique.
+
+ The function is then looked up in the catalog.
+
+- **Unqualified**
+
+ For unqualified function names Spark follows a fixed order of precedence (`PATH`):
+
+ 1. **Builtin function**
+
+ If a function by this name exists among the set of built-in functions, that function is chosen.
+
+ 1. **Temporary function**
+
+ If a function by this name exists among the set of temporary functions, that function is chosen.
+
+ 1. **Persisted function**
+
+ Fully qualify the function name by pre-pending the result of `SELECT current_catalog()` and `SELECT current_schema()` and look it up as a persistent function.
+
+If the function cannot be resolved Spark raises an `UNRESOLVED_ROUTINE` error.
+
+### Examples
+
+```sql
+> USE CATALOG spark_catalog;
+> USE SCHEMA default;
+
+-- Create a function with the same name as a builtin
+> CREATE FUNCTION concat(a STRING, b STRING) RETURNS STRING
+ RETURN b || a;
+
+-- unqualified reference resolves to the builtin CONCAT
+> SELECT concat('hello', 'world');
+ helloworld
+
+-- Qualified reference resolves to the persistent function
+> SELECT default.concat('hello', 'world');
+ worldhello
+
+-- Create a persistent function
+> CREATE FUNCTION func(a INT, b INT) RETURNS INT
+ RETURN a + b;
+
+-- The persistent function is resolved without qualifying it
+> SELECT func(4, 2);
+ 6
+
+-- Create a conflicting temporary function
+> CREATE FUNCTION func(a INT, b INT) RETURNS INT
+ RETURN a / b;
+
+-- The temporary function takes precedent
+> SELECT func(4, 2);
+ 2
+
+-- To resolve the persistent function it now needs qualification
+> SELECT spark_catalog.default.func(4, 3);
+ 6
+```
diff --git a/docs/sql-ref-scripting.md b/docs/sql-ref-scripting.md
new file mode 100644
index 0000000000000..7d26bd07fed51
--- /dev/null
+++ b/docs/sql-ref-scripting.md
@@ -0,0 +1,85 @@
+---
+layout: global
+title: SQL Scripting
+displayTitle: SQL Scripting
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+You can employ powerful procedural logic using SQL/PSM standard-based scripting syntax.
+Any SQL script consists of and starts with a [compound statement](control-flow/compound-stmt.html) block (`BEGIN ... END`).
+A compound statement starts with a section to declare local variables, user-defined conditions, and condition handlers, which are used to catch exceptions.
+This is followed by the compound statement body, which consists of:
+
+- Flow control statements include loops over predicate expressions, [FOR](control-flow/for-stmt.html) loops over query results, conditional logic such as [IF](control-flow/if-stmt.html) and [CASE](control-flow/case-stmt.html), and means to break out loops such as [LEAVE](control-flow/leave-stmt.html) and [ITERATE](control-flow/iterate-stmt.html).
+- DDL statements such as `ALTER`, `CREATE`, `DROP`.
+- DML statements [INSERT](sql-ref-syntax-dml-insert-table.html).
+- [Queries](sql-ref-syntax-qry-select.html) that return result sets to the invoker of the script.
+- [SET](sql-ref-syntax-aux-set-var.html) statements to set local variables as well as session variables.
+- The [EXECUTE IMMEDIATE](sql-ref-syntax-aux-exec-imm.html) statement.
+- Nested compound statements, which provide nested scopes for variables, conditions, and condition handlers.
+
+## Passing data between the invoker and the compound statement
+
+There are two ways to pass data to and from a SQL script:
+
+- Use session variables to pass scalar values or small sets of arrays or maps from one SQL script to another.
+- Use parameter markers to pass scalar values or small sets of arrays or map data from a notebook widget, Python, or another language to the SQL Script.
+
+## Variable scoping
+
+Variables declared within a compound statement can be referenced in any expression within a compound statement.
+Spark resolves identifiers from the innermost scope outward, following the rules described in [Name Resolution](sql-ref-name-resolution.html).
+You can use the optional compound statement labels to disambiguate duplicate variable names.
+
+## Condition handling
+
+SQL Scripting supports condition handlers, which are used to intercept and process exceptions to `EXIT` processing of the SQL script.
+
+Condition handlers can be defined to handle three distinct classes of conditions:
+
+- One or more named conditions that can be a specific Spark-defined error class such as `DIVIDE_BY_ZERO` or a user-declared condition.
+ These handlers handle these specific conditions.
+
+- One or more `SQLSTATE`s, that can be raised by Spark.
+ These handlers can handle any condition associated with that `SQLSTATE`.
+
+- A generic `SQLEXCEPTION` handler can catch all conditions falling into the `SQLEXCEPTION` (any `SQLSTATE` which is not `XX***` and not `02***`).
+
+The following are used to decide which condition handler applies to an exception.
+This condition handler is called the **most appropriate handler**:
+
+- A condition handler cannot apply to any statement defined in its own body or the body of any condition handler declared in the same compound statement.
+
+- The applicable condition handlers defined in the innermost compound statement within which the exception was raised are appropriate.
+
+- If more than one appropriate handler is available, the most specific handler is the most appropriate.
+ For example, a handler on a named condition is more specific than one on a named `SQLSTATE`.
+ A generic `EXCEPTION` handler is the least specific.
+
+The outcome of a condition handler is to execute the statement following the compound statement that declared the handler to execute next.
+
+The following is a list of supported control flow statements:
+
+* [CASE](control-flow/case-stmt.html)
+* [compound statement](control-flow/compound-stmt.html)
+* [FOR](control-flow/for-stmt.html)
+* [IF](control-flow/if-stmt.html)
+* [ITERATE](control-flow/iterate-stmt.html)
+* [LEAVE](control-flow/leave-stmt.html)
+* [LOOP](control-flow/loop-stmt.html)
+* [REPEAT](control-flow/repeat-stmt.html)
+* [WHILE](control-flow/while-stmt.html)
diff --git a/docs/sql-ref-sketch-aggregates.md b/docs/sql-ref-sketch-aggregates.md
new file mode 100644
index 0000000000000..6b92ba7b3c9e4
--- /dev/null
+++ b/docs/sql-ref-sketch-aggregates.md
@@ -0,0 +1,988 @@
+---
+layout: global
+title: Sketch-Based Approximate Functions
+displayTitle: Sketch-Based Approximate Functions
+license: |
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+---
+
+Spark's SQL and DataFrame APIs provide a collection of sketch-based approximate functions powered by the [Apache DataSketches](https://datasketches.apache.org/) library. These functions enable efficient probabilistic computations on large datasets with bounded memory usage and accuracy guarantees.
+
+Sketches are compact data structures that summarize large datasets, supporting distributed aggregation through serialization and merging. This makes them ideal for use cases including (so far):
+- **Approximate count distinct** (HLL and Theta sketches)
+- **Approximate quantile estimation** (KLL sketches)
+- **Approximate frequent items** (Top-K sketches)
+- **Set operations** on distinct counts (Theta sketches)
+
+### Table of Contents
+
+* [HyperLogLog (HLL) Sketch Functions](#hyperloglog-hll-sketch-functions)
+ * [hll_sketch_agg](#hll_sketch_agg)
+ * [hll_union_agg](#hll_union_agg)
+ * [hll_sketch_estimate](#hll_sketch_estimate)
+ * [hll_union](#hll_union)
+* [Theta Sketch Functions](#theta-sketch-functions)
+ * [theta_sketch_agg](#theta_sketch_agg)
+ * [theta_union_agg](#theta_union_agg)
+ * [theta_intersection_agg](#theta_intersection_agg)
+ * [theta_sketch_estimate](#theta_sketch_estimate)
+ * [theta_union](#theta_union)
+ * [theta_intersection](#theta_intersection)
+ * [theta_difference](#theta_difference)
+* [KLL Quantile Sketch Functions](#kll-quantile-sketch-functions)
+ * [kll_sketch_agg_*](#kll_sketch_agg_)
+ * [kll_merge_agg_*](#kll_merge_agg_)
+ * [kll_sketch_to_string_*](#kll_sketch_to_string_)
+ * [kll_sketch_get_n_*](#kll_sketch_get_n_)
+ * [kll_sketch_merge_*](#kll_sketch_merge_)
+ * [kll_sketch_get_quantile_*](#kll_sketch_get_quantile_)
+ * [kll_sketch_get_rank_*](#kll_sketch_get_rank_)
+* [Approximate Top-K Functions](#approximate-top-k-functions)
+ * [approx_top_k_accumulate](#approx_top_k_accumulate)
+ * [approx_top_k_combine](#approx_top_k_combine)
+ * [approx_top_k_estimate](#approx_top_k_estimate)
+* [Best Practices](#best-practices)
+ * [Choosing Between HLL and Theta Sketches](#choosing-between-hll-and-theta-sketches)
+ * [Accuracy vs. Memory Trade-offs](#accuracy-vs-memory-trade-offs)
+ * [Storing and Reusing Sketches](#storing-and-reusing-sketches)
+* [Common Use Cases and Examples](#common-use-cases-and-examples)
+ * [Example: Tracking Daily Unique Users with HLL Sketches](#example-tracking-daily-unique-users-with-hll-sketches)
+ * [Example: Computing Percentiles Over Time with KLL Sketches](#example-computing-percentiles-over-time-with-kll-sketches)
+ * [Example: Set Operations with Theta Sketches](#example-set-operations-with-theta-sketches)
+ * [Example: Finding Trending Items with Top-K Sketches](#example-finding-trending-items-with-top-k-sketches)
+
+---
+
+## HyperLogLog (HLL) Sketch Functions
+
+HyperLogLog sketches provide approximate count distinct functionality with configurable accuracy and memory usage. They are well-suited for counting unique values in very large datasets.
+
+See the [Apache DataSketches HLL documentation](https://datasketches.apache.org/docs/HLL/HLL.html) for more information.
+
+### hll_sketch_agg
+
+Creates an HLL sketch from input values that can later be used to estimate count distinct.
+
+**Syntax:**
+```sql
+hll_sketch_agg(expr [, lgConfigK])
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `expr` | INT, BIGINT, STRING, or BINARY | The expression whose distinct values will be counted |
+| `lgConfigK` | INT (optional) | Log-base-2 of K, where K is the number of buckets. Range: 4-21. Default: 12. Higher values provide more accuracy but use more memory. |
+
+Returns a BINARY containing the HLL sketch in updatable binary representation.
+
+**Examples:**
+```sql
+-- Basic usage: create a sketch and estimate distinct count
+SELECT hll_sketch_estimate(hll_sketch_agg(col))
+FROM VALUES (1), (1), (2), (2), (3) tab(col);
+-- Result: 3
+
+-- With custom lgConfigK for higher accuracy
+SELECT hll_sketch_estimate(hll_sketch_agg(col, 16))
+FROM VALUES (50), (60), (60), (60), (75), (100) tab(col);
+-- Result: 4
+
+-- With string values
+SELECT hll_sketch_estimate(hll_sketch_agg(col))
+FROM VALUES ('abc'), ('def'), ('abc'), ('ghi'), ('abc') tab(col);
+-- Result: 3
+```
+
+**Notes:**
+- NULL values are ignored during aggregation.
+- Empty strings (for STRING type) and empty byte arrays (for BINARY type) are ignored.
+- The sketch can be stored and later merged with other sketches using `hll_union` or `hll_union_agg`.
+
+---
+
+### hll_union_agg
+
+Aggregates multiple HLL sketches into a single merged sketch.
+
+**Syntax:**
+```sql
+hll_union_agg(sketch [, allowDifferentLgConfigK])
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `sketch` | BINARY | An HLL sketch in binary format (produced by `hll_sketch_agg`) |
+| `allowDifferentLgConfigK` | BOOLEAN (optional) | If true, allows merging sketches with different lgConfigK values. Default: false. |
+
+Returns a BINARY containing the merged HLL sketch.
+
+**Examples:**
+```sql
+-- Merge sketches from different partitions
+SELECT hll_sketch_estimate(hll_union_agg(sketch, true))
+FROM (
+ SELECT hll_sketch_agg(col) as sketch
+ FROM VALUES (1) tab(col)
+ UNION ALL
+ SELECT hll_sketch_agg(col, 20) as sketch
+ FROM VALUES (1) tab(col)
+);
+-- Result: 1
+
+-- Standard merge (same lgConfigK)
+SELECT hll_sketch_estimate(hll_union_agg(sketch))
+FROM (
+ SELECT hll_sketch_agg(col) as sketch
+ FROM VALUES (1), (2) tab(col)
+ UNION ALL
+ SELECT hll_sketch_agg(col) as sketch
+ FROM VALUES (3), (4) tab(col)
+);
+-- Result: 4
+```
+
+**Notes:**
+- If `allowDifferentLgConfigK` is false and sketches have different lgConfigK values, an error is thrown.
+- The output sketch uses the minimum lgConfigK value of all input sketches when merging sketches with different sizes.
+
+---
+
+### hll_sketch_estimate
+
+Estimates the number of unique values from an HLL sketch.
+
+**Syntax:**
+```sql
+hll_sketch_estimate(sketch)
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `sketch` | BINARY | An HLL sketch in binary format |
+
+Returns a BIGINT representing the estimated count of distinct values.
+
+**Examples:**
+```sql
+SELECT hll_sketch_estimate(hll_sketch_agg(col))
+FROM VALUES (1), (1), (2), (2), (3) tab(col);
+-- Result: 3
+```
+
+**Errors:**
+- Throws an error if the input is not a valid HLL sketch binary representation.
+
+---
+
+### hll_union
+
+Merges two HLL sketches into one (scalar function).
+
+**Syntax:**
+```sql
+hll_union(first, second [, allowDifferentLgConfigK])
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `first` | BINARY | First HLL sketch |
+| `second` | BINARY | Second HLL sketch |
+| `allowDifferentLgConfigK` | BOOLEAN (optional) | Allow different lgConfigK values. Default: false. |
+
+Returns a BINARY containing the merged HLL sketch.
+
+**Examples:**
+```sql
+SELECT hll_sketch_estimate(
+ hll_union(
+ hll_sketch_agg(col1),
+ hll_sketch_agg(col2)))
+FROM VALUES (1, 4), (1, 4), (2, 5), (2, 5), (3, 6) tab(col1, col2);
+-- Result: 6
+```
+
+---
+
+## Theta Sketch Functions
+
+Theta sketches provide approximate count distinct with support for set operations (union, intersection, and difference). This makes them ideal for computing unique counts across overlapping datasets.
+
+See the [Apache DataSketches Theta documentation](https://datasketches.apache.org/docs/Theta/ThetaSketches.html) for more information.
+
+### theta_sketch_agg
+
+Creates a Theta sketch from input values.
+
+**Syntax:**
+```sql
+theta_sketch_agg(expr [, lgNomEntries])
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `expr` | INT, BIGINT, FLOAT, DOUBLE, STRING, BINARY, ARRAY<INT>, or ARRAY<BIGINT> | The expression whose distinct values will be counted |
+| `lgNomEntries` | INT (optional) | Log-base-2 of nominal entries. Range: 4-26. Default: 12. |
+
+Returns a BINARY containing the Theta sketch in compact binary representation.
+
+**Examples:**
+```sql
+-- Basic distinct count
+SELECT theta_sketch_estimate(theta_sketch_agg(col))
+FROM VALUES (1), (1), (2), (2), (3) tab(col);
+-- Result: 3
+
+-- With custom lgNomEntries
+SELECT theta_sketch_estimate(theta_sketch_agg(col, 22))
+FROM VALUES (1), (2), (3), (4), (5), (6), (7) tab(col);
+-- Result: 7
+
+-- With array values
+SELECT theta_sketch_estimate(theta_sketch_agg(col))
+FROM VALUES (ARRAY(1, 2)), (ARRAY(3, 4)), (ARRAY(1, 2)) tab(col);
+-- Result: 2
+```
+
+**Notes:**
+- NULL values are ignored.
+- Supports a wider range of input types compared to HLL sketches.
+- Empty arrays, empty strings, and empty binary values are ignored.
+
+---
+
+### theta_union_agg
+
+Aggregates multiple Theta sketches using union operation.
+
+**Syntax:**
+```sql
+theta_union_agg(sketch [, lgNomEntries])
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `sketch` | BINARY | A Theta sketch in binary format |
+| `lgNomEntries` | INT (optional) | Log-base-2 of nominal entries. Range: 4-26. Default: 12. |
+
+Returns a BINARY containing the merged Theta sketch.
+
+**Examples:**
+```sql
+SELECT theta_sketch_estimate(theta_union_agg(sketch, 15))
+FROM (
+ SELECT theta_sketch_agg(col1) as sketch
+ FROM VALUES (1), (2), (3), (4), (5), (6), (7) tab(col1)
+ UNION ALL
+ SELECT theta_sketch_agg(col2, 20) as sketch
+ FROM VALUES (5), (6), (7), (8), (9), (10), (11) tab(col2)
+);
+-- Result: 11
+```
+
+---
+
+### theta_intersection_agg
+
+Aggregates multiple Theta sketches using intersection operation (finds common distinct values).
+
+**Syntax:**
+```sql
+theta_intersection_agg(sketch)
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `sketch` | BINARY | A Theta sketch in binary format |
+
+Returns a BINARY containing the intersected Theta sketch.
+
+**Examples:**
+```sql
+SELECT theta_sketch_estimate(theta_intersection_agg(sketch))
+FROM (
+ SELECT theta_sketch_agg(col1) as sketch
+ FROM VALUES (1), (2), (3), (4), (5), (6), (7) tab(col1)
+ UNION ALL
+ SELECT theta_sketch_agg(col2) as sketch
+ FROM VALUES (5), (6), (7), (8), (9), (10), (11) tab(col2)
+);
+-- Result: 3 (values 5, 6, 7 are common)
+```
+
+---
+
+### theta_sketch_estimate
+
+Estimates the number of unique values from a Theta sketch.
+
+**Syntax:**
+```sql
+theta_sketch_estimate(sketch)
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `sketch` | BINARY | A Theta sketch in binary format |
+
+Returns a BIGINT representing the estimated count of distinct values.
+
+**Examples:**
+```sql
+SELECT theta_sketch_estimate(theta_sketch_agg(col))
+FROM VALUES (1), (1), (2), (2), (3) tab(col);
+-- Result: 3
+```
+
+---
+
+### theta_union
+
+Merges two Theta sketches using union (scalar function).
+
+**Syntax:**
+```sql
+theta_union(first, second [, lgNomEntries])
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `first` | BINARY | First Theta sketch |
+| `second` | BINARY | Second Theta sketch |
+| `lgNomEntries` | INT (optional) | Log-base-2 of nominal entries. Range: 4-26. Default: 12. |
+
+Returns a BINARY containing the merged Theta sketch.
+
+**Examples:**
+```sql
+SELECT theta_sketch_estimate(
+ theta_union(
+ theta_sketch_agg(col1),
+ theta_sketch_agg(col2)))
+FROM VALUES (1, 4), (1, 4), (2, 5), (2, 5), (3, 6) tab(col1, col2);
+-- Result: 6
+```
+
+---
+
+### theta_intersection
+
+Computes the intersection of two Theta sketches (scalar function).
+
+**Syntax:**
+```sql
+theta_intersection(first, second)
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `first` | BINARY | First Theta sketch |
+| `second` | BINARY | Second Theta sketch |
+
+Returns a BINARY containing the intersected Theta sketch.
+
+**Examples:**
+```sql
+SELECT theta_sketch_estimate(
+ theta_intersection(
+ theta_sketch_agg(col1),
+ theta_sketch_agg(col2)))
+FROM VALUES (5, 4), (1, 4), (2, 5), (2, 5), (3, 1) tab(col1, col2);
+-- Result: 2 (values 1 and 5 are common)
+```
+
+---
+
+### theta_difference
+
+Computes the set difference of two Theta sketches (A - B).
+
+**Syntax:**
+```sql
+theta_difference(first, second)
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `first` | BINARY | First Theta sketch (A) |
+| `second` | BINARY | Second Theta sketch (B) |
+
+Returns a BINARY containing a Theta sketch representing values in A but not in B.
+
+**Examples:**
+```sql
+SELECT theta_sketch_estimate(
+ theta_difference(
+ theta_sketch_agg(col1),
+ theta_sketch_agg(col2)))
+FROM VALUES (5, 4), (1, 4), (2, 5), (2, 5), (3, 1) tab(col1, col2);
+-- Result: 2 (values 2 and 3 are in col1 but not col2)
+```
+
+---
+
+## KLL Quantile Sketch Functions
+
+KLL (K-Linear-Logarithmic) sketches provide approximate quantile estimation. They are useful for computing percentiles, medians, and other order statistics on large datasets without sorting.
+
+See the [Apache DataSketches KLL documentation](https://datasketches.apache.org/docs/KLL/KLLSketch.html) for more information.
+
+KLL functions are type-specific to avoid precision loss:
+- **BIGINT** variants: For integer types (TINYINT, SMALLINT, INT, BIGINT)
+- **FLOAT** variants: For FLOAT values only
+- **DOUBLE** variants: For FLOAT and DOUBLE values
+
+### kll_sketch_agg_*
+
+Creates a KLL sketch from numeric values for quantile estimation.
+
+**Syntax:**
+```sql
+kll_sketch_agg_bigint(expr [, k])
+kll_sketch_agg_float(expr [, k])
+kll_sketch_agg_double(expr [, k])
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `expr` | Numeric (see variants above) | The numeric column to summarize |
+| `k` | INT (optional) | Controls accuracy and size. Range: 8-65535. Default: 200 (~1.65% normalized rank error). |
+
+Returns a BINARY containing the KLL sketch in compact binary representation.
+
+**Examples:**
+```sql
+-- Get median (0.5 quantile)
+SELECT kll_sketch_get_quantile_bigint(kll_sketch_agg_bigint(col), 0.5)
+FROM VALUES (1), (2), (3), (4), (5), (6), (7) tab(col);
+-- Result: 4
+
+-- With custom k for higher accuracy
+SELECT kll_sketch_get_quantile_bigint(kll_sketch_agg_bigint(col, 400), 0.5)
+FROM VALUES (1), (2), (3), (4), (5), (6), (7) tab(col);
+-- Result: 4
+```
+
+**Notes:**
+- Use the appropriate variant to avoid precision loss: use `_bigint` for integers, `_float` for floats, `_double` for doubles.
+- NULL values are ignored during aggregation.
+
+---
+
+### kll_merge_agg_*
+
+Aggregates multiple KLL sketches of the same type by merging them together. This is useful for combining sketches created in separate aggregations (e.g., from different partitions or time windows). These are aggregate functions.
+
+**Syntax:**
+```sql
+kll_merge_agg_bigint(sketch [, k])
+kll_merge_agg_float(sketch [, k])
+kll_merge_agg_double(sketch [, k])
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `sketch` | BINARY | A KLL sketch in binary format (e.g., from `kll_sketch_agg_*`) |
+| `k` | INT (optional) | Controls accuracy and size of the merged sketch. Range: 8-65535. If not specified, the merged sketch adopts the k value from the first input sketch. |
+
+Returns a BINARY containing the merged KLL sketch.
+
+**Examples:**
+```sql
+-- Merge sketches from different partitions
+SELECT kll_sketch_get_quantile_bigint(
+ kll_merge_agg_bigint(sketch),
+ 0.5
+)
+FROM (
+ SELECT kll_sketch_agg_bigint(col) as sketch
+ FROM VALUES (1), (2), (3) tab(col)
+ UNION ALL
+ SELECT kll_sketch_agg_bigint(col) as sketch
+ FROM VALUES (4), (5), (6) tab(col)
+);
+-- Result: 3
+
+-- Get the total count from merged sketches
+SELECT kll_sketch_get_n_bigint(kll_merge_agg_bigint(sketch))
+FROM (
+ SELECT kll_sketch_agg_bigint(col) as sketch
+ FROM VALUES (1), (2), (3) tab(col)
+ UNION ALL
+ SELECT kll_sketch_agg_bigint(col) as sketch
+ FROM VALUES (4), (5), (6) tab(col)
+);
+-- Result: 6
+```
+
+**Notes:**
+- When `k` is not specified, the merged sketch adopts the k value from the first input sketch.
+- The merge operation can handle input sketches with different k values.
+- NULL values are ignored during aggregation.
+- Use this function when you need to merge multiple sketches in an aggregation context. For merging exactly two sketches, use the scalar `kll_sketch_merge_*` functions instead.
+
+---
+
+### kll_sketch_to_string_*
+
+Returns a human-readable summary of the sketch.
+
+**Syntax:**
+```sql
+kll_sketch_to_string_bigint(sketch)
+kll_sketch_to_string_float(sketch)
+kll_sketch_to_string_double(sketch)
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `sketch` | BINARY | A KLL sketch of the corresponding type |
+
+Returns a STRING containing a human-readable summary including sketch parameters and statistics.
+
+---
+
+### kll_sketch_get_n_*
+
+Returns the number of items collected in the sketch.
+
+**Syntax:**
+```sql
+kll_sketch_get_n_bigint(sketch)
+kll_sketch_get_n_float(sketch)
+kll_sketch_get_n_double(sketch)
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `sketch` | BINARY | A KLL sketch of the corresponding type |
+
+Returns a BIGINT representing the count of items in the sketch.
+
+**Examples:**
+```sql
+SELECT kll_sketch_get_n_bigint(kll_sketch_agg_bigint(col))
+FROM VALUES (1), (2), (3), (4), (5), (6), (7) tab(col);
+-- Result: 7
+```
+
+---
+
+### kll_sketch_merge_*
+
+Merges two KLL sketches of the same type. These are scalar functions.
+
+**Syntax:**
+```sql
+kll_sketch_merge_bigint(left, right)
+kll_sketch_merge_float(left, right)
+kll_sketch_merge_double(left, right)
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `left` | BINARY | First KLL sketch |
+| `right` | BINARY | Second KLL sketch (must be same type as left) |
+
+Returns a BINARY containing the merged KLL sketch.
+
+**Examples:**
+```sql
+-- Merge two sketches from different data partitions
+SELECT kll_sketch_get_quantile_bigint(
+ kll_sketch_merge_bigint(
+ kll_sketch_agg_bigint(col1),
+ kll_sketch_agg_bigint(col2)), 0.5)
+FROM VALUES (1, 6), (2, 7), (3, 8), (4, 9), (5, 10) tab(col1, col2);
+-- Result: approximately 5 (median of 1-10)
+```
+
+**Errors:**
+- Throws an error if sketches are of incompatible types or formats.
+
+**Notes:**
+- The merge operation can handle input sketches with different k values.
+- Use this function when you need to merge exactly two sketches in an scalar context. For merging multiple sketches in an aggregation context, use the aggregate `kll_merge_agg_*` functions instead.
+
+---
+
+### kll_sketch_get_quantile_*
+
+Gets the approximate value at a given quantile rank.
+
+**Syntax:**
+```sql
+kll_sketch_get_quantile_bigint(sketch, rank)
+kll_sketch_get_quantile_float(sketch, rank)
+kll_sketch_get_quantile_double(sketch, rank)
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `sketch` | BINARY | A KLL sketch of the corresponding type |
+| `rank` | DOUBLE or ARRAY<DOUBLE> | Quantile rank(s) between 0.0 and 1.0. Use 0.5 for median, 0.95 for 95th percentile, etc. |
+
+Returns the approximate value at the given quantile:
+- If `rank` is a scalar: Returns the corresponding type (BIGINT, FLOAT, or DOUBLE)
+- If `rank` is an array: Returns ARRAY of the corresponding type
+
+**Examples:**
+```sql
+-- Get the median
+SELECT kll_sketch_get_quantile_bigint(kll_sketch_agg_bigint(col), 0.5)
+FROM VALUES (1), (2), (3), (4), (5), (6), (7) tab(col);
+-- Result: 4
+
+-- Get multiple percentiles at once
+SELECT kll_sketch_get_quantile_bigint(
+ kll_sketch_agg_bigint(col),
+ ARRAY(0.25, 0.5, 0.75, 0.95))
+FROM VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9), (10) tab(col);
+-- Result: Array of values at 25th, 50th, 75th, and 95th percentiles
+```
+
+**Errors:**
+- Throws an error if rank values are outside [0.0, 1.0].
+- Returns NULL if the input sketch is NULL.
+
+---
+
+### kll_sketch_get_rank_*
+
+Gets the normalized rank (0.0 to 1.0) of a given value in the sketch's distribution.
+
+**Syntax:**
+```sql
+kll_sketch_get_rank_bigint(sketch, value)
+kll_sketch_get_rank_float(sketch, value)
+kll_sketch_get_rank_double(sketch, value)
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `sketch` | BINARY | A KLL sketch of the corresponding type |
+| `value` | Corresponding type (BIGINT, FLOAT, or DOUBLE) | The value to find the rank for |
+
+Returns a DOUBLE representing the normalized rank between 0.0 and 1.0.
+
+**Examples:**
+```sql
+-- Find what percentile the value 3 is at
+SELECT kll_sketch_get_rank_bigint(kll_sketch_agg_bigint(col), 3)
+FROM VALUES (1), (2), (3), (4), (5), (6), (7) tab(col);
+-- Result: approximately 0.43 (3 is around the 43rd percentile)
+```
+
+---
+
+## Approximate Top-K Functions
+
+Top-K functions estimate the most frequent items (heavy hitters) in a dataset using the DataSketches Frequent Items sketch.
+
+See the [Apache DataSketches Frequency documentation](https://datasketches.apache.org/docs/Frequency/FrequencySketches.html) for more information.
+
+### approx_top_k_accumulate
+
+Creates a sketch that can be stored and later combined or estimated. Useful for pre-aggregating data.
+
+**Syntax:**
+```sql
+approx_top_k_accumulate(expr [, maxItemsTracked])
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `expr` | Same as `approx_top_k` | The column to accumulate |
+| `maxItemsTracked` | INT (optional) | Maximum items tracked. Range: 1 to 1,000,000. Default: 10,000. |
+
+Returns a STRUCT containing a sketch state that can be passed to `approx_top_k_combine` or `approx_top_k_estimate`.
+
+**Examples:**
+```sql
+-- Accumulate then estimate
+SELECT approx_top_k_estimate(approx_top_k_accumulate(expr))
+FROM VALUES (0), (0), (1), (1), (2), (3), (4), (4) tab(expr);
+-- Result: [{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}]
+```
+
+---
+
+### approx_top_k_combine
+
+Combines multiple sketches into a single sketch.
+
+**Syntax:**
+```sql
+approx_top_k_combine(state [, maxItemsTracked])
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `state` | STRUCT | A sketch state from `approx_top_k_accumulate` or `approx_top_k_combine` |
+| `maxItemsTracked` | INT (optional) | If specified, sets the combined sketch size. If not specified, all input sketches must have the same maxItemsTracked. |
+
+Returns a STRUCT containing the combined sketch state.
+
+**Examples:**
+```sql
+-- Combine sketches from different partitions
+SELECT approx_top_k_estimate(approx_top_k_combine(sketch, 10000), 5)
+FROM (
+ SELECT approx_top_k_accumulate(expr) AS sketch
+ FROM VALUES (0), (0), (1), (1) tab(expr)
+ UNION ALL
+ SELECT approx_top_k_accumulate(expr) AS sketch
+ FROM VALUES (2), (3), (4), (4) tab(expr)
+);
+-- Result: [{"item":0,"count":2},{"item":4,"count":2},{"item":1,"count":2},{"item":2,"count":1},{"item":3,"count":1}]
+```
+
+**Errors:**
+- Throws an error if input sketches have different `maxItemsTracked` values and no explicit value is provided.
+- Throws an error if input sketches have different item data types.
+
+---
+
+### approx_top_k_estimate
+
+Extracts the top K items from a sketch.
+
+**Syntax:**
+```sql
+approx_top_k_estimate(state [, k])
+```
+
+| Argument | Type | Description |
+|----------|------|-------------|
+| `state` | STRUCT | A sketch state from `approx_top_k_accumulate` or `approx_top_k_combine` |
+| `k` | INT (optional) | Number of top items to return. Default: 5. |
+
+Returns an ARRAY<STRUCT<item, count>> containing the frequent items sorted by count descending.
+
+**Examples:**
+```sql
+SELECT approx_top_k_estimate(approx_top_k_accumulate(expr), 2)
+FROM VALUES 'a', 'b', 'c', 'c', 'c', 'c', 'd', 'd' tab(expr);
+-- Result: [{"item":"c","count":4},{"item":"d","count":2}]
+```
+
+---
+
+## Best Practices
+
+### Choosing Between HLL and Theta Sketches
+
+| Use Case | Recommended Sketch |
+|----------|-------------------|
+| Simple count distinct | HLL (more memory efficient) |
+| Set operations (union, intersection, difference) | Theta |
+| Very high cardinality with moderate accuracy | HLL with higher lgConfigK |
+| Need to compute overlaps between datasets | Theta |
+
+### Accuracy vs. Memory Trade-offs
+
+| Sketch Type | Parameter | Effect of Increasing |
+|-------------|-----------|---------------------|
+| HLL | lgConfigK | Higher accuracy, more memory (2^lgConfigK bytes) |
+| Theta | lgNomEntries | Higher accuracy, more memory (8 * 2^lgNomEntries bytes) |
+| KLL | k | Higher accuracy, more memory |
+| Top-K | maxItemsTracked | Better heavy-hitter detection, more memory |
+
+### Storing and Reusing Sketches
+
+Sketches can be stored in BINARY columns and later merged:
+
+```sql
+-- Create a table to store daily sketches
+CREATE TABLE daily_user_sketches (
+ date DATE,
+ user_sketch BINARY
+);
+
+-- Insert daily sketches
+INSERT INTO daily_user_sketches
+SELECT current_date(), hll_sketch_agg(user_id)
+FROM events;
+
+-- Compute weekly unique users by merging daily sketches
+SELECT hll_sketch_estimate(hll_union_agg(user_sketch))
+FROM daily_user_sketches
+WHERE date BETWEEN '2024-01-01' AND '2024-01-07';
+```
+
+---
+
+## Common Use Cases and Examples
+
+Sketches are particularly valuable for periodic ETL jobs where you need to maintain running statistics across multiple batches of data. The general workflow is:
+
+1. **Aggregate** input values into a sketch using an aggregate function
+2. **Store** the sketch (as BINARY) in a table
+3. **Merge** new sketches with previously stored sketches
+4. **Query** the final sketch to get approximate answers
+
+### Example: Tracking Daily Unique Users with HLL Sketches
+
+This example shows how to maintain a running count of unique users across daily batches.
+
+```sql
+-- Create a table to store daily HLL sketches
+CREATE TABLE daily_user_sketches (
+ event_date DATE,
+ user_sketch BINARY
+) USING PARQUET;
+
+-- Day 1: Process first batch of events and store the sketch
+INSERT INTO daily_user_sketches
+SELECT
+ DATE'2024-01-01' as event_date,
+ hll_sketch_agg(user_id) as user_sketch
+FROM day1_events;
+
+-- Day 2: Process second batch and store its sketch
+INSERT INTO daily_user_sketches
+SELECT
+ DATE'2024-01-02' as event_date,
+ hll_sketch_agg(user_id) as user_sketch
+FROM day2_events;
+
+-- Query: Get unique users for a single day
+SELECT
+ event_date,
+ hll_sketch_estimate(user_sketch) as unique_users
+FROM daily_user_sketches
+WHERE event_date = DATE'2024-01-01';
+
+-- Query: Get unique users across a date range (merging sketches)
+SELECT hll_sketch_estimate(hll_union_agg(user_sketch)) as unique_users_in_week
+FROM daily_user_sketches
+WHERE event_date BETWEEN DATE'2024-01-01' AND DATE'2024-01-07';
+```
+
+### Example: Computing Percentiles Over Time with KLL Sketches
+
+This example shows how to track response time percentiles across hourly batches.
+
+```sql
+-- Create a table to store hourly KLL sketches for response times
+CREATE TABLE hourly_latency_sketches (
+ hour_ts TIMESTAMP,
+ latency_sketch BINARY
+) USING PARQUET;
+
+-- Process each hour's data and store the sketch
+INSERT INTO hourly_latency_sketches
+SELECT
+ DATE_TRUNC('hour', event_time) as hour_ts,
+ kll_sketch_agg_bigint(response_time_ms) as latency_sketch
+FROM hourly_events
+GROUP BY DATE_TRUNC('hour', event_time);
+
+-- Query: Get p50, p95, p99 for a specific hour
+SELECT
+ hour_ts,
+ kll_sketch_get_quantile_bigint(latency_sketch, 0.5) as p50_ms,
+ kll_sketch_get_quantile_bigint(latency_sketch, 0.95) as p95_ms,
+ kll_sketch_get_quantile_bigint(latency_sketch, 0.99) as p99_ms
+FROM hourly_latency_sketches
+WHERE hour_ts = TIMESTAMP'2024-01-15 14:00:00';
+
+-- Query: Get percentiles across a full day by merging hourly sketches
+WITH daily_sketch AS (
+ SELECT kll_merge_agg_bigint(latency_sketch) as merged_sketch
+ FROM hourly_latency_sketches
+ WHERE DATE(hour_ts) = DATE'2024-01-15'
+)
+SELECT
+ kll_sketch_get_quantile_bigint(merged_sketch, 0.5) as p50_ms,
+ kll_sketch_get_quantile_bigint(merged_sketch, 0.95) as p95_ms,
+ kll_sketch_get_quantile_bigint(merged_sketch, 0.99) as p99_ms
+FROM daily_sketch;
+```
+
+### Example: Set Operations with Theta Sketches
+
+Theta sketches support set operations, making them useful for analyzing overlapping populations.
+
+```sql
+-- Create sketches for users who performed different actions
+CREATE TABLE action_sketches (
+ action_type STRING,
+ user_sketch BINARY
+) USING PARQUET;
+
+-- Store sketches for each action type
+INSERT INTO action_sketches
+SELECT 'purchase', theta_sketch_agg(user_id) FROM purchases;
+
+INSERT INTO action_sketches
+SELECT 'add_to_cart', theta_sketch_agg(user_id) FROM cart_additions;
+
+INSERT INTO action_sketches
+SELECT 'page_view', theta_sketch_agg(user_id) FROM page_views;
+
+-- Query: How many users purchased?
+SELECT theta_sketch_estimate(user_sketch) as purchasers
+FROM action_sketches WHERE action_type = 'purchase';
+
+-- Query: How many users added to cart but did NOT purchase?
+SELECT theta_sketch_estimate(
+ theta_difference(
+ (SELECT user_sketch FROM action_sketches WHERE action_type = 'add_to_cart'),
+ (SELECT user_sketch FROM action_sketches WHERE action_type = 'purchase')
+ )
+) as cart_abandoners;
+
+-- Query: How many users both viewed pages AND purchased (intersection)?
+SELECT theta_sketch_estimate(
+ theta_intersection(
+ (SELECT user_sketch FROM action_sketches WHERE action_type = 'page_view'),
+ (SELECT user_sketch FROM action_sketches WHERE action_type = 'purchase')
+ )
+) as engaged_purchasers;
+```
+
+### Example: Finding Trending Items with Top-K Sketches
+
+Track the most frequently occurring items across batches.
+
+```sql
+-- Create a table to store hourly top-k sketches
+CREATE TABLE hourly_search_sketches (
+ hour_ts TIMESTAMP,
+ search_sketch STRUCT
+) USING PARQUET;
+
+-- Process each hour's search queries
+INSERT INTO hourly_search_sketches
+SELECT
+ DATE_TRUNC('hour', search_time) as hour_ts,
+ approx_top_k_accumulate(search_term, 10000) as search_sketch
+FROM search_logs
+GROUP BY DATE_TRUNC('hour', search_time);
+
+-- Query: Get top 10 searches for a specific hour
+SELECT approx_top_k_estimate(search_sketch, 10) as top_searches
+FROM hourly_search_sketches
+WHERE hour_ts = TIMESTAMP'2024-01-15 14:00:00';
+
+-- Query: Get top 10 searches across the full day by combining sketches
+SELECT approx_top_k_estimate(
+ approx_top_k_combine(search_sketch, 10000),
+ 10
+) as daily_top_searches
+FROM hourly_search_sketches
+WHERE DATE(hour_ts) = DATE'2024-01-15';
+```
diff --git a/docs/sql-ref-syntax-ddl-create-view.md b/docs/sql-ref-syntax-ddl-create-view.md
index 21174f12300e3..2d832636b38fc 100644
--- a/docs/sql-ref-syntax-ddl-create-view.md
+++ b/docs/sql-ref-syntax-ddl-create-view.md
@@ -47,6 +47,7 @@ CREATE [ OR REPLACE ] [ [ GLOBAL ] TEMPORARY ] VIEW [ IF NOT EXISTS ] view_ident
* **IF NOT EXISTS**
Creates a view if it does not exist.
+ This clause is not supported for `TEMPORARY` views yet.
* **view_identifier**
@@ -86,8 +87,8 @@ CREATE OR REPLACE VIEW experienced_employee
AS SELECT id, name FROM all_employee
WHERE working_years > 5;
--- Create a global temporary view `subscribed_movies` if it does not exist.
-CREATE GLOBAL TEMPORARY VIEW IF NOT EXISTS subscribed_movies
+-- Create a global temporary view `subscribed_movies`.
+CREATE GLOBAL TEMPORARY VIEW subscribed_movies
AS SELECT mo.member_id, mb.full_name, mo.movie_title
FROM movies AS mo INNER JOIN members AS mb
ON mo.member_id = mb.id;
diff --git a/docs/sql-ref-syntax-qry-select-tvf.md b/docs/sql-ref-syntax-qry-select-tvf.md
index 099f423f18531..9dd73a3f25263 100644
--- a/docs/sql-ref-syntax-qry-select-tvf.md
+++ b/docs/sql-ref-syntax-qry-select-tvf.md
@@ -172,5 +172,5 @@ SELECT * FROM test, LATERAL explode (ARRAY(3,4)) AS c2;
### Related Statements
* [SELECT](sql-ref-syntax-qry-select.html)
-* [LATERAL](sql-ref-syntax-qry-select-lateral-subquery.md)
+* [LATERAL](sql-ref-syntax-qry-select-lateral-subquery.html)
* [LATERAL VIEW Clause](sql-ref-syntax-qry-select-lateral-view.html)
diff --git a/docs/sql-ref-syntax.md b/docs/sql-ref-syntax.md
index 3dc7d47c4f454..0a2b9ba34b522 100644
--- a/docs/sql-ref-syntax.md
+++ b/docs/sql-ref-syntax.md
@@ -90,6 +90,20 @@ ability to generate logical and physical plan for a given query using
* [star (*) Clause](sql-ref-syntax-qry-star.html)
* [EXPLAIN](sql-ref-syntax-qry-explain.html)
+### SQL Scripting Statements
+
+You use SQL scripting to execute procedural logic in SQL.
+
+* [CASE](control-flow/case-stmt.html)
+* [compound statement](control-flow/compound-stmt.html)
+* [FOR](control-flow/for-stmt.html)
+* [IF](control-flow/if-stmt.html)
+* [ITERATE](control-flow/iterate-stmt.html)
+* [LEAVE](control-flow/leave-stmt.html)
+* [LOOP](control-flow/loop-stmt.html)
+* [REPEAT](control-flow/repeat-stmt.html)
+* [WHILE](control-flow/while-stmt.html)
+
### Auxiliary Statements
* [ADD FILE](sql-ref-syntax-aux-resource-mgmt-add-file.html)
diff --git a/docs/sql-ref.md b/docs/sql-ref.md
index 6d557caaca3d6..cf9a887bd4928 100644
--- a/docs/sql-ref.md
+++ b/docs/sql-ref.md
@@ -37,9 +37,12 @@ Spark SQL is Apache Spark's module for working with structured data. This guide
* [IDENTIFIER clause](sql-ref-identifier-clause.html)
* [Literals](sql-ref-literals.html)
* [Null Semantics](sql-ref-null-semantics.html)
+ * [Name Resolution](sql-ref-name-resolution.html)
+ * [SQL Scripting](sql-ref-scripting.html)
* [SQL Syntax](sql-ref-syntax.html)
* [DDL Statements](sql-ref-syntax.html#ddl-statements)
* [DML Statements](sql-ref-syntax.html#dml-statements)
* [Data Retrieval Statements](sql-ref-syntax.html#data-retrieval-statements)
+ * [SQL Scripting Statements](sql-ref-syntax.html#sql-scripting-statements)
* [Auxiliary Statements](sql-ref-syntax.html#auxiliary-statements)
* [Pipe Syntax](sql-pipe-syntax.html)
diff --git a/examples/pom.xml b/examples/pom.xml
index 5b654b89d7fd0..2e863dac54a09 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../pom.xml
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SparkDataFramePi.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SparkDataFramePi.scala
index 0102b2d291e9f..bddd6f9f206c0 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/SparkDataFramePi.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/SparkDataFramePi.scala
@@ -31,7 +31,7 @@ object SparkDataFramePi {
import spark.implicits._
val slices = if (args.length > 0) args(0).toInt else 2
val n = math.min(100000L * slices, Int.MaxValue).toInt // avoid overflow
- val count = spark.range(0, n, 1, slices)
+ val count = spark.range(1, n, 1, slices)
.select((pow(rand() * 2 - 1, lit(2)) + pow(rand() * 2 - 1, lit(2))).as("v"))
.where($"v" <= 1)
.count()
diff --git a/graphx/pom.xml b/graphx/pom.xml
index c165485652861..4448d976179fc 100644
--- a/graphx/pom.xml
+++ b/graphx/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../pom.xml
diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml
index dbdbe8846e189..390bab68a82de 100644
--- a/hadoop-cloud/pom.xml
+++ b/hadoop-cloud/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../pom.xml
diff --git a/launcher/pom.xml b/launcher/pom.xml
index a4a44db2e6ac6..c59360214a178 100644
--- a/launcher/pom.xml
+++ b/launcher/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../pom.xml
diff --git a/licenses-binary/LICENSE-check-qual.txt b/licenses-binary/LICENSE-check-qual.txt
deleted file mode 100644
index d542ab3ec3ed8..0000000000000
--- a/licenses-binary/LICENSE-check-qual.txt
+++ /dev/null
@@ -1,413 +0,0 @@
-The Checker Framework
-Copyright 2004-present by the Checker Framework developers
-
-
-Most of the Checker Framework is licensed under the GNU General Public
-License, version 2 (GPL2), with the classpath exception. The text of this
-license appears below. This is the same license used for OpenJDK.
-
-A few parts of the Checker Framework have more permissive licenses, notably
-the parts that you might want to include with your own program.
-
- * The annotations and utility files are licensed under the MIT License.
- (The text of this license also appears below.) This applies to
- checker-qual*.jar and checker-util.jar and all the files that appear in
- them, which is all files in checker-qual and checker-util directories.
- It also applies to the cleanroom implementations of
- third-party annotations (in checker/src/testannotations/,
- framework/src/main/java/org/jmlspecs/, and
- framework/src/main/java/com/google/).
-
-The Checker Framework includes annotations for some libraries. Those in
-.astub files use the MIT License. Those in https://github.com/typetools/jdk
-(which appears in the annotated-jdk directory of file checker.jar) use the
-GPL2 license.
-
-Some external libraries that are included with the Checker Framework
-distribution have different licenses. Here are some examples.
-
- * JavaParser is dual licensed under the LGPL or the Apache license -- you
- may use it under whichever one you want. (The JavaParser source code
- contains a file with the text of the GPL, but it is not clear why, since
- JavaParser does not use the GPL.) See
- https://github.com/typetools/stubparser .
-
- * Annotation Tools (https://github.com/typetools/annotation-tools) uses
- the MIT license.
-
- * Libraries in plume-lib (https://github.com/plume-lib/) are licensed
- under the MIT License.
-
-===========================================================================
-
-The GNU General Public License (GPL)
-
-Version 2, June 1991
-
-Copyright (C) 1989, 1991 Free Software Foundation, Inc.
-59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
-
-Everyone is permitted to copy and distribute verbatim copies of this license
-document, but changing it is not allowed.
-
-Preamble
-
-The licenses for most software are designed to take away your freedom to share
-and change it. By contrast, the GNU General Public License is intended to
-guarantee your freedom to share and change free software--to make sure the
-software is free for all its users. This General Public License applies to
-most of the Free Software Foundation's software and to any other program whose
-authors commit to using it. (Some other Free Software Foundation software is
-covered by the GNU Library General Public License instead.) You can apply it to
-your programs, too.
-
-When we speak of free software, we are referring to freedom, not price. Our
-General Public Licenses are designed to make sure that you have the freedom to
-distribute copies of free software (and charge for this service if you wish),
-that you receive source code or can get it if you want it, that you can change
-the software or use pieces of it in new free programs; and that you know you
-can do these things.
-
-To protect your rights, we need to make restrictions that forbid anyone to deny
-you these rights or to ask you to surrender the rights. These restrictions
-translate to certain responsibilities for you if you distribute copies of the
-software, or if you modify it.
-
-For example, if you distribute copies of such a program, whether gratis or for
-a fee, you must give the recipients all the rights that you have. You must
-make sure that they, too, receive or can get the source code. And you must
-show them these terms so they know their rights.
-
-We protect your rights with two steps: (1) copyright the software, and (2)
-offer you this license which gives you legal permission to copy, distribute
-and/or modify the software.
-
-Also, for each author's protection and ours, we want to make certain that
-everyone understands that there is no warranty for this free software. If the
-software is modified by someone else and passed on, we want its recipients to
-know that what they have is not the original, so that any problems introduced
-by others will not reflect on the original authors' reputations.
-
-Finally, any free program is threatened constantly by software patents. We
-wish to avoid the danger that redistributors of a free program will
-individually obtain patent licenses, in effect making the program proprietary.
-To prevent this, we have made it clear that any patent must be licensed for
-everyone's free use or not licensed at all.
-
-The precise terms and conditions for copying, distribution and modification
-follow.
-
-TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
-
-0. This License applies to any program or other work which contains a notice
-placed by the copyright holder saying it may be distributed under the terms of
-this General Public License. The "Program", below, refers to any such program
-or work, and a "work based on the Program" means either the Program or any
-derivative work under copyright law: that is to say, a work containing the
-Program or a portion of it, either verbatim or with modifications and/or
-translated into another language. (Hereinafter, translation is included
-without limitation in the term "modification".) Each licensee is addressed as
-"you".
-
-Activities other than copying, distribution and modification are not covered by
-this License; they are outside its scope. The act of running the Program is
-not restricted, and the output from the Program is covered only if its contents
-constitute a work based on the Program (independent of having been made by
-running the Program). Whether that is true depends on what the Program does.
-
-1. You may copy and distribute verbatim copies of the Program's source code as
-you receive it, in any medium, provided that you conspicuously and
-appropriately publish on each copy an appropriate copyright notice and
-disclaimer of warranty; keep intact all the notices that refer to this License
-and to the absence of any warranty; and give any other recipients of the
-Program a copy of this License along with the Program.
-
-You may charge a fee for the physical act of transferring a copy, and you may
-at your option offer warranty protection in exchange for a fee.
-
-2. You may modify your copy or copies of the Program or any portion of it, thus
-forming a work based on the Program, and copy and distribute such modifications
-or work under the terms of Section 1 above, provided that you also meet all of
-these conditions:
-
- a) You must cause the modified files to carry prominent notices stating
- that you changed the files and the date of any change.
-
- b) You must cause any work that you distribute or publish, that in whole or
- in part contains or is derived from the Program or any part thereof, to be
- licensed as a whole at no charge to all third parties under the terms of
- this License.
-
- c) If the modified program normally reads commands interactively when run,
- you must cause it, when started running for such interactive use in the
- most ordinary way, to print or display an announcement including an
- appropriate copyright notice and a notice that there is no warranty (or
- else, saying that you provide a warranty) and that users may redistribute
- the program under these conditions, and telling the user how to view a copy
- of this License. (Exception: if the Program itself is interactive but does
- not normally print such an announcement, your work based on the Program is
- not required to print an announcement.)
-
-These requirements apply to the modified work as a whole. If identifiable
-sections of that work are not derived from the Program, and can be reasonably
-considered independent and separate works in themselves, then this License, and
-its terms, do not apply to those sections when you distribute them as separate
-works. But when you distribute the same sections as part of a whole which is a
-work based on the Program, the distribution of the whole must be on the terms
-of this License, whose permissions for other licensees extend to the entire
-whole, and thus to each and every part regardless of who wrote it.
-
-Thus, it is not the intent of this section to claim rights or contest your
-rights to work written entirely by you; rather, the intent is to exercise the
-right to control the distribution of derivative or collective works based on
-the Program.
-
-In addition, mere aggregation of another work not based on the Program with the
-Program (or with a work based on the Program) on a volume of a storage or
-distribution medium does not bring the other work under the scope of this
-License.
-
-3. You may copy and distribute the Program (or a work based on it, under
-Section 2) in object code or executable form under the terms of Sections 1 and
-2 above provided that you also do one of the following:
-
- a) Accompany it with the complete corresponding machine-readable source
- code, which must be distributed under the terms of Sections 1 and 2 above
- on a medium customarily used for software interchange; or,
-
- b) Accompany it with a written offer, valid for at least three years, to
- give any third party, for a charge no more than your cost of physically
- performing source distribution, a complete machine-readable copy of the
- corresponding source code, to be distributed under the terms of Sections 1
- and 2 above on a medium customarily used for software interchange; or,
-
- c) Accompany it with the information you received as to the offer to
- distribute corresponding source code. (This alternative is allowed only
- for noncommercial distribution and only if you received the program in
- object code or executable form with such an offer, in accord with
- Subsection b above.)
-
-The source code for a work means the preferred form of the work for making
-modifications to it. For an executable work, complete source code means all
-the source code for all modules it contains, plus any associated interface
-definition files, plus the scripts used to control compilation and installation
-of the executable. However, as a special exception, the source code
-distributed need not include anything that is normally distributed (in either
-source or binary form) with the major components (compiler, kernel, and so on)
-of the operating system on which the executable runs, unless that component
-itself accompanies the executable.
-
-If distribution of executable or object code is made by offering access to copy
-from a designated place, then offering equivalent access to copy the source
-code from the same place counts as distribution of the source code, even though
-third parties are not compelled to copy the source along with the object code.
-
-4. You may not copy, modify, sublicense, or distribute the Program except as
-expressly provided under this License. Any attempt otherwise to copy, modify,
-sublicense or distribute the Program is void, and will automatically terminate
-your rights under this License. However, parties who have received copies, or
-rights, from you under this License will not have their licenses terminated so
-long as such parties remain in full compliance.
-
-5. You are not required to accept this License, since you have not signed it.
-However, nothing else grants you permission to modify or distribute the Program
-or its derivative works. These actions are prohibited by law if you do not
-accept this License. Therefore, by modifying or distributing the Program (or
-any work based on the Program), you indicate your acceptance of this License to
-do so, and all its terms and conditions for copying, distributing or modifying
-the Program or works based on it.
-
-6. Each time you redistribute the Program (or any work based on the Program),
-the recipient automatically receives a license from the original licensor to
-copy, distribute or modify the Program subject to these terms and conditions.
-You may not impose any further restrictions on the recipients' exercise of the
-rights granted herein. You are not responsible for enforcing compliance by
-third parties to this License.
-
-7. If, as a consequence of a court judgment or allegation of patent
-infringement or for any other reason (not limited to patent issues), conditions
-are imposed on you (whether by court order, agreement or otherwise) that
-contradict the conditions of this License, they do not excuse you from the
-conditions of this License. If you cannot distribute so as to satisfy
-simultaneously your obligations under this License and any other pertinent
-obligations, then as a consequence you may not distribute the Program at all.
-For example, if a patent license would not permit royalty-free redistribution
-of the Program by all those who receive copies directly or indirectly through
-you, then the only way you could satisfy both it and this License would be to
-refrain entirely from distribution of the Program.
-
-If any portion of this section is held invalid or unenforceable under any
-particular circumstance, the balance of the section is intended to apply and
-the section as a whole is intended to apply in other circumstances.
-
-It is not the purpose of this section to induce you to infringe any patents or
-other property right claims or to contest validity of any such claims; this
-section has the sole purpose of protecting the integrity of the free software
-distribution system, which is implemented by public license practices. Many
-people have made generous contributions to the wide range of software
-distributed through that system in reliance on consistent application of that
-system; it is up to the author/donor to decide if he or she is willing to
-distribute software through any other system and a licensee cannot impose that
-choice.
-
-This section is intended to make thoroughly clear what is believed to be a
-consequence of the rest of this License.
-
-8. If the distribution and/or use of the Program is restricted in certain
-countries either by patents or by copyrighted interfaces, the original
-copyright holder who places the Program under this License may add an explicit
-geographical distribution limitation excluding those countries, so that
-distribution is permitted only in or among countries not thus excluded. In
-such case, this License incorporates the limitation as if written in the body
-of this License.
-
-9. The Free Software Foundation may publish revised and/or new versions of the
-General Public License from time to time. Such new versions will be similar in
-spirit to the present version, but may differ in detail to address new problems
-or concerns.
-
-Each version is given a distinguishing version number. If the Program
-specifies a version number of this License which applies to it and "any later
-version", you have the option of following the terms and conditions either of
-that version or of any later version published by the Free Software Foundation.
-If the Program does not specify a version number of this License, you may
-choose any version ever published by the Free Software Foundation.
-
-10. If you wish to incorporate parts of the Program into other free programs
-whose distribution conditions are different, write to the author to ask for
-permission. For software which is copyrighted by the Free Software Foundation,
-write to the Free Software Foundation; we sometimes make exceptions for this.
-Our decision will be guided by the two goals of preserving the free status of
-all derivatives of our free software and of promoting the sharing and reuse of
-software generally.
-
-NO WARRANTY
-
-11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY FOR
-THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE
-STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE
-PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED,
-INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
-FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND
-PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE,
-YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
-
-12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL
-ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR REDISTRIBUTE THE
-PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
-GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR
-INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA
-BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A
-FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER
-OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
-
-END OF TERMS AND CONDITIONS
-
-How to Apply These Terms to Your New Programs
-
-If you develop a new program, and you want it to be of the greatest possible
-use to the public, the best way to achieve this is to make it free software
-which everyone can redistribute and change under these terms.
-
-To do so, attach the following notices to the program. It is safest to attach
-them to the start of each source file to most effectively convey the exclusion
-of warranty; and each file should have at least the "copyright" line and a
-pointer to where the full notice is found.
-
- One line to give the program's name and a brief idea of what it does.
-
- Copyright (C)
-
- This program is free software; you can redistribute it and/or modify it
- under the terms of the GNU General Public License as published by the Free
- Software Foundation; either version 2 of the License, or (at your option)
- any later version.
-
- This program is distributed in the hope that it will be useful, but WITHOUT
- ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
- FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
- more details.
-
- You should have received a copy of the GNU General Public License along
- with this program; if not, write to the Free Software Foundation, Inc., 59
- Temple Place, Suite 330, Boston, MA 02111-1307 USA
-
-Also add information on how to contact you by electronic and paper mail.
-
-If the program is interactive, make it output a short notice like this when it
-starts in an interactive mode:
-
- Gnomovision version 69, Copyright (C) year name of author Gnomovision comes
- with ABSOLUTELY NO WARRANTY; for details type 'show w'. This is free
- software, and you are welcome to redistribute it under certain conditions;
- type 'show c' for details.
-
-The hypothetical commands 'show w' and 'show c' should show the appropriate
-parts of the General Public License. Of course, the commands you use may be
-called something other than 'show w' and 'show c'; they could even be
-mouse-clicks or menu items--whatever suits your program.
-
-You should also get your employer (if you work as a programmer) or your school,
-if any, to sign a "copyright disclaimer" for the program, if necessary. Here
-is a sample; alter the names:
-
- Yoyodyne, Inc., hereby disclaims all copyright interest in the program
- 'Gnomovision' (which makes passes at compilers) written by James Hacker.
-
- signature of Ty Coon, 1 April 1989
-
- Ty Coon, President of Vice
-
-This General Public License does not permit incorporating your program into
-proprietary programs. If your program is a subroutine library, you may
-consider it more useful to permit linking proprietary applications with the
-library. If this is what you want to do, use the GNU Library General Public
-License instead of this License.
-
-
-"CLASSPATH" EXCEPTION TO THE GPL
-
-Certain source files distributed by Oracle America and/or its affiliates are
-subject to the following clarification and special exception to the GPL, but
-only where Oracle has expressly included in the particular source file's header
-the words "Oracle designates this particular file as subject to the "Classpath"
-exception as provided by Oracle in the LICENSE file that accompanied this code."
-
- Linking this library statically or dynamically with other modules is making
- a combined work based on this library. Thus, the terms and conditions of
- the GNU General Public License cover the whole combination.
-
- As a special exception, the copyright holders of this library give you
- permission to link this library with independent modules to produce an
- executable, regardless of the license terms of these independent modules,
- and to copy and distribute the resulting executable under terms of your
- choice, provided that you also meet, for each linked independent module,
- the terms and conditions of the license of that module. An independent
- module is a module which is not derived from or based on this library. If
- you modify this library, you may extend this exception to your version of
- the library, but you are not obligated to do so. If you do not wish to do
- so, delete this exception statement from your version.
-
-===========================================================================
-
-MIT License:
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in
-all copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-THE SOFTWARE.
-
-===========================================================================
\ No newline at end of file
diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml
index d0d310b9371df..5b2e4512e5f23 100644
--- a/mllib-local/pom.xml
+++ b/mllib-local/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../pom.xml
diff --git a/mllib/pom.xml b/mllib/pom.xml
index 0d87640c6b47d..80bf3b4053fe6 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../pom.xml
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index 0e1f64cc7b630..84ced24414379 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -106,7 +106,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
}
-object EstimatorUtils {
+private[spark] object EstimatorUtils {
// This warningMessagesBuffer is for collecting warning messages during `estimator.fit`
// execution in Spark Connect server.
private[spark] val warningMessagesBuffer = new java.lang.ThreadLocal[ArrayBuffer[String]]() {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 8902d12bdf94d..887d8277d3117 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -195,7 +195,7 @@ class DecisionTreeClassificationModel private[ml] (
// For ml connect only
private[ml] def this() = this("", Node.dummyNode, -1, -1)
- override def estimatedSize: Long = getEstimatedSize()
+ private[spark] override def estimatedSize: Long = getEstimatedSize()
override def predict(features: Vector): Double = {
rootNode.predictImpl(features).prediction
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
index b653383161e74..29ca909f79302 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala
@@ -314,7 +314,7 @@ class FMClassificationModel private[classification] (
copyValues(new FMClassificationModel(uid, intercept, linear, factors), extra)
}
- override def estimatedSize: Long = {
+ private[spark] override def estimatedSize: Long = {
var size = this.estimateMatadataSize
if (this.linear != null) {
size += this.linear.getSizeInBytes
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 9ca3a16609586..2c5c7e7740a33 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -276,7 +276,7 @@ class GBTClassificationModel private[ml](
private[ml] def this() = this("",
Array(new DecisionTreeRegressionModel), Array(Double.NaN), -1, -1)
- override def estimatedSize: Long = getEstimatedSize()
+ private[spark] override def estimatedSize: Long = getEstimatedSize()
@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 8b580b1e075c5..fb61358536d0c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -240,7 +240,7 @@ class RandomForestClassificationModel private[ml] (
// For ml connect only
private[ml] def this() = this("", Array(new DecisionTreeClassificationModel), -1, -1)
- override def estimatedSize: Long = getEstimatedSize()
+ private[spark] override def estimatedSize: Long = getEstimatedSize()
@Since("1.4.0")
override def trees: Array[DecisionTreeClassificationModel] = _trees
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index 9e09ee00c3e30..c129da5f0d7ee 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -179,7 +179,8 @@ class BisectingKMeansModel private[ml] (
@Since("2.1.0")
override def summary: BisectingKMeansSummary = super.summary
- override def estimatedSize: Long = SizeEstimator.estimate(parentModel)
+ private[spark] override def estimatedSize: Long =
+ SizeEstimator.estimate(parentModel)
// BisectingKMeans model hasn't supported offloading, so put an empty `saveSummary` here for now
override private[spark] def saveSummary(path: String): Unit = {}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 5d3e36be28082..ddcca167ff30b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -222,7 +222,8 @@ class GaussianMixtureModel private[ml] (
@Since("2.0.0")
override def summary: GaussianMixtureSummary = super.summary
- override def estimatedSize: Long = SizeEstimator.estimate((weights, gaussians))
+ private[spark] override def estimatedSize: Long =
+ SizeEstimator.estimate((weights, gaussians))
private[spark] def createSummary(
predictions: DataFrame, logLikelihood: Double, iteration: Int
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 2abd82c712960..ad6ca0924064b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -213,7 +213,8 @@ class KMeansModel private[ml] (
@Since("2.0.0")
override def summary: KMeansSummary = super.summary
- override def estimatedSize: Long = SizeEstimator.estimate(parentModel.clusterCenters)
+ private[spark] override def estimatedSize: Long =
+ SizeEstimator.estimate(parentModel.clusterCenters)
private[spark] def createSummary(
predictions: DataFrame, numIter: Int, trainingCost: Double
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 67c9a8f58dd26..c64d3a98c0a92 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -826,7 +826,7 @@ class DistributedLDAModel private[ml] (
s"DistributedLDAModel: uid=$uid, k=${$(k)}, numFeatures=$vocabSize"
}
- override def estimatedSize: Long = {
+ private[spark] override def estimatedSize: Long = {
// TODO: Implement this method.
throw new UnsupportedOperationException
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index 6fd20ceb562b1..acd4635c5bbf2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -323,7 +323,7 @@ class FPGrowthModel private[ml] (
s"FPGrowthModel: uid=$uid, numTrainingRecords=$numTrainingRecords"
}
- override def estimatedSize: Long = {
+ private[spark] override def estimatedSize: Long = {
// TODO: Implement this method.
throw new UnsupportedOperationException
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 538ad03820754..1cee915046c07 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -540,7 +540,7 @@ class ALSModel private[ml] (
}
}
- override def estimatedSize: Long = {
+ private[spark] override def estimatedSize: Long = {
val userCount = userFactors.count()
val itemCount = itemFactors.count()
(userCount + itemCount) * (rank + 1) * 4
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index f049e9a44cc28..5387e0e282a3d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -190,7 +190,7 @@ class DecisionTreeRegressionModel private[ml] (
// For ml connect only
private[ml] def this() = this("", Node.dummyNode, -1)
- override def estimatedSize: Long = getEstimatedSize()
+ private[spark] override def estimatedSize: Long = getEstimatedSize()
override def predict(features: Vector): Double = {
rootNode.predictImpl(features).prediction
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
index 1b624895c7f31..d2fcb9280c631 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/FMRegressor.scala
@@ -443,7 +443,7 @@ class FMRegressor @Since("3.0.0") (
@Since("3.0.0")
override def copy(extra: ParamMap): FMRegressor = defaultCopy(extra)
- override def estimateModelSize(dataset: Dataset[_]): Long = {
+ private[spark] override def estimateModelSize(dataset: Dataset[_]): Long = {
val numFeatures = DatasetUtils.getNumFeatures(dataset, $(featuresCol))
var size = this.estimateMatadataSize
@@ -488,7 +488,7 @@ class FMRegressionModel private[regression] (
copyValues(new FMRegressionModel(uid, intercept, linear, factors), extra)
}
- override def estimatedSize: Long = {
+ private[spark] override def estimatedSize: Long = {
var size = this.estimateMatadataSize
if (this.linear != null) {
size += this.linear.getSizeInBytes
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index c8fa97bfccce0..71436036d1ea6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -245,7 +245,7 @@ class GBTRegressionModel private[ml](
// For ml connect only
private[ml] def this() = this("", Array(new DecisionTreeRegressionModel), Array(Double.NaN), -1)
- override def estimatedSize: Long = getEstimatedSize()
+ private[spark] override def estimatedSize: Long = getEstimatedSize()
@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index a9e2c47a3229a..8d9b4817833bc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -215,7 +215,7 @@ class RandomForestRegressionModel private[ml] (
// For ml connect only
private[ml] def this() = this("", Array(new DecisionTreeRegressionModel), -1)
- override def estimatedSize: Long = getEstimatedSize()
+ private[spark] override def estimatedSize: Long = getEstimatedSize()
@Since("1.4.0")
override def trees: Array[DecisionTreeRegressionModel] = _trees
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
index b20b2e943deeb..4e9fa89cbde90 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -175,7 +175,7 @@ private[spark] trait TreeEnsembleModel[M <: DecisionTreeModel] {
new AttributeGroup(leafCol, attrs = trees.map(_.leafAttr)).toStructField()
}
- def getEstimatedSize(): Long = {
+ private[ml] def getEstimatedSize(): Long = {
org.apache.spark.util.SizeEstimator.estimate(trees.map(_.rootNode))
}
}
diff --git a/pom.xml b/pom.xml
index 812ffa89d4fc2..c48218cb2e74d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -26,7 +26,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
pom
Spark Project Parent POM
https://spark.apache.org/
@@ -142,7 +142,7 @@
10.16.1.1
1.16.0
- 2.2.1
+ 2.2.2
shaded-protobuf
11.0.26
5.0.0
@@ -156,7 +156,7 @@
If you change codahale.metrics.version, you also need to change
the link to metrics.dropwizard.io in docs/monitoring.md.
-->
- 4.2.33
+ 4.2.37
1.12.1
1.15.3
@@ -190,7 +190,7 @@
3.0.4
1.19.0
1.28.0
- 2.20.0
+ 2.21.0
2.6
@@ -198,7 +198,8 @@
2.12.1
4.1.17
- 33.4.0-jre
+ 33.4.8-jre
+ 1.0.3
2.11.0
3.1.9
3.0.18
@@ -303,11 +304,8 @@
true
- 33.4.0-jre
- 1.0.2
- 1.67.1
+ 1.76.0
1.1.4
- 6.0.53
4.0-10
@@ -344,7 +342,7 @@
9.2.0
42.7.7
11.5.9.0
- 12.8.1.jre11
+ 13.2.1.jre11
23.6.0.24.10
2.7.1
3.26.1
@@ -611,6 +609,12 @@
${guava.version}
provided
+
+ com.google.guava
+ failureaccess
+ ${guava.failureaccess.version}
+ provided
+
org.jpmml
pmml-model
@@ -2334,7 +2338,7 @@
io.airlift
aircompressor
- 2.0.2
+ 2.0.3
org.apache.orc
@@ -2519,6 +2523,25 @@
+
+ org.apache.arrow
+ arrow-compression
+ ${arrow.version}
+
+
+ com.fasterxml.jackson.core
+ jackson-annotations
+
+
+ com.fasterxml.jackson.core
+ jackson-core
+
+
+ io.netty
+ netty-common
+
+
+
org.apache.arrow
arrow-memory-netty
@@ -3129,6 +3152,10 @@
com.google.common
${spark.shade.packageName}.guava
+
+ com.google.thirdparty
+ ${spark.shade.packageName}.guava.thirdparty
+
org.dmg.pmml
${spark.shade.packageName}.dmg.pmml
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 1614ec212c2e8..2e4598810451c 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -48,7 +48,10 @@ object MimaExcludes {
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.Dataset.repartitionById"),
// [SPARK-54001][CONNECT] Replace block copying with ref-counting in ArtifactManager cloning
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.artifact.ArtifactManager.cachedBlockIdList")
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.artifact.ArtifactManager.cachedBlockIdList"),
+
+ // [SPARK-54323][PYTHON] Change the way to access logs to TVF instead of system view
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.TableValuedFunction.python_worker_logs")
)
// Default exclude rules
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index d852b4155bea9..bb505d390e5ee 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -31,6 +31,7 @@ import sbt.Classpaths.publishOrSkip
import sbt.Keys._
import sbt.librarymanagement.{ VersionNumber, SemanticSelector }
import com.etsy.sbt.checkstyle.CheckstylePlugin.autoImport._
+import com.github.sbt.junit.jupiter.sbt.JupiterPlugin.autoImport._
import com.here.bom.Bom
import com.simplytyped.Antlr4Plugin._
import sbtpomreader.{PomBuild, SbtPomKeys}
@@ -366,6 +367,7 @@ object SparkBuild extends PomBuild {
"org.apache.spark.kafka010",
"org.apache.spark.network",
"org.apache.spark.sql.avro",
+ "org.apache.spark.sql.pipelines",
"org.apache.spark.sql.scripting",
"org.apache.spark.types.variant",
"org.apache.spark.ui.flamegraph",
@@ -392,7 +394,8 @@ object SparkBuild extends PomBuild {
/* Enable shared settings on all projects */
(allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ copyJarsProjects ++ Seq(spark, tools))
.foreach(enable(sharedSettings ++ DependencyOverrides.settings ++
- ExcludedDependencies.settings ++ Checkstyle.settings ++ ExcludeShims.settings))
+ ExcludedDependencies.settings ++ (if (noLintOnCompile) Nil else Checkstyle.settings) ++
+ ExcludeShims.settings))
/* Enable tests settings for all projects except examples, assembly and tools */
(allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings))
@@ -644,8 +647,10 @@ object Core {
"com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf"
)
},
+ // Use Maven's output directory so sbt and Maven can share generated sources.
+ // Core uses protoc-jar-maven-plugin which outputs to target/generated-sources.
(Compile / PB.targets) := Seq(
- PB.gens.java -> (Compile / sourceManaged).value
+ PB.gens.java -> target.value / "generated-sources"
)
) ++ {
val sparkProtocExecPath = sys.props.get("spark.protoc.executable.path")
@@ -672,7 +677,7 @@ object SparkConnectCommon {
libraryDependencies ++= {
val guavaVersion =
SbtPomKeys.effectivePom.value.getProperties.get(
- "connect.guava.version").asInstanceOf[String]
+ "guava.version").asInstanceOf[String]
val guavaFailureaccessVersion =
SbtPomKeys.effectivePom.value.getProperties.get(
"guava.failureaccess.version").asInstanceOf[String]
@@ -690,7 +695,7 @@ object SparkConnectCommon {
dependencyOverrides ++= {
val guavaVersion =
SbtPomKeys.effectivePom.value.getProperties.get(
- "connect.guava.version").asInstanceOf[String]
+ "guava.version").asInstanceOf[String]
val guavaFailureaccessVersion =
SbtPomKeys.effectivePom.value.getProperties.get(
"guava.failureaccess.version").asInstanceOf[String]
@@ -730,19 +735,20 @@ object SparkConnectCommon {
) ++ {
val sparkProtocExecPath = sys.props.get("spark.protoc.executable.path")
val connectPluginExecPath = sys.props.get("connect.plugin.executable.path")
+ // Use Maven's output directory so sbt and Maven can share generated sources
if (sparkProtocExecPath.isDefined && connectPluginExecPath.isDefined) {
Seq(
(Compile / PB.targets) := Seq(
- PB.gens.java -> (Compile / sourceManaged).value,
- PB.gens.plugin(name = "grpc-java", path = connectPluginExecPath.get) -> (Compile / sourceManaged).value
+ PB.gens.java -> target.value / "generated-sources" / "protobuf" / "java",
+ PB.gens.plugin(name = "grpc-java", path = connectPluginExecPath.get) -> target.value / "generated-sources" / "protobuf" / "grpc-java"
),
PB.protocExecutable := file(sparkProtocExecPath.get)
)
} else {
Seq(
(Compile / PB.targets) := Seq(
- PB.gens.java -> (Compile / sourceManaged).value,
- PB.gens.plugin("grpc-java") -> (Compile / sourceManaged).value
+ PB.gens.java -> target.value / "generated-sources" / "protobuf" / "java",
+ PB.gens.plugin("grpc-java") -> target.value / "generated-sources" / "protobuf" / "grpc-java"
)
)
}
@@ -758,7 +764,7 @@ object SparkConnect {
libraryDependencies ++= {
val guavaVersion =
SbtPomKeys.effectivePom.value.getProperties.get(
- "connect.guava.version").asInstanceOf[String]
+ "guava.version").asInstanceOf[String]
val guavaFailureaccessVersion =
SbtPomKeys.effectivePom.value.getProperties.get(
"guava.failureaccess.version").asInstanceOf[String]
@@ -772,7 +778,7 @@ object SparkConnect {
dependencyOverrides ++= {
val guavaVersion =
SbtPomKeys.effectivePom.value.getProperties.get(
- "connect.guava.version").asInstanceOf[String]
+ "guava.version").asInstanceOf[String]
val guavaFailureaccessVersion =
SbtPomKeys.effectivePom.value.getProperties.get(
"guava.failureaccess.version").asInstanceOf[String]
@@ -790,41 +796,35 @@ object SparkConnect {
// Exclude `scala-library` from assembly.
(assembly / assemblyPackageScala / assembleArtifact) := false,
- // SPARK-46733: Include `spark-connect-*.jar`, `unused-*.jar`,`guava-*.jar`,
- // `failureaccess-*.jar`, `annotations-*.jar`, `grpc-*.jar`, `protobuf-*.jar`,
- // `gson-*.jar`, `error_prone_annotations-*.jar`, `j2objc-annotations-*.jar`,
- // `animal-sniffer-annotations-*.jar`, `perfmark-api-*.jar`,
- // `proto-google-common-protos-*.jar` in assembly.
+ // SPARK-46733: Include `spark-connect-*.jar`, `unused-*.jar`, `annotations-*.jar`,
+ // `grpc-*.jar`, `protobuf-*.jar`, `gson-*.jar`, `animal-sniffer-annotations-*.jar`,
+ // `perfmark-api-*.jar`, `proto-google-common-protos-*.jar` in assembly.
// This needs to be consistent with the content of `maven-shade-plugin`.
(assembly / assemblyExcludedJars) := {
val cp = (assembly / fullClasspath).value
- val validPrefixes = Set("spark-connect", "unused-", "guava-", "failureaccess-",
- "annotations-", "grpc-", "protobuf-", "gson", "error_prone_annotations",
- "j2objc-annotations", "animal-sniffer-annotations", "perfmark-api",
- "proto-google-common-protos")
+ val validPrefixes = Set("spark-connect", "unused-", "annotations-",
+ "grpc-", "protobuf-", "gson", "animal-sniffer-annotations",
+ "perfmark-api", "proto-google-common-protos")
cp filterNot { v =>
validPrefixes.exists(v.data.getName.startsWith)
}
},
(assembly / assemblyShadeRules) := Seq(
- ShadeRule.rename("io.grpc.**" -> "org.sparkproject.connect.grpc.@0").inAll,
- ShadeRule.rename("com.google.common.**" -> "org.sparkproject.connect.guava.@1").inAll,
- ShadeRule.rename("com.google.thirdparty.**" -> "org.sparkproject.connect.guava.@1").inAll,
+ ShadeRule.rename("io.grpc.**" -> "org.sparkproject.connect.grpc.@1").inAll,
ShadeRule.rename("com.google.protobuf.**" -> "org.sparkproject.connect.protobuf.@1").inAll,
ShadeRule.rename("android.annotation.**" -> "org.sparkproject.connect.android_annotation.@1").inAll,
ShadeRule.rename("io.perfmark.**" -> "org.sparkproject.connect.io_perfmark.@1").inAll,
ShadeRule.rename("org.codehaus.mojo.animal_sniffer.**" -> "org.sparkproject.connect.animal_sniffer.@1").inAll,
- ShadeRule.rename("com.google.j2objc.annotations.**" -> "org.sparkproject.connect.j2objc_annotations.@1").inAll,
- ShadeRule.rename("com.google.errorprone.annotations.**" -> "org.sparkproject.connect.errorprone_annotations.@1").inAll,
- ShadeRule.rename("org.checkerframework.**" -> "org.sparkproject.connect.checkerframework.@1").inAll,
ShadeRule.rename("com.google.gson.**" -> "org.sparkproject.connect.gson.@1").inAll,
ShadeRule.rename("com.google.api.**" -> "org.sparkproject.connect.google_protos.api.@1").inAll,
+ ShadeRule.rename("com.google.apps.**" -> "org.sparkproject.connect.google_protos.apps.@1").inAll,
ShadeRule.rename("com.google.cloud.**" -> "org.sparkproject.connect.google_protos.cloud.@1").inAll,
ShadeRule.rename("com.google.geo.**" -> "org.sparkproject.connect.google_protos.geo.@1").inAll,
ShadeRule.rename("com.google.logging.**" -> "org.sparkproject.connect.google_protos.logging.@1").inAll,
ShadeRule.rename("com.google.longrunning.**" -> "org.sparkproject.connect.google_protos.longrunning.@1").inAll,
ShadeRule.rename("com.google.rpc.**" -> "org.sparkproject.connect.google_protos.rpc.@1").inAll,
+ ShadeRule.rename("com.google.shopping.**" -> "org.sparkproject.connect.google_protos.shopping.@1").inAll,
ShadeRule.rename("com.google.type.**" -> "org.sparkproject.connect.google_protos.type.@1").inAll
),
@@ -849,7 +849,7 @@ object SparkConnectJdbc {
libraryDependencies ++= {
val guavaVersion =
SbtPomKeys.effectivePom.value.getProperties.get(
- "connect.guava.version").asInstanceOf[String]
+ "guava.version").asInstanceOf[String]
Seq(
"com.google.guava" % "guava" % guavaVersion,
"com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf"
@@ -858,7 +858,7 @@ object SparkConnectJdbc {
dependencyOverrides ++= {
val guavaVersion =
SbtPomKeys.effectivePom.value.getProperties.get(
- "connect.guava.version").asInstanceOf[String]
+ "guava.version").asInstanceOf[String]
Seq(
"com.google.guava" % "guava" % guavaVersion,
"com.google.protobuf" % "protobuf-java" % protoVersion
@@ -886,14 +886,17 @@ object SparkConnectJdbc {
// Exclude `scala-library` from assembly.
(assembly / assemblyPackageScala / assembleArtifact) := false,
- // Exclude `pmml-model-*.jar`, `scala-collection-compat_*.jar`,`jsr305-*.jar` and
- // `netty-*.jar` and `unused-1.0.0.jar` from assembly.
+ // Exclude `pmml-model-*.jar`, `scala-collection-compat_*.jar`, `jspecify-*.jar`,
+ // `error_prone_annotations-*.jar`, `listenablefuture-9999.0-empty-to-avoid-conflict-with-guava.jar`,
+ // `j2objc-annotations-*.jar` and `unused-1.0.0.jar` from assembly.
(assembly / assemblyExcludedJars) := {
val cp = (assembly / fullClasspath).value
cp filter { v =>
val name = v.data.getName
name.startsWith("pmml-model-") || name.startsWith("scala-collection-compat_") ||
- name.startsWith("jsr305-") || name == "unused-1.0.0.jar"
+ name.startsWith("jspecify-") || name.startsWith("error_prone_annotations") ||
+ name.startsWith("listenablefuture") || name.startsWith("j2objc-annotations") ||
+ name == "unused-1.0.0.jar"
}
},
// Only include `spark-connect-client-jdbc-*.jar`
@@ -910,8 +913,6 @@ object SparkConnectJdbc {
ShadeRule.rename("io.grpc.**" -> "org.sparkproject.connect.client.io.grpc.@1").inAll,
ShadeRule.rename("com.google.**" -> "org.sparkproject.connect.client.com.google.@1").inAll,
ShadeRule.rename("io.netty.**" -> "org.sparkproject.connect.client.io.netty.@1").inAll,
- ShadeRule.rename("org.checkerframework.**" -> "org.sparkproject.connect.client.org.checkerframework.@1").inAll,
- ShadeRule.rename("javax.annotation.**" -> "org.sparkproject.connect.client.javax.annotation.@1").inAll,
ShadeRule.rename("io.perfmark.**" -> "org.sparkproject.connect.client.io.perfmark.@1").inAll,
ShadeRule.rename("org.codehaus.**" -> "org.sparkproject.connect.client.org.codehaus.@1").inAll,
ShadeRule.rename("android.annotation.**" -> "org.sparkproject.connect.client.android.annotation.@1").inAll
@@ -938,7 +939,7 @@ object SparkConnectClient {
libraryDependencies ++= {
val guavaVersion =
SbtPomKeys.effectivePom.value.getProperties.get(
- "connect.guava.version").asInstanceOf[String]
+ "guava.version").asInstanceOf[String]
Seq(
"com.google.guava" % "guava" % guavaVersion,
"com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf"
@@ -947,7 +948,7 @@ object SparkConnectClient {
dependencyOverrides ++= {
val guavaVersion =
SbtPomKeys.effectivePom.value.getProperties.get(
- "connect.guava.version").asInstanceOf[String]
+ "guava.version").asInstanceOf[String]
Seq(
"com.google.guava" % "guava" % guavaVersion,
"com.google.protobuf" % "protobuf-java" % protoVersion
@@ -975,14 +976,17 @@ object SparkConnectClient {
// Exclude `scala-library` from assembly.
(assembly / assemblyPackageScala / assembleArtifact) := false,
- // Exclude `pmml-model-*.jar`, `scala-collection-compat_*.jar`,`jsr305-*.jar` and
- // `netty-*.jar` and `unused-1.0.0.jar` from assembly.
+ // Exclude `pmml-model-*.jar`, `scala-collection-compat_*.jar`, `jspecify-*.jar`,
+ // `error_prone_annotations-*.jar`, `listenablefuture-9999.0-empty-to-avoid-conflict-with-guava.jar`,
+ // `j2objc-annotations-*.jar` and `unused-1.0.0.jar` from assembly.
(assembly / assemblyExcludedJars) := {
val cp = (assembly / fullClasspath).value
cp filter { v =>
val name = v.data.getName
name.startsWith("pmml-model-") || name.startsWith("scala-collection-compat_") ||
- name.startsWith("jsr305-") || name == "unused-1.0.0.jar"
+ name.startsWith("jspecify-") || name.startsWith("error_prone_annotations") ||
+ name.startsWith("listenablefuture") || name.startsWith("j2objc-annotations") ||
+ name == "unused-1.0.0.jar"
}
},
@@ -990,8 +994,6 @@ object SparkConnectClient {
ShadeRule.rename("io.grpc.**" -> "org.sparkproject.connect.client.io.grpc.@1").inAll,
ShadeRule.rename("com.google.**" -> "org.sparkproject.connect.client.com.google.@1").inAll,
ShadeRule.rename("io.netty.**" -> "org.sparkproject.connect.client.io.netty.@1").inAll,
- ShadeRule.rename("org.checkerframework.**" -> "org.sparkproject.connect.client.org.checkerframework.@1").inAll,
- ShadeRule.rename("javax.annotation.**" -> "org.sparkproject.connect.client.javax.annotation.@1").inAll,
ShadeRule.rename("io.perfmark.**" -> "org.sparkproject.connect.client.io.perfmark.@1").inAll,
ShadeRule.rename("org.codehaus.**" -> "org.sparkproject.connect.client.org.codehaus.@1").inAll,
ShadeRule.rename("android.annotation.**" -> "org.sparkproject.connect.client.android.annotation.@1").inAll
@@ -1224,8 +1226,15 @@ object ExcludedDependencies {
* client dependencies.
*/
object ExcludeShims {
+ import bloop.integrations.sbt.BloopKeys
+
val shimmedProjects = Set("spark-sql-api", "spark-connect-common", "spark-connect-client-jdbc", "spark-connect-client-jvm")
val classPathFilter = TaskKey[Classpath => Classpath]("filter for classpath")
+
+ // Filter for bloopInternalClasspath which is Seq[(File, File)]
+ type BloopClasspath = Seq[(java.io.File, java.io.File)]
+ val bloopClasspathFilter = TaskKey[BloopClasspath => BloopClasspath]("filter for bloop classpath")
+
lazy val settings = Seq(
classPathFilter := {
if (!shimmedProjects(moduleName.value)) {
@@ -1234,6 +1243,16 @@ object ExcludeShims {
identity _
}
},
+ bloopClasspathFilter := {
+ if (!shimmedProjects(moduleName.value)) {
+ // Note: bloop output directories use "connect-shims" (without "spark-" prefix)
+ cp => cp.filterNot { case (f1, f2) =>
+ f1.getPath.contains("connect-shims") || f2.getPath.contains("connect-shims")
+ }
+ } else {
+ identity _
+ }
+ },
Compile / internalDependencyClasspath :=
classPathFilter.value((Compile / internalDependencyClasspath).value),
Compile / internalDependencyAsJars :=
@@ -1246,6 +1265,13 @@ object ExcludeShims {
classPathFilter.value((Test / internalDependencyClasspath).value),
Test / internalDependencyAsJars :=
classPathFilter.value((Test / internalDependencyAsJars).value),
+ // Filter bloop's internal classpath for correct IDE integration
+ Compile / BloopKeys.bloopInternalClasspath :=
+ bloopClasspathFilter.value((Compile / BloopKeys.bloopInternalClasspath).value),
+ Runtime / BloopKeys.bloopInternalClasspath :=
+ bloopClasspathFilter.value((Runtime / BloopKeys.bloopInternalClasspath).value),
+ Test / BloopKeys.bloopInternalClasspath :=
+ bloopClasspathFilter.value((Test / BloopKeys.bloopInternalClasspath).value),
)
}
@@ -1285,7 +1311,9 @@ object SqlApi {
(Antlr4 / antlr4PackageName) := Some("org.apache.spark.sql.catalyst.parser"),
(Antlr4 / antlr4GenListener) := true,
(Antlr4 / antlr4GenVisitor) := true,
- (Antlr4 / antlr4TreatWarningsAsErrors) := true
+ (Antlr4 / antlr4TreatWarningsAsErrors) := true,
+ // Use Maven's output directory so sbt and Maven can share generated sources
+ (Antlr4 / javaSource) := target.value / "generated-sources" / "antlr4"
)
}
@@ -1302,8 +1330,10 @@ object SQL {
"com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf"
)
},
+ // Use Maven's output directory so sbt and Maven can share generated sources.
+ // sql/core uses protoc-jar-maven-plugin which outputs to target/generated-sources.
(Compile / PB.targets) := Seq(
- PB.gens.java -> (Compile / sourceManaged).value
+ PB.gens.java -> target.value / "generated-sources"
)
) ++ {
val sparkProtocExecPath = sys.props.get("spark.protoc.executable.path")
@@ -1533,6 +1563,7 @@ object Unidoc {
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/classic/")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/execution")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/internal")))
+ .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/pipelines")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/scripting")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/ml")))
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive")))
@@ -1872,25 +1903,25 @@ object TestSettings {
sys.props.get("test.default.exclude.tags").map(tags => tags.split(",").toSeq)
.map(tags => tags.filter(!_.trim.isEmpty)).getOrElse(defaultExcludedTags)
.flatMap(tag => Seq("-l", tag)): _*),
- (Test / testOptions) += Tests.Argument(TestFrameworks.JUnit,
+ (Test / testOptions) += Tests.Argument(jupiterTestFramework,
sys.props.get("test.exclude.tags").map { tags =>
- Seq("--exclude-categories=" + tags)
+ Seq("--exclude-tag=" + tags)
}.getOrElse(Nil): _*),
// Include tags defined in a system property
(Test / testOptions) += Tests.Argument(TestFrameworks.ScalaTest,
sys.props.get("test.include.tags").map { tags =>
tags.split(",").flatMap { tag => Seq("-n", tag) }.toSeq
}.getOrElse(Nil): _*),
- (Test / testOptions) += Tests.Argument(TestFrameworks.JUnit,
+ (Test / testOptions) += Tests.Argument(jupiterTestFramework,
sys.props.get("test.include.tags").map { tags =>
- Seq("--include-categories=" + tags)
+ Seq("--include-tags=" + tags)
}.getOrElse(Nil): _*),
// Show full stack trace and duration in test cases.
(Test / testOptions) += Tests.Argument("-oDF"),
// Slowpoke notifications: receive notifications every 5 minute of tests that have been running
// longer than two minutes.
(Test / testOptions) += Tests.Argument(TestFrameworks.ScalaTest, "-W", "120", "300"),
- (Test / testOptions) += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"),
+ (Test / testOptions) += Tests.Argument(jupiterTestFramework, "-v", "-a"),
// Enable Junit testing.
libraryDependencies += "com.github.sbt.junit" % "jupiter-interface" % "0.17.0" % "test",
// `parallelExecutionInTest` controls whether test suites belonging to the same SBT project
diff --git a/project/plugins.sbt b/project/plugins.sbt
index fe18d16c48227..0b7e04222e306 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -46,3 +46,5 @@ addSbtPlugin("com.github.sbt.junit" % "sbt-jupiter-interface" % "0.17.0")
addSbtPlugin("com.thesamet" % "sbt-protoc" % "1.0.7")
addSbtPlugin("com.here.platform" % "sbt-bom" % "1.0.29")
+
+addSbtPlugin("ch.epfl.scala" % "sbt-bloop" % "2.0.17")
diff --git a/python/conf_viztracer/daemon_viztracer.py b/python/conf_viztracer/daemon_viztracer.py
new file mode 100644
index 0000000000000..2897c10bfe708
--- /dev/null
+++ b/python/conf_viztracer/daemon_viztracer.py
@@ -0,0 +1,48 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+import os
+import sys
+
+import viztracer
+from viztracer.main import main
+
+import pyspark.worker
+
+
+def viztracer_wrapper(func):
+
+ def wrapper(*args, **kwargs):
+ result = func(*args, **kwargs)
+ tracer = viztracer.get_tracer()
+ if tracer is not None:
+ tracer.exit_routine()
+ return result
+ return wrapper
+
+
+if __name__ == "__main__":
+ pyspark.worker.main = viztracer_wrapper(pyspark.worker.main)
+
+ if os.getenv("SPARK_VIZTRACER_OUTPUT_DIR") is not None:
+ output_dir = os.getenv("SPARK_VIZTRACER_OUTPUT_DIR")
+ else:
+ output_dir = "./"
+
+ sys.argv[:] = ["viztracer", "-m", "pyspark.daemon", "--quiet", "-u",
+ "--output_dir", output_dir]
+ main()
diff --git a/python/conf_viztracer/spark-defaults.conf b/python/conf_viztracer/spark-defaults.conf
new file mode 100644
index 0000000000000..ad84770ff6e0f
--- /dev/null
+++ b/python/conf_viztracer/spark-defaults.conf
@@ -0,0 +1,21 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Default system properties included when running spark-submit.
+# This is useful for setting default environmental settings.
+
+spark.python.daemon.module=daemon_viztracer
diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst
index 82db489651ff9..6b5a09205e4aa 100644
--- a/python/docs/source/getting_started/install.rst
+++ b/python/docs/source/getting_started/install.rst
@@ -227,9 +227,10 @@ Package Supported version Note
========================== ================= ==========================
`pandas` >=2.2.0 Required for Spark Connect
`pyarrow` >=15.0.0 Required for Spark Connect
-`grpcio` >=1.67.0 Required for Spark Connect
-`grpcio-status` >=1.67.0 Required for Spark Connect
-`googleapis-common-protos` >=1.65.0 Required for Spark Connect
+`grpcio` >=1.76.0 Required for Spark Connect
+`grpcio-status` >=1.76.0 Required for Spark Connect
+`googleapis-common-protos` >=1.71.0 Required for Spark Connect
+`zstandard` >=0.25.0 Required for Spark Connect
`graphviz` >=0.20 Optional for Spark Connect
========================== ================= ==========================
@@ -310,9 +311,10 @@ Package Supported version Note
========================== ================= ===================================================
`pandas` >=2.2.0 Required for Spark Connect and Spark SQL
`pyarrow` >=15.0.0 Required for Spark Connect and Spark SQL
-`grpcio` >=1.67.0 Required for Spark Connect
-`grpcio-status` >=1.67.0 Required for Spark Connect
-`googleapis-common-protos` >=1.65.0 Required for Spark Connect
+`grpcio` >=1.76.0 Required for Spark Connect
+`grpcio-status` >=1.76.0 Required for Spark Connect
+`googleapis-common-protos` >=1.71.0 Required for Spark Connect
+`zstandard` >=0.25.0 Required for Spark Connect
`pyyaml` >=3.11 Required for spark-pipelines command line interface
`graphviz` >=0.20 Optional for Spark Connect
========================== ================= ===================================================
diff --git a/python/docs/source/index.rst b/python/docs/source/index.rst
index c8d3fe62bf3b3..b412f2746bdd2 100644
--- a/python/docs/source/index.rst
+++ b/python/docs/source/index.rst
@@ -34,7 +34,7 @@ PySpark combines Python's learnability and ease of use with the power of Apache
to enable processing and analysis of data at any size for everyone familiar with Python.
PySpark supports all of Spark's features such as Spark SQL,
-DataFrames, Structured Streaming, Machine Learning (MLlib) and Spark Core.
+DataFrames, Structured Streaming, Machine Learning (MLlib), Pipelines and Spark Core.
.. list-table::
:widths: 10 80 10
@@ -151,6 +151,15 @@ learning pipelines.
- `Machine Learning Library (MLlib) Programming Guide `_
- :ref:`Machine Learning (MLlib) API Reference`
+**Declarative Pipelines**
+
+Spark Declarative Pipelines (SDP) is a declarative framework for building reliable,
+maintainable, and testable data pipelines on Spark. SDP simplifies ETL development by allowing
+you to focus on the transformations you want to apply to your data, rather than the mechanics
+of pipeline execution.
+
+- :ref:`Pipelines API Reference`
+
**Spark Core and RDDs**
Spark Core is the underlying general execution engine for the Spark platform that all
diff --git a/python/docs/source/reference/index.rst b/python/docs/source/reference/index.rst
index 0068c0b2322e2..11c180c4825e7 100644
--- a/python/docs/source/reference/index.rst
+++ b/python/docs/source/reference/index.rst
@@ -36,6 +36,7 @@ This page lists an overview of all public PySpark modules, classes, functions an
pyspark.streaming
pyspark.mllib
pyspark
+ pyspark.pipelines
pyspark.resource
pyspark.errors
pyspark.logger
diff --git a/python/docs/source/reference/pyspark.pipelines.rst b/python/docs/source/reference/pyspark.pipelines.rst
new file mode 100644
index 0000000000000..e3f7384334442
--- /dev/null
+++ b/python/docs/source/reference/pyspark.pipelines.rst
@@ -0,0 +1,33 @@
+.. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+.. http://www.apache.org/licenses/LICENSE-2.0
+
+.. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+==================
+PySpark Pipelines
+==================
+
+.. currentmodule:: pyspark.pipelines
+
+.. autosummary::
+ :toctree: api/
+
+ materialized_view
+ table
+ temporary_view
+ create_streaming_table
+ append_flow
+ create_sink
diff --git a/python/docs/source/reference/pyspark.sql/functions.rst b/python/docs/source/reference/pyspark.sql/functions.rst
index e4175707aecd7..9fcdac38e7d5a 100644
--- a/python/docs/source/reference/pyspark.sql/functions.rst
+++ b/python/docs/source/reference/pyspark.sql/functions.rst
@@ -458,6 +458,12 @@ Aggregate Functions
histogram_numeric
hll_sketch_agg
hll_union_agg
+ kll_sketch_agg_bigint
+ kll_sketch_agg_double
+ kll_sketch_agg_float
+ kll_merge_agg_bigint
+ kll_merge_agg_float
+ kll_merge_agg_double
kurtosis
last
last_value
@@ -631,6 +637,21 @@ Misc Functions
current_user
hll_sketch_estimate
hll_union
+ kll_sketch_get_n_bigint
+ kll_sketch_get_n_double
+ kll_sketch_get_n_float
+ kll_sketch_get_quantile_bigint
+ kll_sketch_get_quantile_double
+ kll_sketch_get_quantile_float
+ kll_sketch_get_rank_bigint
+ kll_sketch_get_rank_double
+ kll_sketch_get_rank_float
+ kll_sketch_merge_bigint
+ kll_sketch_merge_double
+ kll_sketch_merge_float
+ kll_sketch_to_string_bigint
+ kll_sketch_to_string_double
+ kll_sketch_to_string_float
input_file_block_length
input_file_block_start
input_file_name
@@ -652,6 +673,18 @@ Misc Functions
version
+Geospatial ST Functions
+-----------------------
+.. autosummary::
+ :toctree: api/
+
+ st_asbinary
+ st_geogfromwkb
+ st_geomfromwkb
+ st_setsrid
+ st_srid
+
+
UDF, UDTF and UDT
-----------------
.. autosummary::
@@ -681,6 +714,7 @@ Table-Valued Functions
TableValuedFunction.json_tuple
TableValuedFunction.posexplode
TableValuedFunction.posexplode_outer
+ TableValuedFunction.python_worker_logs
TableValuedFunction.range
TableValuedFunction.sql_keywords
TableValuedFunction.stack
diff --git a/python/docs/source/tutorial/sql/python_data_source.rst b/python/docs/source/tutorial/sql/python_data_source.rst
index 78ffeda0db1ce..1cc1811000ca4 100644
--- a/python/docs/source/tutorial/sql/python_data_source.rst
+++ b/python/docs/source/tutorial/sql/python_data_source.rst
@@ -305,7 +305,14 @@ This is the same dummy streaming reader that generate 2 rows every batch impleme
def read(self, start: dict) -> Tuple[Iterator[Tuple], dict]:
"""
- Takes start offset as an input, return an iterator of tuples and the start offset of next read.
+ Takes start offset as an input, return an iterator of tuples and
+ the end offset (start offset for the next read). The end offset must
+ advance past the start offset when returning data; otherwise Spark
+ raises a validation exception.
+ For example, returning 2 records from start_idx 0 means end should
+ be {"offset": 2} (i.e. start + 2).
+ When there is no data to read, you may return the same offset as end and
+ start, but you must provide an empty iterator.
"""
start_idx = start["offset"]
it = iter([(i,) for i in range(start_idx, start_idx + 2)])
diff --git a/python/docs/source/user_guide/bugbusting.ipynb b/python/docs/source/user_guide/bugbusting.ipynb
index 8e64bda1175c5..407bcbf18e9b3 100644
--- a/python/docs/source/user_guide/bugbusting.ipynb
+++ b/python/docs/source/user_guide/bugbusting.ipynb
@@ -792,7 +792,7 @@
"id": "09b420ba",
"metadata": {},
"source": [
- "## Disply Stacktraces"
+ "## Display Stacktraces"
]
},
{
@@ -900,6 +900,363 @@
"See also [Stack Traces](https://spark.apache.org/docs/latest/api/python/development/debugging.html#stack-traces) for more details."
]
},
+ {
+ "cell_type": "markdown",
+ "id": "cff22ba8",
+ "metadata": {},
+ "source": [
+ "## Python Worker Logging\n",
+ "\n",
+ "\n",
+ "Note: This section applies to Spark 4.1\n",
+ "
\n",
+ "\n",
+ "PySpark provides a logging mechanism for Python workers that execute UDFs, UDTFs, Pandas UDFs, and Python data sources. When enabled, all logging output (including `print` statements, standard logging, and exceptions) is captured and made available for querying and analysis.\n",
+ "\n",
+ "### Enabling Worker Logging\n",
+ "\n",
+ "Worker logging is **disabled by default**. Enable it by setting the Spark SQL configuration:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "74786d45",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "spark.conf.set(\"spark.sql.pyspark.worker.logging.enabled\", \"true\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0f23fee2",
+ "metadata": {},
+ "source": [
+ "### Accessing Logs\n",
+ "\n",
+ "All captured logs can be queried as a DataFrame:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "9db0c509",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "logs = spark.tvf.python_worker_logs()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "34bca836",
+ "metadata": {},
+ "source": [
+ "The logs DataFrame contains the following columns:\n",
+ "\n",
+ "- **ts**: Timestamp of the log entry\n",
+ "- **level**: Log level (e.g., `\"INFO\"`, `\"WARNING\"`, `\"ERROR\"`)\n",
+ "- **logger**: Logger name (e.g., custom logger name, `\"stdout\"`, `\"stderr\"`)\n",
+ "- **msg**: The log message\n",
+ "- **context**: A map containing contextual information (e.g., `func_name`, `class_name`, custom fields)\n",
+ "- **exception**: Exception details (if an exception was logged)\n",
+ "\n",
+ "### Examples\n",
+ "\n",
+ "#### Basic UDF Logging"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "4cb5bbca",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+------------+\n",
+ "|my_udf(text)|\n",
+ "+------------+\n",
+ "| HELLO|\n",
+ "| WORLD|\n",
+ "+------------+\n",
+ "\n",
+ "+-------+------------------------+----------------+---------------------+\n",
+ "|level |msg |logger |context |\n",
+ "+-------+------------------------+----------------+---------------------+\n",
+ "|INFO |Processing value: hello |my_custom_logger|{func_name -> my_udf}|\n",
+ "|WARNING|This is a warning |my_custom_logger|{func_name -> my_udf}|\n",
+ "|INFO |This is a stdout message|stdout |{func_name -> my_udf}|\n",
+ "|ERROR |This is a stderr message|stderr |{func_name -> my_udf}|\n",
+ "|INFO |Processing value: world |my_custom_logger|{func_name -> my_udf}|\n",
+ "|WARNING|This is a warning |my_custom_logger|{func_name -> my_udf}|\n",
+ "|INFO |This is a stdout message|stdout |{func_name -> my_udf}|\n",
+ "|ERROR |This is a stderr message|stderr |{func_name -> my_udf}|\n",
+ "+-------+------------------------+----------------+---------------------+\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "from pyspark.sql.functions import udf\n",
+ "import logging\n",
+ "import sys\n",
+ "\n",
+ "@udf(\"string\")\n",
+ "def my_udf(value):\n",
+ " logger = logging.getLogger(\"my_custom_logger\")\n",
+ " logger.setLevel(logging.INFO) # Set level to INFO to capture info messages\n",
+ " logger.info(f\"Processing value: {value}\")\n",
+ " logger.warning(\"This is a warning\")\n",
+ " print(\"This is a stdout message\") # INFO level, logger=stdout\n",
+ " print(\"This is a stderr message\", file=sys.stderr) # ERROR level, logger=stderr\n",
+ " return value.upper()\n",
+ "\n",
+ "# Enable logging and execute\n",
+ "spark.conf.set(\"spark.sql.pyspark.worker.logging.enabled\", \"true\")\n",
+ "df = spark.createDataFrame([(\"hello\",), (\"world\",)], [\"text\"])\n",
+ "df.select(my_udf(\"text\")).show()\n",
+ "\n",
+ "# Query the logs\n",
+ "logs = spark.tvf.python_worker_logs()\n",
+ "logs.select(\"level\", \"msg\", \"logger\", \"context\").show(truncate=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "15a80ffb",
+ "metadata": {},
+ "source": [
+ "#### Logging with Custom Context\n",
+ "\n",
+ "You can add custom context information to your logs:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "427a06c5",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+--------------------+\n",
+ "|contextual_udf(test)|\n",
+ "+--------------------+\n",
+ "| test|\n",
+ "+--------------------+\n",
+ "\n",
+ "+-----------------------------+---------------------------------------------------------------------+\n",
+ "|msg |context |\n",
+ "+-----------------------------+---------------------------------------------------------------------+\n",
+ "|Processing with extra context|{func_name -> contextual_udf, user_id -> 123, operation -> transform}|\n",
+ "+-----------------------------+---------------------------------------------------------------------+\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "from pyspark.sql.functions import lit, udf\n",
+ "import logging\n",
+ "\n",
+ "@udf(\"string\")\n",
+ "def contextual_udf(value):\n",
+ " logger = logging.getLogger(\"contextual\")\n",
+ " logger.warning(\n",
+ " \"Processing with extra context\",\n",
+ " extra={\"context\": {\"user_id\": 123, \"operation\": \"transform\"}}\n",
+ " )\n",
+ " return value\n",
+ "\n",
+ "spark.conf.set(\"spark.sql.pyspark.worker.logging.enabled\", \"true\")\n",
+ "spark.range(1).select(contextual_udf(lit(\"test\"))).show()\n",
+ "\n",
+ "logs = spark.tvf.python_worker_logs()\n",
+ "logs.filter(\"logger = 'contextual'\").select(\"msg\", \"context\").show(truncate=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a19db296",
+ "metadata": {},
+ "source": [
+ "The context includes both automatic fields (like `func_name`) and custom fields (like `user_id`, `operation`).\n",
+ "\n",
+ "#### Exception Logging\n",
+ "\n",
+ "Exceptions are automatically captured with full stack traces:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "3ab34a4c",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+------------------+\n",
+ "|failing_udf(value)|\n",
+ "+------------------+\n",
+ "| -1|\n",
+ "| 20|\n",
+ "+------------------+\n",
+ "\n",
+ "+-------------------------+----------------------------------------------------------------------------------------------------------------------------------------------+\n",
+ "|msg |exception |\n",
+ "+-------------------------+----------------------------------------------------------------------------------------------------------------------------------------------+\n",
+ "|Division by zero occurred|{ZeroDivisionError, division by zero, [{NULL, failing_udf, /var/folders/r8/0v7zwfbd59q4ym2gn6kxjq8h0000gp/T/ipykernel_79089/916837455.py, 8}]}|\n",
+ "+-------------------------+----------------------------------------------------------------------------------------------------------------------------------------------+\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "from pyspark.sql.functions import udf\n",
+ "import logging\n",
+ "\n",
+ "@udf(\"int\")\n",
+ "def failing_udf(x):\n",
+ " logger = logging.getLogger(\"error_handler\")\n",
+ " try:\n",
+ " result = 100 / x\n",
+ " except ZeroDivisionError:\n",
+ " logger.exception(\"Division by zero occurred\")\n",
+ " return -1\n",
+ " return int(result)\n",
+ "\n",
+ "spark.conf.set(\"spark.sql.pyspark.worker.logging.enabled\", \"true\")\n",
+ "spark.createDataFrame([(0,), (5,)], [\"value\"]).select(failing_udf(\"value\")).show()\n",
+ "\n",
+ "logs = spark.tvf.python_worker_logs()\n",
+ "logs.filter(\"logger = 'error_handler'\").select(\"msg\", \"exception\").show(truncate=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e54f6ac3",
+ "metadata": {},
+ "source": [
+ "#### UDTF and Python Data Source Logging\n",
+ "\n",
+ "Worker logging also works with UDTFs and Python Data Sources, capturing both the class and function names:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "02d454b0",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+-----------+-----+------+\n",
+ "| text| word|length|\n",
+ "+-----------+-----+------+\n",
+ "|hello world|hello| 5|\n",
+ "|hello world|world| 5|\n",
+ "+-----------+-----+------+\n",
+ "\n",
+ "+-----------------------------+---------------------------------------------------------------------+\n",
+ "|msg |context |\n",
+ "+-----------------------------+---------------------------------------------------------------------+\n",
+ "|Processing 2 words |{func_name -> eval, class_name -> WordSplitter} |\n",
+ "+-----------------------------+---------------------------------------------------------------------+\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "from pyspark.sql.functions import col, udtf\n",
+ "import logging\n",
+ "\n",
+ "@udtf(returnType=\"word: string, length: int\")\n",
+ "class WordSplitter:\n",
+ " def eval(self, text: str):\n",
+ " logger = logging.getLogger(\"udtf_logger\")\n",
+ " logger.setLevel(logging.INFO) # Set level to INFO to capture info messages\n",
+ " words = text.split()\n",
+ " logger.info(f\"Processing {len(words)} words\")\n",
+ " for word in words:\n",
+ " yield (word, len(word))\n",
+ "\n",
+ "spark.conf.set(\"spark.sql.pyspark.worker.logging.enabled\", \"true\")\n",
+ "df = spark.createDataFrame([(\"hello world\",)], [\"text\"])\n",
+ "df.lateralJoin(WordSplitter(col(\"text\").outer())).show()\n",
+ "\n",
+ "logs = spark.tvf.python_worker_logs()\n",
+ "logs.filter(\"logger = 'udtf_logger'\").select(\"msg\", \"context\").show(truncate=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9d4c119b",
+ "metadata": {},
+ "source": [
+ "### Querying and Analyzing Logs\n",
+ "\n",
+ "You can use standard DataFrame operations to analyze logs:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "5b061011",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "+-------+-----+\n",
+ "| level|count|\n",
+ "+-------+-----+\n",
+ "| INFO| 5|\n",
+ "|WARNING| 3|\n",
+ "| ERROR| 3|\n",
+ "+-------+-----+\n",
+ "\n",
+ "...\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "logs = spark.tvf.python_worker_logs()\n",
+ "\n",
+ "# Count logs by level\n",
+ "logs.groupBy(\"level\").count().show()\n",
+ "\n",
+ "# Find all errors\n",
+ "logs.filter(\"level = 'ERROR'\").show()\n",
+ "\n",
+ "# Logs from a specific function\n",
+ "logs.filter(\"context.func_name = 'my_udf'\").show()\n",
+ "\n",
+ "# Logs with exceptions\n",
+ "logs.filter(\"exception is not null\").show()\n",
+ "\n",
+ "# Time-based analysis\n",
+ "logs.orderBy(\"ts\").show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7eaa72b9",
+ "metadata": {},
+ "source": [
+ "\n"
+ ]
+ },
{
"attachments": {},
"cell_type": "markdown",
diff --git a/python/packaging/classic/setup.py b/python/packaging/classic/setup.py
index eac97af2e8c89..54ec4abe3be91 100755
--- a/python/packaging/classic/setup.py
+++ b/python/packaging/classic/setup.py
@@ -153,9 +153,10 @@ def _supports_symlinks():
_minimum_pandas_version = "2.2.0"
_minimum_numpy_version = "1.21"
_minimum_pyarrow_version = "15.0.0"
-_minimum_grpc_version = "1.67.0"
-_minimum_googleapis_common_protos_version = "1.65.0"
+_minimum_grpc_version = "1.76.0"
+_minimum_googleapis_common_protos_version = "1.71.0"
_minimum_pyyaml_version = "3.11"
+_minimum_zstandard_version = "0.25.0"
class InstallCommand(install):
@@ -366,6 +367,7 @@ def run(self):
"grpcio>=%s" % _minimum_grpc_version,
"grpcio-status>=%s" % _minimum_grpc_version,
"googleapis-common-protos>=%s" % _minimum_googleapis_common_protos_version,
+ "zstandard>=%s" % _minimum_zstandard_version,
"numpy>=%s" % _minimum_numpy_version,
],
"pipelines": [
@@ -375,6 +377,7 @@ def run(self):
"grpcio>=%s" % _minimum_grpc_version,
"grpcio-status>=%s" % _minimum_grpc_version,
"googleapis-common-protos>=%s" % _minimum_googleapis_common_protos_version,
+ "zstandard>=%s" % _minimum_zstandard_version,
"pyyaml>=%s" % _minimum_pyyaml_version,
],
},
diff --git a/python/packaging/client/setup.py b/python/packaging/client/setup.py
index 7ec7e45a31604..ee404210f2932 100755
--- a/python/packaging/client/setup.py
+++ b/python/packaging/client/setup.py
@@ -136,9 +136,10 @@
_minimum_pandas_version = "2.2.0"
_minimum_numpy_version = "1.21"
_minimum_pyarrow_version = "15.0.0"
- _minimum_grpc_version = "1.67.0"
- _minimum_googleapis_common_protos_version = "1.65.0"
+ _minimum_grpc_version = "1.76.0"
+ _minimum_googleapis_common_protos_version = "1.71.0"
_minimum_pyyaml_version = "3.11"
+ _minimum_zstandard_version = "0.25.0"
with open("README.md") as f:
long_description = f.read()
@@ -211,6 +212,7 @@
"grpcio>=%s" % _minimum_grpc_version,
"grpcio-status>=%s" % _minimum_grpc_version,
"googleapis-common-protos>=%s" % _minimum_googleapis_common_protos_version,
+ "zstandard>=%s" % _minimum_zstandard_version,
"numpy>=%s" % _minimum_numpy_version,
"pyyaml>=%s" % _minimum_pyyaml_version,
],
@@ -221,6 +223,7 @@
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
+ "Programming Language :: Python :: 3.14",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
"Typing :: Typed",
diff --git a/python/packaging/connect/setup.py b/python/packaging/connect/setup.py
index f2b53211b3a0d..fc3cfca739aaa 100755
--- a/python/packaging/connect/setup.py
+++ b/python/packaging/connect/setup.py
@@ -86,12 +86,13 @@
# binary format protocol with the Java version, see ARROW_HOME/format/* for specifications.
# Also don't forget to update python/docs/source/getting_started/install.rst,
# python/packaging/classic/setup.py, and python/packaging/client/setup.py
- _minimum_pandas_version = "2.0.0"
+ _minimum_pandas_version = "2.2.0"
_minimum_numpy_version = "1.21"
_minimum_pyarrow_version = "11.0.0"
- _minimum_grpc_version = "1.67.0"
- _minimum_googleapis_common_protos_version = "1.65.0"
+ _minimum_grpc_version = "1.76.0"
+ _minimum_googleapis_common_protos_version = "1.71.0"
_minimum_pyyaml_version = "3.11"
+ _minimum_zstandard_version = "0.25.0"
with open("README.md") as f:
long_description = f.read()
@@ -121,17 +122,18 @@
"grpcio>=%s" % _minimum_grpc_version,
"grpcio-status>=%s" % _minimum_grpc_version,
"googleapis-common-protos>=%s" % _minimum_googleapis_common_protos_version,
+ "zstandard>=%s" % _minimum_zstandard_version,
"numpy>=%s" % _minimum_numpy_version,
"pyyaml>=%s" % _minimum_pyyaml_version,
],
- python_requires=">=3.9",
+ python_requires=">=3.10",
classifiers=[
"Development Status :: 5 - Production/Stable",
- "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
+ "Programming Language :: Python :: 3.14",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
"Typing :: Typed",
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index 59f7856688ee9..e557fe1cd8fb8 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -321,21 +321,36 @@ def shutdown(self) -> None:
self.server_close()
-class AccumulatorUnixServer(socketserver.UnixStreamServer):
- server_shutdown = False
-
- def __init__(
- self, socket_path: str, RequestHandlerClass: Type[socketserver.BaseRequestHandler]
- ):
- super().__init__(socket_path, RequestHandlerClass)
- self.auth_token = None
-
- def shutdown(self) -> None:
- self.server_shutdown = True
- super().shutdown()
- self.server_close()
- if os.path.exists(self.server_address): # type: ignore[arg-type]
- os.remove(self.server_address) # type: ignore[arg-type]
+# socketserver.UnixStreamServer is not available on Windows yet
+# (https://github.com/python/cpython/issues/77589).
+if hasattr(socketserver, "UnixStreamServer"):
+
+ class AccumulatorUnixServer(socketserver.UnixStreamServer):
+ server_shutdown = False
+
+ def __init__(
+ self, socket_path: str, RequestHandlerClass: Type[socketserver.BaseRequestHandler]
+ ):
+ super().__init__(socket_path, RequestHandlerClass)
+ self.auth_token = None
+
+ def shutdown(self) -> None:
+ self.server_shutdown = True
+ super().shutdown()
+ self.server_close()
+ if os.path.exists(self.server_address): # type: ignore[arg-type]
+ os.remove(self.server_address) # type: ignore[arg-type]
+
+else:
+
+ class AccumulatorUnixServer(socketserver.TCPServer): # type: ignore[no-redef]
+ def __init__(
+ self, socket_path: str, RequestHandlerClass: Type[socketserver.BaseRequestHandler]
+ ):
+ raise NotImplementedError(
+ "Unix Domain Sockets are not supported on this platform. "
+ "Please disable it by setting spark.python.unix.domain.socket.enabled to false."
+ )
def _start_update_server(
diff --git a/python/pyspark/cloudpickle/__init__.py b/python/pyspark/cloudpickle/__init__.py
index bdb1738611b3b..052b6e975a772 100644
--- a/python/pyspark/cloudpickle/__init__.py
+++ b/python/pyspark/cloudpickle/__init__.py
@@ -3,7 +3,7 @@
__doc__ = cloudpickle.__doc__
-__version__ = "3.1.1"
+__version__ = "3.1.2"
__all__ = [ # noqa
"__version__",
diff --git a/python/pyspark/cloudpickle/cloudpickle.py b/python/pyspark/cloudpickle/cloudpickle.py
index 4d532e5de9f2c..e600b35f28422 100644
--- a/python/pyspark/cloudpickle/cloudpickle.py
+++ b/python/pyspark/cloudpickle/cloudpickle.py
@@ -783,6 +783,12 @@ def _class_getstate(obj):
clsdict.pop("__dict__", None) # unpicklable property object
+ if sys.version_info >= (3, 14):
+ # PEP-649/749: __annotate_func__ contains a closure that references the class
+ # dict. We need to exclude it from pickling. Python will recreate it when
+ # __annotations__ is accessed at unpickling time.
+ clsdict.pop("__annotate_func__", None)
+
return (clsdict, {})
@@ -1190,6 +1196,10 @@ def _class_setstate(obj, state):
for subclass in registry:
obj.register(subclass)
+ # PEP-649/749: During pickling, we excluded the __annotate_func__ attribute but it
+ # will be created by Python. Subsequently, annotations will be recreated when
+ # __annotations__ is accessed.
+
return obj
@@ -1301,12 +1311,9 @@ def _function_getnewargs(self, func):
def dump(self, obj):
try:
return super().dump(obj)
- except RuntimeError as e:
- if len(e.args) > 0 and "recursion" in e.args[0]:
- msg = "Could not pickle object as excessively deep recursion required."
- raise pickle.PicklingError(msg) from e
- else:
- raise
+ except RecursionError as e:
+ msg = "Could not pickle object as excessively deep recursion required."
+ raise pickle.PicklingError(msg) from e
def __init__(self, file, protocol=None, buffer_callback=None):
if protocol is None:
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index ca33ce2c39ef7..e75eca68fd0e7 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -24,6 +24,7 @@
import traceback
import time
import gc
+import faulthandler
from errno import EINTR, EAGAIN
from socket import AF_INET, AF_INET6, SOCK_STREAM, SOMAXCONN
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT
@@ -85,7 +86,19 @@ def worker(sock, authenticated):
try:
outfile.flush()
except Exception:
- pass
+ if os.environ.get("PYTHON_DAEMON_KILL_WORKER_ON_FLUSH_FAILURE", False):
+ faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
+ if faulthandler_log_path:
+ faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
+ with open(faulthandler_log_path, "w") as faulthandler_log_file:
+ faulthandler.dump_traceback(file=faulthandler_log_file)
+ raise
+ else:
+ print(
+ "PySpark daemon failed to flush the output to the worker process:\n"
+ + traceback.format_exc(),
+ file=sys.stderr,
+ )
return exit_code
diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json
index d169e6293a1ba..326671c0d5ad2 100644
--- a/python/pyspark/errors/error-conditions.json
+++ b/python/pyspark/errors/error-conditions.json
@@ -549,6 +549,16 @@
" and should be of the same length, got and ."
]
},
+ "MALFORMED_GEOGRAPHY": {
+ "message": [
+ "Geography binary is malformed. Please check the data source is valid."
+ ]
+ },
+ "MALFORMED_GEOMETRY": {
+ "message": [
+ "Geometry binary is malformed. Please check the data source is valid."
+ ]
+ },
"MALFORMED_VARIANT": {
"message": [
"Variant binary is malformed. Please check the data source is valid."
@@ -898,7 +908,7 @@
},
"PIPELINE_SPEC_FILE_NOT_FOUND": {
"message": [
- "No pipeline.yaml or pipeline.yml file provided in arguments or found in directory `` or readable ancestor directories."
+ "No spark-pipeline.yaml or spark-pipeline.yml file provided in arguments or found in directory `` or readable ancestor directories."
]
},
"PIPELINE_SPEC_INVALID_GLOB_PATTERN": {
@@ -1109,6 +1119,11 @@
"SparkContext or SparkSession should be created first."
]
},
+ "SIMPLE_STREAM_READER_OFFSET_DID_NOT_ADVANCE": {
+ "message": [
+ "SimpleDataSourceStreamReader.read() returned a non-empty batch but the end offset: did not advance past the start offset: . The end offset must represent the position after the last record returned."
+ ]
+ },
"SLICE_WITH_STEP": {
"message": [
"Slice with step is not supported."
diff --git a/python/pyspark/errors/exceptions/captured.py b/python/pyspark/errors/exceptions/captured.py
index 56892db91f3be..0f76e3b5f6a07 100644
--- a/python/pyspark/errors/exceptions/captured.py
+++ b/python/pyspark/errors/exceptions/captured.py
@@ -107,7 +107,8 @@ def getCondition(self) -> Optional[str]:
if self._origin is not None and is_instance_of(
gw, self._origin, "org.apache.spark.SparkThrowable"
):
- return self._origin.getCondition()
+ utils = SparkContext._jvm.PythonErrorUtils # type: ignore[union-attr]
+ return utils.getCondition(self._origin)
else:
return None
@@ -118,7 +119,6 @@ def getErrorClass(self) -> Optional[str]:
def getMessageParameters(self) -> Optional[Dict[str, str]]:
from pyspark import SparkContext
from py4j.java_gateway import is_instance_of
- from py4j.protocol import Py4JError
assert SparkContext._gateway is not None
@@ -126,38 +126,28 @@ def getMessageParameters(self) -> Optional[Dict[str, str]]:
if self._origin is not None and is_instance_of(
gw, self._origin, "org.apache.spark.SparkThrowable"
):
- try:
- return dict(self._origin.getMessageParameters())
- except Py4JError as e:
- if "py4j.Py4JException" in str(e) and "Method getMessageParameters" in str(e):
- return None
- raise e
+ utils = SparkContext._jvm.PythonErrorUtils # type: ignore[union-attr]
+ return dict(utils.getMessageParameters(self._origin))
else:
return None
def getSqlState(self) -> Optional[str]:
from pyspark import SparkContext
from py4j.java_gateway import is_instance_of
- from py4j.protocol import Py4JError
assert SparkContext._gateway is not None
gw = SparkContext._gateway
if self._origin is not None and is_instance_of(
gw, self._origin, "org.apache.spark.SparkThrowable"
):
- try:
- return self._origin.getSqlState()
- except Py4JError as e:
- if "py4j.Py4JException" in str(e) and "Method getSqlState" in str(e):
- return None
- raise e
+ utils = SparkContext._jvm.PythonErrorUtils # type: ignore[union-attr]
+ return utils.getSqlState(self._origin)
else:
return None
def getMessage(self) -> str:
from pyspark import SparkContext
from py4j.java_gateway import is_instance_of
- from py4j.protocol import Py4JError
assert SparkContext._gateway is not None
gw = SparkContext._gateway
@@ -165,21 +155,12 @@ def getMessage(self) -> str:
if self._origin is not None and is_instance_of(
gw, self._origin, "org.apache.spark.SparkThrowable"
):
- try:
- error_class = self._origin.getCondition()
- except Py4JError as e:
- if "py4j.Py4JException" in str(e) and "Method getCondition" in str(e):
- return ""
- raise e
- try:
- message_parameters = self._origin.getMessageParameters()
- except Py4JError as e:
- if "py4j.Py4JException" in str(e) and "Method getMessageParameters" in str(e):
- return ""
- raise e
+ utils = SparkContext._jvm.PythonErrorUtils # type: ignore[union-attr]
+ errorClass = utils.getCondition(self._origin)
+ messageParameters = utils.getMessageParameters(self._origin)
error_message = getattr(gw.jvm, "org.apache.spark.SparkThrowableHelper").getMessage(
- error_class, message_parameters
+ errorClass, messageParameters
)
return error_message
@@ -189,7 +170,6 @@ def getMessage(self) -> str:
def getQueryContext(self) -> List[BaseQueryContext]:
from pyspark import SparkContext
from py4j.java_gateway import is_instance_of
- from py4j.protocol import Py4JError
assert SparkContext._gateway is not None
@@ -198,13 +178,8 @@ def getQueryContext(self) -> List[BaseQueryContext]:
gw, self._origin, "org.apache.spark.SparkThrowable"
):
contexts: List[BaseQueryContext] = []
- try:
- context = self._origin.getQueryContext()
- except Py4JError as e:
- if "py4j.Py4JException" in str(e) and "Method getQueryContext" in str(e):
- return []
- raise e
- for q in context:
+ utils = SparkContext._jvm.PythonErrorUtils # type: ignore[union-attr]
+ for q in utils.getQueryContext(self._origin):
if q.contextType().toString() == "SQL":
contexts.append(SQLQueryContext(q))
else:
diff --git a/python/pyspark/logger/logger.py b/python/pyspark/logger/logger.py
index b60561f24c99c..72179b033bb36 100644
--- a/python/pyspark/logger/logger.py
+++ b/python/pyspark/logger/logger.py
@@ -140,7 +140,19 @@ class PySparkLogger(logging.Logger):
"""
def __init__(self, name: str = "PySparkLogger"):
+ from pyspark.logger.worker_io import JSONFormatterWithMarker
+
super().__init__(name, level=logging.WARN)
+
+ root_logger = logging.getLogger()
+ if any(
+ isinstance(h, logging.StreamHandler)
+ and isinstance(h.formatter, JSONFormatterWithMarker)
+ for h in root_logger.handlers
+ ):
+ # Likely in the `capture_outputs` context, so don't add a handler
+ return
+
_handler = logging.StreamHandler()
self.addHandler(_handler)
diff --git a/python/pyspark/logger/worker_io.py b/python/pyspark/logger/worker_io.py
index 2e5ced2e84ad3..79684b7aca624 100644
--- a/python/pyspark/logger/worker_io.py
+++ b/python/pyspark/logger/worker_io.py
@@ -164,6 +164,7 @@ def formatTime(self, record: logging.LogRecord, datefmt: Optional[str] = None) -
)
elif self.default_msec_format:
s = self.default_msec_format % (s, record.msecs)
+ s = f"{s}{time.strftime('%z', ct)}"
return s
diff --git a/python/pyspark/ml/connect/feature.py b/python/pyspark/ml/connect/feature.py
index b0e2028e43faa..2184b3c7f332f 100644
--- a/python/pyspark/ml/connect/feature.py
+++ b/python/pyspark/ml/connect/feature.py
@@ -15,11 +15,11 @@
# limitations under the License.
#
-import pickle
from typing import Any, Union, List, Tuple, Callable, Dict, Optional
import numpy as np
import pandas as pd
+import pyarrow as pa
from pyspark import keyword_only
from pyspark.sql import DataFrame
@@ -133,27 +133,29 @@ def map_value(x: "np.ndarray") -> "np.ndarray":
return transform_fn
def _get_core_model_filename(self) -> str:
- return self.__class__.__name__ + ".sklearn.pkl"
+ return self.__class__.__name__ + ".arrow.parquet"
def _save_core_model(self, path: str) -> None:
- from sklearn.preprocessing import MaxAbsScaler as sk_MaxAbsScaler
-
- sk_model = sk_MaxAbsScaler()
- sk_model.scale_ = self.scale_values
- sk_model.max_abs_ = self.max_abs_values
- sk_model.n_features_in_ = len(self.max_abs_values) # type: ignore[arg-type]
- sk_model.n_samples_seen_ = self.n_samples_seen
-
- with open(path, "wb") as fp:
- pickle.dump(sk_model, fp)
+ import pyarrow.parquet as pq
+
+ table = pa.Table.from_arrays(
+ [
+ pa.array([self.scale_values], pa.list_(pa.float64())),
+ pa.array([self.max_abs_values], pa.list_(pa.float64())),
+ pa.array([self.n_samples_seen], pa.int64()),
+ ],
+ names=["scale", "max_abs", "n_samples"],
+ )
+ pq.write_table(table, path)
def _load_core_model(self, path: str) -> None:
- with open(path, "rb") as fp:
- sk_model = pickle.load(fp)
+ import pyarrow.parquet as pq
+
+ table = pq.read_table(path)
- self.max_abs_values = sk_model.max_abs_
- self.scale_values = sk_model.scale_
- self.n_samples_seen = sk_model.n_samples_seen_
+ self.max_abs_values = np.array(table.column("scale")[0].as_py())
+ self.scale_values = np.array(table.column("max_abs")[0].as_py())
+ self.n_samples_seen = table.column("n_samples")[0].as_py()
class StandardScaler(Estimator, HasInputCol, HasOutputCol, ParamsReadWrite):
@@ -254,29 +256,31 @@ def map_value(x: "np.ndarray") -> "np.ndarray":
return transform_fn
def _get_core_model_filename(self) -> str:
- return self.__class__.__name__ + ".sklearn.pkl"
+ return self.__class__.__name__ + ".arrow.parquet"
def _save_core_model(self, path: str) -> None:
- from sklearn.preprocessing import StandardScaler as sk_StandardScaler
-
- sk_model = sk_StandardScaler(with_mean=True, with_std=True)
- sk_model.scale_ = self.scale_values
- sk_model.var_ = self.std_values * self.std_values # type: ignore[operator]
- sk_model.mean_ = self.mean_values
- sk_model.n_features_in_ = len(self.std_values) # type: ignore[arg-type]
- sk_model.n_samples_seen_ = self.n_samples_seen
-
- with open(path, "wb") as fp:
- pickle.dump(sk_model, fp)
+ import pyarrow.parquet as pq
+
+ table = pa.Table.from_arrays(
+ [
+ pa.array([self.scale_values], pa.list_(pa.float64())),
+ pa.array([self.mean_values], pa.list_(pa.float64())),
+ pa.array([self.std_values], pa.list_(pa.float64())),
+ pa.array([self.n_samples_seen], pa.int64()),
+ ],
+ names=["scale", "mean", "std", "n_samples"],
+ )
+ pq.write_table(table, path)
def _load_core_model(self, path: str) -> None:
- with open(path, "rb") as fp:
- sk_model = pickle.load(fp)
+ import pyarrow.parquet as pq
+
+ table = pq.read_table(path)
- self.std_values = np.sqrt(sk_model.var_)
- self.scale_values = sk_model.scale_
- self.mean_values = sk_model.mean_
- self.n_samples_seen = sk_model.n_samples_seen_
+ self.scale_values = np.array(table.column("scale")[0].as_py())
+ self.mean_values = np.array(table.column("mean")[0].as_py())
+ self.std_values = np.array(table.column("std")[0].as_py())
+ self.n_samples_seen = table.column("n_samples")[0].as_py()
class ArrayAssembler(
diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
index 2d0a37aca5c8c..9fbf24d253424 100644
--- a/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
+++ b/python/pyspark/ml/tests/connect/test_legacy_mode_feature.py
@@ -17,7 +17,6 @@
#
import os
-import pickle
import tempfile
import unittest
@@ -85,12 +84,6 @@ def test_max_abs_scaler(self):
np.testing.assert_allclose(model.max_abs_values, loaded_model.max_abs_values)
assert model.n_samples_seen == loaded_model.n_samples_seen
- # Test loading core model as scikit-learn model
- with open(os.path.join(model_path, "MaxAbsScalerModel.sklearn.pkl"), "rb") as f:
- sk_model = pickle.load(f)
- sk_result = sk_model.transform(np.stack(list(local_df1.features)))
- np.testing.assert_allclose(sk_result, expected_result)
-
def test_standard_scaler(self):
df1 = self.spark.createDataFrame(
[
@@ -141,12 +134,6 @@ def test_standard_scaler(self):
np.testing.assert_allclose(model.scale_values, loaded_model.scale_values)
assert model.n_samples_seen == loaded_model.n_samples_seen
- # Test loading core model as scikit-learn model
- with open(os.path.join(model_path, "StandardScalerModel.sklearn.pkl"), "rb") as f:
- sk_model = pickle.load(f)
- sk_result = sk_model.transform(np.stack(list(local_df1.features)))
- np.testing.assert_allclose(sk_result, expected_result)
-
def test_array_assembler(self):
spark_df = self.spark.createDataFrame(
[
diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py
index af89d18a0edea..05e6de554f7ad 100644
--- a/python/pyspark/pandas/frame.py
+++ b/python/pyspark/pandas/frame.py
@@ -2699,8 +2699,20 @@ def to_feather(
# Make sure locals() call is at the top of the function so we don't capture local variables.
args = locals()
+ pdf = self._to_internal_pandas()
+ # SPARK-54068: PyArrow >= 22.0.0 serializes DataFrame.attrs to JSON metadata,
+ # but PlanMetrics/PlanObservedMetrics objects from Spark Connect are not
+ # JSON serializable. We filter these internal attrs only for affected versions.
+ import pyarrow as pa
+ from pyspark.loose_version import LooseVersion
+
+ if LooseVersion(pa.__version__) >= LooseVersion("22.0.0"):
+ pdf.attrs = {
+ k: v for k, v in pdf.attrs.items() if k not in ("metrics", "observed_metrics")
+ }
+
return validate_arguments_and_invoke_function(
- self._to_internal_pandas(), self.to_feather, pd.DataFrame.to_feather, args
+ pdf, self.to_feather, pd.DataFrame.to_feather, args
)
def to_stata(
diff --git a/python/pyspark/pandas/supported_api_gen.py b/python/pyspark/pandas/supported_api_gen.py
index 595c11c559d05..3f7efa7784ab7 100644
--- a/python/pyspark/pandas/supported_api_gen.py
+++ b/python/pyspark/pandas/supported_api_gen.py
@@ -38,7 +38,7 @@
MAX_MISSING_PARAMS_SIZE = 5
COMMON_PARAMETER_SET = {"kwargs", "args", "cls"}
MODULE_GROUP_MATCH = [(pd, ps), (pdw, psw), (pdg, psg)]
-PANDAS_LATEST_VERSION = "2.3.2"
+PANDAS_LATEST_VERSION = "2.3.3"
RST_HEADER = """
=====================
diff --git a/python/pyspark/pandas/tests/io/test_feather.py b/python/pyspark/pandas/tests/io/test_feather.py
index 3ddf0a2aad925..74fa6bc7d7b64 100644
--- a/python/pyspark/pandas/tests/io/test_feather.py
+++ b/python/pyspark/pandas/tests/io/test_feather.py
@@ -17,7 +17,6 @@
import unittest
import pandas as pd
-import sys
from pyspark import pandas as ps
from pyspark.testing.pandasutils import PandasOnSparkTestCase, TestUtils
@@ -35,7 +34,6 @@ def pdf(self):
def psdf(self):
return ps.from_pandas(self.pdf)
- @unittest.skipIf(sys.version_info > (3, 13), "SPARK-54068")
def test_to_feather(self):
with self.temp_dir() as dirpath:
path1 = f"{dirpath}/file1.feather"
diff --git a/python/pyspark/pandas/tests/test_typedef.py b/python/pyspark/pandas/tests/test_typedef.py
index 329af01a39440..54a03200c05c6 100644
--- a/python/pyspark/pandas/tests/test_typedef.py
+++ b/python/pyspark/pandas/tests/test_typedef.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+import os
import sys
import unittest
import datetime
diff --git a/python/pyspark/pandas/typedef/typehints.py b/python/pyspark/pandas/typedef/typehints.py
index 48545d124b2d8..a4ed9f996fe47 100644
--- a/python/pyspark/pandas/typedef/typehints.py
+++ b/python/pyspark/pandas/typedef/typehints.py
@@ -342,7 +342,7 @@ def pandas_on_spark_type(tpe: Union[str, type, Dtype]) -> Tuple[Dtype, types.Dat
try:
dtype = pandas_dtype(tpe)
spark_type = as_spark_type(dtype)
- except TypeError:
+ except (TypeError, ValueError):
spark_type = as_spark_type(tpe)
dtype = spark_type_to_pandas_dtype(spark_type)
return dtype, spark_type
diff --git a/python/pyspark/pipelines/add_pipeline_analysis_context.py b/python/pyspark/pipelines/add_pipeline_analysis_context.py
new file mode 100644
index 0000000000000..6ecabdf43b072
--- /dev/null
+++ b/python/pyspark/pipelines/add_pipeline_analysis_context.py
@@ -0,0 +1,48 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from contextlib import contextmanager
+from typing import Generator, Optional
+from pyspark.sql import SparkSession
+
+from typing import Any, cast
+
+
+@contextmanager
+def add_pipeline_analysis_context(
+ spark: SparkSession, dataflow_graph_id: str, flow_name: Optional[str]
+) -> Generator[None, None, None]:
+ """
+ Context manager that add PipelineAnalysisContext extension to the user context
+ used for pipeline specific analysis.
+ """
+ extension_id = None
+ # Cast because mypy seems to think `spark` is a function, not an object.
+ # Likely related to SPARK-47544.
+ client = cast(Any, spark).client
+ try:
+ import pyspark.sql.connect.proto as pb2
+ from google.protobuf import any_pb2
+
+ analysis_context = pb2.PipelineAnalysisContext(
+ dataflow_graph_id=dataflow_graph_id, flow_name=flow_name
+ )
+ extension = any_pb2.Any()
+ extension.Pack(analysis_context)
+ extension_id = client.add_threadlocal_user_context_extension(extension)
+ yield
+ finally:
+ client.remove_user_context_extension(extension_id)
diff --git a/python/pyspark/pipelines/api.py b/python/pyspark/pipelines/api.py
index b68cc30b43a7d..f109841d657b8 100644
--- a/python/pyspark/pipelines/api.py
+++ b/python/pyspark/pipelines/api.py
@@ -76,6 +76,7 @@ def _validate_stored_dataset_args(
name: Optional[str],
table_properties: Optional[Dict[str, str]],
partition_cols: Optional[List[str]],
+ cluster_by: Optional[List[str]],
) -> None:
if name is not None and type(name) is not str:
raise PySparkTypeError(
@@ -91,6 +92,7 @@ def _validate_stored_dataset_args(
},
)
validate_optional_list_of_str_arg(arg_name="partition_cols", arg_value=partition_cols)
+ validate_optional_list_of_str_arg(arg_name="cluster_by", arg_value=cluster_by)
@overload
@@ -107,6 +109,7 @@ def table(
spark_conf: Optional[Dict[str, str]] = None,
table_properties: Optional[Dict[str, str]] = None,
partition_cols: Optional[List[str]] = None,
+ cluster_by: Optional[List[str]] = None,
schema: Optional[Union[StructType, str]] = None,
) -> Callable[[QueryFunction], None]:
...
@@ -120,6 +123,7 @@ def table(
spark_conf: Optional[Dict[str, str]] = None,
table_properties: Optional[Dict[str, str]] = None,
partition_cols: Optional[List[str]] = None,
+ cluster_by: Optional[List[str]] = None,
schema: Optional[Union[StructType, str]] = None,
format: Optional[str] = None,
) -> Union[Callable[[QueryFunction], None], None]:
@@ -142,11 +146,12 @@ def table(
:param table_properties: A dict where the keys are the property names and the values are the \
property values. These properties will be set on the table.
:param partition_cols: A list containing the column names of the partition columns.
+ :param cluster_by: A list containing the column names of the cluster columns.
:param schema: Explicit Spark SQL schema to materialize this table with. Supports either a \
Pyspark StructType or a SQL DDL string, such as "a INT, b STRING".
:param format: The format of the table, e.g. "parquet".
"""
- _validate_stored_dataset_args(name, table_properties, partition_cols)
+ _validate_stored_dataset_args(name, table_properties, partition_cols, cluster_by)
source_code_location = get_caller_source_code_location(stacklevel=1)
@@ -163,6 +168,7 @@ def outer(
name=resolved_name,
table_properties=table_properties or {},
partition_cols=partition_cols,
+ cluster_by=cluster_by,
schema=schema,
source_code_location=source_code_location,
format=format,
@@ -209,6 +215,7 @@ def materialized_view(
spark_conf: Optional[Dict[str, str]] = None,
table_properties: Optional[Dict[str, str]] = None,
partition_cols: Optional[List[str]] = None,
+ cluster_by: Optional[List[str]] = None,
schema: Optional[Union[StructType, str]] = None,
) -> Callable[[QueryFunction], None]:
...
@@ -222,6 +229,7 @@ def materialized_view(
spark_conf: Optional[Dict[str, str]] = None,
table_properties: Optional[Dict[str, str]] = None,
partition_cols: Optional[List[str]] = None,
+ cluster_by: Optional[List[str]] = None,
schema: Optional[Union[StructType, str]] = None,
format: Optional[str] = None,
) -> Union[Callable[[QueryFunction], None], None]:
@@ -244,11 +252,12 @@ def materialized_view(
:param table_properties: A dict where the keys are the property names and the values are the \
property values. These properties will be set on the table.
:param partition_cols: A list containing the column names of the partition columns.
+ :param cluster_by: A list containing the column names of the cluster columns.
:param schema: Explicit Spark SQL schema to materialize this table with. Supports either a \
Pyspark StructType or a SQL DDL string, such as "a INT, b STRING".
:param format: The format of the table, e.g. "parquet".
"""
- _validate_stored_dataset_args(name, table_properties, partition_cols)
+ _validate_stored_dataset_args(name, table_properties, partition_cols, cluster_by)
source_code_location = get_caller_source_code_location(stacklevel=1)
@@ -265,6 +274,7 @@ def outer(
name=resolved_name,
table_properties=table_properties or {},
partition_cols=partition_cols,
+ cluster_by=cluster_by,
schema=schema,
source_code_location=source_code_location,
format=format,
@@ -403,6 +413,7 @@ def create_streaming_table(
comment: Optional[str] = None,
table_properties: Optional[Dict[str, str]] = None,
partition_cols: Optional[List[str]] = None,
+ cluster_by: Optional[List[str]] = None,
schema: Optional[Union[StructType, str]] = None,
format: Optional[str] = None,
) -> None:
@@ -410,14 +421,15 @@ def create_streaming_table(
Creates a table that can be targeted by append flows.
Example:
- create_streaming_table("target")
+ create_streaming_table("target")
:param name: The name of the table.
:param comment: Description of the table.
:param table_properties: A dict where the keys are the property names and the values are the \
property values. These properties will be set on the table.
:param partition_cols: A list containing the column names of the partition columns.
- :param schema Explicit Spark SQL schema to materialize this table with. Supports either a \
+ :param cluster_by: A list containing the column names of the cluster columns.
+ :param schema: Explicit Spark SQL schema to materialize this table with. Supports either a \
Pyspark StructType or a SQL DDL string, such as "a INT, b STRING".
:param format: The format of the table, e.g. "parquet".
"""
@@ -435,6 +447,7 @@ def create_streaming_table(
},
)
validate_optional_list_of_str_arg(arg_name="partition_cols", arg_value=partition_cols)
+ validate_optional_list_of_str_arg(arg_name="cluster_by", arg_value=cluster_by)
source_code_location = get_caller_source_code_location(stacklevel=1)
@@ -444,6 +457,7 @@ def create_streaming_table(
source_code_location=source_code_location,
table_properties=table_properties or {},
partition_cols=partition_cols,
+ cluster_by=cluster_by,
schema=schema,
format=format,
)
@@ -456,12 +470,12 @@ def create_sink(
options: Optional[Dict[str, str]] = None,
) -> None:
"""
- Creates a sink that can be targeted by streaming flows, providing a generic destination \
+ Creates a sink that can be targeted by streaming flows, providing a generic destination
for flows to send data external to the pipeline.
:param name: The name of the sink.
:param format: The format of the sink, e.g. "parquet".
- :param options: A dict where the keys are the property names and the values are the \
+ :param options: A dict where the keys are the property names and the values are the
property values. These properties will be set on the sink.
"""
if type(name) is not str:
diff --git a/python/pyspark/pipelines/block_connect_access.py b/python/pyspark/pipelines/block_connect_access.py
index c5dacbbc2c5cb..696d0e39b005d 100644
--- a/python/pyspark/pipelines/block_connect_access.py
+++ b/python/pyspark/pipelines/block_connect_access.py
@@ -15,7 +15,7 @@
# limitations under the License.
#
from contextlib import contextmanager
-from typing import Callable, Generator, NoReturn
+from typing import Any, Callable, Generator
from pyspark.errors import PySparkException
from pyspark.sql.connect.proto.base_pb2_grpc import SparkConnectServiceStub
@@ -24,6 +24,27 @@
BLOCKED_RPC_NAMES = ["AnalyzePlan", "ExecutePlan"]
+def _is_sql_command_request(rpc_name: str, args: tuple) -> bool:
+ """
+ Check if the RPC call is a spark.sql() command (ExecutePlan with sql_command).
+
+ :param rpc_name: Name of the RPC being called
+ :param args: Arguments passed to the RPC
+ :return: True if this is an ExecutePlan request with a sql_command
+ """
+ if rpc_name != "ExecutePlan" or len(args) == 0:
+ return False
+
+ request = args[0]
+ if not hasattr(request, "plan"):
+ return False
+ plan = request.plan
+ if not plan.HasField("command"):
+ return False
+ command = plan.command
+ return command.HasField("sql_command")
+
+
@contextmanager
def block_spark_connect_execution_and_analysis() -> Generator[None, None, None]:
"""
@@ -38,16 +59,23 @@ def block_spark_connect_execution_and_analysis() -> Generator[None, None, None]:
# Define a new __getattribute__ method that blocks RPC calls
def blocked_getattr(self: SparkConnectServiceStub, name: str) -> Callable:
- if name not in BLOCKED_RPC_NAMES:
- return original_getattr(self, name)
+ original_method = original_getattr(self, name)
- def blocked_method(*args: object, **kwargs: object) -> NoReturn:
- raise PySparkException(
- errorClass="ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION",
- messageParameters={},
- )
+ def intercepted_method(*args: object, **kwargs: object) -> Any:
+ # Allow all RPCs that are not AnalyzePlan or ExecutePlan
+ if name not in BLOCKED_RPC_NAMES:
+ return original_method(*args, **kwargs)
+ # Allow spark.sql() commands (ExecutePlan with sql_command)
+ elif _is_sql_command_request(name, args):
+ return original_method(*args, **kwargs)
+ # Block all other AnalyzePlan and ExecutePlan calls
+ else:
+ raise PySparkException(
+ errorClass="ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION",
+ messageParameters={},
+ )
- return blocked_method
+ return intercepted_method
try:
# Apply our custom __getattribute__ method
diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py
index ca198f1c3aff3..8c70c83311201 100644
--- a/python/pyspark/pipelines/cli.py
+++ b/python/pyspark/pipelines/cli.py
@@ -49,7 +49,9 @@
handle_pipeline_events,
)
-PIPELINE_SPEC_FILE_NAMES = ["pipeline.yaml", "pipeline.yml"]
+from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context
+
+PIPELINE_SPEC_FILE_NAMES = ["spark-pipeline.yaml", "spark-pipeline.yml"]
@dataclass(frozen=True)
@@ -216,13 +218,18 @@ def validate_str_dict(d: Mapping[str, str], field_name: str) -> Mapping[str, str
def register_definitions(
- spec_path: Path, registry: GraphElementRegistry, spec: PipelineSpec
+ spec_path: Path,
+ registry: GraphElementRegistry,
+ spec: PipelineSpec,
+ spark: SparkSession,
+ dataflow_graph_id: str,
) -> None:
"""Register the graph element definitions in the pipeline spec with the given registry.
- - Looks for Python files matching the glob patterns in the spec and imports them.
- - Looks for SQL files matching the blob patterns in the spec and registers thems.
+ - Import Python files matching the glob patterns in the spec.
+ - Register SQL files matching the glob patterns in the spec.
"""
- path = spec_path.parent
+ path = spec_path.parent.resolve()
+
with change_dir(path):
with graph_element_registration_context(registry):
log_with_curr_timestamp(f"Loading definitions. Root directory: '{path}'.")
@@ -245,13 +252,16 @@ def register_definitions(
assert (
module_spec.loader is not None
), f"Module spec has no loader for {file}"
- with block_session_mutations():
- module_spec.loader.exec_module(module)
+ with add_pipeline_analysis_context(
+ spark=spark, dataflow_graph_id=dataflow_graph_id, flow_name=None
+ ):
+ with block_session_mutations():
+ module_spec.loader.exec_module(module)
elif file.suffix == ".sql":
log_with_curr_timestamp(f"Registering SQL file {file}...")
with file.open("r") as f:
sql = f.read()
- file_path_relative_to_spec = file.relative_to(spec_path.parent)
+ file_path_relative_to_spec = file.relative_to(path)
registry.register_sql(sql, file_path_relative_to_spec)
else:
raise PySparkException(
@@ -324,7 +334,7 @@ def run(
log_with_curr_timestamp("Registering graph elements...")
registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id)
- register_definitions(spec_path, registry, spec)
+ register_definitions(spec_path, registry, spec, spark, dataflow_graph_id)
log_with_curr_timestamp("Starting run...")
result_iter = start_run(
@@ -347,8 +357,9 @@ def parse_table_list(value: str) -> List[str]:
return [table.strip() for table in value.split(",") if table.strip()]
-if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Pipeline CLI")
+def main() -> None:
+ """The entry point of spark-pipelines CLI."""
+ parser = argparse.ArgumentParser(description="Pipelines CLI")
subparsers = parser.add_subparsers(dest="command", required=True)
# "run" subcommand
@@ -366,7 +377,9 @@ def parse_table_list(value: str) -> List[str]:
default=[],
)
run_parser.add_argument(
- "--full-refresh-all", action="store_true", help="Perform a full graph reset and recompute."
+ "--full-refresh-all",
+ action="store_true",
+ help="Perform a full graph reset and recompute.",
)
run_parser.add_argument(
"--refresh",
@@ -386,7 +399,7 @@ def parse_table_list(value: str) -> List[str]:
# "init" subcommand
init_parser = subparsers.add_parser(
"init",
- help="Generates a simple pipeline project, including a spec file and example definitions.",
+ help="Generate a sample pipeline project, with a spec file and example transformations.",
)
init_parser.add_argument(
"--name",
@@ -415,7 +428,7 @@ def parse_table_list(value: str) -> List[str]:
full_refresh=args.full_refresh,
full_refresh_all=args.full_refresh_all,
refresh=args.refresh,
- dry=args.command == "dry-run",
+ dry=False,
)
else:
assert args.command == "dry-run"
@@ -428,3 +441,7 @@ def parse_table_list(value: str) -> List[str]:
)
elif args.command == "init":
init(args.name)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/python/pyspark/pipelines/init_cli.py b/python/pyspark/pipelines/init_cli.py
index ffe5d3c12b636..a1dbdfd9d5586 100644
--- a/python/pyspark/pipelines/init_cli.py
+++ b/python/pyspark/pipelines/init_cli.py
@@ -19,7 +19,7 @@
SPEC = """
name: {{ name }}
-storage: storage-root
+storage: {{ storage_root }}
libraries:
- glob:
include: transformations/**
@@ -44,12 +44,25 @@ def example_python_materialized_view() -> DataFrame:
def init(name: str) -> None:
"""Generates a simple pipeline project."""
project_dir = Path.cwd() / name
+ if project_dir.exists():
+ raise FileExistsError(
+ f"Directory '{name}' already exists. "
+ "Please choose a different name or remove the existing directory."
+ )
project_dir.mkdir(parents=True, exist_ok=False)
+ # Create the storage directory
+ storage_dir = project_dir / "pipeline-storage"
+ storage_dir.mkdir(parents=True)
+
+ # Create absolute file URI for storage path
+ storage_path = f"file://{storage_dir.resolve()}"
+
# Write the spec file to the project directory
- spec_file = project_dir / "pipeline.yml"
+ spec_file = project_dir / "spark-pipeline.yml"
with open(spec_file, "w") as f:
- f.write(SPEC.replace("{{ name }}", name))
+ spec_content = SPEC.replace("{{ name }}", name).replace("{{ storage_root }}", storage_path)
+ f.write(spec_content)
# Create the transformations directory
transformations_dir = project_dir / "transformations"
diff --git a/python/pyspark/pipelines/output.py b/python/pyspark/pipelines/output.py
index 84e950f161742..92058e68721f4 100644
--- a/python/pyspark/pipelines/output.py
+++ b/python/pyspark/pipelines/output.py
@@ -45,6 +45,7 @@ class Table(Output):
:param table_properties: A dict where the keys are the property names and the values are the
property values. These properties will be set on the table.
:param partition_cols: A list containing the column names of the partition columns.
+ :param cluster_by: A list containing the column names of the cluster columns.
:param schema Explicit Spark SQL schema to materialize this table with. Supports either a
Pyspark StructType or a SQL DDL string, such as "a INT, b STRING".
:param format: The format of the table, e.g. "parquet".
@@ -52,6 +53,7 @@ class Table(Output):
table_properties: Mapping[str, str]
partition_cols: Optional[Sequence[str]]
+ cluster_by: Optional[Sequence[str]]
schema: Optional[Union[StructType, str]]
format: Optional[str]
diff --git a/python/pyspark/pipelines/spark_connect_graph_element_registry.py b/python/pyspark/pipelines/spark_connect_graph_element_registry.py
index 5c5ef9fc30401..b8d297fced3fb 100644
--- a/python/pyspark/pipelines/spark_connect_graph_element_registry.py
+++ b/python/pyspark/pipelines/spark_connect_graph_element_registry.py
@@ -35,6 +35,7 @@
from pyspark.sql.types import StructType
from typing import Any, cast
import pyspark.sql.connect.proto as pb2
+from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context
class SparkConnectGraphElementRegistry(GraphElementRegistry):
@@ -43,6 +44,7 @@ class SparkConnectGraphElementRegistry(GraphElementRegistry):
def __init__(self, spark: SparkSession, dataflow_graph_id: str) -> None:
# Cast because mypy seems to think `spark`` is a function, not an object. Likely related to
# SPARK-47544.
+ self._spark = spark
self._client = cast(Any, spark).client
self._dataflow_graph_id = dataflow_graph_id
@@ -63,6 +65,7 @@ def register_output(self, output: Output) -> None:
table_details = pb2.PipelineCommand.DefineOutput.TableDetails(
table_properties=output.table_properties,
partition_cols=output.partition_cols,
+ clustering_columns=output.cluster_by,
format=output.format,
# Even though schema_string is not required, the generated Python code seems to
# erroneously think it is required.
@@ -109,8 +112,11 @@ def register_output(self, output: Output) -> None:
self._client.execute_command(command)
def register_flow(self, flow: Flow) -> None:
- with block_spark_connect_execution_and_analysis():
- df = flow.func()
+ with add_pipeline_analysis_context(
+ spark=self._spark, dataflow_graph_id=self._dataflow_graph_id, flow_name=flow.name
+ ):
+ with block_spark_connect_execution_and_analysis():
+ df = flow.func()
relation = cast(ConnectDataFrame, df)._plan.plan(self._client)
relation_flow_details = pb2.PipelineCommand.DefineFlow.WriteRelationFlowDetails(
diff --git a/python/pyspark/pipelines/spark_connect_pipeline.py b/python/pyspark/pipelines/spark_connect_pipeline.py
index e3c1184cea39d..2fd11d7cf6692 100644
--- a/python/pyspark/pipelines/spark_connect_pipeline.py
+++ b/python/pyspark/pipelines/spark_connect_pipeline.py
@@ -29,7 +29,7 @@ def create_dataflow_graph(
default_database: Optional[str],
sql_conf: Optional[Mapping[str, str]],
) -> str:
- """Create a dataflow graph in in the Spark Connect server.
+ """Create a dataflow graph in the Spark Connect server.
:returns: The ID of the created dataflow graph.
"""
@@ -57,7 +57,7 @@ def handle_pipeline_events(iter: Iterator[Dict[str, Any]]) -> None:
continue
elif "pipeline_event_result" not in result.keys():
raise PySparkValueError(
- "Pipeline logs stream handler received an unexpected result: " f"{result}"
+ f"Pipeline logs stream handler received an unexpected result: {result}"
)
else:
event = result["pipeline_event_result"].event
@@ -76,6 +76,7 @@ def start_run(
) -> Iterator[Dict[str, Any]]:
"""Start a run of the dataflow graph in the Spark Connect server.
+ :param spark: SparkSession.
:param dataflow_graph_id: The ID of the dataflow graph to start.
:param full_refresh: List of datasets to reset and recompute.
:param full_refresh_all: Perform a full graph reset and recompute.
diff --git a/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py b/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py
new file mode 100644
index 0000000000000..57c5da22d4601
--- /dev/null
+++ b/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py
@@ -0,0 +1,100 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.testing.connectutils import (
+ ReusedConnectTestCase,
+ should_test_connect,
+ connect_requirement_message,
+)
+
+if should_test_connect:
+ from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context
+
+
+@unittest.skipIf(not should_test_connect, connect_requirement_message)
+class AddPipelineAnalysisContextTests(ReusedConnectTestCase):
+ def test_add_pipeline_analysis_context_with_flow_name(self):
+ with add_pipeline_analysis_context(self.spark, "test_dataflow_graph_id", "test_flow_name"):
+ import pyspark.sql.connect.proto as pb2
+
+ thread_local_extensions = self.spark.client.thread_local.user_context_extensions
+ self.assertEqual(len(thread_local_extensions), 1)
+ # Extension is stored as (id, extension), unpack the extension
+ _extension_id, extension = thread_local_extensions[0]
+ context = pb2.PipelineAnalysisContext()
+ extension.Unpack(context)
+ self.assertEqual(context.dataflow_graph_id, "test_dataflow_graph_id")
+ self.assertEqual(context.flow_name, "test_flow_name")
+ thread_local_extensions_after = self.spark.client.thread_local.user_context_extensions
+ self.assertEqual(len(thread_local_extensions_after), 0)
+
+ def test_add_pipeline_analysis_context_without_flow_name(self):
+ with add_pipeline_analysis_context(self.spark, "test_dataflow_graph_id", None):
+ import pyspark.sql.connect.proto as pb2
+
+ thread_local_extensions = self.spark.client.thread_local.user_context_extensions
+ self.assertEqual(len(thread_local_extensions), 1)
+ # Extension is stored as (id, extension), unpack the extension
+ _extension_id, extension = thread_local_extensions[0]
+ context = pb2.PipelineAnalysisContext()
+ extension.Unpack(context)
+ self.assertEqual(context.dataflow_graph_id, "test_dataflow_graph_id")
+ # Empty string means no flow name
+ self.assertEqual(context.flow_name, "")
+ thread_local_extensions_after = self.spark.client.thread_local.user_context_extensions
+ self.assertEqual(len(thread_local_extensions_after), 0)
+
+ def test_nested_add_pipeline_analysis_context(self):
+ import pyspark.sql.connect.proto as pb2
+
+ with add_pipeline_analysis_context(self.spark, "test_dataflow_graph_id_1", flow_name=None):
+ with add_pipeline_analysis_context(
+ self.spark, "test_dataflow_graph_id_2", flow_name="test_flow_name"
+ ):
+ thread_local_extensions = self.spark.client.thread_local.user_context_extensions
+ self.assertEqual(len(thread_local_extensions), 2)
+ # Extension is stored as (id, extension), unpack the extensions
+ _, extension_1 = thread_local_extensions[0]
+ context_1 = pb2.PipelineAnalysisContext()
+ extension_1.Unpack(context_1)
+ self.assertEqual(context_1.dataflow_graph_id, "test_dataflow_graph_id_1")
+ self.assertEqual(context_1.flow_name, "")
+ _, extension_2 = thread_local_extensions[1]
+ context_2 = pb2.PipelineAnalysisContext()
+ extension_2.Unpack(context_2)
+ self.assertEqual(context_2.dataflow_graph_id, "test_dataflow_graph_id_2")
+ self.assertEqual(context_2.flow_name, "test_flow_name")
+ thread_local_extensions_after_1 = self.spark.client.thread_local.user_context_extensions
+ self.assertEqual(len(thread_local_extensions_after_1), 1)
+ _, extension_3 = thread_local_extensions_after_1[0]
+ context_3 = pb2.PipelineAnalysisContext()
+ extension_3.Unpack(context_3)
+ self.assertEqual(context_3.dataflow_graph_id, "test_dataflow_graph_id_1")
+ self.assertEqual(context_3.flow_name, "")
+ thread_local_extensions_after_2 = self.spark.client.thread_local.user_context_extensions
+ self.assertEqual(len(thread_local_extensions_after_2), 0)
+
+
+if __name__ == "__main__":
+ try:
+ import xmlrunner # type: ignore
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/pipelines/tests/test_cli.py b/python/pyspark/pipelines/tests/test_cli.py
index e8445e63d439d..f810ab099b7be 100644
--- a/python/pyspark/pipelines/tests/test_cli.py
+++ b/python/pyspark/pipelines/tests/test_cli.py
@@ -22,6 +22,7 @@
from pyspark.errors import PySparkException
from pyspark.testing.connectutils import (
+ ReusedConnectTestCase,
should_test_connect,
connect_requirement_message,
)
@@ -45,7 +46,7 @@
not should_test_connect or not have_yaml,
connect_requirement_message or yaml_requirement_message,
)
-class CLIUtilityTests(unittest.TestCase):
+class CLIUtilityTests(ReusedConnectTestCase):
def test_load_pipeline_spec(self):
with tempfile.NamedTemporaryFile(mode="w") as tmpfile:
tmpfile.write(
@@ -190,7 +191,7 @@ def test_unpack_pipeline_spec_bad_configuration(self):
def test_find_pipeline_spec_in_current_directory(self):
with tempfile.TemporaryDirectory() as temp_dir:
- spec_path = Path(temp_dir) / "pipeline.yaml"
+ spec_path = Path(temp_dir) / "spark-pipeline.yaml"
with spec_path.open("w") as f:
f.write(
"""
@@ -207,7 +208,7 @@ def test_find_pipeline_spec_in_current_directory(self):
def test_find_pipeline_spec_in_current_directory_yml(self):
with tempfile.TemporaryDirectory() as temp_dir:
- spec_path = Path(temp_dir) / "pipeline.yml"
+ spec_path = Path(temp_dir) / "spark-pipeline.yml"
with spec_path.open("w") as f:
f.write(
"""
@@ -224,10 +225,10 @@ def test_find_pipeline_spec_in_current_directory_yml(self):
def test_find_pipeline_spec_in_current_directory_yml_and_yaml(self):
with tempfile.TemporaryDirectory() as temp_dir:
- with (Path(temp_dir) / "pipeline.yml").open("w") as f:
+ with (Path(temp_dir) / "spark-pipeline.yml").open("w") as f:
f.write("")
- with (Path(temp_dir) / "pipeline.yaml").open("w") as f:
+ with (Path(temp_dir) / "spark-pipeline.yaml").open("w") as f:
f.write("")
with self.assertRaises(PySparkException) as context:
@@ -240,7 +241,7 @@ def test_find_pipeline_spec_in_parent_directory(self):
parent_dir = Path(temp_dir)
child_dir = Path(temp_dir) / "child"
child_dir.mkdir()
- spec_path = parent_dir / "pipeline.yaml"
+ spec_path = parent_dir / "spark-pipeline.yaml"
with spec_path.open("w") as f:
f.write(
"""
@@ -294,7 +295,9 @@ def mv2():
)
registry = LocalGraphElementRegistry()
- register_definitions(outer_dir / "pipeline.yaml", registry, spec)
+ register_definitions(
+ outer_dir / "spark-pipeline.yaml", registry, spec, self.spark, "test_graph_id"
+ )
self.assertEqual(len(registry.outputs), 1)
self.assertEqual(registry.outputs[0].name, "mv1")
@@ -315,7 +318,9 @@ def test_register_definitions_file_raises_error(self):
registry = LocalGraphElementRegistry()
with self.assertRaises(RuntimeError) as context:
- register_definitions(outer_dir / "pipeline.yml", registry, spec)
+ register_definitions(
+ outer_dir / "spark-pipeline.yml", registry, spec, self.spark, "test_graph_id"
+ )
self.assertIn("This is a test exception", str(context.exception))
def test_register_definitions_unsupported_file_extension_matches_glob(self):
@@ -334,7 +339,7 @@ def test_register_definitions_unsupported_file_extension_matches_glob(self):
registry = LocalGraphElementRegistry()
with self.assertRaises(PySparkException) as context:
- register_definitions(outer_dir, registry, spec)
+ register_definitions(outer_dir, registry, spec, self.spark, "test_graph_id")
self.assertEqual(
context.exception.getCondition(), "PIPELINE_UNSUPPORTED_DEFINITIONS_FILE_EXTENSION"
)
@@ -372,7 +377,7 @@ def test_python_import_current_directory(self):
registry = LocalGraphElementRegistry()
with change_dir(inner_dir2):
register_definitions(
- inner_dir1 / "pipeline.yaml",
+ inner_dir1 / "spark-pipeline.yaml",
registry,
PipelineSpec(
name="test_pipeline",
@@ -382,12 +387,14 @@ def test_python_import_current_directory(self):
configuration={},
libraries=[LibrariesGlob(include="defs.py")],
),
+ self.spark,
+ "test_graph_id",
)
def test_full_refresh_all_conflicts_with_full_refresh(self):
with tempfile.TemporaryDirectory() as temp_dir:
# Create a minimal pipeline spec
- spec_path = Path(temp_dir) / "pipeline.yaml"
+ spec_path = Path(temp_dir) / "spark-pipeline.yaml"
with spec_path.open("w") as f:
f.write('{"name": "test_pipeline"}')
@@ -411,7 +418,7 @@ def test_full_refresh_all_conflicts_with_full_refresh(self):
def test_full_refresh_all_conflicts_with_refresh(self):
with tempfile.TemporaryDirectory() as temp_dir:
# Create a minimal pipeline spec
- spec_path = Path(temp_dir) / "pipeline.yaml"
+ spec_path = Path(temp_dir) / "spark-pipeline.yaml"
with spec_path.open("w") as f:
f.write('{"name": "test_pipeline"}')
@@ -436,7 +443,7 @@ def test_full_refresh_all_conflicts_with_refresh(self):
def test_full_refresh_all_conflicts_with_both(self):
with tempfile.TemporaryDirectory() as temp_dir:
# Create a minimal pipeline spec
- spec_path = Path(temp_dir) / "pipeline.yaml"
+ spec_path = Path(temp_dir) / "spark-pipeline.yaml"
with spec_path.open("w") as f:
f.write('{"name": "test_pipeline"}')
diff --git a/python/pyspark/pipelines/tests/test_init_cli.py b/python/pyspark/pipelines/tests/test_init_cli.py
index 49c949200821a..f88956b647acc 100644
--- a/python/pyspark/pipelines/tests/test_init_cli.py
+++ b/python/pyspark/pipelines/tests/test_init_cli.py
@@ -51,8 +51,16 @@ def test_init(self):
spec_path = find_pipeline_spec(Path.cwd())
spec = load_pipeline_spec(spec_path)
assert spec.name == project_name
+
+ # Verify that the storage path is an absolute URI with file scheme
+ expected_storage_path = f"file://{Path.cwd() / 'pipeline-storage'}"
+ self.assertEqual(spec.storage, expected_storage_path)
+
+ # Verify that the storage directory was created
+ self.assertTrue((Path.cwd() / "pipeline-storage").exists())
+
registry = LocalGraphElementRegistry()
- register_definitions(spec_path, registry, spec)
+ register_definitions(spec_path, registry, spec, self.spark, "test_graph_id")
self.assertEqual(len(registry.outputs), 1)
self.assertEqual(registry.outputs[0].name, "example_python_materialized_view")
self.assertEqual(len(registry.flows), 1)
@@ -64,6 +72,21 @@ def test_init(self):
Path("transformations") / "example_sql_materialized_view.sql",
)
+ def test_init_existing_directory(self):
+ with tempfile.TemporaryDirectory() as temp_dir:
+ project_name = "test_project"
+ with change_dir(Path(temp_dir)):
+ init(project_name)
+
+ with self.assertRaises(FileExistsError) as context:
+ init(project_name)
+
+ expected_message = (
+ f"Directory '{project_name}' already exists. "
+ "Please choose a different name or remove the existing directory."
+ )
+ self.assertEqual(str(context.exception), expected_message)
+
if __name__ == "__main__":
try:
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index a0a6e8ef70c8d..eeeeddd00e3af 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -39,7 +39,7 @@
- :class:`pyspark.sql.Window`
For working with window functions.
"""
-from pyspark.sql.types import Row, VariantVal
+from pyspark.sql.types import Geography, Geometry, Row, VariantVal
from pyspark.sql.context import SQLContext, HiveContext, UDFRegistration, UDTFRegistration
from pyspark.sql.session import SparkSession
from pyspark.sql.column import Column
@@ -69,6 +69,8 @@
"DataFrameNaFunctions",
"DataFrameStatFunctions",
"VariantVal",
+ "Geography",
+ "Geometry",
"Window",
"WindowSpec",
"DataFrameReader",
diff --git a/python/pyspark/sql/connect/client/artifact.py b/python/pyspark/sql/connect/client/artifact.py
index 72a6ffa8bf68b..a37642186fdaf 100644
--- a/python/pyspark/sql/connect/client/artifact.py
+++ b/python/pyspark/sql/connect/client/artifact.py
@@ -29,7 +29,7 @@
from itertools import chain
from typing import List, Iterable, BinaryIO, Iterator, Optional, Tuple
import abc
-from pathlib import Path
+from pathlib import Path, PureWindowsPath
from urllib.parse import urlparse
from urllib.request import url2pathname
from functools import cached_property
@@ -184,7 +184,17 @@ def __init__(
def _parse_artifacts(
self, path_or_uri: str, pyfile: bool, archive: bool, file: bool
) -> List[Artifact]:
- # Currently only local files with .jar extension is supported.
+ # Handle Windows absolute paths (e.g., C:\path\to\file) which urlparse
+ # incorrectly interprets as having URI scheme 'C' instead of being a local path.
+ # First check if path_or_uri is a Windows path, if so, convert it to file:// URI.
+ try:
+ win_path = PureWindowsPath(path_or_uri)
+ if win_path.is_absolute() and win_path.drive:
+ # Convert Windows path to file:// URI so urlparse handles it correctly
+ path_or_uri = Path(path_or_uri).resolve().as_uri()
+ except Exception:
+ pass
+
parsed = urlparse(path_or_uri)
# Check if it is a file from the scheme
if parsed.scheme == "":
diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py
index 414781d67cd45..80d83c69c45c8 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -34,6 +34,7 @@
import urllib.parse
import uuid
import sys
+import time
from typing import (
Iterable,
Iterator,
@@ -113,6 +114,19 @@
from pyspark.sql.datasource import DataSource
+def _import_zstandard_if_available() -> Optional[Any]:
+ """
+ Import zstandard if available, otherwise return None.
+ This is used to handle the case when zstandard is not installed.
+ """
+ try:
+ import zstandard
+
+ return zstandard
+ except ImportError:
+ return None
+
+
class ChannelBuilder:
"""
This is a helper class that is used to create a GRPC channel based on the given
@@ -487,6 +501,20 @@ def pairs(self) -> dict[str, Any]:
def keys(self) -> List[str]:
return self._keys
+ def to_dict(self) -> dict[str, Any]:
+ """Return a JSON-serializable dictionary representation of this observed metrics.
+
+ Returns
+ -------
+ dict
+ A dictionary with keys 'name', 'keys', and 'pairs'.
+ """
+ return {
+ "name": self._name,
+ "keys": self._keys,
+ "pairs": self.pairs,
+ }
+
class AnalyzeResult:
def __init__(
@@ -706,9 +734,16 @@ def __init__(
self._progress_handlers: List[ProgressHandler] = []
+ self._zstd_module = _import_zstandard_if_available()
+ self._plan_compression_threshold: Optional[int] = None # Will be fetched lazily
+ self._plan_compression_algorithm: Optional[str] = None # Will be fetched lazily
+
# cleanup ml cache if possible
atexit.register(self._cleanup_ml_cache)
+ self.global_user_context_extensions: List[Tuple[str, any_pb2.Any]] = []
+ self.global_user_context_extensions_lock = threading.Lock()
+
@property
def _stub(self) -> grpc_lib.SparkConnectServiceStub:
if self.is_closed:
@@ -1156,7 +1191,7 @@ def execute_command(
req = self._execute_plan_request_with_metadata()
if self._user_id:
req.user_context.user_id = self._user_id
- req.plan.command.CopyFrom(command)
+ self._set_command_in_plan(req.plan, command)
data, _, metrics, observed_metrics, properties = self._execute_and_fetch(
req, observations or {}
)
@@ -1182,7 +1217,7 @@ def execute_command_as_iterator(
req = self._execute_plan_request_with_metadata()
if self._user_id:
req.user_context.user_id = self._user_id
- req.plan.command.CopyFrom(command)
+ self._set_command_in_plan(req.plan, command)
for response in self._execute_and_fetch_as_iterator(req, observations or {}):
if isinstance(response, dict):
yield response
@@ -1240,6 +1275,24 @@ def token(self) -> Optional[str]:
"""
return self._builder.token
+ def _update_request_with_user_context_extensions(
+ self,
+ req: Union[
+ pb2.AnalyzePlanRequest,
+ pb2.ConfigRequest,
+ pb2.ExecutePlanRequest,
+ pb2.FetchErrorDetailsRequest,
+ pb2.InterruptRequest,
+ ],
+ ) -> None:
+ with self.global_user_context_extensions_lock:
+ for _, extension in self.global_user_context_extensions:
+ req.user_context.extensions.append(extension)
+ if not hasattr(self.thread_local, "user_context_extensions"):
+ return
+ for _, extension in self.thread_local.user_context_extensions:
+ req.user_context.extensions.append(extension)
+
def _execute_plan_request_with_metadata(
self, operation_id: Optional[str] = None
) -> pb2.ExecutePlanRequest:
@@ -1270,6 +1323,7 @@ def _execute_plan_request_with_metadata(
messageParameters={"arg_name": "operation_id", "origin": str(ve)},
)
req.operation_id = operation_id
+ self._update_request_with_user_context_extensions(req)
return req
def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
@@ -1280,6 +1334,7 @@ def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
req.client_type = self._builder.userAgent
if self._user_id:
req.user_context.user_id = self._user_id
+ self._update_request_with_user_context_extensions(req)
return req
def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult:
@@ -1694,6 +1749,7 @@ def _config_request_with_metadata(self) -> pb2.ConfigRequest:
req.client_type = self._builder.userAgent
if self._user_id:
req.user_context.user_id = self._user_id
+ self._update_request_with_user_context_extensions(req)
return req
def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]:
@@ -1770,6 +1826,7 @@ def _interrupt_request(
)
if self._user_id:
req.user_context.user_id = self._user_id
+ self._update_request_with_user_context_extensions(req)
return req
def interrupt_all(self) -> Optional[List[str]]:
@@ -1868,6 +1925,38 @@ def _throw_if_invalid_tag(self, tag: str) -> None:
messageParameters={"arg_name": "Spark Connect tag", "arg_value": tag},
)
+ def add_threadlocal_user_context_extension(self, extension: any_pb2.Any) -> str:
+ if not hasattr(self.thread_local, "user_context_extensions"):
+ self.thread_local.user_context_extensions = list()
+ extension_id = "threadlocal_" + str(uuid.uuid4())
+ self.thread_local.user_context_extensions.append((extension_id, extension))
+ return extension_id
+
+ def add_global_user_context_extension(self, extension: any_pb2.Any) -> str:
+ extension_id = "global_" + str(uuid.uuid4())
+ with self.global_user_context_extensions_lock:
+ self.global_user_context_extensions.append((extension_id, extension))
+ return extension_id
+
+ def remove_user_context_extension(self, extension_id: str) -> None:
+ if extension_id.find("threadlocal_") == 0:
+ if not hasattr(self.thread_local, "user_context_extensions"):
+ return
+ self.thread_local.user_context_extensions = list(
+ filter(lambda ex: ex[0] != extension_id, self.thread_local.user_context_extensions)
+ )
+ elif extension_id.find("global_") == 0:
+ with self.global_user_context_extensions_lock:
+ self.global_user_context_extensions = list(
+ filter(lambda ex: ex[0] != extension_id, self.global_user_context_extensions)
+ )
+
+ def clear_user_context_extensions(self) -> None:
+ if hasattr(self.thread_local, "user_context_extensions"):
+ self.thread_local.user_context_extensions = list()
+ with self.global_user_context_extensions_lock:
+ self.global_user_context_extensions = list()
+
def _handle_error(self, error: Exception) -> NoReturn:
"""
Handle errors that occur during RPC calls.
@@ -1908,7 +1997,7 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet
req.client_observed_server_side_session_id = self._server_session_id
if self._user_id:
req.user_context.user_id = self._user_id
-
+ self._update_request_with_user_context_extensions(req)
try:
return self._stub.FetchErrorDetails(req, metadata=self._builder.metadata())
except grpc.RpcError:
@@ -1963,6 +2052,17 @@ def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
if info.metadata.get("errorClass") == "INVALID_HANDLE.SESSION_CHANGED":
self._closed = True
+ if info.metadata.get("errorClass") == "CONNECT_INVALID_PLAN.CANNOT_PARSE":
+ # Disable plan compression if the server fails to interpret the plan.
+ logger.info(
+ "Disabling plan compression for the session due to "
+ "CONNECT_INVALID_PLAN.CANNOT_PARSE error."
+ )
+ self._plan_compression_threshold, self._plan_compression_algorithm = (
+ -1,
+ "NONE",
+ )
+
raise convert_exception(
info,
status.message,
@@ -2112,6 +2212,104 @@ def _query_model_size(self, model_ref_id: str) -> int:
ml_command_result = properties["ml_command_result"]
return ml_command_result.param.long
+ def _set_relation_in_plan(self, plan: pb2.Plan, relation: pb2.Relation) -> None:
+ """Sets the relation in the plan, attempting compression if configured."""
+ self._try_compress_and_set_plan(
+ plan=plan,
+ message=relation,
+ op_type=pb2.Plan.CompressedOperation.OpType.OP_TYPE_RELATION,
+ )
+
+ def _set_command_in_plan(self, plan: pb2.Plan, command: pb2.Command) -> None:
+ """Sets the command in the plan, attempting compression if configured."""
+ self._try_compress_and_set_plan(
+ plan=plan,
+ message=command,
+ op_type=pb2.Plan.CompressedOperation.OpType.OP_TYPE_COMMAND,
+ )
+
+ def _try_compress_and_set_plan(
+ self,
+ plan: pb2.Plan,
+ message: google.protobuf.message.Message,
+ op_type: pb2.Plan.CompressedOperation.OpType.ValueType,
+ ) -> None:
+ """
+ Tries to compress a protobuf message and sets it on the plan.
+ If compression is not enabled, not effective, or not available,
+ it falls back to the original message.
+ """
+ (
+ plan_compression_threshold,
+ plan_compression_algorithm,
+ ) = self._get_plan_compression_threshold_and_algorithm()
+ plan_compression_enabled = (
+ plan_compression_threshold is not None
+ and plan_compression_threshold >= 0
+ and plan_compression_algorithm is not None
+ and plan_compression_algorithm != "NONE"
+ )
+ if plan_compression_enabled:
+ serialized_msg = message.SerializeToString()
+ original_size = len(serialized_msg)
+ if (
+ original_size > plan_compression_threshold
+ and plan_compression_algorithm == "ZSTD"
+ and self._zstd_module
+ ):
+ start_time = time.time()
+ compressed_operation = pb2.Plan.CompressedOperation(
+ data=self._zstd_module.compress(serialized_msg),
+ op_type=op_type,
+ compression_codec=pb2.CompressionCodec.COMPRESSION_CODEC_ZSTD,
+ )
+ duration = time.time() - start_time
+ compressed_size = len(compressed_operation.data)
+ logger.debug(
+ f"Plan compression: original_size={original_size}, "
+ f"compressed_size={compressed_size}, "
+ f"saving_ratio={1 - compressed_size / original_size:.2f}, "
+ f"duration_s={duration:.1f}"
+ )
+ if compressed_size < original_size:
+ plan.compressed_operation.CopyFrom(compressed_operation)
+ return
+ else:
+ logger.debug("Plan compression not effective. Using original plan.")
+
+ if op_type == pb2.Plan.CompressedOperation.OpType.OP_TYPE_RELATION:
+ plan.root.CopyFrom(message) # type: ignore[arg-type]
+ else:
+ plan.command.CopyFrom(message) # type: ignore[arg-type]
+
+ def _get_plan_compression_threshold_and_algorithm(self) -> Tuple[int, str]:
+ if self._plan_compression_threshold is None or self._plan_compression_algorithm is None:
+ try:
+ (
+ plan_compression_threshold_str,
+ self._plan_compression_algorithm,
+ ) = self.get_configs(
+ "spark.connect.session.planCompression.threshold",
+ "spark.connect.session.planCompression.defaultAlgorithm",
+ )
+ self._plan_compression_threshold = (
+ int(plan_compression_threshold_str) if plan_compression_threshold_str else -1
+ )
+ logger.debug(
+ f"Plan compression threshold: {self._plan_compression_threshold}, "
+ f"algorithm: {self._plan_compression_algorithm}"
+ )
+ except Exception as e:
+ self._plan_compression_threshold = -1
+ self._plan_compression_algorithm = "NONE"
+ logger.debug(
+ "Plan compression is disabled because the server does not support it.", e
+ )
+ return (
+ self._plan_compression_threshold,
+ self._plan_compression_algorithm,
+ ) # type: ignore[return-value]
+
def clone(self, new_session_id: Optional[str] = None) -> "SparkConnectClient":
"""
Clone this client session on the server side. The server-side session is cloned with
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index 71a499afd2ff7..6a448025932b6 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -587,7 +587,7 @@ def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
if isinstance(c, Column):
_cols.append(c)
elif isinstance(c, str):
- _cols.append(self[c])
+ _cols.append(F.col(c))
elif isinstance(c, int) and not isinstance(c, bool):
if c < 1:
raise PySparkIndexError(
@@ -619,7 +619,7 @@ def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": # type: igno
if isinstance(c, Column):
_cols.append(c)
elif isinstance(c, str):
- _cols.append(self[c])
+ _cols.append(F.col(c))
elif isinstance(c, int) and not isinstance(c, bool):
if c < 1:
raise PySparkIndexError(
@@ -649,7 +649,7 @@ def cube(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc]
if isinstance(c, Column):
_cols.append(c)
elif isinstance(c, str):
- _cols.append(self[c])
+ _cols.append(F.col(c))
elif isinstance(c, int) and not isinstance(c, bool):
if c < 1:
raise PySparkIndexError(
@@ -675,7 +675,7 @@ def groupingSets(
if isinstance(c, Column):
gset.append(c)
elif isinstance(c, str):
- gset.append(self[c])
+ gset.append(F.col(c))
else:
raise PySparkTypeError(
errorClass="NOT_COLUMN_OR_STR",
@@ -691,7 +691,7 @@ def groupingSets(
if isinstance(c, Column):
gcols.append(c)
elif isinstance(c, str):
- gcols.append(self[c])
+ gcols.append(F.col(c))
else:
raise PySparkTypeError(
errorClass="NOT_COLUMN_OR_STR",
diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py
index 1198596fbb5db..a2db7e172b5da 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -4529,6 +4529,210 @@ def theta_intersection_agg(
theta_intersection_agg.__doc__ = pysparkfuncs.theta_intersection_agg.__doc__
+def kll_sketch_agg_bigint(
+ col: "ColumnOrName",
+ k: Optional[Union[int, Column]] = None,
+) -> Column:
+ fn = "kll_sketch_agg_bigint"
+ if k is None:
+ return _invoke_function_over_columns(fn, col)
+ else:
+ return _invoke_function_over_columns(fn, col, lit(k))
+
+
+kll_sketch_agg_bigint.__doc__ = pysparkfuncs.kll_sketch_agg_bigint.__doc__
+
+
+def kll_sketch_agg_float(
+ col: "ColumnOrName",
+ k: Optional[Union[int, Column]] = None,
+) -> Column:
+ fn = "kll_sketch_agg_float"
+ if k is None:
+ return _invoke_function_over_columns(fn, col)
+ else:
+ return _invoke_function_over_columns(fn, col, lit(k))
+
+
+kll_sketch_agg_float.__doc__ = pysparkfuncs.kll_sketch_agg_float.__doc__
+
+
+def kll_sketch_agg_double(
+ col: "ColumnOrName",
+ k: Optional[Union[int, Column]] = None,
+) -> Column:
+ fn = "kll_sketch_agg_double"
+ if k is None:
+ return _invoke_function_over_columns(fn, col)
+ else:
+ return _invoke_function_over_columns(fn, col, lit(k))
+
+
+kll_sketch_agg_double.__doc__ = pysparkfuncs.kll_sketch_agg_double.__doc__
+
+
+def kll_merge_agg_bigint(
+ col: "ColumnOrName",
+ k: Optional[Union[int, Column]] = None,
+) -> Column:
+ fn = "kll_merge_agg_bigint"
+ if k is None:
+ return _invoke_function_over_columns(fn, col)
+ else:
+ return _invoke_function_over_columns(fn, col, lit(k))
+
+
+kll_merge_agg_bigint.__doc__ = pysparkfuncs.kll_merge_agg_bigint.__doc__
+
+
+def kll_merge_agg_float(
+ col: "ColumnOrName",
+ k: Optional[Union[int, Column]] = None,
+) -> Column:
+ fn = "kll_merge_agg_float"
+ if k is None:
+ return _invoke_function_over_columns(fn, col)
+ else:
+ return _invoke_function_over_columns(fn, col, lit(k))
+
+
+kll_merge_agg_float.__doc__ = pysparkfuncs.kll_merge_agg_float.__doc__
+
+
+def kll_merge_agg_double(
+ col: "ColumnOrName",
+ k: Optional[Union[int, Column]] = None,
+) -> Column:
+ fn = "kll_merge_agg_double"
+ if k is None:
+ return _invoke_function_over_columns(fn, col)
+ else:
+ return _invoke_function_over_columns(fn, col, lit(k))
+
+
+kll_merge_agg_double.__doc__ = pysparkfuncs.kll_merge_agg_double.__doc__
+
+
+def kll_sketch_to_string_bigint(col: "ColumnOrName") -> Column:
+ fn = "kll_sketch_to_string_bigint"
+ return _invoke_function_over_columns(fn, col)
+
+
+kll_sketch_to_string_bigint.__doc__ = pysparkfuncs.kll_sketch_to_string_bigint.__doc__
+
+
+def kll_sketch_to_string_float(col: "ColumnOrName") -> Column:
+ fn = "kll_sketch_to_string_float"
+ return _invoke_function_over_columns(fn, col)
+
+
+kll_sketch_to_string_float.__doc__ = pysparkfuncs.kll_sketch_to_string_float.__doc__
+
+
+def kll_sketch_to_string_double(col: "ColumnOrName") -> Column:
+ fn = "kll_sketch_to_string_double"
+ return _invoke_function_over_columns(fn, col)
+
+
+kll_sketch_to_string_double.__doc__ = pysparkfuncs.kll_sketch_to_string_double.__doc__
+
+
+def kll_sketch_get_n_bigint(col: "ColumnOrName") -> Column:
+ fn = "kll_sketch_get_n_bigint"
+ return _invoke_function_over_columns(fn, col)
+
+
+kll_sketch_get_n_bigint.__doc__ = pysparkfuncs.kll_sketch_get_n_bigint.__doc__
+
+
+def kll_sketch_get_n_float(col: "ColumnOrName") -> Column:
+ fn = "kll_sketch_get_n_float"
+ return _invoke_function_over_columns(fn, col)
+
+
+kll_sketch_get_n_float.__doc__ = pysparkfuncs.kll_sketch_get_n_float.__doc__
+
+
+def kll_sketch_get_n_double(col: "ColumnOrName") -> Column:
+ fn = "kll_sketch_get_n_double"
+ return _invoke_function_over_columns(fn, col)
+
+
+kll_sketch_get_n_double.__doc__ = pysparkfuncs.kll_sketch_get_n_double.__doc__
+
+
+def kll_sketch_merge_bigint(left: "ColumnOrName", right: "ColumnOrName") -> Column:
+ fn = "kll_sketch_merge_bigint"
+ return _invoke_function_over_columns(fn, left, right)
+
+
+kll_sketch_merge_bigint.__doc__ = pysparkfuncs.kll_sketch_merge_bigint.__doc__
+
+
+def kll_sketch_merge_float(left: "ColumnOrName", right: "ColumnOrName") -> Column:
+ fn = "kll_sketch_merge_float"
+ return _invoke_function_over_columns(fn, left, right)
+
+
+kll_sketch_merge_float.__doc__ = pysparkfuncs.kll_sketch_merge_float.__doc__
+
+
+def kll_sketch_merge_double(left: "ColumnOrName", right: "ColumnOrName") -> Column:
+ fn = "kll_sketch_merge_double"
+ return _invoke_function_over_columns(fn, left, right)
+
+
+kll_sketch_merge_double.__doc__ = pysparkfuncs.kll_sketch_merge_double.__doc__
+
+
+def kll_sketch_get_quantile_bigint(sketch: "ColumnOrName", rank: "ColumnOrName") -> Column:
+ fn = "kll_sketch_get_quantile_bigint"
+ return _invoke_function_over_columns(fn, sketch, rank)
+
+
+kll_sketch_get_quantile_bigint.__doc__ = pysparkfuncs.kll_sketch_get_quantile_bigint.__doc__
+
+
+def kll_sketch_get_quantile_float(sketch: "ColumnOrName", rank: "ColumnOrName") -> Column:
+ fn = "kll_sketch_get_quantile_float"
+ return _invoke_function_over_columns(fn, sketch, rank)
+
+
+kll_sketch_get_quantile_float.__doc__ = pysparkfuncs.kll_sketch_get_quantile_float.__doc__
+
+
+def kll_sketch_get_quantile_double(sketch: "ColumnOrName", rank: "ColumnOrName") -> Column:
+ fn = "kll_sketch_get_quantile_double"
+ return _invoke_function_over_columns(fn, sketch, rank)
+
+
+kll_sketch_get_quantile_double.__doc__ = pysparkfuncs.kll_sketch_get_quantile_double.__doc__
+
+
+def kll_sketch_get_rank_bigint(sketch: "ColumnOrName", quantile: "ColumnOrName") -> Column:
+ fn = "kll_sketch_get_rank_bigint"
+ return _invoke_function_over_columns(fn, sketch, quantile)
+
+
+kll_sketch_get_rank_bigint.__doc__ = pysparkfuncs.kll_sketch_get_rank_bigint.__doc__
+
+
+def kll_sketch_get_rank_float(sketch: "ColumnOrName", quantile: "ColumnOrName") -> Column:
+ fn = "kll_sketch_get_rank_float"
+ return _invoke_function_over_columns(fn, sketch, quantile)
+
+
+kll_sketch_get_rank_float.__doc__ = pysparkfuncs.kll_sketch_get_rank_float.__doc__
+
+
+def kll_sketch_get_rank_double(sketch: "ColumnOrName", quantile: "ColumnOrName") -> Column:
+ fn = "kll_sketch_get_rank_double"
+ return _invoke_function_over_columns(fn, sketch, quantile)
+
+
+kll_sketch_get_rank_double.__doc__ = pysparkfuncs.kll_sketch_get_rank_double.__doc__
+
+
def theta_sketch_estimate(col: "ColumnOrName") -> Column:
fn = "theta_sketch_estimate"
return _invoke_function_over_columns(fn, col)
@@ -4783,6 +4987,46 @@ def bitmap_and_agg(col: "ColumnOrName") -> Column:
bitmap_and_agg.__doc__ = pysparkfuncs.bitmap_and_agg.__doc__
+# Geospatial ST Functions
+
+
+def st_asbinary(geo: "ColumnOrName") -> Column:
+ return _invoke_function_over_columns("st_asbinary", geo)
+
+
+st_asbinary.__doc__ = pysparkfuncs.st_asbinary.__doc__
+
+
+def st_geogfromwkb(wkb: "ColumnOrName") -> Column:
+ return _invoke_function_over_columns("st_geogfromwkb", wkb)
+
+
+st_geogfromwkb.__doc__ = pysparkfuncs.st_geogfromwkb.__doc__
+
+
+def st_geomfromwkb(wkb: "ColumnOrName") -> Column:
+ return _invoke_function_over_columns("st_geomfromwkb", wkb)
+
+
+st_geomfromwkb.__doc__ = pysparkfuncs.st_geomfromwkb.__doc__
+
+
+def st_setsrid(geo: "ColumnOrName", srid: Union["ColumnOrName", int]) -> Column:
+ srid = _enum_to_value(srid)
+ srid = lit(srid) if isinstance(srid, int) else srid
+ return _invoke_function_over_columns("st_setsrid", geo, srid)
+
+
+st_setsrid.__doc__ = pysparkfuncs.st_setsrid.__doc__
+
+
+def st_srid(geo: "ColumnOrName") -> Column:
+ return _invoke_function_over_columns("st_srid", geo)
+
+
+st_srid.__doc__ = pysparkfuncs.st_srid.__doc__
+
+
# Call Functions
diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py
index 52d280c2c2646..d540e721f149e 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -158,6 +158,7 @@ def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame":
def _numeric_agg(self, function: str, cols: Sequence[str]) -> "DataFrame":
from pyspark.sql.connect.dataframe import DataFrame
+ from pyspark.sql.connect.types import verify_numeric_col_name
assert isinstance(function, str) and function in ["min", "max", "avg", "sum"]
@@ -165,12 +166,8 @@ def _numeric_agg(self, function: str, cols: Sequence[str]) -> "DataFrame":
schema = self._df.schema
- numerical_cols: List[str] = [
- field.name for field in schema.fields if isinstance(field.dataType, NumericType)
- ]
-
if len(cols) > 0:
- invalid_cols = [c for c in cols if c not in numerical_cols]
+ invalid_cols = [c for c in cols if not verify_numeric_col_name(c, schema)]
if len(invalid_cols) > 0:
raise PySparkTypeError(
errorClass="NOT_NUMERIC_COLUMNS",
@@ -179,7 +176,9 @@ def _numeric_agg(self, function: str, cols: Sequence[str]) -> "DataFrame":
agg_cols = cols
else:
# if no column is provided, then all numerical columns are selected
- agg_cols = numerical_cols
+ agg_cols = [
+ field.name for field in schema.fields if isinstance(field.dataType, NumericType)
+ ]
return DataFrame(
plan.Aggregate(
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index 82a6326c7dc58..6630d96f21ded 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -24,6 +24,7 @@
from typing import (
Any,
+ Iterator,
List,
Optional,
Type,
@@ -143,7 +144,8 @@ def to_proto(self, session: "SparkConnectClient", debug: bool = False) -> proto.
if enabled, the proto plan will be printed.
"""
plan = proto.Plan()
- plan.root.CopyFrom(self.plan(session))
+ relation = self.plan(session)
+ session._set_relation_in_plan(plan, relation)
if debug:
print(plan)
@@ -436,7 +438,7 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:
return plan
def _serialize_table(self) -> bytes:
- assert self._table is not None
+ assert self._table is not None, "table cannot be None"
sink = pa.BufferOutputStream()
with pa.ipc.new_stream(sink, self._table.schema) as writer:
batches = self._table.to_batches()
@@ -448,7 +450,7 @@ def _serialize_table_chunks(
self,
max_chunk_size_rows: int,
max_chunk_size_bytes: int,
- ) -> list[bytes]:
+ ) -> Iterator[bytes]:
"""
Serialize the table into multiple chunks, each up to max_chunk_size_bytes bytes
and max_chunk_size_rows rows.
@@ -456,49 +458,52 @@ def _serialize_table_chunks(
This method processes the table in fixed-size batches (1024 rows) for
efficiency, matching the Scala implementation's batchSizeCheckInterval.
+
+ Yields chunks one at a time to avoid materializing all chunks in memory.
"""
- assert self._table is not None
- chunks = []
+ assert self._table is not None, "table cannot be None"
+ assert self._table.num_rows > 0, "table must have at least one row"
schema = self._table.schema
- # Calculate schema serialization size once
- schema_buffer = pa.BufferOutputStream()
- with pa.ipc.new_stream(schema_buffer, schema):
- pass # Just write schema
- schema_size = len(schema_buffer.getvalue())
+ # Calculate schema serialization size once (empty table = just schema)
+ schema_size = len(self._serialize_batches_to_ipc([], schema))
current_batches: list[pa.RecordBatch] = []
current_size = schema_size
for batch in self._table.to_batches(max_chunksize=min(1024, max_chunk_size_rows)):
+ # Approximate batch size using raw column data (fast, ignores IPC overhead).
+ # Calculating the real batch size of the IPC stream would require serializing each
+ # batch separately, which adds overhead.
batch_size = sum(arr.nbytes for arr in batch.columns)
# If this batch would exceed limit and we have data, flush current chunk
- if current_size > schema_size and current_size + batch_size > max_chunk_size_bytes:
- combined = pa.Table.from_batches(current_batches, schema=schema)
- sink = pa.BufferOutputStream()
- with pa.ipc.new_stream(sink, schema) as writer:
- writer.write_table(combined)
- chunks.append(sink.getvalue().to_pybytes())
+ if len(current_batches) > 0 and current_size + batch_size > max_chunk_size_bytes:
+ yield self._serialize_batches_to_ipc(current_batches, schema)
current_batches = []
current_size = schema_size
current_batches.append(batch)
current_size += batch_size
- # Flush remaining batches
- if current_batches:
- combined = pa.Table.from_batches(current_batches, schema=schema)
- sink = pa.BufferOutputStream()
- with pa.ipc.new_stream(sink, schema) as writer:
- writer.write_table(combined)
- chunks.append(sink.getvalue().to_pybytes())
+ # Flush remaining batches (guaranteed to have at least one due to assertion)
+ yield self._serialize_batches_to_ipc(current_batches, schema)
- return chunks
+ def _serialize_batches_to_ipc(
+ self,
+ batches: list[pa.RecordBatch],
+ schema: pa.Schema,
+ ) -> bytes:
+ """Helper method to serialize Arrow batches to IPC stream format."""
+ combined = pa.Table.from_batches(batches, schema=schema)
+ sink = pa.BufferOutputStream()
+ with pa.ipc.new_stream(sink, schema) as writer:
+ writer.write_table(combined)
+ return sink.getvalue().to_pybytes()
def _serialize_schema(self) -> bytes:
# the server uses UTF-8 for decoding the schema
- assert self._schema is not None
+ assert self._schema is not None, "schema cannot be None"
return self._schema.encode("utf-8")
def serialize(self, session: "SparkConnectClient") -> bytes:
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py
index 0fe992332de71..32b2840dffadc 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -18,7 +18,7 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: spark/connect/base.proto
-# Protobuf Python Version: 5.29.5
+# Protobuf Python Version: 6.33.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
@@ -27,7 +27,7 @@
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
- _runtime_version.Domain.PUBLIC, 5, 29, 5, "", "spark/connect/base.proto"
+ _runtime_version.Domain.PUBLIC, 6, 33, 0, "", "spark/connect/base.proto"
)
# @@protoc_insertion_point(imports)
@@ -45,7 +45,7 @@
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto\x1a\x16spark/connect/ml.proto\x1a\x1dspark/connect/pipelines.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf5\x14\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x11 \x01(\tH\x01R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x02R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x12\x45\n\x07persist\x18\x0e \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.PersistH\x00R\x07persist\x12K\n\tunpersist\x18\x0f \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.UnpersistH\x00R\tunpersist\x12_\n\x11get_storage_level\x18\x10 \x01(\x0b\x32\x31.spark.connect.AnalyzePlanRequest.GetStorageLevelH\x00R\x0fgetStorageLevel\x12M\n\x0bjson_to_ddl\x18\x12 \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.JsonToDDLH\x00R\tjsonToDdl\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1aZ\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12\x19\n\x05level\x18\x02 \x01(\x05H\x00R\x05level\x88\x01\x01\x42\x08\n\x06_level\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x97\x01\n\x07Persist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x45\n\rstorage_level\x18\x02 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level\x1an\n\tUnpersist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x1f\n\x08\x62locking\x18\x02 \x01(\x08H\x00R\x08\x62locking\x88\x01\x01\x42\x0b\n\t_blocking\x1a\x46\n\x0fGetStorageLevel\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1a,\n\tJsonToDDL\x12\x1f\n\x0bjson_string\x18\x01 \x01(\tR\njsonStringB\t\n\x07\x61nalyzeB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\xca\x0e\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x0f \x01(\tR\x13serverSideSessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x12\x46\n\x07persist\x18\x0c \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.PersistH\x00R\x07persist\x12L\n\tunpersist\x18\r \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.UnpersistH\x00R\tunpersist\x12`\n\x11get_storage_level\x18\x0e \x01(\x0b\x32\x32.spark.connect.AnalyzePlanResponse.GetStorageLevelH\x00R\x0fgetStorageLevel\x12N\n\x0bjson_to_ddl\x18\x10 \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.JsonToDDLH\x00R\tjsonToDdl\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06result\x1a\t\n\x07Persist\x1a\x0b\n\tUnpersist\x1aS\n\x0fGetStorageLevel\x12@\n\rstorage_level\x18\x01 \x01(\x0b\x32\x1b.spark.connect.StorageLevelR\x0cstorageLevel\x1a*\n\tJsonToDDL\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlStringB\x08\n\x06result"\x83\x06\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x08 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12&\n\x0coperation_id\x18\x06 \x01(\tH\x01R\x0boperationId\x88\x01\x01\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x02R\nclientType\x88\x01\x01\x12X\n\x0frequest_options\x18\x05 \x03(\x0b\x32/.spark.connect.ExecutePlanRequest.RequestOptionR\x0erequestOptions\x12\x12\n\x04tags\x18\x07 \x03(\tR\x04tags\x1a\x85\x02\n\rRequestOption\x12K\n\x10reattach_options\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ReattachOptionsH\x00R\x0freattachOptions\x12^\n\x17result_chunking_options\x18\x02 \x01(\x0b\x32$.spark.connect.ResultChunkingOptionsH\x00R\x15resultChunkingOptions\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x10\n\x0erequest_optionB)\n\'_client_observed_server_side_session_idB\x0f\n\r_operation_idB\x0e\n\x0c_client_type"\x81\x1b\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x0f \x01(\tR\x13serverSideSessionId\x12!\n\x0coperation_id\x18\x0c \x01(\tR\x0boperationId\x12\x1f\n\x0bresponse_id\x18\r \x01(\tR\nresponseId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12~\n#write_stream_operation_start_result\x18\x08 \x01(\x0b\x32..spark.connect.WriteStreamOperationStartResultH\x00R\x1fwriteStreamOperationStartResult\x12q\n\x1estreaming_query_command_result\x18\t \x01(\x0b\x32*.spark.connect.StreamingQueryCommandResultH\x00R\x1bstreamingQueryCommandResult\x12k\n\x1cget_resources_command_result\x18\n \x01(\x0b\x32(.spark.connect.GetResourcesCommandResultH\x00R\x19getResourcesCommandResult\x12\x87\x01\n&streaming_query_manager_command_result\x18\x0b \x01(\x0b\x32\x31.spark.connect.StreamingQueryManagerCommandResultH\x00R"streamingQueryManagerCommandResult\x12\x87\x01\n&streaming_query_listener_events_result\x18\x10 \x01(\x0b\x32\x31.spark.connect.StreamingQueryListenerEventsResultH\x00R"streamingQueryListenerEventsResult\x12\\\n\x0fresult_complete\x18\x0e \x01(\x0b\x32\x31.spark.connect.ExecutePlanResponse.ResultCompleteH\x00R\x0eresultComplete\x12\x87\x01\n&create_resource_profile_command_result\x18\x11 \x01(\x0b\x32\x31.spark.connect.CreateResourceProfileCommandResultH\x00R"createResourceProfileCommandResult\x12\x65\n\x12\x65xecution_progress\x18\x12 \x01(\x0b\x32\x34.spark.connect.ExecutePlanResponse.ExecutionProgressH\x00R\x11\x65xecutionProgress\x12\x64\n\x19\x63heckpoint_command_result\x18\x13 \x01(\x0b\x32&.spark.connect.CheckpointCommandResultH\x00R\x17\x63heckpointCommandResult\x12L\n\x11ml_command_result\x18\x14 \x01(\x0b\x32\x1e.spark.connect.MlCommandResultH\x00R\x0fmlCommandResult\x12X\n\x15pipeline_event_result\x18\x15 \x01(\x0b\x32".spark.connect.PipelineEventResultH\x00R\x13pipelineEventResult\x12^\n\x17pipeline_command_result\x18\x16 \x01(\x0b\x32$.spark.connect.PipelineCommandResultH\x00R\x15pipelineCommandResult\x12\x8d\x01\n(pipeline_query_function_execution_signal\x18\x17 \x01(\x0b\x32\x33.spark.connect.PipelineQueryFunctionExecutionSignalH\x00R$pipelineQueryFunctionExecutionSignal\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1a\xf8\x01\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x12&\n\x0cstart_offset\x18\x03 \x01(\x03H\x00R\x0bstartOffset\x88\x01\x01\x12$\n\x0b\x63hunk_index\x18\x04 \x01(\x03H\x01R\nchunkIndex\x88\x01\x01\x12\x32\n\x13num_chunks_in_batch\x18\x05 \x01(\x03H\x02R\x10numChunksInBatch\x88\x01\x01\x42\x0f\n\r_start_offsetB\x0e\n\x0c_chunk_indexB\x16\n\x14_num_chunks_in_batch\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1a\x8d\x01\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x12\x12\n\x04keys\x18\x03 \x03(\tR\x04keys\x12\x17\n\x07plan_id\x18\x04 \x01(\x03R\x06planId\x1a\x10\n\x0eResultComplete\x1a\xcd\x02\n\x11\x45xecutionProgress\x12V\n\x06stages\x18\x01 \x03(\x0b\x32>.spark.connect.ExecutePlanResponse.ExecutionProgress.StageInfoR\x06stages\x12,\n\x12num_inflight_tasks\x18\x02 \x01(\x03R\x10numInflightTasks\x1a\xb1\x01\n\tStageInfo\x12\x19\n\x08stage_id\x18\x01 \x01(\x03R\x07stageId\x12\x1b\n\tnum_tasks\x18\x02 \x01(\x03R\x08numTasks\x12.\n\x13num_completed_tasks\x18\x03 \x01(\x03R\x11numCompletedTasks\x12(\n\x10input_bytes_read\x18\x04 \x01(\x03R\x0einputBytesRead\x12\x12\n\x04\x64one\x18\x05 \x01(\x08R\x04\x64oneB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\xaf\t\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x08 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\\\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1b\n\x06silent\x18\x02 \x01(\x08H\x00R\x06silent\x88\x01\x01\x42\t\n\x07_silent\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\xaf\x01\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x04 \x01(\tR\x13serverSideSessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xea\x07\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12V\n&client_observed_server_side_session_id\x18\x07 \x01(\tH\x01R!clientObservedServerSideSessionId\x88\x01\x01\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x02R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\x90\x02\n\x14\x41\x64\x64\x41rtifactsResponse\x12\x1d\n\nsession_id\x18\x02 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful"\xc6\x02\n\x17\x41rtifactStatusesRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x05 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x14\n\x05names\x18\x04 \x03(\tR\x05namesB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\xe0\x02\n\x18\x41rtifactStatusesResponse\x12\x1d\n\nsession_id\x18\x02 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12Q\n\x08statuses\x18\x01 \x03(\x0b\x32\x35.spark.connect.ArtifactStatusesResponse.StatusesEntryR\x08statuses\x1as\n\rStatusesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ArtifactStatusesResponse.ArtifactStatusR\x05value:\x02\x38\x01\x1a(\n\x0e\x41rtifactStatus\x12\x16\n\x06\x65xists\x18\x01 \x01(\x08R\x06\x65xists"\xdb\x04\n\x10InterruptRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x07 \x01(\tH\x01R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x02R\nclientType\x88\x01\x01\x12T\n\x0einterrupt_type\x18\x04 \x01(\x0e\x32-.spark.connect.InterruptRequest.InterruptTypeR\rinterruptType\x12%\n\roperation_tag\x18\x05 \x01(\tH\x00R\x0coperationTag\x12#\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId"\x80\x01\n\rInterruptType\x12\x1e\n\x1aINTERRUPT_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12INTERRUPT_TYPE_ALL\x10\x01\x12\x16\n\x12INTERRUPT_TYPE_TAG\x10\x02\x12\x1f\n\x1bINTERRUPT_TYPE_OPERATION_ID\x10\x03\x42\x0b\n\tinterruptB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\x90\x01\n\x11InterruptResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12\'\n\x0finterrupted_ids\x18\x02 \x03(\tR\x0einterruptedIds"5\n\x0fReattachOptions\x12"\n\x0creattachable\x18\x01 \x01(\x08R\x0creattachable"\xb5\x01\n\x15ResultChunkingOptions\x12;\n\x1a\x61llow_arrow_batch_chunking\x18\x01 \x01(\x08R\x17\x61llowArrowBatchChunking\x12@\n\x1apreferred_arrow_chunk_size\x18\x02 \x01(\x03H\x00R\x17preferredArrowChunkSize\x88\x01\x01\x42\x1d\n\x1b_preferred_arrow_chunk_size"\x96\x03\n\x16ReattachExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x06 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12-\n\x10last_response_id\x18\x05 \x01(\tH\x02R\x0elastResponseId\x88\x01\x01\x42)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_typeB\x13\n\x11_last_response_id"\xc9\x04\n\x15ReleaseExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x07 \x01(\tH\x01R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x02R\nclientType\x88\x01\x01\x12R\n\x0brelease_all\x18\x05 \x01(\x0b\x32/.spark.connect.ReleaseExecuteRequest.ReleaseAllH\x00R\nreleaseAll\x12X\n\rrelease_until\x18\x06 \x01(\x0b\x32\x31.spark.connect.ReleaseExecuteRequest.ReleaseUntilH\x00R\x0creleaseUntil\x1a\x0c\n\nReleaseAll\x1a/\n\x0cReleaseUntil\x12\x1f\n\x0bresponse_id\x18\x01 \x01(\tR\nresponseIdB\t\n\x07releaseB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\xa5\x01\n\x16ReleaseExecuteResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12&\n\x0coperation_id\x18\x02 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x42\x0f\n\r_operation_id"\xd4\x01\n\x15ReleaseSessionRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\'\n\x0f\x61llow_reconnect\x18\x04 \x01(\x08R\x0e\x61llowReconnectB\x0e\n\x0c_client_type"l\n\x16ReleaseSessionResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x02 \x01(\tR\x13serverSideSessionId"\xcc\x02\n\x18\x46\x65tchErrorDetailsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x05 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x19\n\x08\x65rror_id\x18\x03 \x01(\tR\x07\x65rrorId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x42)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\xd9\x0f\n\x19\x46\x65tchErrorDetailsResponse\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12\x1d\n\nsession_id\x18\x04 \x01(\tR\tsessionId\x12)\n\x0eroot_error_idx\x18\x01 \x01(\x05H\x00R\x0crootErrorIdx\x88\x01\x01\x12\x46\n\x06\x65rrors\x18\x02 \x03(\x0b\x32..spark.connect.FetchErrorDetailsResponse.ErrorR\x06\x65rrors\x1a\xae\x01\n\x11StackTraceElement\x12\'\n\x0f\x64\x65\x63laring_class\x18\x01 \x01(\tR\x0e\x64\x65\x63laringClass\x12\x1f\n\x0bmethod_name\x18\x02 \x01(\tR\nmethodName\x12 \n\tfile_name\x18\x03 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12\x1f\n\x0bline_number\x18\x04 \x01(\x05R\nlineNumberB\x0c\n\n_file_name\x1a\xf0\x02\n\x0cQueryContext\x12\x64\n\x0c\x63ontext_type\x18\n \x01(\x0e\x32\x41.spark.connect.FetchErrorDetailsResponse.QueryContext.ContextTypeR\x0b\x63ontextType\x12\x1f\n\x0bobject_type\x18\x01 \x01(\tR\nobjectType\x12\x1f\n\x0bobject_name\x18\x02 \x01(\tR\nobjectName\x12\x1f\n\x0bstart_index\x18\x03 \x01(\x05R\nstartIndex\x12\x1d\n\nstop_index\x18\x04 \x01(\x05R\tstopIndex\x12\x1a\n\x08\x66ragment\x18\x05 \x01(\tR\x08\x66ragment\x12\x1b\n\tcall_site\x18\x06 \x01(\tR\x08\x63\x61llSite\x12\x18\n\x07summary\x18\x07 \x01(\tR\x07summary"%\n\x0b\x43ontextType\x12\x07\n\x03SQL\x10\x00\x12\r\n\tDATAFRAME\x10\x01\x1a\xa6\x04\n\x0eSparkThrowable\x12$\n\x0b\x65rror_class\x18\x01 \x01(\tH\x00R\nerrorClass\x88\x01\x01\x12}\n\x12message_parameters\x18\x02 \x03(\x0b\x32N.spark.connect.FetchErrorDetailsResponse.SparkThrowable.MessageParametersEntryR\x11messageParameters\x12\\\n\x0equery_contexts\x18\x03 \x03(\x0b\x32\x35.spark.connect.FetchErrorDetailsResponse.QueryContextR\rqueryContexts\x12 \n\tsql_state\x18\x04 \x01(\tH\x01R\x08sqlState\x88\x01\x01\x12r\n\x14\x62reaking_change_info\x18\x05 \x01(\x0b\x32;.spark.connect.FetchErrorDetailsResponse.BreakingChangeInfoH\x02R\x12\x62reakingChangeInfo\x88\x01\x01\x1a\x44\n\x16MessageParametersEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0e\n\x0c_error_classB\x0c\n\n_sql_stateB\x17\n\x15_breaking_change_info\x1a\xfa\x01\n\x12\x42reakingChangeInfo\x12+\n\x11migration_message\x18\x01 \x03(\tR\x10migrationMessage\x12k\n\x11mitigation_config\x18\x02 \x01(\x0b\x32\x39.spark.connect.FetchErrorDetailsResponse.MitigationConfigH\x00R\x10mitigationConfig\x88\x01\x01\x12$\n\x0bneeds_audit\x18\x03 \x01(\x08H\x01R\nneedsAudit\x88\x01\x01\x42\x14\n\x12_mitigation_configB\x0e\n\x0c_needs_audit\x1a:\n\x10MitigationConfig\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value\x1a\xdb\x02\n\x05\x45rror\x12\x30\n\x14\x65rror_type_hierarchy\x18\x01 \x03(\tR\x12\x65rrorTypeHierarchy\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12[\n\x0bstack_trace\x18\x03 \x03(\x0b\x32:.spark.connect.FetchErrorDetailsResponse.StackTraceElementR\nstackTrace\x12 \n\tcause_idx\x18\x04 \x01(\x05H\x00R\x08\x63\x61useIdx\x88\x01\x01\x12\x65\n\x0fspark_throwable\x18\x05 \x01(\x0b\x32\x37.spark.connect.FetchErrorDetailsResponse.SparkThrowableH\x01R\x0esparkThrowable\x88\x01\x01\x42\x0c\n\n_cause_idxB\x12\n\x10_spark_throwableB\x11\n\x0f_root_error_idx"Z\n\x17\x43heckpointCommandResult\x12?\n\x08relation\x18\x01 \x01(\x0b\x32#.spark.connect.CachedRemoteRelationR\x08relation"\xea\x02\n\x13\x43loneSessionRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x05 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12)\n\x0enew_session_id\x18\x04 \x01(\tH\x02R\x0cnewSessionId\x88\x01\x01\x42)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_typeB\x11\n\x0f_new_session_id"\xcc\x01\n\x14\x43loneSessionResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x02 \x01(\tR\x13serverSideSessionId\x12$\n\x0enew_session_id\x18\x03 \x01(\tR\x0cnewSessionId\x12:\n\x1anew_server_side_session_id\x18\x04 \x01(\tR\x16newServerSideSessionId2\x8d\x08\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x12\x63\n\x0e\x41rtifactStatus\x12&.spark.connect.ArtifactStatusesRequest\x1a\'.spark.connect.ArtifactStatusesResponse"\x00\x12P\n\tInterrupt\x12\x1f.spark.connect.InterruptRequest\x1a .spark.connect.InterruptResponse"\x00\x12`\n\x0fReattachExecute\x12%.spark.connect.ReattachExecuteRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12_\n\x0eReleaseExecute\x12$.spark.connect.ReleaseExecuteRequest\x1a%.spark.connect.ReleaseExecuteResponse"\x00\x12_\n\x0eReleaseSession\x12$.spark.connect.ReleaseSessionRequest\x1a%.spark.connect.ReleaseSessionResponse"\x00\x12h\n\x11\x46\x65tchErrorDetails\x12\'.spark.connect.FetchErrorDetailsRequest\x1a(.spark.connect.FetchErrorDetailsResponse"\x00\x12Y\n\x0c\x43loneSession\x12".spark.connect.CloneSessionRequest\x1a#.spark.connect.CloneSessionResponse"\x00\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3'
+ b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto\x1a\x16spark/connect/ml.proto\x1a\x1dspark/connect/pipelines.proto"\xe3\x03\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommand\x12\\\n\x14\x63ompressed_operation\x18\x03 \x01(\x0b\x32\'.spark.connect.Plan.CompressedOperationH\x00R\x13\x63ompressedOperation\x1a\x8e\x02\n\x13\x43ompressedOperation\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12G\n\x07op_type\x18\x02 \x01(\x0e\x32..spark.connect.Plan.CompressedOperation.OpTypeR\x06opType\x12L\n\x11\x63ompression_codec\x18\x03 \x01(\x0e\x32\x1f.spark.connect.CompressionCodecR\x10\x63ompressionCodec"L\n\x06OpType\x12\x17\n\x13OP_TYPE_UNSPECIFIED\x10\x00\x12\x14\n\x10OP_TYPE_RELATION\x10\x01\x12\x13\n\x0fOP_TYPE_COMMAND\x10\x02\x42\t\n\x07op_type"z\n\x0bUserContext\x12\x17\n\x07user_id\x18\x01 \x01(\tR\x06userId\x12\x1b\n\tuser_name\x18\x02 \x01(\tR\x08userName\x12\x35\n\nextensions\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\nextensions"\xf5\x14\n\x12\x41nalyzePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x11 \x01(\tH\x01R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x02R\nclientType\x88\x01\x01\x12\x42\n\x06schema\x18\x04 \x01(\x0b\x32(.spark.connect.AnalyzePlanRequest.SchemaH\x00R\x06schema\x12\x45\n\x07\x65xplain\x18\x05 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.ExplainH\x00R\x07\x65xplain\x12O\n\x0btree_string\x18\x06 \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.TreeStringH\x00R\ntreeString\x12\x46\n\x08is_local\x18\x07 \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.IsLocalH\x00R\x07isLocal\x12R\n\x0cis_streaming\x18\x08 \x01(\x0b\x32-.spark.connect.AnalyzePlanRequest.IsStreamingH\x00R\x0bisStreaming\x12O\n\x0binput_files\x18\t \x01(\x0b\x32,.spark.connect.AnalyzePlanRequest.InputFilesH\x00R\ninputFiles\x12U\n\rspark_version\x18\n \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SparkVersionH\x00R\x0csparkVersion\x12I\n\tddl_parse\x18\x0b \x01(\x0b\x32*.spark.connect.AnalyzePlanRequest.DDLParseH\x00R\x08\x64\x64lParse\x12X\n\x0esame_semantics\x18\x0c \x01(\x0b\x32/.spark.connect.AnalyzePlanRequest.SameSemanticsH\x00R\rsameSemantics\x12U\n\rsemantic_hash\x18\r \x01(\x0b\x32..spark.connect.AnalyzePlanRequest.SemanticHashH\x00R\x0csemanticHash\x12\x45\n\x07persist\x18\x0e \x01(\x0b\x32).spark.connect.AnalyzePlanRequest.PersistH\x00R\x07persist\x12K\n\tunpersist\x18\x0f \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.UnpersistH\x00R\tunpersist\x12_\n\x11get_storage_level\x18\x10 \x01(\x0b\x32\x31.spark.connect.AnalyzePlanRequest.GetStorageLevelH\x00R\x0fgetStorageLevel\x12M\n\x0bjson_to_ddl\x18\x12 \x01(\x0b\x32+.spark.connect.AnalyzePlanRequest.JsonToDDLH\x00R\tjsonToDdl\x1a\x31\n\x06Schema\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\xbb\x02\n\x07\x45xplain\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12X\n\x0c\x65xplain_mode\x18\x02 \x01(\x0e\x32\x35.spark.connect.AnalyzePlanRequest.Explain.ExplainModeR\x0b\x65xplainMode"\xac\x01\n\x0b\x45xplainMode\x12\x1c\n\x18\x45XPLAIN_MODE_UNSPECIFIED\x10\x00\x12\x17\n\x13\x45XPLAIN_MODE_SIMPLE\x10\x01\x12\x19\n\x15\x45XPLAIN_MODE_EXTENDED\x10\x02\x12\x18\n\x14\x45XPLAIN_MODE_CODEGEN\x10\x03\x12\x15\n\x11\x45XPLAIN_MODE_COST\x10\x04\x12\x1a\n\x16\x45XPLAIN_MODE_FORMATTED\x10\x05\x1aZ\n\nTreeString\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12\x19\n\x05level\x18\x02 \x01(\x05H\x00R\x05level\x88\x01\x01\x42\x08\n\x06_level\x1a\x32\n\x07IsLocal\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x36\n\x0bIsStreaming\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x35\n\nInputFiles\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x0e\n\x0cSparkVersion\x1a)\n\x08\x44\x44LParse\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlString\x1ay\n\rSameSemantics\x12\x34\n\x0btarget_plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\ntargetPlan\x12\x32\n\nother_plan\x18\x02 \x01(\x0b\x32\x13.spark.connect.PlanR\totherPlan\x1a\x37\n\x0cSemanticHash\x12\'\n\x04plan\x18\x01 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x1a\x97\x01\n\x07Persist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x45\n\rstorage_level\x18\x02 \x01(\x0b\x32\x1b.spark.connect.StorageLevelH\x00R\x0cstorageLevel\x88\x01\x01\x42\x10\n\x0e_storage_level\x1an\n\tUnpersist\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x12\x1f\n\x08\x62locking\x18\x02 \x01(\x08H\x00R\x08\x62locking\x88\x01\x01\x42\x0b\n\t_blocking\x1a\x46\n\x0fGetStorageLevel\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1a,\n\tJsonToDDL\x12\x1f\n\x0bjson_string\x18\x01 \x01(\tR\njsonStringB\t\n\x07\x61nalyzeB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\xca\x0e\n\x13\x41nalyzePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x0f \x01(\tR\x13serverSideSessionId\x12\x43\n\x06schema\x18\x02 \x01(\x0b\x32).spark.connect.AnalyzePlanResponse.SchemaH\x00R\x06schema\x12\x46\n\x07\x65xplain\x18\x03 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.ExplainH\x00R\x07\x65xplain\x12P\n\x0btree_string\x18\x04 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.TreeStringH\x00R\ntreeString\x12G\n\x08is_local\x18\x05 \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.IsLocalH\x00R\x07isLocal\x12S\n\x0cis_streaming\x18\x06 \x01(\x0b\x32..spark.connect.AnalyzePlanResponse.IsStreamingH\x00R\x0bisStreaming\x12P\n\x0binput_files\x18\x07 \x01(\x0b\x32-.spark.connect.AnalyzePlanResponse.InputFilesH\x00R\ninputFiles\x12V\n\rspark_version\x18\x08 \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SparkVersionH\x00R\x0csparkVersion\x12J\n\tddl_parse\x18\t \x01(\x0b\x32+.spark.connect.AnalyzePlanResponse.DDLParseH\x00R\x08\x64\x64lParse\x12Y\n\x0esame_semantics\x18\n \x01(\x0b\x32\x30.spark.connect.AnalyzePlanResponse.SameSemanticsH\x00R\rsameSemantics\x12V\n\rsemantic_hash\x18\x0b \x01(\x0b\x32/.spark.connect.AnalyzePlanResponse.SemanticHashH\x00R\x0csemanticHash\x12\x46\n\x07persist\x18\x0c \x01(\x0b\x32*.spark.connect.AnalyzePlanResponse.PersistH\x00R\x07persist\x12L\n\tunpersist\x18\r \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.UnpersistH\x00R\tunpersist\x12`\n\x11get_storage_level\x18\x0e \x01(\x0b\x32\x32.spark.connect.AnalyzePlanResponse.GetStorageLevelH\x00R\x0fgetStorageLevel\x12N\n\x0bjson_to_ddl\x18\x10 \x01(\x0b\x32,.spark.connect.AnalyzePlanResponse.JsonToDDLH\x00R\tjsonToDdl\x1a\x39\n\x06Schema\x12/\n\x06schema\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1a\x30\n\x07\x45xplain\x12%\n\x0e\x65xplain_string\x18\x01 \x01(\tR\rexplainString\x1a-\n\nTreeString\x12\x1f\n\x0btree_string\x18\x01 \x01(\tR\ntreeString\x1a$\n\x07IsLocal\x12\x19\n\x08is_local\x18\x01 \x01(\x08R\x07isLocal\x1a\x30\n\x0bIsStreaming\x12!\n\x0cis_streaming\x18\x01 \x01(\x08R\x0bisStreaming\x1a"\n\nInputFiles\x12\x14\n\x05\x66iles\x18\x01 \x03(\tR\x05\x66iles\x1a(\n\x0cSparkVersion\x12\x18\n\x07version\x18\x01 \x01(\tR\x07version\x1a;\n\x08\x44\x44LParse\x12/\n\x06parsed\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06parsed\x1a\'\n\rSameSemantics\x12\x16\n\x06result\x18\x01 \x01(\x08R\x06result\x1a&\n\x0cSemanticHash\x12\x16\n\x06result\x18\x01 \x01(\x05R\x06result\x1a\t\n\x07Persist\x1a\x0b\n\tUnpersist\x1aS\n\x0fGetStorageLevel\x12@\n\rstorage_level\x18\x01 \x01(\x0b\x32\x1b.spark.connect.StorageLevelR\x0cstorageLevel\x1a*\n\tJsonToDDL\x12\x1d\n\nddl_string\x18\x01 \x01(\tR\tddlStringB\x08\n\x06result"\x83\x06\n\x12\x45xecutePlanRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x08 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12&\n\x0coperation_id\x18\x06 \x01(\tH\x01R\x0boperationId\x88\x01\x01\x12\'\n\x04plan\x18\x03 \x01(\x0b\x32\x13.spark.connect.PlanR\x04plan\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x02R\nclientType\x88\x01\x01\x12X\n\x0frequest_options\x18\x05 \x03(\x0b\x32/.spark.connect.ExecutePlanRequest.RequestOptionR\x0erequestOptions\x12\x12\n\x04tags\x18\x07 \x03(\tR\x04tags\x1a\x85\x02\n\rRequestOption\x12K\n\x10reattach_options\x18\x01 \x01(\x0b\x32\x1e.spark.connect.ReattachOptionsH\x00R\x0freattachOptions\x12^\n\x17result_chunking_options\x18\x02 \x01(\x0b\x32$.spark.connect.ResultChunkingOptionsH\x00R\x15resultChunkingOptions\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textensionB\x10\n\x0erequest_optionB)\n\'_client_observed_server_side_session_idB\x0f\n\r_operation_idB\x0e\n\x0c_client_type"\x81\x1b\n\x13\x45xecutePlanResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x0f \x01(\tR\x13serverSideSessionId\x12!\n\x0coperation_id\x18\x0c \x01(\tR\x0boperationId\x12\x1f\n\x0bresponse_id\x18\r \x01(\tR\nresponseId\x12P\n\x0b\x61rrow_batch\x18\x02 \x01(\x0b\x32-.spark.connect.ExecutePlanResponse.ArrowBatchH\x00R\narrowBatch\x12\x63\n\x12sql_command_result\x18\x05 \x01(\x0b\x32\x33.spark.connect.ExecutePlanResponse.SqlCommandResultH\x00R\x10sqlCommandResult\x12~\n#write_stream_operation_start_result\x18\x08 \x01(\x0b\x32..spark.connect.WriteStreamOperationStartResultH\x00R\x1fwriteStreamOperationStartResult\x12q\n\x1estreaming_query_command_result\x18\t \x01(\x0b\x32*.spark.connect.StreamingQueryCommandResultH\x00R\x1bstreamingQueryCommandResult\x12k\n\x1cget_resources_command_result\x18\n \x01(\x0b\x32(.spark.connect.GetResourcesCommandResultH\x00R\x19getResourcesCommandResult\x12\x87\x01\n&streaming_query_manager_command_result\x18\x0b \x01(\x0b\x32\x31.spark.connect.StreamingQueryManagerCommandResultH\x00R"streamingQueryManagerCommandResult\x12\x87\x01\n&streaming_query_listener_events_result\x18\x10 \x01(\x0b\x32\x31.spark.connect.StreamingQueryListenerEventsResultH\x00R"streamingQueryListenerEventsResult\x12\\\n\x0fresult_complete\x18\x0e \x01(\x0b\x32\x31.spark.connect.ExecutePlanResponse.ResultCompleteH\x00R\x0eresultComplete\x12\x87\x01\n&create_resource_profile_command_result\x18\x11 \x01(\x0b\x32\x31.spark.connect.CreateResourceProfileCommandResultH\x00R"createResourceProfileCommandResult\x12\x65\n\x12\x65xecution_progress\x18\x12 \x01(\x0b\x32\x34.spark.connect.ExecutePlanResponse.ExecutionProgressH\x00R\x11\x65xecutionProgress\x12\x64\n\x19\x63heckpoint_command_result\x18\x13 \x01(\x0b\x32&.spark.connect.CheckpointCommandResultH\x00R\x17\x63heckpointCommandResult\x12L\n\x11ml_command_result\x18\x14 \x01(\x0b\x32\x1e.spark.connect.MlCommandResultH\x00R\x0fmlCommandResult\x12X\n\x15pipeline_event_result\x18\x15 \x01(\x0b\x32".spark.connect.PipelineEventResultH\x00R\x13pipelineEventResult\x12^\n\x17pipeline_command_result\x18\x16 \x01(\x0b\x32$.spark.connect.PipelineCommandResultH\x00R\x15pipelineCommandResult\x12\x8d\x01\n(pipeline_query_function_execution_signal\x18\x17 \x01(\x0b\x32\x33.spark.connect.PipelineQueryFunctionExecutionSignalH\x00R$pipelineQueryFunctionExecutionSignal\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x44\n\x07metrics\x18\x04 \x01(\x0b\x32*.spark.connect.ExecutePlanResponse.MetricsR\x07metrics\x12]\n\x10observed_metrics\x18\x06 \x03(\x0b\x32\x32.spark.connect.ExecutePlanResponse.ObservedMetricsR\x0fobservedMetrics\x12/\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema\x1aG\n\x10SqlCommandResult\x12\x33\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x08relation\x1a\xf8\x01\n\nArrowBatch\x12\x1b\n\trow_count\x18\x01 \x01(\x03R\x08rowCount\x12\x12\n\x04\x64\x61ta\x18\x02 \x01(\x0cR\x04\x64\x61ta\x12&\n\x0cstart_offset\x18\x03 \x01(\x03H\x00R\x0bstartOffset\x88\x01\x01\x12$\n\x0b\x63hunk_index\x18\x04 \x01(\x03H\x01R\nchunkIndex\x88\x01\x01\x12\x32\n\x13num_chunks_in_batch\x18\x05 \x01(\x03H\x02R\x10numChunksInBatch\x88\x01\x01\x42\x0f\n\r_start_offsetB\x0e\n\x0c_chunk_indexB\x16\n\x14_num_chunks_in_batch\x1a\x85\x04\n\x07Metrics\x12Q\n\x07metrics\x18\x01 \x03(\x0b\x32\x37.spark.connect.ExecutePlanResponse.Metrics.MetricObjectR\x07metrics\x1a\xcc\x02\n\x0cMetricObject\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x17\n\x07plan_id\x18\x02 \x01(\x03R\x06planId\x12\x16\n\x06parent\x18\x03 \x01(\x03R\x06parent\x12z\n\x11\x65xecution_metrics\x18\x04 \x03(\x0b\x32M.spark.connect.ExecutePlanResponse.Metrics.MetricObject.ExecutionMetricsEntryR\x10\x65xecutionMetrics\x1a{\n\x15\x45xecutionMetricsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ExecutePlanResponse.Metrics.MetricValueR\x05value:\x02\x38\x01\x1aX\n\x0bMetricValue\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n\x05value\x18\x02 \x01(\x03R\x05value\x12\x1f\n\x0bmetric_type\x18\x03 \x01(\tR\nmetricType\x1a\x8d\x01\n\x0fObservedMetrics\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x12\x12\n\x04keys\x18\x03 \x03(\tR\x04keys\x12\x17\n\x07plan_id\x18\x04 \x01(\x03R\x06planId\x1a\x10\n\x0eResultComplete\x1a\xcd\x02\n\x11\x45xecutionProgress\x12V\n\x06stages\x18\x01 \x03(\x0b\x32>.spark.connect.ExecutePlanResponse.ExecutionProgress.StageInfoR\x06stages\x12,\n\x12num_inflight_tasks\x18\x02 \x01(\x03R\x10numInflightTasks\x1a\xb1\x01\n\tStageInfo\x12\x19\n\x08stage_id\x18\x01 \x01(\x03R\x07stageId\x12\x1b\n\tnum_tasks\x18\x02 \x01(\x03R\x08numTasks\x12.\n\x13num_completed_tasks\x18\x03 \x01(\x03R\x11numCompletedTasks\x12(\n\x10input_bytes_read\x18\x04 \x01(\x03R\x0einputBytesRead\x12\x12\n\x04\x64one\x18\x05 \x01(\x08R\x04\x64oneB\x0f\n\rresponse_type"A\n\x08KeyValue\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x19\n\x05value\x18\x02 \x01(\tH\x00R\x05value\x88\x01\x01\x42\x08\n\x06_value"\xaf\t\n\rConfigRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x08 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x44\n\toperation\x18\x03 \x01(\x0b\x32&.spark.connect.ConfigRequest.OperationR\toperation\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x1a\xf2\x03\n\tOperation\x12\x34\n\x03set\x18\x01 \x01(\x0b\x32 .spark.connect.ConfigRequest.SetH\x00R\x03set\x12\x34\n\x03get\x18\x02 \x01(\x0b\x32 .spark.connect.ConfigRequest.GetH\x00R\x03get\x12W\n\x10get_with_default\x18\x03 \x01(\x0b\x32+.spark.connect.ConfigRequest.GetWithDefaultH\x00R\x0egetWithDefault\x12G\n\nget_option\x18\x04 \x01(\x0b\x32&.spark.connect.ConfigRequest.GetOptionH\x00R\tgetOption\x12>\n\x07get_all\x18\x05 \x01(\x0b\x32#.spark.connect.ConfigRequest.GetAllH\x00R\x06getAll\x12:\n\x05unset\x18\x06 \x01(\x0b\x32".spark.connect.ConfigRequest.UnsetH\x00R\x05unset\x12P\n\ris_modifiable\x18\x07 \x01(\x0b\x32).spark.connect.ConfigRequest.IsModifiableH\x00R\x0cisModifiableB\t\n\x07op_type\x1a\\\n\x03Set\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1b\n\x06silent\x18\x02 \x01(\x08H\x00R\x06silent\x88\x01\x01\x42\t\n\x07_silent\x1a\x19\n\x03Get\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a?\n\x0eGetWithDefault\x12-\n\x05pairs\x18\x01 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x1a\x1f\n\tGetOption\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a\x30\n\x06GetAll\x12\x1b\n\x06prefix\x18\x01 \x01(\tH\x00R\x06prefix\x88\x01\x01\x42\t\n\x07_prefix\x1a\x1b\n\x05Unset\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keys\x1a"\n\x0cIsModifiable\x12\x12\n\x04keys\x18\x01 \x03(\tR\x04keysB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\xaf\x01\n\x0e\x43onfigResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x04 \x01(\tR\x13serverSideSessionId\x12-\n\x05pairs\x18\x02 \x03(\x0b\x32\x17.spark.connect.KeyValueR\x05pairs\x12\x1a\n\x08warnings\x18\x03 \x03(\tR\x08warnings"\xea\x07\n\x13\x41\x64\x64\x41rtifactsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12V\n&client_observed_server_side_session_id\x18\x07 \x01(\tH\x01R!clientObservedServerSideSessionId\x88\x01\x01\x12$\n\x0b\x63lient_type\x18\x06 \x01(\tH\x02R\nclientType\x88\x01\x01\x12@\n\x05\x62\x61tch\x18\x03 \x01(\x0b\x32(.spark.connect.AddArtifactsRequest.BatchH\x00R\x05\x62\x61tch\x12Z\n\x0b\x62\x65gin_chunk\x18\x04 \x01(\x0b\x32\x37.spark.connect.AddArtifactsRequest.BeginChunkedArtifactH\x00R\nbeginChunk\x12H\n\x05\x63hunk\x18\x05 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkH\x00R\x05\x63hunk\x1a\x35\n\rArtifactChunk\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta\x12\x10\n\x03\x63rc\x18\x02 \x01(\x03R\x03\x63rc\x1ao\n\x13SingleChunkArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x44\n\x04\x64\x61ta\x18\x02 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x04\x64\x61ta\x1a]\n\x05\x42\x61tch\x12T\n\tartifacts\x18\x01 \x03(\x0b\x32\x36.spark.connect.AddArtifactsRequest.SingleChunkArtifactR\tartifacts\x1a\xc1\x01\n\x14\x42\x65ginChunkedArtifact\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x1f\n\x0btotal_bytes\x18\x02 \x01(\x03R\ntotalBytes\x12\x1d\n\nnum_chunks\x18\x03 \x01(\x03R\tnumChunks\x12U\n\rinitial_chunk\x18\x04 \x01(\x0b\x32\x30.spark.connect.AddArtifactsRequest.ArtifactChunkR\x0cinitialChunkB\t\n\x07payloadB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\x90\x02\n\x14\x41\x64\x64\x41rtifactsResponse\x12\x1d\n\nsession_id\x18\x02 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12Q\n\tartifacts\x18\x01 \x03(\x0b\x32\x33.spark.connect.AddArtifactsResponse.ArtifactSummaryR\tartifacts\x1aQ\n\x0f\x41rtifactSummary\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12*\n\x11is_crc_successful\x18\x02 \x01(\x08R\x0fisCrcSuccessful"\xc6\x02\n\x17\x41rtifactStatusesRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x05 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12\x14\n\x05names\x18\x04 \x03(\tR\x05namesB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\xe0\x02\n\x18\x41rtifactStatusesResponse\x12\x1d\n\nsession_id\x18\x02 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12Q\n\x08statuses\x18\x01 \x03(\x0b\x32\x35.spark.connect.ArtifactStatusesResponse.StatusesEntryR\x08statuses\x1as\n\rStatusesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12L\n\x05value\x18\x02 \x01(\x0b\x32\x36.spark.connect.ArtifactStatusesResponse.ArtifactStatusR\x05value:\x02\x38\x01\x1a(\n\x0e\x41rtifactStatus\x12\x16\n\x06\x65xists\x18\x01 \x01(\x08R\x06\x65xists"\xdb\x04\n\x10InterruptRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x07 \x01(\tH\x01R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x02R\nclientType\x88\x01\x01\x12T\n\x0einterrupt_type\x18\x04 \x01(\x0e\x32-.spark.connect.InterruptRequest.InterruptTypeR\rinterruptType\x12%\n\roperation_tag\x18\x05 \x01(\tH\x00R\x0coperationTag\x12#\n\x0coperation_id\x18\x06 \x01(\tH\x00R\x0boperationId"\x80\x01\n\rInterruptType\x12\x1e\n\x1aINTERRUPT_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12INTERRUPT_TYPE_ALL\x10\x01\x12\x16\n\x12INTERRUPT_TYPE_TAG\x10\x02\x12\x1f\n\x1bINTERRUPT_TYPE_OPERATION_ID\x10\x03\x42\x0b\n\tinterruptB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\x90\x01\n\x11InterruptResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12\'\n\x0finterrupted_ids\x18\x02 \x03(\tR\x0einterruptedIds"5\n\x0fReattachOptions\x12"\n\x0creattachable\x18\x01 \x01(\x08R\x0creattachable"\xb5\x01\n\x15ResultChunkingOptions\x12;\n\x1a\x61llow_arrow_batch_chunking\x18\x01 \x01(\x08R\x17\x61llowArrowBatchChunking\x12@\n\x1apreferred_arrow_chunk_size\x18\x02 \x01(\x03H\x00R\x17preferredArrowChunkSize\x88\x01\x01\x42\x1d\n\x1b_preferred_arrow_chunk_size"\x96\x03\n\x16ReattachExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x06 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x12-\n\x10last_response_id\x18\x05 \x01(\tH\x02R\x0elastResponseId\x88\x01\x01\x42)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_typeB\x13\n\x11_last_response_id"\xc9\x04\n\x15ReleaseExecuteRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x07 \x01(\tH\x01R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12!\n\x0coperation_id\x18\x03 \x01(\tR\x0boperationId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x02R\nclientType\x88\x01\x01\x12R\n\x0brelease_all\x18\x05 \x01(\x0b\x32/.spark.connect.ReleaseExecuteRequest.ReleaseAllH\x00R\nreleaseAll\x12X\n\rrelease_until\x18\x06 \x01(\x0b\x32\x31.spark.connect.ReleaseExecuteRequest.ReleaseUntilH\x00R\x0creleaseUntil\x1a\x0c\n\nReleaseAll\x1a/\n\x0cReleaseUntil\x12\x1f\n\x0bresponse_id\x18\x01 \x01(\tR\nresponseIdB\t\n\x07releaseB)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\xa5\x01\n\x16ReleaseExecuteResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12&\n\x0coperation_id\x18\x02 \x01(\tH\x00R\x0boperationId\x88\x01\x01\x42\x0f\n\r_operation_id"\xd4\x01\n\x15ReleaseSessionRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x00R\nclientType\x88\x01\x01\x12\'\n\x0f\x61llow_reconnect\x18\x04 \x01(\x08R\x0e\x61llowReconnectB\x0e\n\x0c_client_type"l\n\x16ReleaseSessionResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x02 \x01(\tR\x13serverSideSessionId"\xcc\x02\n\x18\x46\x65tchErrorDetailsRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x05 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12\x19\n\x08\x65rror_id\x18\x03 \x01(\tR\x07\x65rrorId\x12$\n\x0b\x63lient_type\x18\x04 \x01(\tH\x01R\nclientType\x88\x01\x01\x42)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_type"\xd9\x0f\n\x19\x46\x65tchErrorDetailsResponse\x12\x33\n\x16server_side_session_id\x18\x03 \x01(\tR\x13serverSideSessionId\x12\x1d\n\nsession_id\x18\x04 \x01(\tR\tsessionId\x12)\n\x0eroot_error_idx\x18\x01 \x01(\x05H\x00R\x0crootErrorIdx\x88\x01\x01\x12\x46\n\x06\x65rrors\x18\x02 \x03(\x0b\x32..spark.connect.FetchErrorDetailsResponse.ErrorR\x06\x65rrors\x1a\xae\x01\n\x11StackTraceElement\x12\'\n\x0f\x64\x65\x63laring_class\x18\x01 \x01(\tR\x0e\x64\x65\x63laringClass\x12\x1f\n\x0bmethod_name\x18\x02 \x01(\tR\nmethodName\x12 \n\tfile_name\x18\x03 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12\x1f\n\x0bline_number\x18\x04 \x01(\x05R\nlineNumberB\x0c\n\n_file_name\x1a\xf0\x02\n\x0cQueryContext\x12\x64\n\x0c\x63ontext_type\x18\n \x01(\x0e\x32\x41.spark.connect.FetchErrorDetailsResponse.QueryContext.ContextTypeR\x0b\x63ontextType\x12\x1f\n\x0bobject_type\x18\x01 \x01(\tR\nobjectType\x12\x1f\n\x0bobject_name\x18\x02 \x01(\tR\nobjectName\x12\x1f\n\x0bstart_index\x18\x03 \x01(\x05R\nstartIndex\x12\x1d\n\nstop_index\x18\x04 \x01(\x05R\tstopIndex\x12\x1a\n\x08\x66ragment\x18\x05 \x01(\tR\x08\x66ragment\x12\x1b\n\tcall_site\x18\x06 \x01(\tR\x08\x63\x61llSite\x12\x18\n\x07summary\x18\x07 \x01(\tR\x07summary"%\n\x0b\x43ontextType\x12\x07\n\x03SQL\x10\x00\x12\r\n\tDATAFRAME\x10\x01\x1a\xa6\x04\n\x0eSparkThrowable\x12$\n\x0b\x65rror_class\x18\x01 \x01(\tH\x00R\nerrorClass\x88\x01\x01\x12}\n\x12message_parameters\x18\x02 \x03(\x0b\x32N.spark.connect.FetchErrorDetailsResponse.SparkThrowable.MessageParametersEntryR\x11messageParameters\x12\\\n\x0equery_contexts\x18\x03 \x03(\x0b\x32\x35.spark.connect.FetchErrorDetailsResponse.QueryContextR\rqueryContexts\x12 \n\tsql_state\x18\x04 \x01(\tH\x01R\x08sqlState\x88\x01\x01\x12r\n\x14\x62reaking_change_info\x18\x05 \x01(\x0b\x32;.spark.connect.FetchErrorDetailsResponse.BreakingChangeInfoH\x02R\x12\x62reakingChangeInfo\x88\x01\x01\x1a\x44\n\x16MessageParametersEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0e\n\x0c_error_classB\x0c\n\n_sql_stateB\x17\n\x15_breaking_change_info\x1a\xfa\x01\n\x12\x42reakingChangeInfo\x12+\n\x11migration_message\x18\x01 \x03(\tR\x10migrationMessage\x12k\n\x11mitigation_config\x18\x02 \x01(\x0b\x32\x39.spark.connect.FetchErrorDetailsResponse.MitigationConfigH\x00R\x10mitigationConfig\x88\x01\x01\x12$\n\x0bneeds_audit\x18\x03 \x01(\x08H\x01R\nneedsAudit\x88\x01\x01\x42\x14\n\x12_mitigation_configB\x0e\n\x0c_needs_audit\x1a:\n\x10MitigationConfig\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value\x1a\xdb\x02\n\x05\x45rror\x12\x30\n\x14\x65rror_type_hierarchy\x18\x01 \x03(\tR\x12\x65rrorTypeHierarchy\x12\x18\n\x07message\x18\x02 \x01(\tR\x07message\x12[\n\x0bstack_trace\x18\x03 \x03(\x0b\x32:.spark.connect.FetchErrorDetailsResponse.StackTraceElementR\nstackTrace\x12 \n\tcause_idx\x18\x04 \x01(\x05H\x00R\x08\x63\x61useIdx\x88\x01\x01\x12\x65\n\x0fspark_throwable\x18\x05 \x01(\x0b\x32\x37.spark.connect.FetchErrorDetailsResponse.SparkThrowableH\x01R\x0esparkThrowable\x88\x01\x01\x42\x0c\n\n_cause_idxB\x12\n\x10_spark_throwableB\x11\n\x0f_root_error_idx"Z\n\x17\x43heckpointCommandResult\x12?\n\x08relation\x18\x01 \x01(\x0b\x32#.spark.connect.CachedRemoteRelationR\x08relation"\xea\x02\n\x13\x43loneSessionRequest\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12V\n&client_observed_server_side_session_id\x18\x05 \x01(\tH\x00R!clientObservedServerSideSessionId\x88\x01\x01\x12=\n\x0cuser_context\x18\x02 \x01(\x0b\x32\x1a.spark.connect.UserContextR\x0buserContext\x12$\n\x0b\x63lient_type\x18\x03 \x01(\tH\x01R\nclientType\x88\x01\x01\x12)\n\x0enew_session_id\x18\x04 \x01(\tH\x02R\x0cnewSessionId\x88\x01\x01\x42)\n\'_client_observed_server_side_session_idB\x0e\n\x0c_client_typeB\x11\n\x0f_new_session_id"\xcc\x01\n\x14\x43loneSessionResponse\x12\x1d\n\nsession_id\x18\x01 \x01(\tR\tsessionId\x12\x33\n\x16server_side_session_id\x18\x02 \x01(\tR\x13serverSideSessionId\x12$\n\x0enew_session_id\x18\x03 \x01(\tR\x0cnewSessionId\x12:\n\x1anew_server_side_session_id\x18\x04 \x01(\tR\x16newServerSideSessionId*Q\n\x10\x43ompressionCodec\x12!\n\x1d\x43OMPRESSION_CODEC_UNSPECIFIED\x10\x00\x12\x1a\n\x16\x43OMPRESSION_CODEC_ZSTD\x10\x01\x32\x8d\x08\n\x13SparkConnectService\x12X\n\x0b\x45xecutePlan\x12!.spark.connect.ExecutePlanRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12V\n\x0b\x41nalyzePlan\x12!.spark.connect.AnalyzePlanRequest\x1a".spark.connect.AnalyzePlanResponse"\x00\x12G\n\x06\x43onfig\x12\x1c.spark.connect.ConfigRequest\x1a\x1d.spark.connect.ConfigResponse"\x00\x12[\n\x0c\x41\x64\x64\x41rtifacts\x12".spark.connect.AddArtifactsRequest\x1a#.spark.connect.AddArtifactsResponse"\x00(\x01\x12\x63\n\x0e\x41rtifactStatus\x12&.spark.connect.ArtifactStatusesRequest\x1a\'.spark.connect.ArtifactStatusesResponse"\x00\x12P\n\tInterrupt\x12\x1f.spark.connect.InterruptRequest\x1a .spark.connect.InterruptResponse"\x00\x12`\n\x0fReattachExecute\x12%.spark.connect.ReattachExecuteRequest\x1a".spark.connect.ExecutePlanResponse"\x00\x30\x01\x12_\n\x0eReleaseExecute\x12$.spark.connect.ReleaseExecuteRequest\x1a%.spark.connect.ReleaseExecuteResponse"\x00\x12_\n\x0eReleaseSession\x12$.spark.connect.ReleaseSessionRequest\x1a%.spark.connect.ReleaseSessionResponse"\x00\x12h\n\x11\x46\x65tchErrorDetails\x12\'.spark.connect.FetchErrorDetailsRequest\x1a(.spark.connect.FetchErrorDetailsResponse"\x00\x12Y\n\x0c\x43loneSession\x12".spark.connect.CloneSessionRequest\x1a#.spark.connect.CloneSessionResponse"\x00\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3'
)
_globals = globals()
@@ -70,200 +70,206 @@
_globals[
"_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY"
]._serialized_options = b"8\001"
- _globals["_PLAN"]._serialized_start = 274
- _globals["_PLAN"]._serialized_end = 390
- _globals["_USERCONTEXT"]._serialized_start = 392
- _globals["_USERCONTEXT"]._serialized_end = 514
- _globals["_ANALYZEPLANREQUEST"]._serialized_start = 517
- _globals["_ANALYZEPLANREQUEST"]._serialized_end = 3194
- _globals["_ANALYZEPLANREQUEST_SCHEMA"]._serialized_start = 1879
- _globals["_ANALYZEPLANREQUEST_SCHEMA"]._serialized_end = 1928
- _globals["_ANALYZEPLANREQUEST_EXPLAIN"]._serialized_start = 1931
- _globals["_ANALYZEPLANREQUEST_EXPLAIN"]._serialized_end = 2246
- _globals["_ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE"]._serialized_start = 2074
- _globals["_ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE"]._serialized_end = 2246
- _globals["_ANALYZEPLANREQUEST_TREESTRING"]._serialized_start = 2248
- _globals["_ANALYZEPLANREQUEST_TREESTRING"]._serialized_end = 2338
- _globals["_ANALYZEPLANREQUEST_ISLOCAL"]._serialized_start = 2340
- _globals["_ANALYZEPLANREQUEST_ISLOCAL"]._serialized_end = 2390
- _globals["_ANALYZEPLANREQUEST_ISSTREAMING"]._serialized_start = 2392
- _globals["_ANALYZEPLANREQUEST_ISSTREAMING"]._serialized_end = 2446
- _globals["_ANALYZEPLANREQUEST_INPUTFILES"]._serialized_start = 2448
- _globals["_ANALYZEPLANREQUEST_INPUTFILES"]._serialized_end = 2501
- _globals["_ANALYZEPLANREQUEST_SPARKVERSION"]._serialized_start = 2503
- _globals["_ANALYZEPLANREQUEST_SPARKVERSION"]._serialized_end = 2517
- _globals["_ANALYZEPLANREQUEST_DDLPARSE"]._serialized_start = 2519
- _globals["_ANALYZEPLANREQUEST_DDLPARSE"]._serialized_end = 2560
- _globals["_ANALYZEPLANREQUEST_SAMESEMANTICS"]._serialized_start = 2562
- _globals["_ANALYZEPLANREQUEST_SAMESEMANTICS"]._serialized_end = 2683
- _globals["_ANALYZEPLANREQUEST_SEMANTICHASH"]._serialized_start = 2685
- _globals["_ANALYZEPLANREQUEST_SEMANTICHASH"]._serialized_end = 2740
- _globals["_ANALYZEPLANREQUEST_PERSIST"]._serialized_start = 2743
- _globals["_ANALYZEPLANREQUEST_PERSIST"]._serialized_end = 2894
- _globals["_ANALYZEPLANREQUEST_UNPERSIST"]._serialized_start = 2896
- _globals["_ANALYZEPLANREQUEST_UNPERSIST"]._serialized_end = 3006
- _globals["_ANALYZEPLANREQUEST_GETSTORAGELEVEL"]._serialized_start = 3008
- _globals["_ANALYZEPLANREQUEST_GETSTORAGELEVEL"]._serialized_end = 3078
- _globals["_ANALYZEPLANREQUEST_JSONTODDL"]._serialized_start = 3080
- _globals["_ANALYZEPLANREQUEST_JSONTODDL"]._serialized_end = 3124
- _globals["_ANALYZEPLANRESPONSE"]._serialized_start = 3197
- _globals["_ANALYZEPLANRESPONSE"]._serialized_end = 5063
- _globals["_ANALYZEPLANRESPONSE_SCHEMA"]._serialized_start = 4438
- _globals["_ANALYZEPLANRESPONSE_SCHEMA"]._serialized_end = 4495
- _globals["_ANALYZEPLANRESPONSE_EXPLAIN"]._serialized_start = 4497
- _globals["_ANALYZEPLANRESPONSE_EXPLAIN"]._serialized_end = 4545
- _globals["_ANALYZEPLANRESPONSE_TREESTRING"]._serialized_start = 4547
- _globals["_ANALYZEPLANRESPONSE_TREESTRING"]._serialized_end = 4592
- _globals["_ANALYZEPLANRESPONSE_ISLOCAL"]._serialized_start = 4594
- _globals["_ANALYZEPLANRESPONSE_ISLOCAL"]._serialized_end = 4630
- _globals["_ANALYZEPLANRESPONSE_ISSTREAMING"]._serialized_start = 4632
- _globals["_ANALYZEPLANRESPONSE_ISSTREAMING"]._serialized_end = 4680
- _globals["_ANALYZEPLANRESPONSE_INPUTFILES"]._serialized_start = 4682
- _globals["_ANALYZEPLANRESPONSE_INPUTFILES"]._serialized_end = 4716
- _globals["_ANALYZEPLANRESPONSE_SPARKVERSION"]._serialized_start = 4718
- _globals["_ANALYZEPLANRESPONSE_SPARKVERSION"]._serialized_end = 4758
- _globals["_ANALYZEPLANRESPONSE_DDLPARSE"]._serialized_start = 4760
- _globals["_ANALYZEPLANRESPONSE_DDLPARSE"]._serialized_end = 4819
- _globals["_ANALYZEPLANRESPONSE_SAMESEMANTICS"]._serialized_start = 4821
- _globals["_ANALYZEPLANRESPONSE_SAMESEMANTICS"]._serialized_end = 4860
- _globals["_ANALYZEPLANRESPONSE_SEMANTICHASH"]._serialized_start = 4862
- _globals["_ANALYZEPLANRESPONSE_SEMANTICHASH"]._serialized_end = 4900
- _globals["_ANALYZEPLANRESPONSE_PERSIST"]._serialized_start = 2743
- _globals["_ANALYZEPLANRESPONSE_PERSIST"]._serialized_end = 2752
- _globals["_ANALYZEPLANRESPONSE_UNPERSIST"]._serialized_start = 2896
- _globals["_ANALYZEPLANRESPONSE_UNPERSIST"]._serialized_end = 2907
- _globals["_ANALYZEPLANRESPONSE_GETSTORAGELEVEL"]._serialized_start = 4926
- _globals["_ANALYZEPLANRESPONSE_GETSTORAGELEVEL"]._serialized_end = 5009
- _globals["_ANALYZEPLANRESPONSE_JSONTODDL"]._serialized_start = 5011
- _globals["_ANALYZEPLANRESPONSE_JSONTODDL"]._serialized_end = 5053
- _globals["_EXECUTEPLANREQUEST"]._serialized_start = 5066
- _globals["_EXECUTEPLANREQUEST"]._serialized_end = 5837
- _globals["_EXECUTEPLANREQUEST_REQUESTOPTION"]._serialized_start = 5500
- _globals["_EXECUTEPLANREQUEST_REQUESTOPTION"]._serialized_end = 5761
- _globals["_EXECUTEPLANRESPONSE"]._serialized_start = 5840
- _globals["_EXECUTEPLANRESPONSE"]._serialized_end = 9297
- _globals["_EXECUTEPLANRESPONSE_SQLCOMMANDRESULT"]._serialized_start = 7940
- _globals["_EXECUTEPLANRESPONSE_SQLCOMMANDRESULT"]._serialized_end = 8011
- _globals["_EXECUTEPLANRESPONSE_ARROWBATCH"]._serialized_start = 8014
- _globals["_EXECUTEPLANRESPONSE_ARROWBATCH"]._serialized_end = 8262
- _globals["_EXECUTEPLANRESPONSE_METRICS"]._serialized_start = 8265
- _globals["_EXECUTEPLANRESPONSE_METRICS"]._serialized_end = 8782
- _globals["_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT"]._serialized_start = 8360
- _globals["_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT"]._serialized_end = 8692
+ _globals["_COMPRESSIONCODEC"]._serialized_start = 18571
+ _globals["_COMPRESSIONCODEC"]._serialized_end = 18652
+ _globals["_PLAN"]._serialized_start = 275
+ _globals["_PLAN"]._serialized_end = 758
+ _globals["_PLAN_COMPRESSEDOPERATION"]._serialized_start = 477
+ _globals["_PLAN_COMPRESSEDOPERATION"]._serialized_end = 747
+ _globals["_PLAN_COMPRESSEDOPERATION_OPTYPE"]._serialized_start = 671
+ _globals["_PLAN_COMPRESSEDOPERATION_OPTYPE"]._serialized_end = 747
+ _globals["_USERCONTEXT"]._serialized_start = 760
+ _globals["_USERCONTEXT"]._serialized_end = 882
+ _globals["_ANALYZEPLANREQUEST"]._serialized_start = 885
+ _globals["_ANALYZEPLANREQUEST"]._serialized_end = 3562
+ _globals["_ANALYZEPLANREQUEST_SCHEMA"]._serialized_start = 2247
+ _globals["_ANALYZEPLANREQUEST_SCHEMA"]._serialized_end = 2296
+ _globals["_ANALYZEPLANREQUEST_EXPLAIN"]._serialized_start = 2299
+ _globals["_ANALYZEPLANREQUEST_EXPLAIN"]._serialized_end = 2614
+ _globals["_ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE"]._serialized_start = 2442
+ _globals["_ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE"]._serialized_end = 2614
+ _globals["_ANALYZEPLANREQUEST_TREESTRING"]._serialized_start = 2616
+ _globals["_ANALYZEPLANREQUEST_TREESTRING"]._serialized_end = 2706
+ _globals["_ANALYZEPLANREQUEST_ISLOCAL"]._serialized_start = 2708
+ _globals["_ANALYZEPLANREQUEST_ISLOCAL"]._serialized_end = 2758
+ _globals["_ANALYZEPLANREQUEST_ISSTREAMING"]._serialized_start = 2760
+ _globals["_ANALYZEPLANREQUEST_ISSTREAMING"]._serialized_end = 2814
+ _globals["_ANALYZEPLANREQUEST_INPUTFILES"]._serialized_start = 2816
+ _globals["_ANALYZEPLANREQUEST_INPUTFILES"]._serialized_end = 2869
+ _globals["_ANALYZEPLANREQUEST_SPARKVERSION"]._serialized_start = 2871
+ _globals["_ANALYZEPLANREQUEST_SPARKVERSION"]._serialized_end = 2885
+ _globals["_ANALYZEPLANREQUEST_DDLPARSE"]._serialized_start = 2887
+ _globals["_ANALYZEPLANREQUEST_DDLPARSE"]._serialized_end = 2928
+ _globals["_ANALYZEPLANREQUEST_SAMESEMANTICS"]._serialized_start = 2930
+ _globals["_ANALYZEPLANREQUEST_SAMESEMANTICS"]._serialized_end = 3051
+ _globals["_ANALYZEPLANREQUEST_SEMANTICHASH"]._serialized_start = 3053
+ _globals["_ANALYZEPLANREQUEST_SEMANTICHASH"]._serialized_end = 3108
+ _globals["_ANALYZEPLANREQUEST_PERSIST"]._serialized_start = 3111
+ _globals["_ANALYZEPLANREQUEST_PERSIST"]._serialized_end = 3262
+ _globals["_ANALYZEPLANREQUEST_UNPERSIST"]._serialized_start = 3264
+ _globals["_ANALYZEPLANREQUEST_UNPERSIST"]._serialized_end = 3374
+ _globals["_ANALYZEPLANREQUEST_GETSTORAGELEVEL"]._serialized_start = 3376
+ _globals["_ANALYZEPLANREQUEST_GETSTORAGELEVEL"]._serialized_end = 3446
+ _globals["_ANALYZEPLANREQUEST_JSONTODDL"]._serialized_start = 3448
+ _globals["_ANALYZEPLANREQUEST_JSONTODDL"]._serialized_end = 3492
+ _globals["_ANALYZEPLANRESPONSE"]._serialized_start = 3565
+ _globals["_ANALYZEPLANRESPONSE"]._serialized_end = 5431
+ _globals["_ANALYZEPLANRESPONSE_SCHEMA"]._serialized_start = 4806
+ _globals["_ANALYZEPLANRESPONSE_SCHEMA"]._serialized_end = 4863
+ _globals["_ANALYZEPLANRESPONSE_EXPLAIN"]._serialized_start = 4865
+ _globals["_ANALYZEPLANRESPONSE_EXPLAIN"]._serialized_end = 4913
+ _globals["_ANALYZEPLANRESPONSE_TREESTRING"]._serialized_start = 4915
+ _globals["_ANALYZEPLANRESPONSE_TREESTRING"]._serialized_end = 4960
+ _globals["_ANALYZEPLANRESPONSE_ISLOCAL"]._serialized_start = 4962
+ _globals["_ANALYZEPLANRESPONSE_ISLOCAL"]._serialized_end = 4998
+ _globals["_ANALYZEPLANRESPONSE_ISSTREAMING"]._serialized_start = 5000
+ _globals["_ANALYZEPLANRESPONSE_ISSTREAMING"]._serialized_end = 5048
+ _globals["_ANALYZEPLANRESPONSE_INPUTFILES"]._serialized_start = 5050
+ _globals["_ANALYZEPLANRESPONSE_INPUTFILES"]._serialized_end = 5084
+ _globals["_ANALYZEPLANRESPONSE_SPARKVERSION"]._serialized_start = 5086
+ _globals["_ANALYZEPLANRESPONSE_SPARKVERSION"]._serialized_end = 5126
+ _globals["_ANALYZEPLANRESPONSE_DDLPARSE"]._serialized_start = 5128
+ _globals["_ANALYZEPLANRESPONSE_DDLPARSE"]._serialized_end = 5187
+ _globals["_ANALYZEPLANRESPONSE_SAMESEMANTICS"]._serialized_start = 5189
+ _globals["_ANALYZEPLANRESPONSE_SAMESEMANTICS"]._serialized_end = 5228
+ _globals["_ANALYZEPLANRESPONSE_SEMANTICHASH"]._serialized_start = 5230
+ _globals["_ANALYZEPLANRESPONSE_SEMANTICHASH"]._serialized_end = 5268
+ _globals["_ANALYZEPLANRESPONSE_PERSIST"]._serialized_start = 3111
+ _globals["_ANALYZEPLANRESPONSE_PERSIST"]._serialized_end = 3120
+ _globals["_ANALYZEPLANRESPONSE_UNPERSIST"]._serialized_start = 3264
+ _globals["_ANALYZEPLANRESPONSE_UNPERSIST"]._serialized_end = 3275
+ _globals["_ANALYZEPLANRESPONSE_GETSTORAGELEVEL"]._serialized_start = 5294
+ _globals["_ANALYZEPLANRESPONSE_GETSTORAGELEVEL"]._serialized_end = 5377
+ _globals["_ANALYZEPLANRESPONSE_JSONTODDL"]._serialized_start = 5379
+ _globals["_ANALYZEPLANRESPONSE_JSONTODDL"]._serialized_end = 5421
+ _globals["_EXECUTEPLANREQUEST"]._serialized_start = 5434
+ _globals["_EXECUTEPLANREQUEST"]._serialized_end = 6205
+ _globals["_EXECUTEPLANREQUEST_REQUESTOPTION"]._serialized_start = 5868
+ _globals["_EXECUTEPLANREQUEST_REQUESTOPTION"]._serialized_end = 6129
+ _globals["_EXECUTEPLANRESPONSE"]._serialized_start = 6208
+ _globals["_EXECUTEPLANRESPONSE"]._serialized_end = 9665
+ _globals["_EXECUTEPLANRESPONSE_SQLCOMMANDRESULT"]._serialized_start = 8308
+ _globals["_EXECUTEPLANRESPONSE_SQLCOMMANDRESULT"]._serialized_end = 8379
+ _globals["_EXECUTEPLANRESPONSE_ARROWBATCH"]._serialized_start = 8382
+ _globals["_EXECUTEPLANRESPONSE_ARROWBATCH"]._serialized_end = 8630
+ _globals["_EXECUTEPLANRESPONSE_METRICS"]._serialized_start = 8633
+ _globals["_EXECUTEPLANRESPONSE_METRICS"]._serialized_end = 9150
+ _globals["_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT"]._serialized_start = 8728
+ _globals["_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT"]._serialized_end = 9060
_globals[
"_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY"
- ]._serialized_start = 8569
+ ]._serialized_start = 8937
_globals[
"_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY"
- ]._serialized_end = 8692
- _globals["_EXECUTEPLANRESPONSE_METRICS_METRICVALUE"]._serialized_start = 8694
- _globals["_EXECUTEPLANRESPONSE_METRICS_METRICVALUE"]._serialized_end = 8782
- _globals["_EXECUTEPLANRESPONSE_OBSERVEDMETRICS"]._serialized_start = 8785
- _globals["_EXECUTEPLANRESPONSE_OBSERVEDMETRICS"]._serialized_end = 8926
- _globals["_EXECUTEPLANRESPONSE_RESULTCOMPLETE"]._serialized_start = 8928
- _globals["_EXECUTEPLANRESPONSE_RESULTCOMPLETE"]._serialized_end = 8944
- _globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS"]._serialized_start = 8947
- _globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS"]._serialized_end = 9280
- _globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO"]._serialized_start = 9103
- _globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO"]._serialized_end = 9280
- _globals["_KEYVALUE"]._serialized_start = 9299
- _globals["_KEYVALUE"]._serialized_end = 9364
- _globals["_CONFIGREQUEST"]._serialized_start = 9367
- _globals["_CONFIGREQUEST"]._serialized_end = 10566
- _globals["_CONFIGREQUEST_OPERATION"]._serialized_start = 9675
- _globals["_CONFIGREQUEST_OPERATION"]._serialized_end = 10173
- _globals["_CONFIGREQUEST_SET"]._serialized_start = 10175
- _globals["_CONFIGREQUEST_SET"]._serialized_end = 10267
- _globals["_CONFIGREQUEST_GET"]._serialized_start = 10269
- _globals["_CONFIGREQUEST_GET"]._serialized_end = 10294
- _globals["_CONFIGREQUEST_GETWITHDEFAULT"]._serialized_start = 10296
- _globals["_CONFIGREQUEST_GETWITHDEFAULT"]._serialized_end = 10359
- _globals["_CONFIGREQUEST_GETOPTION"]._serialized_start = 10361
- _globals["_CONFIGREQUEST_GETOPTION"]._serialized_end = 10392
- _globals["_CONFIGREQUEST_GETALL"]._serialized_start = 10394
- _globals["_CONFIGREQUEST_GETALL"]._serialized_end = 10442
- _globals["_CONFIGREQUEST_UNSET"]._serialized_start = 10444
- _globals["_CONFIGREQUEST_UNSET"]._serialized_end = 10471
- _globals["_CONFIGREQUEST_ISMODIFIABLE"]._serialized_start = 10473
- _globals["_CONFIGREQUEST_ISMODIFIABLE"]._serialized_end = 10507
- _globals["_CONFIGRESPONSE"]._serialized_start = 10569
- _globals["_CONFIGRESPONSE"]._serialized_end = 10744
- _globals["_ADDARTIFACTSREQUEST"]._serialized_start = 10747
- _globals["_ADDARTIFACTSREQUEST"]._serialized_end = 11749
- _globals["_ADDARTIFACTSREQUEST_ARTIFACTCHUNK"]._serialized_start = 11222
- _globals["_ADDARTIFACTSREQUEST_ARTIFACTCHUNK"]._serialized_end = 11275
- _globals["_ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT"]._serialized_start = 11277
- _globals["_ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT"]._serialized_end = 11388
- _globals["_ADDARTIFACTSREQUEST_BATCH"]._serialized_start = 11390
- _globals["_ADDARTIFACTSREQUEST_BATCH"]._serialized_end = 11483
- _globals["_ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT"]._serialized_start = 11486
- _globals["_ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT"]._serialized_end = 11679
- _globals["_ADDARTIFACTSRESPONSE"]._serialized_start = 11752
- _globals["_ADDARTIFACTSRESPONSE"]._serialized_end = 12024
- _globals["_ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY"]._serialized_start = 11943
- _globals["_ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY"]._serialized_end = 12024
- _globals["_ARTIFACTSTATUSESREQUEST"]._serialized_start = 12027
- _globals["_ARTIFACTSTATUSESREQUEST"]._serialized_end = 12353
- _globals["_ARTIFACTSTATUSESRESPONSE"]._serialized_start = 12356
- _globals["_ARTIFACTSTATUSESRESPONSE"]._serialized_end = 12708
- _globals["_ARTIFACTSTATUSESRESPONSE_STATUSESENTRY"]._serialized_start = 12551
- _globals["_ARTIFACTSTATUSESRESPONSE_STATUSESENTRY"]._serialized_end = 12666
- _globals["_ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS"]._serialized_start = 12668
- _globals["_ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS"]._serialized_end = 12708
- _globals["_INTERRUPTREQUEST"]._serialized_start = 12711
- _globals["_INTERRUPTREQUEST"]._serialized_end = 13314
- _globals["_INTERRUPTREQUEST_INTERRUPTTYPE"]._serialized_start = 13114
- _globals["_INTERRUPTREQUEST_INTERRUPTTYPE"]._serialized_end = 13242
- _globals["_INTERRUPTRESPONSE"]._serialized_start = 13317
- _globals["_INTERRUPTRESPONSE"]._serialized_end = 13461
- _globals["_REATTACHOPTIONS"]._serialized_start = 13463
- _globals["_REATTACHOPTIONS"]._serialized_end = 13516
- _globals["_RESULTCHUNKINGOPTIONS"]._serialized_start = 13519
- _globals["_RESULTCHUNKINGOPTIONS"]._serialized_end = 13700
- _globals["_REATTACHEXECUTEREQUEST"]._serialized_start = 13703
- _globals["_REATTACHEXECUTEREQUEST"]._serialized_end = 14109
- _globals["_RELEASEEXECUTEREQUEST"]._serialized_start = 14112
- _globals["_RELEASEEXECUTEREQUEST"]._serialized_end = 14697
- _globals["_RELEASEEXECUTEREQUEST_RELEASEALL"]._serialized_start = 14566
- _globals["_RELEASEEXECUTEREQUEST_RELEASEALL"]._serialized_end = 14578
- _globals["_RELEASEEXECUTEREQUEST_RELEASEUNTIL"]._serialized_start = 14580
- _globals["_RELEASEEXECUTEREQUEST_RELEASEUNTIL"]._serialized_end = 14627
- _globals["_RELEASEEXECUTERESPONSE"]._serialized_start = 14700
- _globals["_RELEASEEXECUTERESPONSE"]._serialized_end = 14865
- _globals["_RELEASESESSIONREQUEST"]._serialized_start = 14868
- _globals["_RELEASESESSIONREQUEST"]._serialized_end = 15080
- _globals["_RELEASESESSIONRESPONSE"]._serialized_start = 15082
- _globals["_RELEASESESSIONRESPONSE"]._serialized_end = 15190
- _globals["_FETCHERRORDETAILSREQUEST"]._serialized_start = 15193
- _globals["_FETCHERRORDETAILSREQUEST"]._serialized_end = 15525
- _globals["_FETCHERRORDETAILSRESPONSE"]._serialized_start = 15528
- _globals["_FETCHERRORDETAILSRESPONSE"]._serialized_end = 17537
- _globals["_FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT"]._serialized_start = 15757
- _globals["_FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT"]._serialized_end = 15931
- _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT"]._serialized_start = 15934
- _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT"]._serialized_end = 16302
- _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE"]._serialized_start = 16265
- _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE"]._serialized_end = 16302
- _globals["_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE"]._serialized_start = 16305
- _globals["_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE"]._serialized_end = 16855
+ ]._serialized_end = 9060
+ _globals["_EXECUTEPLANRESPONSE_METRICS_METRICVALUE"]._serialized_start = 9062
+ _globals["_EXECUTEPLANRESPONSE_METRICS_METRICVALUE"]._serialized_end = 9150
+ _globals["_EXECUTEPLANRESPONSE_OBSERVEDMETRICS"]._serialized_start = 9153
+ _globals["_EXECUTEPLANRESPONSE_OBSERVEDMETRICS"]._serialized_end = 9294
+ _globals["_EXECUTEPLANRESPONSE_RESULTCOMPLETE"]._serialized_start = 9296
+ _globals["_EXECUTEPLANRESPONSE_RESULTCOMPLETE"]._serialized_end = 9312
+ _globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS"]._serialized_start = 9315
+ _globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS"]._serialized_end = 9648
+ _globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO"]._serialized_start = 9471
+ _globals["_EXECUTEPLANRESPONSE_EXECUTIONPROGRESS_STAGEINFO"]._serialized_end = 9648
+ _globals["_KEYVALUE"]._serialized_start = 9667
+ _globals["_KEYVALUE"]._serialized_end = 9732
+ _globals["_CONFIGREQUEST"]._serialized_start = 9735
+ _globals["_CONFIGREQUEST"]._serialized_end = 10934
+ _globals["_CONFIGREQUEST_OPERATION"]._serialized_start = 10043
+ _globals["_CONFIGREQUEST_OPERATION"]._serialized_end = 10541
+ _globals["_CONFIGREQUEST_SET"]._serialized_start = 10543
+ _globals["_CONFIGREQUEST_SET"]._serialized_end = 10635
+ _globals["_CONFIGREQUEST_GET"]._serialized_start = 10637
+ _globals["_CONFIGREQUEST_GET"]._serialized_end = 10662
+ _globals["_CONFIGREQUEST_GETWITHDEFAULT"]._serialized_start = 10664
+ _globals["_CONFIGREQUEST_GETWITHDEFAULT"]._serialized_end = 10727
+ _globals["_CONFIGREQUEST_GETOPTION"]._serialized_start = 10729
+ _globals["_CONFIGREQUEST_GETOPTION"]._serialized_end = 10760
+ _globals["_CONFIGREQUEST_GETALL"]._serialized_start = 10762
+ _globals["_CONFIGREQUEST_GETALL"]._serialized_end = 10810
+ _globals["_CONFIGREQUEST_UNSET"]._serialized_start = 10812
+ _globals["_CONFIGREQUEST_UNSET"]._serialized_end = 10839
+ _globals["_CONFIGREQUEST_ISMODIFIABLE"]._serialized_start = 10841
+ _globals["_CONFIGREQUEST_ISMODIFIABLE"]._serialized_end = 10875
+ _globals["_CONFIGRESPONSE"]._serialized_start = 10937
+ _globals["_CONFIGRESPONSE"]._serialized_end = 11112
+ _globals["_ADDARTIFACTSREQUEST"]._serialized_start = 11115
+ _globals["_ADDARTIFACTSREQUEST"]._serialized_end = 12117
+ _globals["_ADDARTIFACTSREQUEST_ARTIFACTCHUNK"]._serialized_start = 11590
+ _globals["_ADDARTIFACTSREQUEST_ARTIFACTCHUNK"]._serialized_end = 11643
+ _globals["_ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT"]._serialized_start = 11645
+ _globals["_ADDARTIFACTSREQUEST_SINGLECHUNKARTIFACT"]._serialized_end = 11756
+ _globals["_ADDARTIFACTSREQUEST_BATCH"]._serialized_start = 11758
+ _globals["_ADDARTIFACTSREQUEST_BATCH"]._serialized_end = 11851
+ _globals["_ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT"]._serialized_start = 11854
+ _globals["_ADDARTIFACTSREQUEST_BEGINCHUNKEDARTIFACT"]._serialized_end = 12047
+ _globals["_ADDARTIFACTSRESPONSE"]._serialized_start = 12120
+ _globals["_ADDARTIFACTSRESPONSE"]._serialized_end = 12392
+ _globals["_ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY"]._serialized_start = 12311
+ _globals["_ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY"]._serialized_end = 12392
+ _globals["_ARTIFACTSTATUSESREQUEST"]._serialized_start = 12395
+ _globals["_ARTIFACTSTATUSESREQUEST"]._serialized_end = 12721
+ _globals["_ARTIFACTSTATUSESRESPONSE"]._serialized_start = 12724
+ _globals["_ARTIFACTSTATUSESRESPONSE"]._serialized_end = 13076
+ _globals["_ARTIFACTSTATUSESRESPONSE_STATUSESENTRY"]._serialized_start = 12919
+ _globals["_ARTIFACTSTATUSESRESPONSE_STATUSESENTRY"]._serialized_end = 13034
+ _globals["_ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS"]._serialized_start = 13036
+ _globals["_ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS"]._serialized_end = 13076
+ _globals["_INTERRUPTREQUEST"]._serialized_start = 13079
+ _globals["_INTERRUPTREQUEST"]._serialized_end = 13682
+ _globals["_INTERRUPTREQUEST_INTERRUPTTYPE"]._serialized_start = 13482
+ _globals["_INTERRUPTREQUEST_INTERRUPTTYPE"]._serialized_end = 13610
+ _globals["_INTERRUPTRESPONSE"]._serialized_start = 13685
+ _globals["_INTERRUPTRESPONSE"]._serialized_end = 13829
+ _globals["_REATTACHOPTIONS"]._serialized_start = 13831
+ _globals["_REATTACHOPTIONS"]._serialized_end = 13884
+ _globals["_RESULTCHUNKINGOPTIONS"]._serialized_start = 13887
+ _globals["_RESULTCHUNKINGOPTIONS"]._serialized_end = 14068
+ _globals["_REATTACHEXECUTEREQUEST"]._serialized_start = 14071
+ _globals["_REATTACHEXECUTEREQUEST"]._serialized_end = 14477
+ _globals["_RELEASEEXECUTEREQUEST"]._serialized_start = 14480
+ _globals["_RELEASEEXECUTEREQUEST"]._serialized_end = 15065
+ _globals["_RELEASEEXECUTEREQUEST_RELEASEALL"]._serialized_start = 14934
+ _globals["_RELEASEEXECUTEREQUEST_RELEASEALL"]._serialized_end = 14946
+ _globals["_RELEASEEXECUTEREQUEST_RELEASEUNTIL"]._serialized_start = 14948
+ _globals["_RELEASEEXECUTEREQUEST_RELEASEUNTIL"]._serialized_end = 14995
+ _globals["_RELEASEEXECUTERESPONSE"]._serialized_start = 15068
+ _globals["_RELEASEEXECUTERESPONSE"]._serialized_end = 15233
+ _globals["_RELEASESESSIONREQUEST"]._serialized_start = 15236
+ _globals["_RELEASESESSIONREQUEST"]._serialized_end = 15448
+ _globals["_RELEASESESSIONRESPONSE"]._serialized_start = 15450
+ _globals["_RELEASESESSIONRESPONSE"]._serialized_end = 15558
+ _globals["_FETCHERRORDETAILSREQUEST"]._serialized_start = 15561
+ _globals["_FETCHERRORDETAILSREQUEST"]._serialized_end = 15893
+ _globals["_FETCHERRORDETAILSRESPONSE"]._serialized_start = 15896
+ _globals["_FETCHERRORDETAILSRESPONSE"]._serialized_end = 17905
+ _globals["_FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT"]._serialized_start = 16125
+ _globals["_FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT"]._serialized_end = 16299
+ _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT"]._serialized_start = 16302
+ _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT"]._serialized_end = 16670
+ _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE"]._serialized_start = 16633
+ _globals["_FETCHERRORDETAILSRESPONSE_QUERYCONTEXT_CONTEXTTYPE"]._serialized_end = 16670
+ _globals["_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE"]._serialized_start = 16673
+ _globals["_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE"]._serialized_end = 17223
_globals[
"_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY"
- ]._serialized_start = 16732
+ ]._serialized_start = 17100
_globals[
"_FETCHERRORDETAILSRESPONSE_SPARKTHROWABLE_MESSAGEPARAMETERSENTRY"
- ]._serialized_end = 16800
- _globals["_FETCHERRORDETAILSRESPONSE_BREAKINGCHANGEINFO"]._serialized_start = 16858
- _globals["_FETCHERRORDETAILSRESPONSE_BREAKINGCHANGEINFO"]._serialized_end = 17108
- _globals["_FETCHERRORDETAILSRESPONSE_MITIGATIONCONFIG"]._serialized_start = 17110
- _globals["_FETCHERRORDETAILSRESPONSE_MITIGATIONCONFIG"]._serialized_end = 17168
- _globals["_FETCHERRORDETAILSRESPONSE_ERROR"]._serialized_start = 17171
- _globals["_FETCHERRORDETAILSRESPONSE_ERROR"]._serialized_end = 17518
- _globals["_CHECKPOINTCOMMANDRESULT"]._serialized_start = 17539
- _globals["_CHECKPOINTCOMMANDRESULT"]._serialized_end = 17629
- _globals["_CLONESESSIONREQUEST"]._serialized_start = 17632
- _globals["_CLONESESSIONREQUEST"]._serialized_end = 17994
- _globals["_CLONESESSIONRESPONSE"]._serialized_start = 17997
- _globals["_CLONESESSIONRESPONSE"]._serialized_end = 18201
- _globals["_SPARKCONNECTSERVICE"]._serialized_start = 18204
- _globals["_SPARKCONNECTSERVICE"]._serialized_end = 19241
+ ]._serialized_end = 17168
+ _globals["_FETCHERRORDETAILSRESPONSE_BREAKINGCHANGEINFO"]._serialized_start = 17226
+ _globals["_FETCHERRORDETAILSRESPONSE_BREAKINGCHANGEINFO"]._serialized_end = 17476
+ _globals["_FETCHERRORDETAILSRESPONSE_MITIGATIONCONFIG"]._serialized_start = 17478
+ _globals["_FETCHERRORDETAILSRESPONSE_MITIGATIONCONFIG"]._serialized_end = 17536
+ _globals["_FETCHERRORDETAILSRESPONSE_ERROR"]._serialized_start = 17539
+ _globals["_FETCHERRORDETAILSRESPONSE_ERROR"]._serialized_end = 17886
+ _globals["_CHECKPOINTCOMMANDRESULT"]._serialized_start = 17907
+ _globals["_CHECKPOINTCOMMANDRESULT"]._serialized_end = 17997
+ _globals["_CLONESESSIONREQUEST"]._serialized_start = 18000
+ _globals["_CLONESESSIONREQUEST"]._serialized_end = 18362
+ _globals["_CLONESESSIONRESPONSE"]._serialized_start = 18365
+ _globals["_CLONESESSIONRESPONSE"]._serialized_end = 18569
+ _globals["_SPARKCONNECTSERVICE"]._serialized_start = 18655
+ _globals["_SPARKCONNECTSERVICE"]._serialized_end = 19692
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi
index dc3099ecdffca..f12c21e5536de 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -57,42 +57,123 @@ else:
DESCRIPTOR: google.protobuf.descriptor.FileDescriptor
+class _CompressionCodec:
+ ValueType = typing.NewType("ValueType", builtins.int)
+ V: typing_extensions.TypeAlias = ValueType
+
+class _CompressionCodecEnumTypeWrapper(
+ google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_CompressionCodec.ValueType],
+ builtins.type,
+): # noqa: F821
+ DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
+ COMPRESSION_CODEC_UNSPECIFIED: _CompressionCodec.ValueType # 0
+ COMPRESSION_CODEC_ZSTD: _CompressionCodec.ValueType # 1
+
+class CompressionCodec(_CompressionCodec, metaclass=_CompressionCodecEnumTypeWrapper):
+ """Compression codec for plan compression."""
+
+COMPRESSION_CODEC_UNSPECIFIED: CompressionCodec.ValueType # 0
+COMPRESSION_CODEC_ZSTD: CompressionCodec.ValueType # 1
+global___CompressionCodec = CompressionCodec
+
class Plan(google.protobuf.message.Message):
"""A [[Plan]] is the structure that carries the runtime information for the execution from the
- client to the server. A [[Plan]] can either be of the type [[Relation]] which is a reference
- to the underlying logical plan or it can be of the [[Command]] type that is used to execute
- commands on the server.
+ client to the server. A [[Plan]] can be one of the following:
+ - [[Relation]]: a reference to the underlying logical plan.
+ - [[Command]]: used to execute commands on the server.
+ - [[CompressedOperation]]: a compressed representation of either a Relation or a Command.
"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
+ class CompressedOperation(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ class _OpType:
+ ValueType = typing.NewType("ValueType", builtins.int)
+ V: typing_extensions.TypeAlias = ValueType
+
+ class _OpTypeEnumTypeWrapper(
+ google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[
+ Plan.CompressedOperation._OpType.ValueType
+ ],
+ builtins.type,
+ ): # noqa: F821
+ DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
+ OP_TYPE_UNSPECIFIED: Plan.CompressedOperation._OpType.ValueType # 0
+ OP_TYPE_RELATION: Plan.CompressedOperation._OpType.ValueType # 1
+ OP_TYPE_COMMAND: Plan.CompressedOperation._OpType.ValueType # 2
+
+ class OpType(_OpType, metaclass=_OpTypeEnumTypeWrapper): ...
+ OP_TYPE_UNSPECIFIED: Plan.CompressedOperation.OpType.ValueType # 0
+ OP_TYPE_RELATION: Plan.CompressedOperation.OpType.ValueType # 1
+ OP_TYPE_COMMAND: Plan.CompressedOperation.OpType.ValueType # 2
+
+ DATA_FIELD_NUMBER: builtins.int
+ OP_TYPE_FIELD_NUMBER: builtins.int
+ COMPRESSION_CODEC_FIELD_NUMBER: builtins.int
+ data: builtins.bytes
+ op_type: global___Plan.CompressedOperation.OpType.ValueType
+ compression_codec: global___CompressionCodec.ValueType
+ def __init__(
+ self,
+ *,
+ data: builtins.bytes = ...,
+ op_type: global___Plan.CompressedOperation.OpType.ValueType = ...,
+ compression_codec: global___CompressionCodec.ValueType = ...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "compression_codec", b"compression_codec", "data", b"data", "op_type", b"op_type"
+ ],
+ ) -> None: ...
+
ROOT_FIELD_NUMBER: builtins.int
COMMAND_FIELD_NUMBER: builtins.int
+ COMPRESSED_OPERATION_FIELD_NUMBER: builtins.int
@property
def root(self) -> pyspark.sql.connect.proto.relations_pb2.Relation: ...
@property
def command(self) -> pyspark.sql.connect.proto.commands_pb2.Command: ...
+ @property
+ def compressed_operation(self) -> global___Plan.CompressedOperation: ...
def __init__(
self,
*,
root: pyspark.sql.connect.proto.relations_pb2.Relation | None = ...,
command: pyspark.sql.connect.proto.commands_pb2.Command | None = ...,
+ compressed_operation: global___Plan.CompressedOperation | None = ...,
) -> None: ...
def HasField(
self,
field_name: typing_extensions.Literal[
- "command", b"command", "op_type", b"op_type", "root", b"root"
+ "command",
+ b"command",
+ "compressed_operation",
+ b"compressed_operation",
+ "op_type",
+ b"op_type",
+ "root",
+ b"root",
],
) -> builtins.bool: ...
def ClearField(
self,
field_name: typing_extensions.Literal[
- "command", b"command", "op_type", b"op_type", "root", b"root"
+ "command",
+ b"command",
+ "compressed_operation",
+ b"compressed_operation",
+ "op_type",
+ b"op_type",
+ "root",
+ b"root",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["op_type", b"op_type"]
- ) -> typing_extensions.Literal["root", "command"] | None: ...
+ ) -> typing_extensions.Literal["root", "command", "compressed_operation"] | None: ...
global___Plan = Plan
diff --git a/python/pyspark/sql/connect/proto/catalog_pb2.py b/python/pyspark/sql/connect/proto/catalog_pb2.py
index 58c129a01daa8..054b367bd3b34 100644
--- a/python/pyspark/sql/connect/proto/catalog_pb2.py
+++ b/python/pyspark/sql/connect/proto/catalog_pb2.py
@@ -18,7 +18,7 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: spark/connect/catalog.proto
-# Protobuf Python Version: 5.29.5
+# Protobuf Python Version: 6.33.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
@@ -27,7 +27,7 @@
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
- _runtime_version.Domain.PUBLIC, 5, 29, 5, "", "spark/connect/catalog.proto"
+ _runtime_version.Domain.PUBLIC, 6, 33, 0, "", "spark/connect/catalog.proto"
)
# @@protoc_insertion_point(imports)
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py
index 694b4a9a9aa37..4eccf1b71706d 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.py
+++ b/python/pyspark/sql/connect/proto/commands_pb2.py
@@ -18,7 +18,7 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: spark/connect/commands.proto
-# Protobuf Python Version: 5.29.5
+# Protobuf Python Version: 6.33.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
@@ -27,7 +27,7 @@
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
- _runtime_version.Domain.PUBLIC, 5, 29, 5, "", "spark/connect/commands.proto"
+ _runtime_version.Domain.PUBLIC, 6, 33, 0, "", "spark/connect/commands.proto"
)
# @@protoc_insertion_point(imports)
diff --git a/python/pyspark/sql/connect/proto/common_pb2.py b/python/pyspark/sql/connect/proto/common_pb2.py
index 07ea9f7ed3173..8abd8fa6dc041 100644
--- a/python/pyspark/sql/connect/proto/common_pb2.py
+++ b/python/pyspark/sql/connect/proto/common_pb2.py
@@ -18,7 +18,7 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: spark/connect/common.proto
-# Protobuf Python Version: 5.29.5
+# Protobuf Python Version: 6.33.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
@@ -27,7 +27,7 @@
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
- _runtime_version.Domain.PUBLIC, 5, 29, 5, "", "spark/connect/common.proto"
+ _runtime_version.Domain.PUBLIC, 6, 33, 0, "", "spark/connect/common.proto"
)
# @@protoc_insertion_point(imports)
diff --git a/python/pyspark/sql/connect/proto/example_plugins_pb2.py b/python/pyspark/sql/connect/proto/example_plugins_pb2.py
index 71a73a6d592ae..423768ee63d65 100644
--- a/python/pyspark/sql/connect/proto/example_plugins_pb2.py
+++ b/python/pyspark/sql/connect/proto/example_plugins_pb2.py
@@ -18,7 +18,7 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: spark/connect/example_plugins.proto
-# Protobuf Python Version: 5.29.5
+# Protobuf Python Version: 6.33.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
@@ -27,7 +27,7 @@
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
- _runtime_version.Domain.PUBLIC, 5, 29, 5, "", "spark/connect/example_plugins.proto"
+ _runtime_version.Domain.PUBLIC, 6, 33, 0, "", "spark/connect/example_plugins.proto"
)
# @@protoc_insertion_point(imports)
diff --git a/python/pyspark/sql/connect/proto/expressions_pb2.py b/python/pyspark/sql/connect/proto/expressions_pb2.py
index bd75ade02d8be..0c466aeb67a0d 100644
--- a/python/pyspark/sql/connect/proto/expressions_pb2.py
+++ b/python/pyspark/sql/connect/proto/expressions_pb2.py
@@ -18,7 +18,7 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: spark/connect/expressions.proto
-# Protobuf Python Version: 5.29.5
+# Protobuf Python Version: 6.33.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
@@ -27,7 +27,7 @@
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
- _runtime_version.Domain.PUBLIC, 5, 29, 5, "", "spark/connect/expressions.proto"
+ _runtime_version.Domain.PUBLIC, 6, 33, 0, "", "spark/connect/expressions.proto"
)
# @@protoc_insertion_point(imports)
diff --git a/python/pyspark/sql/connect/proto/ml_common_pb2.py b/python/pyspark/sql/connect/proto/ml_common_pb2.py
index a49491b8ad1ed..de547fc2a102f 100644
--- a/python/pyspark/sql/connect/proto/ml_common_pb2.py
+++ b/python/pyspark/sql/connect/proto/ml_common_pb2.py
@@ -18,7 +18,7 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: spark/connect/ml_common.proto
-# Protobuf Python Version: 5.29.5
+# Protobuf Python Version: 6.33.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
@@ -27,7 +27,7 @@
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
- _runtime_version.Domain.PUBLIC, 5, 29, 5, "", "spark/connect/ml_common.proto"
+ _runtime_version.Domain.PUBLIC, 6, 33, 0, "", "spark/connect/ml_common.proto"
)
# @@protoc_insertion_point(imports)
diff --git a/python/pyspark/sql/connect/proto/ml_pb2.py b/python/pyspark/sql/connect/proto/ml_pb2.py
index 9574966472a58..3bd141815c8eb 100644
--- a/python/pyspark/sql/connect/proto/ml_pb2.py
+++ b/python/pyspark/sql/connect/proto/ml_pb2.py
@@ -18,7 +18,7 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: spark/connect/ml.proto
-# Protobuf Python Version: 5.29.5
+# Protobuf Python Version: 6.33.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
@@ -27,7 +27,7 @@
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
- _runtime_version.Domain.PUBLIC, 5, 29, 5, "", "spark/connect/ml.proto"
+ _runtime_version.Domain.PUBLIC, 6, 33, 0, "", "spark/connect/ml.proto"
)
# @@protoc_insertion_point(imports)
diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.py b/python/pyspark/sql/connect/proto/pipelines_pb2.py
index f3489f55ed874..7a30def861d29 100644
--- a/python/pyspark/sql/connect/proto/pipelines_pb2.py
+++ b/python/pyspark/sql/connect/proto/pipelines_pb2.py
@@ -18,7 +18,7 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: spark/connect/pipelines.proto
-# Protobuf Python Version: 5.29.5
+# Protobuf Python Version: 6.33.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
@@ -27,7 +27,7 @@
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
- _runtime_version.Domain.PUBLIC, 5, 29, 5, "", "spark/connect/pipelines.proto"
+ _runtime_version.Domain.PUBLIC, 6, 33, 0, "", "spark/connect/pipelines.proto"
)
# @@protoc_insertion_point(imports)
@@ -42,7 +42,7 @@
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\x9c"\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01 \x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12R\n\rdefine_output\x18\x02 \x01(\x0b\x32+.spark.connect.PipelineCommand.DefineOutputH\x00R\x0c\x64\x65\x66ineOutput\x12L\n\x0b\x64\x65\x66ine_flow\x18\x03 \x01(\x0b\x32).spark.connect.PipelineCommand.DefineFlowH\x00R\ndefineFlow\x12\x62\n\x13\x64rop_dataflow_graph\x18\x04 \x01(\x0b\x32\x30.spark.connect.PipelineCommand.DropDataflowGraphH\x00R\x11\x64ropDataflowGraph\x12\x46\n\tstart_run\x18\x05 \x01(\x0b\x32\'.spark.connect.PipelineCommand.StartRunH\x00R\x08startRun\x12r\n\x19\x64\x65\x66ine_sql_graph_elements\x18\x06 \x01(\x0b\x32\x35.spark.connect.PipelineCommand.DefineSqlGraphElementsH\x00R\x16\x64\x65\x66ineSqlGraphElements\x12\xa1\x01\n*get_query_function_execution_signal_stream\x18\x07 \x01(\x0b\x32\x44.spark.connect.PipelineCommand.GetQueryFunctionExecutionSignalStreamH\x00R%getQueryFunctionExecutionSignalStream\x12\x88\x01\n!define_flow_query_function_result\x18\x08 \x01(\x0b\x32<.spark.connect.PipelineCommand.DefineFlowQueryFunctionResultH\x00R\x1d\x64\x65\x66ineFlowQueryFunctionResult\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\xb4\x02\n\x13\x43reateDataflowGraph\x12,\n\x0f\x64\x65\x66\x61ult_catalog\x18\x01 \x01(\tH\x00R\x0e\x64\x65\x66\x61ultCatalog\x88\x01\x01\x12.\n\x10\x64\x65\x66\x61ult_database\x18\x02 \x01(\tH\x01R\x0f\x64\x65\x66\x61ultDatabase\x88\x01\x01\x12Z\n\x08sql_conf\x18\x05 \x03(\x0b\x32?.spark.connect.PipelineCommand.CreateDataflowGraph.SqlConfEntryR\x07sqlConf\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x12\n\x10_default_catalogB\x13\n\x11_default_database\x1aZ\n\x11\x44ropDataflowGraph\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\xe3\t\n\x0c\x44\x65\x66ineOutput\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12$\n\x0boutput_name\x18\x02 \x01(\tH\x02R\noutputName\x88\x01\x01\x12?\n\x0boutput_type\x18\x03 \x01(\x0e\x32\x19.spark.connect.OutputTypeH\x03R\noutputType\x88\x01\x01\x12\x1d\n\x07\x63omment\x18\x04 \x01(\tH\x04R\x07\x63omment\x88\x01\x01\x12X\n\x14source_code_location\x18\x05 \x01(\x0b\x32!.spark.connect.SourceCodeLocationH\x05R\x12sourceCodeLocation\x88\x01\x01\x12_\n\rtable_details\x18\x06 \x01(\x0b\x32\x38.spark.connect.PipelineCommand.DefineOutput.TableDetailsH\x00R\x0ctableDetails\x12\\\n\x0csink_details\x18\x07 \x01(\x0b\x32\x37.spark.connect.PipelineCommand.DefineOutput.SinkDetailsH\x00R\x0bsinkDetails\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\x91\x03\n\x0cTableDetails\x12x\n\x10table_properties\x18\x01 \x03(\x0b\x32M.spark.connect.PipelineCommand.DefineOutput.TableDetails.TablePropertiesEntryR\x0ftableProperties\x12%\n\x0epartition_cols\x18\x02 \x03(\tR\rpartitionCols\x12\x1b\n\x06\x66ormat\x18\x03 \x01(\tH\x01R\x06\x66ormat\x88\x01\x01\x12\x43\n\x10schema_data_type\x18\x04 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x0eschemaDataType\x12%\n\rschema_string\x18\x05 \x01(\tH\x00R\x0cschemaString\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x08\n\x06schemaB\t\n\x07_format\x1a\xd1\x01\n\x0bSinkDetails\x12^\n\x07options\x18\x01 \x03(\x0b\x32\x44.spark.connect.PipelineCommand.DefineOutput.SinkDetails.OptionsEntryR\x07options\x12\x1b\n\x06\x66ormat\x18\x02 \x01(\tH\x00R\x06\x66ormat\x88\x01\x01\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_formatB\t\n\x07\x64\x65tailsB\x14\n\x12_dataflow_graph_idB\x0e\n\x0c_output_nameB\x0e\n\x0c_output_typeB\n\n\x08_commentB\x17\n\x15_source_code_location\x1a\xdd\x06\n\nDefineFlow\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tflow_name\x18\x02 \x01(\tH\x02R\x08\x66lowName\x88\x01\x01\x12\x33\n\x13target_dataset_name\x18\x03 \x01(\tH\x03R\x11targetDatasetName\x88\x01\x01\x12Q\n\x08sql_conf\x18\x04 \x03(\x0b\x32\x36.spark.connect.PipelineCommand.DefineFlow.SqlConfEntryR\x07sqlConf\x12 \n\tclient_id\x18\x05 \x01(\tH\x04R\x08\x63lientId\x88\x01\x01\x12X\n\x14source_code_location\x18\x06 \x01(\x0b\x32!.spark.connect.SourceCodeLocationH\x05R\x12sourceCodeLocation\x88\x01\x01\x12x\n\x15relation_flow_details\x18\x07 \x01(\x0b\x32\x42.spark.connect.PipelineCommand.DefineFlow.WriteRelationFlowDetailsH\x00R\x13relationFlowDetails\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x61\n\x18WriteRelationFlowDetails\x12\x38\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x08relation\x88\x01\x01\x42\x0b\n\t_relation\x1a:\n\x08Response\x12 \n\tflow_name\x18\x01 \x01(\tH\x00R\x08\x66lowName\x88\x01\x01\x42\x0c\n\n_flow_nameB\t\n\x07\x64\x65tailsB\x14\n\x12_dataflow_graph_idB\x0c\n\n_flow_nameB\x16\n\x14_target_dataset_nameB\x0c\n\n_client_idB\x17\n\x15_source_code_location\x1a\xc2\x02\n\x08StartRun\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\x34\n\x16\x66ull_refresh_selection\x18\x02 \x03(\tR\x14\x66ullRefreshSelection\x12-\n\x10\x66ull_refresh_all\x18\x03 \x01(\x08H\x01R\x0e\x66ullRefreshAll\x88\x01\x01\x12+\n\x11refresh_selection\x18\x04 \x03(\tR\x10refreshSelection\x12\x15\n\x03\x64ry\x18\x05 \x01(\x08H\x02R\x03\x64ry\x88\x01\x01\x12\x1d\n\x07storage\x18\x06 \x01(\tH\x03R\x07storage\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x13\n\x11_full_refresh_allB\x06\n\x04_dryB\n\n\x08_storage\x1a\xc7\x01\n\x16\x44\x65\x66ineSqlGraphElements\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\'\n\rsql_file_path\x18\x02 \x01(\tH\x01R\x0bsqlFilePath\x88\x01\x01\x12\x1e\n\x08sql_text\x18\x03 \x01(\tH\x02R\x07sqlText\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x10\n\x0e_sql_file_pathB\x0b\n\t_sql_text\x1a\x9e\x01\n%GetQueryFunctionExecutionSignalStream\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tclient_id\x18\x02 \x01(\tH\x01R\x08\x63lientId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x0c\n\n_client_id\x1a\xdd\x01\n\x1d\x44\x65\x66ineFlowQueryFunctionResult\x12 \n\tflow_name\x18\x01 \x01(\tH\x00R\x08\x66lowName\x88\x01\x01\x12/\n\x11\x64\x61taflow_graph_id\x18\x02 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\x38\n\x08relation\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationH\x02R\x08relation\x88\x01\x01\x42\x0c\n\n_flow_nameB\x14\n\x12_dataflow_graph_idB\x0b\n\t_relationB\x0e\n\x0c\x63ommand_type"\xf0\x05\n\x15PipelineCommandResult\x12\x81\x01\n\x1c\x63reate_dataflow_graph_result\x18\x01 \x01(\x0b\x32>.spark.connect.PipelineCommandResult.CreateDataflowGraphResultH\x00R\x19\x63reateDataflowGraphResult\x12k\n\x14\x64\x65\x66ine_output_result\x18\x02 \x01(\x0b\x32\x37.spark.connect.PipelineCommandResult.DefineOutputResultH\x00R\x12\x64\x65\x66ineOutputResult\x12\x65\n\x12\x64\x65\x66ine_flow_result\x18\x03 \x01(\x0b\x32\x35.spark.connect.PipelineCommandResult.DefineFlowResultH\x00R\x10\x64\x65\x66ineFlowResult\x1a\x62\n\x19\x43reateDataflowGraphResult\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\x85\x01\n\x12\x44\x65\x66ineOutputResult\x12W\n\x13resolved_identifier\x18\x01 \x01(\x0b\x32!.spark.connect.ResolvedIdentifierH\x00R\x12resolvedIdentifier\x88\x01\x01\x42\x16\n\x14_resolved_identifier\x1a\x83\x01\n\x10\x44\x65\x66ineFlowResult\x12W\n\x13resolved_identifier\x18\x01 \x01(\x0b\x32!.spark.connect.ResolvedIdentifierH\x00R\x12resolvedIdentifier\x88\x01\x01\x42\x16\n\x14_resolved_identifierB\r\n\x0bresult_type"I\n\x13PipelineEventResult\x12\x32\n\x05\x65vent\x18\x01 \x01(\x0b\x32\x1c.spark.connect.PipelineEventR\x05\x65vent"t\n\rPipelineEvent\x12\x38\n\ttimestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1d\n\x07message\x18\x02 \x01(\tH\x00R\x07message\x88\x01\x01\x42\n\n\x08_message"\xf1\x01\n\x12SourceCodeLocation\x12 \n\tfile_name\x18\x01 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12$\n\x0bline_number\x18\x02 \x01(\x05H\x01R\nlineNumber\x88\x01\x01\x12,\n\x0f\x64\x65\x66inition_path\x18\x03 \x01(\tH\x02R\x0e\x64\x65\x66initionPath\x88\x01\x01\x12\x33\n\textension\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\textensionB\x0c\n\n_file_nameB\x0e\n\x0c_line_numberB\x12\n\x10_definition_path"E\n$PipelineQueryFunctionExecutionSignal\x12\x1d\n\nflow_names\x18\x01 \x03(\tR\tflowNames"\xd7\x01\n\x17PipelineAnalysisContext\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12,\n\x0f\x64\x65\x66inition_path\x18\x02 \x01(\tH\x01R\x0e\x64\x65\x66initionPath\x88\x01\x01\x12\x33\n\textension\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\textensionB\x14\n\x12_dataflow_graph_idB\x12\n\x10_definition_path*i\n\nOutputType\x12\x1b\n\x17OUTPUT_TYPE_UNSPECIFIED\x10\x00\x12\x15\n\x11MATERIALIZED_VIEW\x10\x01\x12\t\n\x05TABLE\x10\x02\x12\x12\n\x0eTEMPORARY_VIEW\x10\x03\x12\x08\n\x04SINK\x10\x04\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3'
+ b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xed"\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01 \x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12R\n\rdefine_output\x18\x02 \x01(\x0b\x32+.spark.connect.PipelineCommand.DefineOutputH\x00R\x0c\x64\x65\x66ineOutput\x12L\n\x0b\x64\x65\x66ine_flow\x18\x03 \x01(\x0b\x32).spark.connect.PipelineCommand.DefineFlowH\x00R\ndefineFlow\x12\x62\n\x13\x64rop_dataflow_graph\x18\x04 \x01(\x0b\x32\x30.spark.connect.PipelineCommand.DropDataflowGraphH\x00R\x11\x64ropDataflowGraph\x12\x46\n\tstart_run\x18\x05 \x01(\x0b\x32\'.spark.connect.PipelineCommand.StartRunH\x00R\x08startRun\x12r\n\x19\x64\x65\x66ine_sql_graph_elements\x18\x06 \x01(\x0b\x32\x35.spark.connect.PipelineCommand.DefineSqlGraphElementsH\x00R\x16\x64\x65\x66ineSqlGraphElements\x12\xa1\x01\n*get_query_function_execution_signal_stream\x18\x07 \x01(\x0b\x32\x44.spark.connect.PipelineCommand.GetQueryFunctionExecutionSignalStreamH\x00R%getQueryFunctionExecutionSignalStream\x12\x88\x01\n!define_flow_query_function_result\x18\x08 \x01(\x0b\x32<.spark.connect.PipelineCommand.DefineFlowQueryFunctionResultH\x00R\x1d\x64\x65\x66ineFlowQueryFunctionResult\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\xb4\x02\n\x13\x43reateDataflowGraph\x12,\n\x0f\x64\x65\x66\x61ult_catalog\x18\x01 \x01(\tH\x00R\x0e\x64\x65\x66\x61ultCatalog\x88\x01\x01\x12.\n\x10\x64\x65\x66\x61ult_database\x18\x02 \x01(\tH\x01R\x0f\x64\x65\x66\x61ultDatabase\x88\x01\x01\x12Z\n\x08sql_conf\x18\x05 \x03(\x0b\x32?.spark.connect.PipelineCommand.CreateDataflowGraph.SqlConfEntryR\x07sqlConf\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x12\n\x10_default_catalogB\x13\n\x11_default_database\x1aZ\n\x11\x44ropDataflowGraph\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\x92\n\n\x0c\x44\x65\x66ineOutput\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12$\n\x0boutput_name\x18\x02 \x01(\tH\x02R\noutputName\x88\x01\x01\x12?\n\x0boutput_type\x18\x03 \x01(\x0e\x32\x19.spark.connect.OutputTypeH\x03R\noutputType\x88\x01\x01\x12\x1d\n\x07\x63omment\x18\x04 \x01(\tH\x04R\x07\x63omment\x88\x01\x01\x12X\n\x14source_code_location\x18\x05 \x01(\x0b\x32!.spark.connect.SourceCodeLocationH\x05R\x12sourceCodeLocation\x88\x01\x01\x12_\n\rtable_details\x18\x06 \x01(\x0b\x32\x38.spark.connect.PipelineCommand.DefineOutput.TableDetailsH\x00R\x0ctableDetails\x12\\\n\x0csink_details\x18\x07 \x01(\x0b\x32\x37.spark.connect.PipelineCommand.DefineOutput.SinkDetailsH\x00R\x0bsinkDetails\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\xc0\x03\n\x0cTableDetails\x12x\n\x10table_properties\x18\x01 \x03(\x0b\x32M.spark.connect.PipelineCommand.DefineOutput.TableDetails.TablePropertiesEntryR\x0ftableProperties\x12%\n\x0epartition_cols\x18\x02 \x03(\tR\rpartitionCols\x12\x1b\n\x06\x66ormat\x18\x03 \x01(\tH\x01R\x06\x66ormat\x88\x01\x01\x12\x43\n\x10schema_data_type\x18\x04 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x0eschemaDataType\x12%\n\rschema_string\x18\x05 \x01(\tH\x00R\x0cschemaString\x12-\n\x12\x63lustering_columns\x18\x06 \x03(\tR\x11\x63lusteringColumns\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x08\n\x06schemaB\t\n\x07_format\x1a\xd1\x01\n\x0bSinkDetails\x12^\n\x07options\x18\x01 \x03(\x0b\x32\x44.spark.connect.PipelineCommand.DefineOutput.SinkDetails.OptionsEntryR\x07options\x12\x1b\n\x06\x66ormat\x18\x02 \x01(\tH\x00R\x06\x66ormat\x88\x01\x01\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_formatB\t\n\x07\x64\x65tailsB\x14\n\x12_dataflow_graph_idB\x0e\n\x0c_output_nameB\x0e\n\x0c_output_typeB\n\n\x08_commentB\x17\n\x15_source_code_location\x1a\xff\x06\n\nDefineFlow\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tflow_name\x18\x02 \x01(\tH\x02R\x08\x66lowName\x88\x01\x01\x12\x33\n\x13target_dataset_name\x18\x03 \x01(\tH\x03R\x11targetDatasetName\x88\x01\x01\x12Q\n\x08sql_conf\x18\x04 \x03(\x0b\x32\x36.spark.connect.PipelineCommand.DefineFlow.SqlConfEntryR\x07sqlConf\x12 \n\tclient_id\x18\x05 \x01(\tH\x04R\x08\x63lientId\x88\x01\x01\x12X\n\x14source_code_location\x18\x06 \x01(\x0b\x32!.spark.connect.SourceCodeLocationH\x05R\x12sourceCodeLocation\x88\x01\x01\x12x\n\x15relation_flow_details\x18\x07 \x01(\x0b\x32\x42.spark.connect.PipelineCommand.DefineFlow.WriteRelationFlowDetailsH\x00R\x13relationFlowDetails\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x17\n\x04once\x18\x08 \x01(\x08H\x06R\x04once\x88\x01\x01\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x61\n\x18WriteRelationFlowDetails\x12\x38\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x08relation\x88\x01\x01\x42\x0b\n\t_relation\x1a:\n\x08Response\x12 \n\tflow_name\x18\x01 \x01(\tH\x00R\x08\x66lowName\x88\x01\x01\x42\x0c\n\n_flow_nameB\t\n\x07\x64\x65tailsB\x14\n\x12_dataflow_graph_idB\x0c\n\n_flow_nameB\x16\n\x14_target_dataset_nameB\x0c\n\n_client_idB\x17\n\x15_source_code_locationB\x07\n\x05_once\x1a\xc2\x02\n\x08StartRun\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\x34\n\x16\x66ull_refresh_selection\x18\x02 \x03(\tR\x14\x66ullRefreshSelection\x12-\n\x10\x66ull_refresh_all\x18\x03 \x01(\x08H\x01R\x0e\x66ullRefreshAll\x88\x01\x01\x12+\n\x11refresh_selection\x18\x04 \x03(\tR\x10refreshSelection\x12\x15\n\x03\x64ry\x18\x05 \x01(\x08H\x02R\x03\x64ry\x88\x01\x01\x12\x1d\n\x07storage\x18\x06 \x01(\tH\x03R\x07storage\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x13\n\x11_full_refresh_allB\x06\n\x04_dryB\n\n\x08_storage\x1a\xc7\x01\n\x16\x44\x65\x66ineSqlGraphElements\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\'\n\rsql_file_path\x18\x02 \x01(\tH\x01R\x0bsqlFilePath\x88\x01\x01\x12\x1e\n\x08sql_text\x18\x03 \x01(\tH\x02R\x07sqlText\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x10\n\x0e_sql_file_pathB\x0b\n\t_sql_text\x1a\x9e\x01\n%GetQueryFunctionExecutionSignalStream\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tclient_id\x18\x02 \x01(\tH\x01R\x08\x63lientId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x0c\n\n_client_id\x1a\xdd\x01\n\x1d\x44\x65\x66ineFlowQueryFunctionResult\x12 \n\tflow_name\x18\x01 \x01(\tH\x00R\x08\x66lowName\x88\x01\x01\x12/\n\x11\x64\x61taflow_graph_id\x18\x02 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\x38\n\x08relation\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationH\x02R\x08relation\x88\x01\x01\x42\x0c\n\n_flow_nameB\x14\n\x12_dataflow_graph_idB\x0b\n\t_relationB\x0e\n\x0c\x63ommand_type"\xf0\x05\n\x15PipelineCommandResult\x12\x81\x01\n\x1c\x63reate_dataflow_graph_result\x18\x01 \x01(\x0b\x32>.spark.connect.PipelineCommandResult.CreateDataflowGraphResultH\x00R\x19\x63reateDataflowGraphResult\x12k\n\x14\x64\x65\x66ine_output_result\x18\x02 \x01(\x0b\x32\x37.spark.connect.PipelineCommandResult.DefineOutputResultH\x00R\x12\x64\x65\x66ineOutputResult\x12\x65\n\x12\x64\x65\x66ine_flow_result\x18\x03 \x01(\x0b\x32\x35.spark.connect.PipelineCommandResult.DefineFlowResultH\x00R\x10\x64\x65\x66ineFlowResult\x1a\x62\n\x19\x43reateDataflowGraphResult\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\x85\x01\n\x12\x44\x65\x66ineOutputResult\x12W\n\x13resolved_identifier\x18\x01 \x01(\x0b\x32!.spark.connect.ResolvedIdentifierH\x00R\x12resolvedIdentifier\x88\x01\x01\x42\x16\n\x14_resolved_identifier\x1a\x83\x01\n\x10\x44\x65\x66ineFlowResult\x12W\n\x13resolved_identifier\x18\x01 \x01(\x0b\x32!.spark.connect.ResolvedIdentifierH\x00R\x12resolvedIdentifier\x88\x01\x01\x42\x16\n\x14_resolved_identifierB\r\n\x0bresult_type"I\n\x13PipelineEventResult\x12\x32\n\x05\x65vent\x18\x01 \x01(\x0b\x32\x1c.spark.connect.PipelineEventR\x05\x65vent"t\n\rPipelineEvent\x12\x38\n\ttimestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1d\n\x07message\x18\x02 \x01(\tH\x00R\x07message\x88\x01\x01\x42\n\n\x08_message"\xf1\x01\n\x12SourceCodeLocation\x12 \n\tfile_name\x18\x01 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12$\n\x0bline_number\x18\x02 \x01(\x05H\x01R\nlineNumber\x88\x01\x01\x12,\n\x0f\x64\x65\x66inition_path\x18\x03 \x01(\tH\x02R\x0e\x64\x65\x66initionPath\x88\x01\x01\x12\x33\n\textension\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\textensionB\x0c\n\n_file_nameB\x0e\n\x0c_line_numberB\x12\n\x10_definition_path"E\n$PipelineQueryFunctionExecutionSignal\x12\x1d\n\nflow_names\x18\x01 \x03(\tR\tflowNames"\x87\x02\n\x17PipelineAnalysisContext\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12,\n\x0f\x64\x65\x66inition_path\x18\x02 \x01(\tH\x01R\x0e\x64\x65\x66initionPath\x88\x01\x01\x12 \n\tflow_name\x18\x03 \x01(\tH\x02R\x08\x66lowName\x88\x01\x01\x12\x33\n\textension\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\textensionB\x14\n\x12_dataflow_graph_idB\x12\n\x10_definition_pathB\x0c\n\n_flow_name*i\n\nOutputType\x12\x1b\n\x17OUTPUT_TYPE_UNSPECIFIED\x10\x00\x12\x15\n\x11MATERIALIZED_VIEW\x10\x01\x12\t\n\x05TABLE\x10\x02\x12\x12\n\x0eTEMPORARY_VIEW\x10\x03\x12\x08\n\x04SINK\x10\x04\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3'
)
_globals = globals()
@@ -69,10 +69,10 @@
]._serialized_options = b"8\001"
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._loaded_options = None
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_options = b"8\001"
- _globals["_OUTPUTTYPE"]._serialized_start = 6058
- _globals["_OUTPUTTYPE"]._serialized_end = 6163
+ _globals["_OUTPUTTYPE"]._serialized_start = 6187
+ _globals["_OUTPUTTYPE"]._serialized_end = 6292
_globals["_PIPELINECOMMAND"]._serialized_start = 195
- _globals["_PIPELINECOMMAND"]._serialized_end = 4575
+ _globals["_PIPELINECOMMAND"]._serialized_end = 4656
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_start = 1129
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_end = 1437
_globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH_SQLCONFENTRY"]._serialized_start = 1338
@@ -80,51 +80,51 @@
_globals["_PIPELINECOMMAND_DROPDATAFLOWGRAPH"]._serialized_start = 1439
_globals["_PIPELINECOMMAND_DROPDATAFLOWGRAPH"]._serialized_end = 1529
_globals["_PIPELINECOMMAND_DEFINEOUTPUT"]._serialized_start = 1532
- _globals["_PIPELINECOMMAND_DEFINEOUTPUT"]._serialized_end = 2783
+ _globals["_PIPELINECOMMAND_DEFINEOUTPUT"]._serialized_end = 2830
_globals["_PIPELINECOMMAND_DEFINEOUTPUT_TABLEDETAILS"]._serialized_start = 2068
- _globals["_PIPELINECOMMAND_DEFINEOUTPUT_TABLEDETAILS"]._serialized_end = 2469
+ _globals["_PIPELINECOMMAND_DEFINEOUTPUT_TABLEDETAILS"]._serialized_end = 2516
_globals[
"_PIPELINECOMMAND_DEFINEOUTPUT_TABLEDETAILS_TABLEPROPERTIESENTRY"
- ]._serialized_start = 2382
+ ]._serialized_start = 2429
_globals[
"_PIPELINECOMMAND_DEFINEOUTPUT_TABLEDETAILS_TABLEPROPERTIESENTRY"
- ]._serialized_end = 2448
- _globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS"]._serialized_start = 2472
- _globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS"]._serialized_end = 2681
- _globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS_OPTIONSENTRY"]._serialized_start = 2612
- _globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS_OPTIONSENTRY"]._serialized_end = 2670
- _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_start = 2786
- _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_end = 3647
+ ]._serialized_end = 2495
+ _globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS"]._serialized_start = 2519
+ _globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS"]._serialized_end = 2728
+ _globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS_OPTIONSENTRY"]._serialized_start = 2659
+ _globals["_PIPELINECOMMAND_DEFINEOUTPUT_SINKDETAILS_OPTIONSENTRY"]._serialized_end = 2717
+ _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_start = 2833
+ _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_end = 3728
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_start = 1338
_globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_end = 1396
- _globals["_PIPELINECOMMAND_DEFINEFLOW_WRITERELATIONFLOWDETAILS"]._serialized_start = 3380
- _globals["_PIPELINECOMMAND_DEFINEFLOW_WRITERELATIONFLOWDETAILS"]._serialized_end = 3477
- _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_start = 3479
- _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_end = 3537
- _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 3650
- _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 3972
- _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start = 3975
- _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 4174
- _globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_start = 4177
- _globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_end = 4335
- _globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_start = 4338
- _globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_end = 4559
- _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 4578
- _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 5330
- _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start = 4947
- _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end = 5045
- _globals["_PIPELINECOMMANDRESULT_DEFINEOUTPUTRESULT"]._serialized_start = 5048
- _globals["_PIPELINECOMMANDRESULT_DEFINEOUTPUTRESULT"]._serialized_end = 5181
- _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_start = 5184
- _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_end = 5315
- _globals["_PIPELINEEVENTRESULT"]._serialized_start = 5332
- _globals["_PIPELINEEVENTRESULT"]._serialized_end = 5405
- _globals["_PIPELINEEVENT"]._serialized_start = 5407
- _globals["_PIPELINEEVENT"]._serialized_end = 5523
- _globals["_SOURCECODELOCATION"]._serialized_start = 5526
- _globals["_SOURCECODELOCATION"]._serialized_end = 5767
- _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_start = 5769
- _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_end = 5838
- _globals["_PIPELINEANALYSISCONTEXT"]._serialized_start = 5841
- _globals["_PIPELINEANALYSISCONTEXT"]._serialized_end = 6056
+ _globals["_PIPELINECOMMAND_DEFINEFLOW_WRITERELATIONFLOWDETAILS"]._serialized_start = 3452
+ _globals["_PIPELINECOMMAND_DEFINEFLOW_WRITERELATIONFLOWDETAILS"]._serialized_end = 3549
+ _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_start = 3551
+ _globals["_PIPELINECOMMAND_DEFINEFLOW_RESPONSE"]._serialized_end = 3609
+ _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 3731
+ _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 4053
+ _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start = 4056
+ _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 4255
+ _globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_start = 4258
+ _globals["_PIPELINECOMMAND_GETQUERYFUNCTIONEXECUTIONSIGNALSTREAM"]._serialized_end = 4416
+ _globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_start = 4419
+ _globals["_PIPELINECOMMAND_DEFINEFLOWQUERYFUNCTIONRESULT"]._serialized_end = 4640
+ _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 4659
+ _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 5411
+ _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start = 5028
+ _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end = 5126
+ _globals["_PIPELINECOMMANDRESULT_DEFINEOUTPUTRESULT"]._serialized_start = 5129
+ _globals["_PIPELINECOMMANDRESULT_DEFINEOUTPUTRESULT"]._serialized_end = 5262
+ _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_start = 5265
+ _globals["_PIPELINECOMMANDRESULT_DEFINEFLOWRESULT"]._serialized_end = 5396
+ _globals["_PIPELINEEVENTRESULT"]._serialized_start = 5413
+ _globals["_PIPELINEEVENTRESULT"]._serialized_end = 5486
+ _globals["_PIPELINEEVENT"]._serialized_start = 5488
+ _globals["_PIPELINEEVENT"]._serialized_end = 5604
+ _globals["_SOURCECODELOCATION"]._serialized_start = 5607
+ _globals["_SOURCECODELOCATION"]._serialized_end = 5848
+ _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_start = 5850
+ _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_end = 5919
+ _globals["_PIPELINEANALYSISCONTEXT"]._serialized_start = 5922
+ _globals["_PIPELINEANALYSISCONTEXT"]._serialized_end = 6185
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
index b9170e763ed92..39a1e29ae7dde 100644
--- a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi
@@ -240,6 +240,7 @@ class PipelineCommand(google.protobuf.message.Message):
FORMAT_FIELD_NUMBER: builtins.int
SCHEMA_DATA_TYPE_FIELD_NUMBER: builtins.int
SCHEMA_STRING_FIELD_NUMBER: builtins.int
+ CLUSTERING_COLUMNS_FIELD_NUMBER: builtins.int
@property
def table_properties(
self,
@@ -255,6 +256,11 @@ class PipelineCommand(google.protobuf.message.Message):
@property
def schema_data_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
schema_string: builtins.str
+ @property
+ def clustering_columns(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+ """Optional cluster columns for the table."""
def __init__(
self,
*,
@@ -263,6 +269,7 @@ class PipelineCommand(google.protobuf.message.Message):
format: builtins.str | None = ...,
schema_data_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
schema_string: builtins.str = ...,
+ clustering_columns: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def HasField(
self,
@@ -284,6 +291,8 @@ class PipelineCommand(google.protobuf.message.Message):
field_name: typing_extensions.Literal[
"_format",
b"_format",
+ "clustering_columns",
+ b"clustering_columns",
"format",
b"format",
"partition_cols",
@@ -579,6 +588,7 @@ class PipelineCommand(google.protobuf.message.Message):
SOURCE_CODE_LOCATION_FIELD_NUMBER: builtins.int
RELATION_FLOW_DETAILS_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
+ ONCE_FIELD_NUMBER: builtins.int
dataflow_graph_id: builtins.str
"""The graph to attach this flow to."""
flow_name: builtins.str
@@ -603,6 +613,13 @@ class PipelineCommand(google.protobuf.message.Message):
) -> global___PipelineCommand.DefineFlow.WriteRelationFlowDetails: ...
@property
def extension(self) -> google.protobuf.any_pb2.Any: ...
+ once: builtins.bool
+ """If true, define the flow as a one-time flow, such as for backfill.
+ Set to true changes the flow in two ways:
+ - The flow is run one time by default. If the pipeline is ran with a full refresh,
+ the flow will run again.
+ - The flow function must be a batch DataFrame, not a streaming DataFrame.
+ """
def __init__(
self,
*,
@@ -615,6 +632,7 @@ class PipelineCommand(google.protobuf.message.Message):
relation_flow_details: global___PipelineCommand.DefineFlow.WriteRelationFlowDetails
| None = ...,
extension: google.protobuf.any_pb2.Any | None = ...,
+ once: builtins.bool | None = ...,
) -> None: ...
def HasField(
self,
@@ -625,6 +643,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"_dataflow_graph_id",
"_flow_name",
b"_flow_name",
+ "_once",
+ b"_once",
"_source_code_location",
b"_source_code_location",
"_target_dataset_name",
@@ -639,6 +659,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"extension",
"flow_name",
b"flow_name",
+ "once",
+ b"once",
"relation_flow_details",
b"relation_flow_details",
"source_code_location",
@@ -656,6 +678,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"_dataflow_graph_id",
"_flow_name",
b"_flow_name",
+ "_once",
+ b"_once",
"_source_code_location",
b"_source_code_location",
"_target_dataset_name",
@@ -670,6 +694,8 @@ class PipelineCommand(google.protobuf.message.Message):
b"extension",
"flow_name",
b"flow_name",
+ "once",
+ b"once",
"relation_flow_details",
b"relation_flow_details",
"source_code_location",
@@ -694,6 +720,10 @@ class PipelineCommand(google.protobuf.message.Message):
self, oneof_group: typing_extensions.Literal["_flow_name", b"_flow_name"]
) -> typing_extensions.Literal["flow_name"] | None: ...
@typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_once", b"_once"]
+ ) -> typing_extensions.Literal["once"] | None: ...
+ @typing.overload
def WhichOneof(
self,
oneof_group: typing_extensions.Literal[
@@ -1469,11 +1499,14 @@ class PipelineAnalysisContext(google.protobuf.message.Message):
DATAFLOW_GRAPH_ID_FIELD_NUMBER: builtins.int
DEFINITION_PATH_FIELD_NUMBER: builtins.int
+ FLOW_NAME_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
dataflow_graph_id: builtins.str
"""Unique identifier of the dataflow graph associated with this pipeline."""
definition_path: builtins.str
"""The path of the top-level pipeline file determined at runtime during pipeline initialization."""
+ flow_name: builtins.str
+ """The name of the Flow involved in this analysis"""
@property
def extension(
self,
@@ -1486,6 +1519,7 @@ class PipelineAnalysisContext(google.protobuf.message.Message):
*,
dataflow_graph_id: builtins.str | None = ...,
definition_path: builtins.str | None = ...,
+ flow_name: builtins.str | None = ...,
extension: collections.abc.Iterable[google.protobuf.any_pb2.Any] | None = ...,
) -> None: ...
def HasField(
@@ -1495,10 +1529,14 @@ class PipelineAnalysisContext(google.protobuf.message.Message):
b"_dataflow_graph_id",
"_definition_path",
b"_definition_path",
+ "_flow_name",
+ b"_flow_name",
"dataflow_graph_id",
b"dataflow_graph_id",
"definition_path",
b"definition_path",
+ "flow_name",
+ b"flow_name",
],
) -> builtins.bool: ...
def ClearField(
@@ -1508,12 +1546,16 @@ class PipelineAnalysisContext(google.protobuf.message.Message):
b"_dataflow_graph_id",
"_definition_path",
b"_definition_path",
+ "_flow_name",
+ b"_flow_name",
"dataflow_graph_id",
b"dataflow_graph_id",
"definition_path",
b"definition_path",
"extension",
b"extension",
+ "flow_name",
+ b"flow_name",
],
) -> None: ...
@typing.overload
@@ -1524,5 +1566,9 @@ class PipelineAnalysisContext(google.protobuf.message.Message):
def WhichOneof(
self, oneof_group: typing_extensions.Literal["_definition_path", b"_definition_path"]
) -> typing_extensions.Literal["definition_path"] | None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_flow_name", b"_flow_name"]
+ ) -> typing_extensions.Literal["flow_name"] | None: ...
global___PipelineAnalysisContext = PipelineAnalysisContext
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py
index e7f319554c5e2..9e630b6ba5e4c 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -18,7 +18,7 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: spark/connect/relations.proto
-# Protobuf Python Version: 5.29.5
+# Protobuf Python Version: 6.33.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
@@ -27,7 +27,7 @@
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
- _runtime_version.Domain.PUBLIC, 5, 29, 5, "", "spark/connect/relations.proto"
+ _runtime_version.Domain.PUBLIC, 6, 33, 0, "", "spark/connect/relations.proto"
)
# @@protoc_insertion_point(imports)
diff --git a/python/pyspark/sql/connect/proto/types_pb2.py b/python/pyspark/sql/connect/proto/types_pb2.py
index 9a52129103ad5..fc5b14d068a87 100644
--- a/python/pyspark/sql/connect/proto/types_pb2.py
+++ b/python/pyspark/sql/connect/proto/types_pb2.py
@@ -18,7 +18,7 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: spark/connect/types.proto
-# Protobuf Python Version: 5.29.5
+# Protobuf Python Version: 6.33.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
@@ -27,7 +27,7 @@
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
- _runtime_version.Domain.PUBLIC, 5, 29, 5, "", "spark/connect/types.proto"
+ _runtime_version.Domain.PUBLIC, 6, 33, 0, "", "spark/connect/types.proto"
)
# @@protoc_insertion_point(imports)
@@ -35,7 +35,7 @@
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b"\n\x19spark/connect/types.proto\x12\rspark.connect\"\xac#\n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01 \x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02 \x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04 \x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05 \x01(\x0b\x32\x1d.spark.connect.DataType.ShortH\x00R\x05short\x12;\n\x07integer\x18\x06 \x01(\x0b\x32\x1f.spark.connect.DataType.IntegerH\x00R\x07integer\x12\x32\n\x04long\x18\x07 \x01(\x0b\x32\x1c.spark.connect.DataType.LongH\x00R\x04long\x12\x35\n\x05\x66loat\x18\x08 \x01(\x0b\x32\x1d.spark.connect.DataType.FloatH\x00R\x05\x66loat\x12\x38\n\x06\x64ouble\x18\t \x01(\x0b\x32\x1e.spark.connect.DataType.DoubleH\x00R\x06\x64ouble\x12;\n\x07\x64\x65\x63imal\x18\n \x01(\x0b\x32\x1f.spark.connect.DataType.DecimalH\x00R\x07\x64\x65\x63imal\x12\x38\n\x06string\x18\x0b \x01(\x0b\x32\x1e.spark.connect.DataType.StringH\x00R\x06string\x12\x32\n\x04\x63har\x18\x0c \x01(\x0b\x32\x1c.spark.connect.DataType.CharH\x00R\x04\x63har\x12<\n\x08var_char\x18\r \x01(\x0b\x32\x1f.spark.connect.DataType.VarCharH\x00R\x07varChar\x12\x32\n\x04\x64\x61te\x18\x0e \x01(\x0b\x32\x1c.spark.connect.DataType.DateH\x00R\x04\x64\x61te\x12\x41\n\ttimestamp\x18\x0f \x01(\x0b\x32!.spark.connect.DataType.TimestampH\x00R\ttimestamp\x12K\n\rtimestamp_ntz\x18\x10 \x01(\x0b\x32$.spark.connect.DataType.TimestampNTZH\x00R\x0ctimestampNtz\x12W\n\x11\x63\x61lendar_interval\x18\x11 \x01(\x0b\x32(.spark.connect.DataType.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12[\n\x13year_month_interval\x18\x12 \x01(\x0b\x32).spark.connect.DataType.YearMonthIntervalH\x00R\x11yearMonthInterval\x12U\n\x11\x64\x61y_time_interval\x18\x13 \x01(\x0b\x32'.spark.connect.DataType.DayTimeIntervalH\x00R\x0f\x64\x61yTimeInterval\x12\x35\n\x05\x61rray\x18\x14 \x01(\x0b\x32\x1d.spark.connect.DataType.ArrayH\x00R\x05\x61rray\x12\x38\n\x06struct\x18\x15 \x01(\x0b\x32\x1e.spark.connect.DataType.StructH\x00R\x06struct\x12/\n\x03map\x18\x16 \x01(\x0b\x32\x1b.spark.connect.DataType.MapH\x00R\x03map\x12;\n\x07variant\x18\x19 \x01(\x0b\x32\x1f.spark.connect.DataType.VariantH\x00R\x07variant\x12/\n\x03udt\x18\x17 \x01(\x0b\x32\x1b.spark.connect.DataType.UDTH\x00R\x03udt\x12>\n\x08unparsed\x18\x18 \x01(\x0b\x32 .spark.connect.DataType.UnparsedH\x00R\x08unparsed\x12\x32\n\x04time\x18\x1c \x01(\x0b\x32\x1c.spark.connect.DataType.TimeH\x00R\x04time\x1a\x43\n\x07\x42oolean\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04\x42yte\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x41\n\x05Short\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x43\n\x07Integer\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04Long\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x41\n\x05\x46loat\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x42\n\x06\x44ouble\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a`\n\x06String\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x1c\n\tcollation\x18\x02 \x01(\tR\tcollation\x1a\x42\n\x06\x42inary\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04NULL\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x45\n\tTimestamp\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04\x44\x61te\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1aH\n\x0cTimestampNTZ\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1aq\n\x04Time\x12!\n\tprecision\x18\x01 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReferenceB\x0c\n\n_precision\x1aL\n\x10\x43\x61lendarInterval\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\xb3\x01\n\x11YearMonthInterval\x12$\n\x0bstart_field\x18\x01 \x01(\x05H\x00R\nstartField\x88\x01\x01\x12 \n\tend_field\x18\x02 \x01(\x05H\x01R\x08\x65ndField\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReferenceB\x0e\n\x0c_start_fieldB\x0c\n\n_end_field\x1a\xb1\x01\n\x0f\x44\x61yTimeInterval\x12$\n\x0bstart_field\x18\x01 \x01(\x05H\x00R\nstartField\x88\x01\x01\x12 \n\tend_field\x18\x02 \x01(\x05H\x01R\x08\x65ndField\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReferenceB\x0e\n\x0c_start_fieldB\x0c\n\n_end_field\x1aX\n\x04\x43har\x12\x16\n\x06length\x18\x01 \x01(\x05R\x06length\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a[\n\x07VarChar\x12\x16\n\x06length\x18\x01 \x01(\x05R\x06length\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a\x99\x01\n\x07\x44\x65\x63imal\x12\x19\n\x05scale\x18\x01 \x01(\x05H\x00R\x05scale\x88\x01\x01\x12!\n\tprecision\x18\x02 \x01(\x05H\x01R\tprecision\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReferenceB\x08\n\x06_scaleB\x0c\n\n_precision\x1a\xa1\x01\n\x0bStructField\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x34\n\tdata_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x08\x64\x61taType\x12\x1a\n\x08nullable\x18\x03 \x01(\x08R\x08nullable\x12\x1f\n\x08metadata\x18\x04 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\x7f\n\x06Struct\x12;\n\x06\x66ields\x18\x01 \x03(\x0b\x32#.spark.connect.DataType.StructFieldR\x06\x66ields\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a\xa2\x01\n\x05\x41rray\x12:\n\x0c\x65lement_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x0b\x65lementType\x12#\n\rcontains_null\x18\x02 \x01(\x08R\x0c\x63ontainsNull\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReference\x1a\xdb\x01\n\x03Map\x12\x32\n\x08key_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x07keyType\x12\x36\n\nvalue_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\tvalueType\x12.\n\x13value_contains_null\x18\x03 \x01(\x08R\x11valueContainsNull\x12\x38\n\x18type_variation_reference\x18\x04 \x01(\rR\x16typeVariationReference\x1a\x43\n\x07Variant\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\xa1\x02\n\x03UDT\x12\x12\n\x04type\x18\x01 \x01(\tR\x04type\x12 \n\tjvm_class\x18\x02 \x01(\tH\x00R\x08jvmClass\x88\x01\x01\x12&\n\x0cpython_class\x18\x03 \x01(\tH\x01R\x0bpythonClass\x88\x01\x01\x12;\n\x17serialized_python_class\x18\x04 \x01(\tH\x02R\x15serializedPythonClass\x88\x01\x01\x12\x37\n\x08sql_type\x18\x05 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x03R\x07sqlType\x88\x01\x01\x42\x0c\n\n_jvm_classB\x0f\n\r_python_classB\x1a\n\x18_serialized_python_classB\x0b\n\t_sql_type\x1a\x34\n\x08Unparsed\x12(\n\x10\x64\x61ta_type_string\x18\x01 \x01(\tR\x0e\x64\x61taTypeStringB\x06\n\x04kindJ\x04\x08\x1a\x10\x1bJ\x04\x08\x1b\x10\x1c\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3"
+ b"\n\x19spark/connect/types.proto\x12\rspark.connect\"\xd8%\n\x08\x44\x61taType\x12\x32\n\x04null\x18\x01 \x01(\x0b\x32\x1c.spark.connect.DataType.NULLH\x00R\x04null\x12\x38\n\x06\x62inary\x18\x02 \x01(\x0b\x32\x1e.spark.connect.DataType.BinaryH\x00R\x06\x62inary\x12;\n\x07\x62oolean\x18\x03 \x01(\x0b\x32\x1f.spark.connect.DataType.BooleanH\x00R\x07\x62oolean\x12\x32\n\x04\x62yte\x18\x04 \x01(\x0b\x32\x1c.spark.connect.DataType.ByteH\x00R\x04\x62yte\x12\x35\n\x05short\x18\x05 \x01(\x0b\x32\x1d.spark.connect.DataType.ShortH\x00R\x05short\x12;\n\x07integer\x18\x06 \x01(\x0b\x32\x1f.spark.connect.DataType.IntegerH\x00R\x07integer\x12\x32\n\x04long\x18\x07 \x01(\x0b\x32\x1c.spark.connect.DataType.LongH\x00R\x04long\x12\x35\n\x05\x66loat\x18\x08 \x01(\x0b\x32\x1d.spark.connect.DataType.FloatH\x00R\x05\x66loat\x12\x38\n\x06\x64ouble\x18\t \x01(\x0b\x32\x1e.spark.connect.DataType.DoubleH\x00R\x06\x64ouble\x12;\n\x07\x64\x65\x63imal\x18\n \x01(\x0b\x32\x1f.spark.connect.DataType.DecimalH\x00R\x07\x64\x65\x63imal\x12\x38\n\x06string\x18\x0b \x01(\x0b\x32\x1e.spark.connect.DataType.StringH\x00R\x06string\x12\x32\n\x04\x63har\x18\x0c \x01(\x0b\x32\x1c.spark.connect.DataType.CharH\x00R\x04\x63har\x12<\n\x08var_char\x18\r \x01(\x0b\x32\x1f.spark.connect.DataType.VarCharH\x00R\x07varChar\x12\x32\n\x04\x64\x61te\x18\x0e \x01(\x0b\x32\x1c.spark.connect.DataType.DateH\x00R\x04\x64\x61te\x12\x41\n\ttimestamp\x18\x0f \x01(\x0b\x32!.spark.connect.DataType.TimestampH\x00R\ttimestamp\x12K\n\rtimestamp_ntz\x18\x10 \x01(\x0b\x32$.spark.connect.DataType.TimestampNTZH\x00R\x0ctimestampNtz\x12W\n\x11\x63\x61lendar_interval\x18\x11 \x01(\x0b\x32(.spark.connect.DataType.CalendarIntervalH\x00R\x10\x63\x61lendarInterval\x12[\n\x13year_month_interval\x18\x12 \x01(\x0b\x32).spark.connect.DataType.YearMonthIntervalH\x00R\x11yearMonthInterval\x12U\n\x11\x64\x61y_time_interval\x18\x13 \x01(\x0b\x32'.spark.connect.DataType.DayTimeIntervalH\x00R\x0f\x64\x61yTimeInterval\x12\x35\n\x05\x61rray\x18\x14 \x01(\x0b\x32\x1d.spark.connect.DataType.ArrayH\x00R\x05\x61rray\x12\x38\n\x06struct\x18\x15 \x01(\x0b\x32\x1e.spark.connect.DataType.StructH\x00R\x06struct\x12/\n\x03map\x18\x16 \x01(\x0b\x32\x1b.spark.connect.DataType.MapH\x00R\x03map\x12;\n\x07variant\x18\x19 \x01(\x0b\x32\x1f.spark.connect.DataType.VariantH\x00R\x07variant\x12/\n\x03udt\x18\x17 \x01(\x0b\x32\x1b.spark.connect.DataType.UDTH\x00R\x03udt\x12>\n\x08geometry\x18\x1a \x01(\x0b\x32 .spark.connect.DataType.GeometryH\x00R\x08geometry\x12\x41\n\tgeography\x18\x1b \x01(\x0b\x32!.spark.connect.DataType.GeographyH\x00R\tgeography\x12>\n\x08unparsed\x18\x18 \x01(\x0b\x32 .spark.connect.DataType.UnparsedH\x00R\x08unparsed\x12\x32\n\x04time\x18\x1c \x01(\x0b\x32\x1c.spark.connect.DataType.TimeH\x00R\x04time\x1a\x43\n\x07\x42oolean\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04\x42yte\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x41\n\x05Short\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x43\n\x07Integer\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04Long\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x41\n\x05\x46loat\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x42\n\x06\x44ouble\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a`\n\x06String\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x12\x1c\n\tcollation\x18\x02 \x01(\tR\tcollation\x1a\x42\n\x06\x42inary\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04NULL\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\x45\n\tTimestamp\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a@\n\x04\x44\x61te\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1aH\n\x0cTimestampNTZ\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1aq\n\x04Time\x12!\n\tprecision\x18\x01 \x01(\x05H\x00R\tprecision\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReferenceB\x0c\n\n_precision\x1aL\n\x10\x43\x61lendarInterval\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\xb3\x01\n\x11YearMonthInterval\x12$\n\x0bstart_field\x18\x01 \x01(\x05H\x00R\nstartField\x88\x01\x01\x12 \n\tend_field\x18\x02 \x01(\x05H\x01R\x08\x65ndField\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReferenceB\x0e\n\x0c_start_fieldB\x0c\n\n_end_field\x1a\xb1\x01\n\x0f\x44\x61yTimeInterval\x12$\n\x0bstart_field\x18\x01 \x01(\x05H\x00R\nstartField\x88\x01\x01\x12 \n\tend_field\x18\x02 \x01(\x05H\x01R\x08\x65ndField\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReferenceB\x0e\n\x0c_start_fieldB\x0c\n\n_end_field\x1aX\n\x04\x43har\x12\x16\n\x06length\x18\x01 \x01(\x05R\x06length\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a[\n\x07VarChar\x12\x16\n\x06length\x18\x01 \x01(\x05R\x06length\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a\x99\x01\n\x07\x44\x65\x63imal\x12\x19\n\x05scale\x18\x01 \x01(\x05H\x00R\x05scale\x88\x01\x01\x12!\n\tprecision\x18\x02 \x01(\x05H\x01R\tprecision\x88\x01\x01\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReferenceB\x08\n\x06_scaleB\x0c\n\n_precision\x1a\xa1\x01\n\x0bStructField\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x34\n\tdata_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x08\x64\x61taType\x12\x1a\n\x08nullable\x18\x03 \x01(\x08R\x08nullable\x12\x1f\n\x08metadata\x18\x04 \x01(\tH\x00R\x08metadata\x88\x01\x01\x42\x0b\n\t_metadata\x1a\x7f\n\x06Struct\x12;\n\x06\x66ields\x18\x01 \x03(\x0b\x32#.spark.connect.DataType.StructFieldR\x06\x66ields\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a\xa2\x01\n\x05\x41rray\x12:\n\x0c\x65lement_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x0b\x65lementType\x12#\n\rcontains_null\x18\x02 \x01(\x08R\x0c\x63ontainsNull\x12\x38\n\x18type_variation_reference\x18\x03 \x01(\rR\x16typeVariationReference\x1a\xdb\x01\n\x03Map\x12\x32\n\x08key_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x07keyType\x12\x36\n\nvalue_type\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\tvalueType\x12.\n\x13value_contains_null\x18\x03 \x01(\x08R\x11valueContainsNull\x12\x38\n\x18type_variation_reference\x18\x04 \x01(\rR\x16typeVariationReference\x1aX\n\x08Geometry\x12\x12\n\x04srid\x18\x01 \x01(\x05R\x04srid\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1aY\n\tGeography\x12\x12\n\x04srid\x18\x01 \x01(\x05R\x04srid\x12\x38\n\x18type_variation_reference\x18\x02 \x01(\rR\x16typeVariationReference\x1a\x43\n\x07Variant\x12\x38\n\x18type_variation_reference\x18\x01 \x01(\rR\x16typeVariationReference\x1a\xa1\x02\n\x03UDT\x12\x12\n\x04type\x18\x01 \x01(\tR\x04type\x12 \n\tjvm_class\x18\x02 \x01(\tH\x00R\x08jvmClass\x88\x01\x01\x12&\n\x0cpython_class\x18\x03 \x01(\tH\x01R\x0bpythonClass\x88\x01\x01\x12;\n\x17serialized_python_class\x18\x04 \x01(\tH\x02R\x15serializedPythonClass\x88\x01\x01\x12\x37\n\x08sql_type\x18\x05 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x03R\x07sqlType\x88\x01\x01\x42\x0c\n\n_jvm_classB\x0f\n\r_python_classB\x1a\n\x18_serialized_python_classB\x0b\n\t_sql_type\x1a\x34\n\x08Unparsed\x12(\n\x10\x64\x61ta_type_string\x18\x01 \x01(\tR\x0e\x64\x61taTypeStringB\x06\n\x04kindB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3"
)
_globals = globals()
@@ -47,59 +47,63 @@
"DESCRIPTOR"
]._serialized_options = b"\n\036org.apache.spark.connect.protoP\001Z\022internal/generated"
_globals["_DATATYPE"]._serialized_start = 45
- _globals["_DATATYPE"]._serialized_end = 4569
- _globals["_DATATYPE_BOOLEAN"]._serialized_start = 1647
- _globals["_DATATYPE_BOOLEAN"]._serialized_end = 1714
- _globals["_DATATYPE_BYTE"]._serialized_start = 1716
- _globals["_DATATYPE_BYTE"]._serialized_end = 1780
- _globals["_DATATYPE_SHORT"]._serialized_start = 1782
- _globals["_DATATYPE_SHORT"]._serialized_end = 1847
- _globals["_DATATYPE_INTEGER"]._serialized_start = 1849
- _globals["_DATATYPE_INTEGER"]._serialized_end = 1916
- _globals["_DATATYPE_LONG"]._serialized_start = 1918
- _globals["_DATATYPE_LONG"]._serialized_end = 1982
- _globals["_DATATYPE_FLOAT"]._serialized_start = 1984
- _globals["_DATATYPE_FLOAT"]._serialized_end = 2049
- _globals["_DATATYPE_DOUBLE"]._serialized_start = 2051
- _globals["_DATATYPE_DOUBLE"]._serialized_end = 2117
- _globals["_DATATYPE_STRING"]._serialized_start = 2119
- _globals["_DATATYPE_STRING"]._serialized_end = 2215
- _globals["_DATATYPE_BINARY"]._serialized_start = 2217
- _globals["_DATATYPE_BINARY"]._serialized_end = 2283
- _globals["_DATATYPE_NULL"]._serialized_start = 2285
- _globals["_DATATYPE_NULL"]._serialized_end = 2349
- _globals["_DATATYPE_TIMESTAMP"]._serialized_start = 2351
- _globals["_DATATYPE_TIMESTAMP"]._serialized_end = 2420
- _globals["_DATATYPE_DATE"]._serialized_start = 2422
- _globals["_DATATYPE_DATE"]._serialized_end = 2486
- _globals["_DATATYPE_TIMESTAMPNTZ"]._serialized_start = 2488
- _globals["_DATATYPE_TIMESTAMPNTZ"]._serialized_end = 2560
- _globals["_DATATYPE_TIME"]._serialized_start = 2562
- _globals["_DATATYPE_TIME"]._serialized_end = 2675
- _globals["_DATATYPE_CALENDARINTERVAL"]._serialized_start = 2677
- _globals["_DATATYPE_CALENDARINTERVAL"]._serialized_end = 2753
- _globals["_DATATYPE_YEARMONTHINTERVAL"]._serialized_start = 2756
- _globals["_DATATYPE_YEARMONTHINTERVAL"]._serialized_end = 2935
- _globals["_DATATYPE_DAYTIMEINTERVAL"]._serialized_start = 2938
- _globals["_DATATYPE_DAYTIMEINTERVAL"]._serialized_end = 3115
- _globals["_DATATYPE_CHAR"]._serialized_start = 3117
- _globals["_DATATYPE_CHAR"]._serialized_end = 3205
- _globals["_DATATYPE_VARCHAR"]._serialized_start = 3207
- _globals["_DATATYPE_VARCHAR"]._serialized_end = 3298
- _globals["_DATATYPE_DECIMAL"]._serialized_start = 3301
- _globals["_DATATYPE_DECIMAL"]._serialized_end = 3454
- _globals["_DATATYPE_STRUCTFIELD"]._serialized_start = 3457
- _globals["_DATATYPE_STRUCTFIELD"]._serialized_end = 3618
- _globals["_DATATYPE_STRUCT"]._serialized_start = 3620
- _globals["_DATATYPE_STRUCT"]._serialized_end = 3747
- _globals["_DATATYPE_ARRAY"]._serialized_start = 3750
- _globals["_DATATYPE_ARRAY"]._serialized_end = 3912
- _globals["_DATATYPE_MAP"]._serialized_start = 3915
- _globals["_DATATYPE_MAP"]._serialized_end = 4134
- _globals["_DATATYPE_VARIANT"]._serialized_start = 4136
- _globals["_DATATYPE_VARIANT"]._serialized_end = 4203
- _globals["_DATATYPE_UDT"]._serialized_start = 4206
- _globals["_DATATYPE_UDT"]._serialized_end = 4495
- _globals["_DATATYPE_UNPARSED"]._serialized_start = 4497
- _globals["_DATATYPE_UNPARSED"]._serialized_end = 4549
+ _globals["_DATATYPE"]._serialized_end = 4869
+ _globals["_DATATYPE_BOOLEAN"]._serialized_start = 1778
+ _globals["_DATATYPE_BOOLEAN"]._serialized_end = 1845
+ _globals["_DATATYPE_BYTE"]._serialized_start = 1847
+ _globals["_DATATYPE_BYTE"]._serialized_end = 1911
+ _globals["_DATATYPE_SHORT"]._serialized_start = 1913
+ _globals["_DATATYPE_SHORT"]._serialized_end = 1978
+ _globals["_DATATYPE_INTEGER"]._serialized_start = 1980
+ _globals["_DATATYPE_INTEGER"]._serialized_end = 2047
+ _globals["_DATATYPE_LONG"]._serialized_start = 2049
+ _globals["_DATATYPE_LONG"]._serialized_end = 2113
+ _globals["_DATATYPE_FLOAT"]._serialized_start = 2115
+ _globals["_DATATYPE_FLOAT"]._serialized_end = 2180
+ _globals["_DATATYPE_DOUBLE"]._serialized_start = 2182
+ _globals["_DATATYPE_DOUBLE"]._serialized_end = 2248
+ _globals["_DATATYPE_STRING"]._serialized_start = 2250
+ _globals["_DATATYPE_STRING"]._serialized_end = 2346
+ _globals["_DATATYPE_BINARY"]._serialized_start = 2348
+ _globals["_DATATYPE_BINARY"]._serialized_end = 2414
+ _globals["_DATATYPE_NULL"]._serialized_start = 2416
+ _globals["_DATATYPE_NULL"]._serialized_end = 2480
+ _globals["_DATATYPE_TIMESTAMP"]._serialized_start = 2482
+ _globals["_DATATYPE_TIMESTAMP"]._serialized_end = 2551
+ _globals["_DATATYPE_DATE"]._serialized_start = 2553
+ _globals["_DATATYPE_DATE"]._serialized_end = 2617
+ _globals["_DATATYPE_TIMESTAMPNTZ"]._serialized_start = 2619
+ _globals["_DATATYPE_TIMESTAMPNTZ"]._serialized_end = 2691
+ _globals["_DATATYPE_TIME"]._serialized_start = 2693
+ _globals["_DATATYPE_TIME"]._serialized_end = 2806
+ _globals["_DATATYPE_CALENDARINTERVAL"]._serialized_start = 2808
+ _globals["_DATATYPE_CALENDARINTERVAL"]._serialized_end = 2884
+ _globals["_DATATYPE_YEARMONTHINTERVAL"]._serialized_start = 2887
+ _globals["_DATATYPE_YEARMONTHINTERVAL"]._serialized_end = 3066
+ _globals["_DATATYPE_DAYTIMEINTERVAL"]._serialized_start = 3069
+ _globals["_DATATYPE_DAYTIMEINTERVAL"]._serialized_end = 3246
+ _globals["_DATATYPE_CHAR"]._serialized_start = 3248
+ _globals["_DATATYPE_CHAR"]._serialized_end = 3336
+ _globals["_DATATYPE_VARCHAR"]._serialized_start = 3338
+ _globals["_DATATYPE_VARCHAR"]._serialized_end = 3429
+ _globals["_DATATYPE_DECIMAL"]._serialized_start = 3432
+ _globals["_DATATYPE_DECIMAL"]._serialized_end = 3585
+ _globals["_DATATYPE_STRUCTFIELD"]._serialized_start = 3588
+ _globals["_DATATYPE_STRUCTFIELD"]._serialized_end = 3749
+ _globals["_DATATYPE_STRUCT"]._serialized_start = 3751
+ _globals["_DATATYPE_STRUCT"]._serialized_end = 3878
+ _globals["_DATATYPE_ARRAY"]._serialized_start = 3881
+ _globals["_DATATYPE_ARRAY"]._serialized_end = 4043
+ _globals["_DATATYPE_MAP"]._serialized_start = 4046
+ _globals["_DATATYPE_MAP"]._serialized_end = 4265
+ _globals["_DATATYPE_GEOMETRY"]._serialized_start = 4267
+ _globals["_DATATYPE_GEOMETRY"]._serialized_end = 4355
+ _globals["_DATATYPE_GEOGRAPHY"]._serialized_start = 4357
+ _globals["_DATATYPE_GEOGRAPHY"]._serialized_end = 4446
+ _globals["_DATATYPE_VARIANT"]._serialized_start = 4448
+ _globals["_DATATYPE_VARIANT"]._serialized_end = 4515
+ _globals["_DATATYPE_UDT"]._serialized_start = 4518
+ _globals["_DATATYPE_UDT"]._serialized_end = 4807
+ _globals["_DATATYPE_UNPARSED"]._serialized_start = 4809
+ _globals["_DATATYPE_UNPARSED"]._serialized_end = 4861
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/types_pb2.pyi b/python/pyspark/sql/connect/proto/types_pb2.pyi
index d46770c4f888e..3f625890a809b 100644
--- a/python/pyspark/sql/connect/proto/types_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/types_pb2.pyi
@@ -674,6 +674,46 @@ class DataType(google.protobuf.message.Message):
],
) -> None: ...
+ class Geometry(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ SRID_FIELD_NUMBER: builtins.int
+ TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
+ srid: builtins.int
+ type_variation_reference: builtins.int
+ def __init__(
+ self,
+ *,
+ srid: builtins.int = ...,
+ type_variation_reference: builtins.int = ...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "srid", b"srid", "type_variation_reference", b"type_variation_reference"
+ ],
+ ) -> None: ...
+
+ class Geography(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ SRID_FIELD_NUMBER: builtins.int
+ TYPE_VARIATION_REFERENCE_FIELD_NUMBER: builtins.int
+ srid: builtins.int
+ type_variation_reference: builtins.int
+ def __init__(
+ self,
+ *,
+ srid: builtins.int = ...,
+ type_variation_reference: builtins.int = ...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "srid", b"srid", "type_variation_reference", b"type_variation_reference"
+ ],
+ ) -> None: ...
+
class Variant(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -821,6 +861,8 @@ class DataType(google.protobuf.message.Message):
MAP_FIELD_NUMBER: builtins.int
VARIANT_FIELD_NUMBER: builtins.int
UDT_FIELD_NUMBER: builtins.int
+ GEOMETRY_FIELD_NUMBER: builtins.int
+ GEOGRAPHY_FIELD_NUMBER: builtins.int
UNPARSED_FIELD_NUMBER: builtins.int
TIME_FIELD_NUMBER: builtins.int
@property
@@ -878,6 +920,11 @@ class DataType(google.protobuf.message.Message):
def udt(self) -> global___DataType.UDT:
"""UserDefinedType"""
@property
+ def geometry(self) -> global___DataType.Geometry:
+ """Geospatial types"""
+ @property
+ def geography(self) -> global___DataType.Geography: ...
+ @property
def unparsed(self) -> global___DataType.Unparsed:
"""UnparsedDataType"""
@property
@@ -909,6 +956,8 @@ class DataType(google.protobuf.message.Message):
map: global___DataType.Map | None = ...,
variant: global___DataType.Variant | None = ...,
udt: global___DataType.UDT | None = ...,
+ geometry: global___DataType.Geometry | None = ...,
+ geography: global___DataType.Geography | None = ...,
unparsed: global___DataType.Unparsed | None = ...,
time: global___DataType.Time | None = ...,
) -> None: ...
@@ -937,6 +986,10 @@ class DataType(google.protobuf.message.Message):
b"double",
"float",
b"float",
+ "geography",
+ b"geography",
+ "geometry",
+ b"geometry",
"integer",
b"integer",
"kind",
@@ -996,6 +1049,10 @@ class DataType(google.protobuf.message.Message):
b"double",
"float",
b"float",
+ "geography",
+ b"geography",
+ "geometry",
+ b"geometry",
"integer",
b"integer",
"kind",
@@ -1058,6 +1115,8 @@ class DataType(google.protobuf.message.Message):
"map",
"variant",
"udt",
+ "geometry",
+ "geography",
"unparsed",
"time",
]
diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py
index 21a7c8329a354..ac1d1f5681e3e 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -29,6 +29,7 @@
from typing import (
Optional,
Any,
+ Iterator,
Union,
Dict,
List,
@@ -537,6 +538,7 @@ def createDataFrame(
"spark.sql.session.localRelationCacheThreshold",
"spark.sql.session.localRelationChunkSizeRows",
"spark.sql.session.localRelationChunkSizeBytes",
+ "spark.sql.session.localRelationBatchOfChunksSizeBytes",
"spark.sql.execution.pandas.convertToArrowArraySafely",
"spark.sql.execution.pandas.inferPandasDictAsMap",
"spark.sql.pyspark.inferNestedDictAsStruct.enabled",
@@ -772,10 +774,16 @@ def createDataFrame(
max_chunk_size_bytes = int(
configs["spark.sql.session.localRelationChunkSizeBytes"] # type: ignore[arg-type]
)
+ max_batch_of_chunks_size_bytes = int(
+ configs["spark.sql.session.localRelationBatchOfChunksSizeBytes"] # type: ignore[arg-type] # noqa: E501
+ )
plan: LogicalPlan = local_relation
if cache_threshold <= _table.nbytes:
plan = self._cache_local_relation(
- local_relation, max_chunk_size_rows, max_chunk_size_bytes
+ local_relation,
+ max_chunk_size_rows,
+ max_chunk_size_bytes,
+ max_batch_of_chunks_size_bytes,
)
df = DataFrame(plan, self)
@@ -1054,30 +1062,62 @@ def _cache_local_relation(
local_relation: LocalRelation,
max_chunk_size_rows: int,
max_chunk_size_bytes: int,
+ max_batch_of_chunks_size_bytes: int,
) -> ChunkedCachedLocalRelation:
"""
Cache the local relation at the server side if it has not been cached yet.
- Should only be called on LocalRelations with _table set.
+ This method serializes the input local relation into multiple data chunks and
+ a schema chunk (if the schema is available) and uploads these chunks as artifacts
+ to the server.
+
+ The method collects a batch of chunks of size up to max_batch_of_chunks_size_bytes and
+ uploads them together to the server.
+ Uploading each chunk separately would require an additional RPC call for each chunk.
+ Uploading all chunks together would require materializing all chunks in memory which
+ may cause high memory usage on the client.
+ Uploading batches of chunks is the middle-ground solution.
+
+ Should only be called on a LocalRelation with a non-empty _table.
"""
- assert local_relation._table is not None
+ assert local_relation._table is not None, "table cannot be None"
has_schema = local_relation._schema is not None
- # Serialize table into chunks
- data_chunks = local_relation._serialize_table_chunks(
- max_chunk_size_rows, max_chunk_size_bytes
+ hashes = []
+ current_batch = []
+ current_batch_size = 0
+ if has_schema:
+ schema_chunk = local_relation._serialize_schema()
+ current_batch.append(schema_chunk)
+ current_batch_size += len(schema_chunk)
+
+ data_chunks: Iterator[bytes] = local_relation._serialize_table_chunks(
+ max_chunk_size_rows, min(max_chunk_size_bytes, max_batch_of_chunks_size_bytes)
)
- blobs = data_chunks.copy() # Start with data chunks
- if has_schema:
- blobs.append(local_relation._serialize_schema())
+ for chunk in data_chunks:
+ chunk_size = len(chunk)
- hashes = self._client.cache_artifacts(blobs)
+ # Check if adding this chunk would exceed batch size
+ if (
+ len(current_batch) > 0
+ and current_batch_size + chunk_size > max_batch_of_chunks_size_bytes
+ ):
+ hashes += self._client.cache_artifacts(current_batch)
+ # start a new batch
+ current_batch = []
+ current_batch_size = 0
- # Extract data hashes and schema hash
- data_hashes = hashes[: len(data_chunks)]
- schema_hash = hashes[len(data_chunks)] if has_schema else None
+ current_batch.append(chunk)
+ current_batch_size += chunk_size
+ hashes += self._client.cache_artifacts(current_batch)
+ if has_schema:
+ schema_hash = hashes[0]
+ data_hashes = hashes[1:]
+ else:
+ schema_hash = None
+ data_hashes = hashes
return ChunkedCachedLocalRelation(data_hashes, schema_hash)
def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None:
diff --git a/python/pyspark/sql/connect/tvf.py b/python/pyspark/sql/connect/tvf.py
index 59a4e4fbe344e..ac41b0b21725f 100644
--- a/python/pyspark/sql/connect/tvf.py
+++ b/python/pyspark/sql/connect/tvf.py
@@ -109,6 +109,11 @@ def variant_explode_outer(self, input: "Column") -> "DataFrame":
variant_explode_outer.__doc__ = PySparkTableValuedFunction.variant_explode_outer.__doc__
+ def python_worker_logs(self) -> "DataFrame":
+ return self._fn("python_worker_logs")
+
+ python_worker_logs.__doc__ = PySparkTableValuedFunction.python_worker_logs.__doc__
+
def _fn(self, name: str, *args: "Column") -> "DataFrame":
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.plan import UnresolvedTableValuedFunction
diff --git a/python/pyspark/sql/connect/types.py b/python/pyspark/sql/connect/types.py
index 8f9e7c0561cc0..d3352b618d7c7 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -48,7 +48,10 @@
BinaryType,
BooleanType,
NullType,
+ NumericType,
VariantType,
+ GeographyType,
+ GeometryType,
UserDefinedType,
)
from pyspark.errors import PySparkAssertionError, PySparkValueError
@@ -190,6 +193,10 @@ def pyspark_types_to_proto_types(data_type: DataType) -> pb2.DataType:
ret.array.contains_null = data_type.containsNull
elif isinstance(data_type, VariantType):
ret.variant.CopyFrom(pb2.DataType.Variant())
+ elif isinstance(data_type, GeometryType):
+ ret.geometry.srid = data_type.srid
+ elif isinstance(data_type, GeographyType):
+ ret.geography.srid = data_type.srid
elif isinstance(data_type, UserDefinedType):
json_value = data_type.jsonValue()
ret.udt.type = "udt"
@@ -302,6 +309,18 @@ def proto_schema_to_pyspark_data_type(schema: pb2.DataType) -> DataType:
)
elif schema.HasField("variant"):
return VariantType()
+ elif schema.HasField("geometry"):
+ srid = schema.geometry.srid
+ if srid == GeometryType.MIXED_SRID:
+ return GeometryType("ANY")
+ else:
+ return GeometryType(srid)
+ elif schema.HasField("geography"):
+ srid = schema.geography.srid
+ if srid == GeographyType.MIXED_SRID:
+ return GeographyType("ANY")
+ else:
+ return GeographyType(srid)
elif schema.HasField("udt"):
assert schema.udt.type == "udt"
json_value = {}
@@ -367,15 +386,42 @@ def verify_col_name(name: str, schema: StructType) -> bool:
if parts is None or len(parts) == 0:
return False
- def _quick_verify(parts: List[str], schema: DataType) -> bool:
+ def _quick_verify(parts: List[str], dt: DataType) -> bool:
if len(parts) == 0:
return True
_schema: Optional[StructType] = None
- if isinstance(schema, StructType):
- _schema = schema
- elif isinstance(schema, ArrayType) and isinstance(schema.elementType, StructType):
- _schema = schema.elementType
+ if isinstance(dt, StructType):
+ _schema = dt
+ elif isinstance(dt, ArrayType) and isinstance(dt.elementType, StructType):
+ _schema = dt.elementType
+ else:
+ return False
+
+ part = parts[0]
+ for field in _schema:
+ if field.name == part:
+ return _quick_verify(parts[1:], field.dataType)
+
+ return False
+
+ return _quick_verify(parts, schema)
+
+
+def verify_numeric_col_name(name: str, schema: StructType) -> bool:
+ parts = parse_attr_name(name)
+ if parts is None or len(parts) == 0:
+ return False
+
+ def _quick_verify(parts: List[str], dt: DataType) -> bool:
+ if len(parts) == 0 and isinstance(dt, NumericType):
+ return True
+
+ _schema: Optional[StructType] = None
+ if isinstance(dt, StructType):
+ _schema = dt
+ elif isinstance(dt, ArrayType) and isinstance(dt.elementType, StructType):
+ _schema = dt.elementType
else:
return False
diff --git a/python/pyspark/sql/connect/utils.py b/python/pyspark/sql/connect/utils.py
index a2511836816c9..0e0e042446531 100644
--- a/python/pyspark/sql/connect/utils.py
+++ b/python/pyspark/sql/connect/utils.py
@@ -37,6 +37,7 @@ def check_dependencies(mod_name: str) -> None:
require_minimum_grpc_version()
require_minimum_grpcio_status_version()
require_minimum_googleapis_common_protos_version()
+ require_minimum_zstandard_version()
def require_minimum_grpc_version() -> None:
@@ -96,5 +97,21 @@ def require_minimum_googleapis_common_protos_version() -> None:
) from error
+def require_minimum_zstandard_version() -> None:
+ """Raise ImportError if zstandard is not installed"""
+ minimum_zstandard_version = "0.25.0"
+
+ try:
+ import zstandard # noqa
+ except ImportError as error:
+ raise PySparkImportError(
+ errorClass="PACKAGE_NOT_INSTALLED",
+ messageParameters={
+ "package_name": "zstandard",
+ "minimum_version": str(minimum_zstandard_version),
+ },
+ ) from error
+
+
def get_python_ver() -> str:
return "%d.%d" % sys.version_info[:2]
diff --git a/python/pyspark/sql/conversion.py b/python/pyspark/sql/conversion.py
index a8f621277a0af..f73727d1d5344 100644
--- a/python/pyspark/sql/conversion.py
+++ b/python/pyspark/sql/conversion.py
@@ -28,6 +28,10 @@
BinaryType,
DataType,
DecimalType,
+ GeographyType,
+ Geography,
+ GeometryType,
+ Geometry,
MapType,
NullType,
Row,
@@ -89,6 +93,10 @@ def _need_converter(
return True
elif isinstance(dataType, VariantType):
return True
+ elif isinstance(dataType, GeometryType):
+ return True
+ elif isinstance(dataType, GeographyType):
+ return True
else:
return False
@@ -392,6 +400,34 @@ def convert_variant(value: Any) -> Any:
return convert_variant
+ elif isinstance(dataType, GeographyType):
+
+ def convert_geography(value: Any) -> Any:
+ if value is None:
+ if not nullable:
+ raise PySparkValueError(f"input for {dataType} must not be None")
+ return None
+ elif isinstance(value, Geography):
+ return dataType.toInternal(value)
+ else:
+ raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY")
+
+ return convert_geography
+
+ elif isinstance(dataType, GeometryType):
+
+ def convert_geometry(value: Any) -> Any:
+ if value is None:
+ if not nullable:
+ raise PySparkValueError(f"input for {dataType} must not be None")
+ return None
+ elif isinstance(value, Geometry):
+ return dataType.toInternal(value)
+ else:
+ raise PySparkValueError(errorClass="MALFORMED_GEOMETRY")
+
+ return convert_geometry
+
elif not nullable:
def convert_other(value: Any) -> Any:
@@ -511,6 +547,10 @@ def _need_converter(dataType: DataType) -> bool:
return True
elif isinstance(dataType, VariantType):
return True
+ elif isinstance(dataType, GeographyType):
+ return True
+ elif isinstance(dataType, GeometryType):
+ return True
else:
return False
@@ -719,6 +759,40 @@ def convert_variant(value: Any) -> Any:
return convert_variant
+ elif isinstance(dataType, GeographyType):
+
+ def convert_geography(value: Any) -> Any:
+ if value is None:
+ return None
+ elif (
+ isinstance(value, dict)
+ and all(key in value for key in ["wkb", "srid"])
+ and isinstance(value["wkb"], bytes)
+ and isinstance(value["srid"], int)
+ ):
+ return Geography.fromWKB(value["wkb"], value["srid"])
+ else:
+ raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY")
+
+ return convert_geography
+
+ elif isinstance(dataType, GeometryType):
+
+ def convert_geometry(value: Any) -> Any:
+ if value is None:
+ return None
+ elif (
+ isinstance(value, dict)
+ and all(key in value for key in ["wkb", "srid"])
+ and isinstance(value["wkb"], bytes)
+ and isinstance(value["srid"], int)
+ ):
+ return Geometry.fromWKB(value["wkb"], value["srid"])
+ else:
+ raise PySparkValueError(errorClass="MALFORMED_GEOMETRY")
+
+ return convert_geometry
+
else:
if none_on_identity:
return None
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index ca33539df960b..f8f5cfecd2c65 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -852,7 +852,6 @@ def isEmpty(self) -> bool:
Notes
-----
- - Unlike `count()`, this method does not trigger any computation.
- An empty DataFrame has no rows. It may have columns, but no data.
Examples
@@ -6691,6 +6690,64 @@ def asTable(self) -> TableArg:
-------
:class:`table_arg.TableArg`
A `TableArg` object representing a table argument.
+
+ Examples
+ --------
+ >>> from pyspark.sql.functions import udtf
+ >>>
+ >>> # Create a simple UDTF that processes table data
+ >>> @udtf(returnType="id: int, doubled: int")
+ ... class DoubleUDTF:
+ ... def eval(self, row):
+ ... yield row["id"], row["id"] * 2
+ ...
+ >>> # Create a DataFrame
+ >>> df = spark.createDataFrame([(1,), (2,), (3,)], ["id"])
+ >>>
+ >>> # Use asTable() to pass the DataFrame as a table argument to the UDTF
+ >>> result = DoubleUDTF(df.asTable())
+ >>> result.show()
+ +---+-------+
+ | id|doubled|
+ +---+-------+
+ | 1| 2|
+ | 2| 4|
+ | 3| 6|
+ +---+-------+
+ >>>
+ >>> # Use partitionBy and orderBy to control data partitioning and ordering
+ >>> df2 = spark.createDataFrame(
+ ... [(1, "a"), (1, "b"), (2, "c"), (2, "d")], ["key", "value"]
+ ... )
+ >>>
+ >>> @udtf(returnType="key: int, value: string")
+ ... class ProcessUDTF:
+ ... def eval(self, row):
+ ... yield row["key"], row["value"]
+ ...
+ >>> # Partition by 'key' and order by 'value' within each partition
+ >>> result2 = ProcessUDTF(df2.asTable().partitionBy("key").orderBy("value"))
+ >>> result2.show()
+ +---+-----+
+ |key|value|
+ +---+-----+
+ | 1| a|
+ | 1| b|
+ | 2| c|
+ | 2| d|
+ +---+-----+
+ >>>
+ >>> # Use withSinglePartition to process all data in a single partition
+ >>> result3 = ProcessUDTF(df2.asTable().withSinglePartition().orderBy("value"))
+ >>> result3.show()
+ +---+-----+
+ |key|value|
+ +---+-----+
+ | 1| a|
+ | 1| b|
+ | 2| c|
+ | 2| d|
+ +---+-----+
"""
...
diff --git a/python/pyspark/sql/datasource_internal.py b/python/pyspark/sql/datasource_internal.py
index 6df0be4192ec8..9467bfdf73bc2 100644
--- a/python/pyspark/sql/datasource_internal.py
+++ b/python/pyspark/sql/datasource_internal.py
@@ -28,7 +28,7 @@
SimpleDataSourceStreamReader,
)
from pyspark.sql.types import StructType
-from pyspark.errors import PySparkNotImplementedError
+from pyspark.errors import PySparkException, PySparkNotImplementedError
def _streamReader(datasource: DataSource, schema: StructType) -> "DataSourceStreamReader":
@@ -88,12 +88,36 @@ def initialOffset(self) -> dict:
self.initial_offset = self.simple_reader.initialOffset()
return self.initial_offset
+ def add_result_to_cache(self, start: dict, end: dict, it: Iterator[Tuple]) -> None:
+ """
+ Validates that read() did not return a non-empty batch with end equal to start,
+ which would cause the same batch to be processed repeatedly. When end != start,
+ appends the result to the cache; when end == start with empty iterator, does not
+ cache (avoids unbounded cache growth).
+ """
+ start_str = json.dumps(start)
+ end_str = json.dumps(end)
+ if end_str != start_str:
+ self.cache.append(PrefetchedCacheEntry(start, end, it))
+ return
+ try:
+ next(it)
+ except StopIteration:
+ return
+ raise PySparkException(
+ errorClass="SIMPLE_STREAM_READER_OFFSET_DID_NOT_ADVANCE",
+ messageParameters={
+ "start_offset": start_str,
+ "end_offset": end_str,
+ },
+ )
+
def latestOffset(self) -> dict:
# when query start for the first time, use initial offset as the start offset.
if self.current_offset is None:
self.current_offset = self.initialOffset()
(iter, end) = self.simple_reader.read(self.current_offset)
- self.cache.append(PrefetchedCacheEntry(self.current_offset, end, iter))
+ self.add_result_to_cache(self.current_offset, end, iter)
self.current_offset = end
return end
diff --git a/python/pyspark/sql/functions/__init__.py b/python/pyspark/sql/functions/__init__.py
index e1b320c98f7fe..64446a835d842 100644
--- a/python/pyspark/sql/functions/__init__.py
+++ b/python/pyspark/sql/functions/__init__.py
@@ -372,6 +372,12 @@
"histogram_numeric",
"hll_sketch_agg",
"hll_union_agg",
+ "kll_sketch_agg_bigint",
+ "kll_sketch_agg_double",
+ "kll_sketch_agg_float",
+ "kll_merge_agg_bigint",
+ "kll_merge_agg_double",
+ "kll_merge_agg_float",
"kurtosis",
"last",
"last_value",
@@ -495,6 +501,21 @@
"input_file_block_start",
"input_file_name",
"java_method",
+ "kll_sketch_get_n_bigint",
+ "kll_sketch_get_n_double",
+ "kll_sketch_get_n_float",
+ "kll_sketch_get_quantile_bigint",
+ "kll_sketch_get_quantile_double",
+ "kll_sketch_get_quantile_float",
+ "kll_sketch_get_rank_bigint",
+ "kll_sketch_get_rank_double",
+ "kll_sketch_get_rank_float",
+ "kll_sketch_merge_bigint",
+ "kll_sketch_merge_double",
+ "kll_sketch_merge_float",
+ "kll_sketch_to_string_bigint",
+ "kll_sketch_to_string_double",
+ "kll_sketch_to_string_float",
"monotonically_increasing_id",
"raise_error",
"reflect",
@@ -522,6 +543,13 @@
"UserDefinedFunction",
"UserDefinedTableFunction",
"arrow_udf",
+ # Geospatial ST Functions
+ "st_asbinary",
+ "st_geogfromwkb",
+ "st_geomfromwkb",
+ "st_setsrid",
+ "st_srid",
+ # Call Functions
"call_udf",
"pandas_udf",
"udf",
diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py
index 1ac3ac23e888e..63b4ad64b5792 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -25901,6 +25901,145 @@ def bucket(numBuckets: Union[Column, int], col: "ColumnOrName") -> Column:
return partitioning.bucket(numBuckets, col)
+# Geospatial ST Functions
+
+
+@_try_remote_functions
+def st_asbinary(geo: "ColumnOrName") -> Column:
+ """Returns the input GEOGRAPHY or GEOMETRY value in WKB format.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ geo : :class:`~pyspark.sql.Column` or str
+ A geospatial value, either a GEOGRAPHY or a GEOMETRY.
+
+ Examples
+ --------
+
+ Example 1: Getting WKB from GEOGRAPHY.
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([(bytes.fromhex('0101000000000000000000F03F0000000000000040'),)], ['wkb']) # noqa
+ >>> df.select(sf.hex(sf.st_asbinary(sf.st_geogfromwkb('wkb')))).collect()
+ [Row(hex(st_asbinary(st_geogfromwkb(wkb)))='0101000000000000000000F03F0000000000000040')]
+
+ Example 2: Getting WKB from GEOMETRY.
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([(bytes.fromhex('0101000000000000000000F03F0000000000000040'),)], ['wkb']) # noqa
+ >>> df.select(sf.hex(sf.st_asbinary(sf.st_geomfromwkb('wkb')))).collect()
+ [Row(hex(st_asbinary(st_geomfromwkb(wkb)))='0101000000000000000000F03F0000000000000040')]
+ """
+ return _invoke_function_over_columns("st_asbinary", geo)
+
+
+@_try_remote_functions
+def st_geogfromwkb(wkb: "ColumnOrName") -> Column:
+ """Parses the input WKB description and returns the corresponding GEOGRAPHY value.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ wkb : :class:`~pyspark.sql.Column` or str
+ A BINARY value in WKB format, representing a GEOGRAPHY value.
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([(bytes.fromhex('0101000000000000000000F03F0000000000000040'),)], ['wkb']) # noqa
+ >>> df.select(sf.hex(sf.st_asbinary(sf.st_geogfromwkb('wkb')))).collect()
+ [Row(hex(st_asbinary(st_geogfromwkb(wkb)))='0101000000000000000000F03F0000000000000040')]
+ """
+ return _invoke_function_over_columns("st_geogfromwkb", wkb)
+
+
+@_try_remote_functions
+def st_geomfromwkb(wkb: "ColumnOrName") -> Column:
+ """Parses the input WKB description and returns the corresponding GEOMETRY value.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ wkb : :class:`~pyspark.sql.Column` or str
+ A BINARY value in WKB format, representing a GEOMETRY value.
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([(bytes.fromhex('0101000000000000000000F03F0000000000000040'),)], ['wkb']) # noqa
+ >>> df.select(sf.hex(sf.st_asbinary(sf.st_geomfromwkb('wkb')))).collect()
+ [Row(hex(st_asbinary(st_geomfromwkb(wkb)))='0101000000000000000000F03F0000000000000040')]
+ """
+ return _invoke_function_over_columns("st_geomfromwkb", wkb)
+
+
+@_try_remote_functions
+def st_setsrid(geo: "ColumnOrName", srid: Union["ColumnOrName", int]) -> Column:
+ """Returns a new GEOGRAPHY or GEOMETRY value whose SRID is the specified SRID value.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ geo : :class:`~pyspark.sql.Column` or str
+ A geospatial value, either a GEOGRAPHY or a GEOMETRY.
+ srid : :class:`~pyspark.sql.Column` or int
+ An INTEGER representing the new SRID of the geospatial value.
+
+ Examples
+ --------
+
+ Example 1: Setting the SRID on GEOGRAPHY with SRID from another column.
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([(bytes.fromhex('0101000000000000000000F03F0000000000000040'), 4326)], ['wkb', 'srid']) # noqa
+ >>> df.select(sf.st_srid(sf.st_setsrid(sf.st_geogfromwkb('wkb'), 'srid'))).collect()
+ [Row(st_srid(st_setsrid(st_geogfromwkb(wkb), srid))=4326)]
+
+ Example 2: Setting the SRID on GEOMETRY with SRID as an integer literal.
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([(bytes.fromhex('0101000000000000000000F03F0000000000000040'),)], ['wkb']) # noqa
+ >>> df.select(sf.st_srid(sf.st_setsrid(sf.st_geomfromwkb('wkb'), 4326))).collect()
+ [Row(st_srid(st_setsrid(st_geomfromwkb(wkb), 4326))=4326)]
+ """
+ srid = _enum_to_value(srid)
+ srid = lit(srid) if isinstance(srid, int) else srid
+ return _invoke_function_over_columns("st_setsrid", geo, srid)
+
+
+@_try_remote_functions
+def st_srid(geo: "ColumnOrName") -> Column:
+ """Returns the SRID of the input GEOGRAPHY or GEOMETRY value.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ geo : :class:`~pyspark.sql.Column` or str
+ A geospatial value, either a GEOGRAPHY or a GEOMETRY.
+
+ Examples
+ --------
+
+ Example 1: Getting the SRID of GEOGRAPHY.
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([(bytes.fromhex('0101000000000000000000F03F0000000000000040'),)], ['wkb']) # noqa
+ >>> df.select(sf.st_srid(sf.st_geogfromwkb('wkb'))).collect()
+ [Row(st_srid(st_geogfromwkb(wkb))=4326)]
+
+ Example 2: Getting the SRID of GEOMETRY.
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([(bytes.fromhex('0101000000000000000000F03F0000000000000040'),)], ['wkb']) # noqa
+ >>> df.select(sf.st_srid(sf.st_geomfromwkb('wkb'))).collect()
+ [Row(st_srid(st_geomfromwkb(wkb))=0)]
+ """
+ return _invoke_function_over_columns("st_srid", geo)
+
+
+# Call Functions
+
+
@_try_remote_functions
def call_udf(udfName: str, *cols: "ColumnOrName") -> Column:
"""
@@ -26438,6 +26577,756 @@ def theta_intersection_agg(col: "ColumnOrName") -> Column:
return _invoke_function_over_columns(fn, col)
+@_try_remote_functions
+def kll_sketch_agg_bigint(
+ col: "ColumnOrName",
+ k: Optional[Union[int, Column]] = None,
+) -> Column:
+ """
+ Aggregate function: returns the compact binary representation of the Datasketches
+ KllLongsSketch built with the values in the input column. The optional k parameter
+ controls the size and accuracy of the sketch (default 200, range 8-65535).
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or column name
+ The column containing bigint values to aggregate
+ k : :class:`~pyspark.sql.Column` or int, optional
+ The k parameter that controls size and accuracy (default 200, range 8-65535)
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ The binary representation of the KllLongsSketch.
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([1,2,3,4,5], "INT")
+ >>> result = df.agg(sf.kll_sketch_agg_bigint("value")).first()[0]
+ >>> result is not None and len(result) > 0
+ True
+ """
+ fn = "kll_sketch_agg_bigint"
+ if k is None:
+ return _invoke_function_over_columns(fn, col)
+ else:
+ return _invoke_function_over_columns(fn, col, lit(k))
+
+
+@_try_remote_functions
+def kll_sketch_agg_float(
+ col: "ColumnOrName",
+ k: Optional[Union[int, Column]] = None,
+) -> Column:
+ """
+ Aggregate function: returns the compact binary representation of the Datasketches
+ KllFloatsSketch built with the values in the input column. The optional k parameter
+ controls the size and accuracy of the sketch (default 200, range 8-65535).
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or column name
+ The column containing float values to aggregate
+ k : :class:`~pyspark.sql.Column` or int, optional
+ The k parameter that controls size and accuracy (default 200, range 8-65535)
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ The binary representation of the KllFloatsSketch.
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([1.0,2.0,3.0,4.0,5.0], "FLOAT")
+ >>> result = df.agg(sf.kll_sketch_agg_float("value")).first()[0]
+ >>> result is not None and len(result) > 0
+ True
+ """
+ fn = "kll_sketch_agg_float"
+ if k is None:
+ return _invoke_function_over_columns(fn, col)
+ else:
+ return _invoke_function_over_columns(fn, col, lit(k))
+
+
+@_try_remote_functions
+def kll_sketch_agg_double(
+ col: "ColumnOrName",
+ k: Optional[Union[int, Column]] = None,
+) -> Column:
+ """
+ Aggregate function: returns the compact binary representation of the Datasketches
+ KllDoublesSketch built with the values in the input column. The optional k parameter
+ controls the size and accuracy of the sketch (default 200, range 8-65535).
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or column name
+ The column containing double values to aggregate
+ k : :class:`~pyspark.sql.Column` or int, optional
+ The k parameter that controls size and accuracy (default 200, range 8-65535)
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ The binary representation of the KllDoublesSketch.
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([1.0,2.0,3.0,4.0,5.0], "DOUBLE")
+ >>> result = df.agg(sf.kll_sketch_agg_double("value")).first()[0]
+ >>> result is not None and len(result) > 0
+ True
+ """
+ fn = "kll_sketch_agg_double"
+ if k is None:
+ return _invoke_function_over_columns(fn, col)
+ else:
+ return _invoke_function_over_columns(fn, col, lit(k))
+
+
+@_try_remote_functions
+def kll_merge_agg_bigint(
+ col: "ColumnOrName",
+ k: Optional[Union[int, Column]] = None,
+) -> Column:
+ """
+ Aggregate function: merges binary KllLongsSketch representations and returns the
+ merged sketch. The optional k parameter controls the size and accuracy of the merged
+ sketch (range 8-65535). If k is not specified, the merged sketch adopts the k value
+ from the first input sketch.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or column name
+ The column containing binary KllLongsSketch representations
+ k : :class:`~pyspark.sql.Column` or int, optional
+ The k parameter that controls size and accuracy (range 8-65535)
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ The merged binary representation of the KllLongsSketch.
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df1 = spark.createDataFrame([1,2,3], "INT")
+ >>> df2 = spark.createDataFrame([4,5,6], "INT")
+ >>> sketch1 = df1.agg(sf.kll_sketch_agg_bigint("value").alias("sketch"))
+ >>> sketch2 = df2.agg(sf.kll_sketch_agg_bigint("value").alias("sketch"))
+ >>> merged = sketch1.union(sketch2).agg(sf.kll_merge_agg_bigint("sketch").alias("merged"))
+ >>> n = merged.select(sf.kll_sketch_get_n_bigint("merged")).first()[0]
+ >>> n
+ 6
+ """
+ fn = "kll_merge_agg_bigint"
+ if k is None:
+ return _invoke_function_over_columns(fn, col)
+ else:
+ return _invoke_function_over_columns(fn, col, lit(k))
+
+
+@_try_remote_functions
+def kll_merge_agg_float(
+ col: "ColumnOrName",
+ k: Optional[Union[int, Column]] = None,
+) -> Column:
+ """
+ Aggregate function: merges binary KllFloatsSketch representations and returns the
+ merged sketch. The optional k parameter controls the size and accuracy of the merged
+ sketch (range 8-65535). If k is not specified, the merged sketch adopts the k value
+ from the first input sketch.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or column name
+ The column containing binary KllFloatsSketch representations
+ k : :class:`~pyspark.sql.Column` or int, optional
+ The k parameter that controls size and accuracy (range 8-65535)
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ The merged binary representation of the KllFloatsSketch.
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df1 = spark.createDataFrame([1.0,2.0,3.0], "FLOAT")
+ >>> df2 = spark.createDataFrame([4.0,5.0,6.0], "FLOAT")
+ >>> sketch1 = df1.agg(sf.kll_sketch_agg_float("value").alias("sketch"))
+ >>> sketch2 = df2.agg(sf.kll_sketch_agg_float("value").alias("sketch"))
+ >>> merged = sketch1.union(sketch2).agg(sf.kll_merge_agg_float("sketch").alias("merged"))
+ >>> n = merged.select(sf.kll_sketch_get_n_float("merged")).first()[0]
+ >>> n
+ 6
+ """
+ fn = "kll_merge_agg_float"
+ if k is None:
+ return _invoke_function_over_columns(fn, col)
+ else:
+ return _invoke_function_over_columns(fn, col, lit(k))
+
+
+@_try_remote_functions
+def kll_merge_agg_double(
+ col: "ColumnOrName",
+ k: Optional[Union[int, Column]] = None,
+) -> Column:
+ """
+ Aggregate function: merges binary KllDoublesSketch representations and returns the
+ merged sketch. The optional k parameter controls the size and accuracy of the merged
+ sketch (range 8-65535). If k is not specified, the merged sketch adopts the k value
+ from the first input sketch.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or column name
+ The column containing binary KllDoublesSketch representations
+ k : :class:`~pyspark.sql.Column` or int, optional
+ The k parameter that controls size and accuracy (range 8-65535)
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ The merged binary representation of the KllDoublesSketch.
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df1 = spark.createDataFrame([1.0,2.0,3.0], "DOUBLE")
+ >>> df2 = spark.createDataFrame([4.0,5.0,6.0], "DOUBLE")
+ >>> sketch1 = df1.agg(sf.kll_sketch_agg_double("value").alias("sketch"))
+ >>> sketch2 = df2.agg(sf.kll_sketch_agg_double("value").alias("sketch"))
+ >>> merged = sketch1.union(sketch2).agg(sf.kll_merge_agg_double("sketch").alias("merged"))
+ >>> n = merged.select(sf.kll_sketch_get_n_double("merged")).first()[0]
+ >>> n
+ 6
+ """
+ fn = "kll_merge_agg_double"
+ if k is None:
+ return _invoke_function_over_columns(fn, col)
+ else:
+ return _invoke_function_over_columns(fn, col, lit(k))
+
+
+@_try_remote_functions
+def kll_sketch_to_string_bigint(col: "ColumnOrName") -> Column:
+ """
+ Returns a string with human readable summary information about the KLL bigint sketch.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or column name
+ The KLL bigint sketch binary representation
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ A string representation of the sketch.
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([1,2,3,4,5], "INT")
+ >>> sketch_df = df.agg(sf.kll_sketch_agg_bigint("value").alias("sketch"))
+ >>> result = sketch_df.select(sf.kll_sketch_to_string_bigint("sketch")).first()[0]
+ >>> "kll" in result.lower()
+ True
+ """
+ fn = "kll_sketch_to_string_bigint"
+ return _invoke_function_over_columns(fn, col)
+
+
+@_try_remote_functions
+def kll_sketch_to_string_float(col: "ColumnOrName") -> Column:
+ """
+ Returns a string with human readable summary information about the KLL float sketch.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or column name
+ The KLL float sketch binary representation
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ A string representation of the sketch.
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([1.0,2.0,3.0,4.0,5.0], "FLOAT")
+ >>> sketch_df = df.agg(sf.kll_sketch_agg_float("value").alias("sketch"))
+ >>> result = sketch_df.select(sf.kll_sketch_to_string_float("sketch")).first()[0]
+ >>> "kll" in result.lower()
+ True
+ """
+ fn = "kll_sketch_to_string_float"
+ return _invoke_function_over_columns(fn, col)
+
+
+@_try_remote_functions
+def kll_sketch_to_string_double(col: "ColumnOrName") -> Column:
+ """
+ Returns a string with human readable summary information about the KLL double sketch.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or column name
+ The KLL double sketch binary representation
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ A string representation of the sketch.
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([1.0,2.0,3.0,4.0,5.0], "DOUBLE")
+ >>> sketch_df = df.agg(sf.kll_sketch_agg_double("value").alias("sketch"))
+ >>> result = sketch_df.select(sf.kll_sketch_to_string_double("sketch")).first()[0]
+ >>> "kll" in result.lower()
+ True
+ """
+ fn = "kll_sketch_to_string_double"
+ return _invoke_function_over_columns(fn, col)
+
+
+@_try_remote_functions
+def kll_sketch_get_n_bigint(col: "ColumnOrName") -> Column:
+ """
+ Returns the number of items collected in the KLL bigint sketch.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or column name
+ The KLL bigint sketch binary representation
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ The count of items in the sketch.
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([1,2,3,4,5], "INT")
+ >>> sketch_df = df.agg(sf.kll_sketch_agg_bigint("value").alias("sketch"))
+ >>> sketch_df.select(sf.kll_sketch_get_n_bigint("sketch")).show()
+ +-------------------------------+
+ |kll_sketch_get_n_bigint(sketch)|
+ +-------------------------------+
+ | 5|
+ +-------------------------------+
+ """
+ fn = "kll_sketch_get_n_bigint"
+ return _invoke_function_over_columns(fn, col)
+
+
+@_try_remote_functions
+def kll_sketch_get_n_float(col: "ColumnOrName") -> Column:
+ """
+ Returns the number of items collected in the KLL float sketch.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or column name
+ The KLL float sketch binary representation
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ The count of items in the sketch.
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([1.0,2.0,3.0,4.0,5.0], "FLOAT")
+ >>> sketch_df = df.agg(sf.kll_sketch_agg_float("value").alias("sketch"))
+ >>> sketch_df.select(sf.kll_sketch_get_n_float("sketch")).show()
+ +------------------------------+
+ |kll_sketch_get_n_float(sketch)|
+ +------------------------------+
+ | 5|
+ +------------------------------+
+ """
+ fn = "kll_sketch_get_n_float"
+ return _invoke_function_over_columns(fn, col)
+
+
+@_try_remote_functions
+def kll_sketch_get_n_double(col: "ColumnOrName") -> Column:
+ """
+ Returns the number of items collected in the KLL double sketch.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ col : :class:`~pyspark.sql.Column` or column name
+ The KLL double sketch binary representation
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ The count of items in the sketch.
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([1.0,2.0,3.0,4.0,5.0], "DOUBLE")
+ >>> sketch_df = df.agg(sf.kll_sketch_agg_double("value").alias("sketch"))
+ >>> sketch_df.select(sf.kll_sketch_get_n_double("sketch")).show()
+ +-------------------------------+
+ |kll_sketch_get_n_double(sketch)|
+ +-------------------------------+
+ | 5|
+ +-------------------------------+
+ """
+ fn = "kll_sketch_get_n_double"
+ return _invoke_function_over_columns(fn, col)
+
+
+@_try_remote_functions
+def kll_sketch_merge_bigint(left: "ColumnOrName", right: "ColumnOrName") -> Column:
+ """
+ Merges two KLL bigint sketch buffers together into one.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ left : :class:`~pyspark.sql.Column` or column name
+ The first KLL bigint sketch
+ right : :class:`~pyspark.sql.Column` or column name
+ The second KLL bigint sketch
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ The merged KLL sketch.
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([1,2,3,4,5], "INT")
+ >>> sketch_df = df.agg(sf.kll_sketch_agg_bigint("value").alias("sketch"))
+ >>> result = sketch_df.select(sf.kll_sketch_merge_bigint("sketch", "sketch")).first()[0]
+ >>> result is not None and len(result) > 0
+ True
+ """
+ fn = "kll_sketch_merge_bigint"
+ return _invoke_function_over_columns(fn, left, right)
+
+
+@_try_remote_functions
+def kll_sketch_merge_float(left: "ColumnOrName", right: "ColumnOrName") -> Column:
+ """
+ Merges two KLL float sketch buffers together into one.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ left : :class:`~pyspark.sql.Column` or column name
+ The first KLL float sketch
+ right : :class:`~pyspark.sql.Column` or column name
+ The second KLL float sketch
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ The merged KLL sketch.
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([1.0,2.0,3.0,4.0,5.0], "FLOAT")
+ >>> sketch_df = df.agg(sf.kll_sketch_agg_float("value").alias("sketch"))
+ >>> result = sketch_df.select(sf.kll_sketch_merge_float("sketch", "sketch")).first()[0]
+ >>> result is not None and len(result) > 0
+ True
+ """
+ fn = "kll_sketch_merge_float"
+ return _invoke_function_over_columns(fn, left, right)
+
+
+@_try_remote_functions
+def kll_sketch_merge_double(left: "ColumnOrName", right: "ColumnOrName") -> Column:
+ """
+ Merges two KLL double sketch buffers together into one.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ left : :class:`~pyspark.sql.Column` or column name
+ The first KLL double sketch
+ right : :class:`~pyspark.sql.Column` or column name
+ The second KLL double sketch
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ The merged KLL sketch.
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([1.0,2.0,3.0,4.0,5.0], "DOUBLE")
+ >>> sketch_df = df.agg(sf.kll_sketch_agg_double("value").alias("sketch"))
+ >>> result = sketch_df.select(sf.kll_sketch_merge_double("sketch", "sketch")).first()[0]
+ >>> result is not None and len(result) > 0
+ True
+ """
+ fn = "kll_sketch_merge_double"
+ return _invoke_function_over_columns(fn, left, right)
+
+
+@_try_remote_functions
+def kll_sketch_get_quantile_bigint(sketch: "ColumnOrName", rank: "ColumnOrName") -> Column:
+ """
+ Extracts a quantile value from a KLL bigint sketch given an input rank value.
+ The rank can be a single value or an array.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ sketch : :class:`~pyspark.sql.Column` or column name
+ The KLL bigint sketch binary representation
+ rank : :class:`~pyspark.sql.Column` or column name
+ The rank value(s) to extract (between 0.0 and 1.0)
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ The quantile value(s).
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([1,2,3,4,5], "INT")
+ >>> sketch_df = df.agg(sf.kll_sketch_agg_bigint("value").alias("sketch"))
+ >>> sketch_df.select(sf.kll_sketch_get_quantile_bigint("sketch", sf.lit(0.5))).show()
+ +-------------------------------------------+
+ |kll_sketch_get_quantile_bigint(sketch, 0.5)|
+ +-------------------------------------------+
+ | 3|
+ +-------------------------------------------+
+ """
+ fn = "kll_sketch_get_quantile_bigint"
+ return _invoke_function_over_columns(fn, sketch, rank)
+
+
+@_try_remote_functions
+def kll_sketch_get_quantile_float(sketch: "ColumnOrName", rank: "ColumnOrName") -> Column:
+ """
+ Extracts a quantile value from a KLL float sketch given an input rank value.
+ The rank can be a single value or an array.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ sketch : :class:`~pyspark.sql.Column` or column name
+ The KLL float sketch binary representation
+ rank : :class:`~pyspark.sql.Column` or column name
+ The rank value(s) to extract (between 0.0 and 1.0)
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ The quantile value(s).
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([1.0,2.0,3.0,4.0,5.0], "FLOAT")
+ >>> sketch_df = df.agg(sf.kll_sketch_agg_float("value").alias("sketch"))
+ >>> sketch_df.select(sf.kll_sketch_get_quantile_float("sketch", sf.lit(0.5))).show()
+ +------------------------------------------+
+ |kll_sketch_get_quantile_float(sketch, 0.5)|
+ +------------------------------------------+
+ | 3.0|
+ +------------------------------------------+
+ """
+ fn = "kll_sketch_get_quantile_float"
+ return _invoke_function_over_columns(fn, sketch, rank)
+
+
+@_try_remote_functions
+def kll_sketch_get_quantile_double(sketch: "ColumnOrName", rank: "ColumnOrName") -> Column:
+ """
+ Extracts a quantile value from a KLL double sketch given an input rank value.
+ The rank can be a single value or an array.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ sketch : :class:`~pyspark.sql.Column` or column name
+ The KLL double sketch binary representation
+ rank : :class:`~pyspark.sql.Column` or column name
+ The rank value(s) to extract (between 0.0 and 1.0)
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ The quantile value(s).
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([1.0,2.0,3.0,4.0,5.0], "DOUBLE")
+ >>> sketch_df = df.agg(sf.kll_sketch_agg_double("value").alias("sketch"))
+ >>> sketch_df.select(sf.kll_sketch_get_quantile_double("sketch", sf.lit(0.5))).show()
+ +-------------------------------------------+
+ |kll_sketch_get_quantile_double(sketch, 0.5)|
+ +-------------------------------------------+
+ | 3.0|
+ +-------------------------------------------+
+ """
+ fn = "kll_sketch_get_quantile_double"
+ return _invoke_function_over_columns(fn, sketch, rank)
+
+
+@_try_remote_functions
+def kll_sketch_get_rank_bigint(sketch: "ColumnOrName", quantile: "ColumnOrName") -> Column:
+ """
+ Extracts a rank value from a KLL bigint sketch given an input quantile value.
+ The quantile can be a single value or an array.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ sketch : :class:`~pyspark.sql.Column` or column name
+ The KLL bigint sketch binary representation
+ quantile : :class:`~pyspark.sql.Column` or column name
+ The quantile value(s) to lookup
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ The rank value(s) (between 0.0 and 1.0).
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([1,2,3,4,5], "INT")
+ >>> sketch_df = df.agg(sf.kll_sketch_agg_bigint("value").alias("sketch"))
+ >>> sketch_df.select(sf.kll_sketch_get_rank_bigint("sketch", sf.lit(3))).show()
+ +-------------------------------------+
+ |kll_sketch_get_rank_bigint(sketch, 3)|
+ +-------------------------------------+
+ | 0.6|
+ +-------------------------------------+
+ """
+ fn = "kll_sketch_get_rank_bigint"
+ return _invoke_function_over_columns(fn, sketch, quantile)
+
+
+@_try_remote_functions
+def kll_sketch_get_rank_float(sketch: "ColumnOrName", quantile: "ColumnOrName") -> Column:
+ """
+ Extracts a rank value from a KLL float sketch given an input quantile value.
+ The quantile can be a single value or an array.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ sketch : :class:`~pyspark.sql.Column` or column name
+ The KLL float sketch binary representation
+ quantile : :class:`~pyspark.sql.Column` or column name
+ The quantile value(s) to lookup
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ The rank value(s) (between 0.0 and 1.0).
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([1.0,2.0,3.0,4.0,5.0], "FLOAT")
+ >>> sketch_df = df.agg(sf.kll_sketch_agg_float("value").alias("sketch"))
+ >>> sketch_df.select(sf.kll_sketch_get_rank_float("sketch", sf.lit(3.0))).show()
+ +--------------------------------------+
+ |kll_sketch_get_rank_float(sketch, 3.0)|
+ +--------------------------------------+
+ | 0.6|
+ +--------------------------------------+
+ """
+ fn = "kll_sketch_get_rank_float"
+ return _invoke_function_over_columns(fn, sketch, quantile)
+
+
+@_try_remote_functions
+def kll_sketch_get_rank_double(sketch: "ColumnOrName", quantile: "ColumnOrName") -> Column:
+ """
+ Extracts a rank value from a KLL double sketch given an input quantile value.
+ The quantile can be a single value or an array.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ sketch : :class:`~pyspark.sql.Column` or column name
+ The KLL double sketch binary representation
+ quantile : :class:`~pyspark.sql.Column` or column name
+ The quantile value(s) to lookup
+
+ Returns
+ -------
+ :class:`~pyspark.sql.Column`
+ The rank value(s) (between 0.0 and 1.0).
+
+ Examples
+ --------
+ >>> from pyspark.sql import functions as sf
+ >>> df = spark.createDataFrame([1.0,2.0,3.0,4.0,5.0], "DOUBLE")
+ >>> sketch_df = df.agg(sf.kll_sketch_agg_double("value").alias("sketch"))
+ >>> sketch_df.select(sf.kll_sketch_get_rank_double("sketch", sf.lit(3.0))).show()
+ +---------------------------------------+
+ |kll_sketch_get_rank_double(sketch, 3.0)|
+ +---------------------------------------+
+ | 0.6|
+ +---------------------------------------+
+ """
+ fn = "kll_sketch_get_rank_double"
+ return _invoke_function_over_columns(fn, sketch, quantile)
+
+
@_try_remote_functions
def theta_sketch_estimate(col: "ColumnOrName") -> Column:
"""
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index 05021aabb50f8..939f7ff6b610c 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -456,7 +456,7 @@ def pivot(self, pivot_col: str, values: Optional[List["LiteralType"]] = None) ->
Examples
--------
- >>> from pyspark.sql import Row
+ >>> from pyspark.sql import Row, functions as sf
>>> df1 = spark.createDataFrame([
... Row(course="dotNET", year=2012, earnings=10000),
... Row(course="Java", year=2012, earnings=20000),
@@ -474,28 +474,30 @@ def pivot(self, pivot_col: str, values: Optional[List["LiteralType"]] = None) ->
|dotNET|2013| 48000|
| Java|2013| 30000|
+------+----+--------+
+
>>> df2 = spark.createDataFrame([
... Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=10000)),
... Row(training="junior", sales=Row(course="Java", year=2012, earnings=20000)),
... Row(training="expert", sales=Row(course="dotNET", year=2012, earnings=5000)),
... Row(training="junior", sales=Row(course="dotNET", year=2013, earnings=48000)),
... Row(training="expert", sales=Row(course="Java", year=2013, earnings=30000)),
- ... ]) # doctest: +SKIP
- >>> df2.show() # doctest: +SKIP
- +--------+--------------------+
- |training| sales|
- +--------+--------------------+
- | expert|{dotNET, 2012, 10...|
- | junior| {Java, 2012, 20000}|
- | expert|{dotNET, 2012, 5000}|
- | junior|{dotNET, 2013, 48...|
- | expert| {Java, 2013, 30000}|
- +--------+--------------------+
+ ... ])
+ >>> df2.show(truncate=False)
+ +--------+---------------------+
+ |training|sales |
+ +--------+---------------------+
+ |expert |{dotNET, 2012, 10000}|
+ |junior |{Java, 2012, 20000} |
+ |expert |{dotNET, 2012, 5000} |
+ |junior |{dotNET, 2013, 48000}|
+ |expert |{Java, 2013, 30000} |
+ +--------+---------------------+
Compute the sum of earnings for each year by course with each course as a separate column
>>> df1.groupBy("year").pivot(
- ... "course", ["dotNET", "Java"]).sum("earnings").sort("year").show()
+ ... "course", ["dotNET", "Java"]
+ ... ).sum("earnings").sort("year").show()
+----+------+-----+
|year|dotNET| Java|
+----+------+-----+
@@ -512,9 +514,10 @@ def pivot(self, pivot_col: str, values: Optional[List["LiteralType"]] = None) ->
|2012|20000| 15000|
|2013|30000| 48000|
+----+-----+------+
- >>> df2.groupBy(
- ... "sales.year").pivot("sales.course").sum("sales.earnings").sort("year").show()
- ... # doctest: +SKIP
+
+ >>> df2.groupBy("sales.year").pivot(
+ ... "sales.course"
+ ... ).agg(sf.sum("sales.earnings")).sort("year").show()
+----+-----+------+
|year| Java|dotNET|
+----+-----+------+
diff --git a/python/pyspark/sql/metrics.py b/python/pyspark/sql/metrics.py
index 4ab9b041e3135..50d8ae444c471 100644
--- a/python/pyspark/sql/metrics.py
+++ b/python/pyspark/sql/metrics.py
@@ -68,6 +68,20 @@ def value(self) -> Union[int, float]:
def metric_type(self) -> str:
return self._type
+ def to_dict(self) -> Dict[str, Any]:
+ """Return a JSON-serializable dictionary representation of this metric value.
+
+ Returns
+ -------
+ dict
+ A dictionary with keys 'name', 'value', and 'type'.
+ """
+ return {
+ "name": self._name,
+ "value": self._value,
+ "type": self._type,
+ }
+
class PlanMetrics:
"""Represents a particular plan node and the associated metrics of this node."""
@@ -97,6 +111,21 @@ def parent_plan_id(self) -> int:
def metrics(self) -> List[MetricValue]:
return self._metrics
+ def to_dict(self) -> Dict[str, Any]:
+ """Return a JSON-serializable dictionary representation of this plan metrics.
+
+ Returns
+ -------
+ dict
+ A dictionary with keys 'name', 'plan_id', 'parent_plan_id', and 'metrics'.
+ """
+ return {
+ "name": self._name,
+ "plan_id": self._id,
+ "parent_plan_id": self._parent_id,
+ "metrics": [m.to_dict() for m in self._metrics],
+ }
+
class CollectedMetrics:
@dataclasses.dataclass
diff --git a/python/pyspark/sql/pandas/typehints.py b/python/pyspark/sql/pandas/typehints.py
index 18858ab0cf686..7c95feee0cfea 100644
--- a/python/pyspark/sql/pandas/typehints.py
+++ b/python/pyspark/sql/pandas/typehints.py
@@ -353,6 +353,9 @@ def infer_group_arrow_eval_type(
return_annotation, parameter_check_func=lambda t: t == pa.RecordBatch
)
)
+ if is_iterator_batch:
+ return PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
+
# Tuple[pa.Scalar, ...], Iterator[pa.RecordBatch] -> Iterator[pa.RecordBatch]
is_iterator_batch_with_keys = (
len(parameters_sig) == 2
@@ -364,19 +367,21 @@ def infer_group_arrow_eval_type(
return_annotation, parameter_check_func=lambda t: t == pa.RecordBatch
)
)
-
- if is_iterator_batch or is_iterator_batch_with_keys:
+ if is_iterator_batch_with_keys:
return PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF
# pa.Table -> pa.Table
is_table = (
len(parameters_sig) == 1 and parameters_sig[0] == pa.Table and return_annotation == pa.Table
)
+ if is_table:
+ return PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
+
# Tuple[pa.Scalar, ...], pa.Table -> pa.Table
is_table_with_keys = (
len(parameters_sig) == 2 and parameters_sig[1] == pa.Table and return_annotation == pa.Table
)
- if is_table or is_table_with_keys:
+ if is_table_with_keys:
return PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF
return None
@@ -441,6 +446,9 @@ def infer_group_pandas_eval_type(
return_annotation, parameter_check_func=lambda t: t == pd.DataFrame
)
)
+ if is_iterator_dataframe:
+ return PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF
+
# Tuple[Any, ...], Iterator[pd.DataFrame] -> Iterator[pd.DataFrame]
is_iterator_dataframe_with_keys = (
len(parameters_sig) == 2
@@ -452,8 +460,7 @@ def infer_group_pandas_eval_type(
return_annotation, parameter_check_func=lambda t: t == pd.DataFrame
)
)
-
- if is_iterator_dataframe or is_iterator_dataframe_with_keys:
+ if is_iterator_dataframe_with_keys:
return PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF
# pd.DataFrame -> pd.DataFrame
@@ -462,13 +469,16 @@ def infer_group_pandas_eval_type(
and parameters_sig[0] == pd.DataFrame
and return_annotation == pd.DataFrame
)
+ if is_dataframe:
+ return PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
+
# Tuple[Any, ...], pd.DataFrame -> pd.DataFrame
is_dataframe_with_keys = (
len(parameters_sig) == 2
and parameters_sig[1] == pd.DataFrame
and return_annotation == pd.DataFrame
)
- if is_dataframe or is_dataframe_with_keys:
+ if is_dataframe_with_keys:
return PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
return None
@@ -512,11 +522,9 @@ def check_iterator_annotation(
def check_union_annotation(
annotation: Any, parameter_check_func: Optional[Callable[[Any], bool]] = None
) -> bool:
- import typing
-
# Note that we cannot rely on '__origin__' in other type hints as it has changed from version
# to version.
origin = getattr(annotation, "__origin__", None)
- return origin == typing.Union and (
+ return origin == Union and (
parameter_check_func is None or all(map(parameter_check_func, annotation.__args__))
)
diff --git a/python/pyspark/sql/pandas/types.py b/python/pyspark/sql/pandas/types.py
index 327e3941d9386..9583c8ac72888 100644
--- a/python/pyspark/sql/pandas/types.py
+++ b/python/pyspark/sql/pandas/types.py
@@ -23,6 +23,8 @@
import itertools
from typing import Any, Callable, Iterable, List, Optional, Union, TYPE_CHECKING
+from pyspark.errors import PySparkTypeError, UnsupportedOperationException, PySparkValueError
+from pyspark.loose_version import LooseVersion
from pyspark.sql.types import (
cast,
BooleanType,
@@ -50,10 +52,12 @@
UserDefinedType,
VariantType,
VariantVal,
+ GeometryType,
+ Geometry,
+ GeographyType,
+ Geography,
_create_row,
)
-from pyspark.errors import PySparkTypeError, UnsupportedOperationException, PySparkValueError
-from pyspark.loose_version import LooseVersion
if TYPE_CHECKING:
import pandas as pd
@@ -202,6 +206,28 @@ def to_arrow_type(
pa.field("metadata", pa.binary(), nullable=False, metadata={b"variant": b"true"}),
]
arrow_type = pa.struct(fields)
+ elif type(dt) == GeometryType:
+ fields = [
+ pa.field("srid", pa.int32(), nullable=False),
+ pa.field(
+ "wkb",
+ pa.binary(),
+ nullable=False,
+ metadata={b"geometry": b"true", b"srid": str(dt.srid)},
+ ),
+ ]
+ arrow_type = pa.struct(fields)
+ elif type(dt) == GeographyType:
+ fields = [
+ pa.field("srid", pa.int32(), nullable=False),
+ pa.field(
+ "wkb",
+ pa.binary(),
+ nullable=False,
+ metadata={b"geography": b"true", b"srid": str(dt.srid)},
+ ),
+ ]
+ arrow_type = pa.struct(fields)
else:
raise PySparkTypeError(
errorClass="UNSUPPORTED_DATA_TYPE_FOR_ARROW_CONVERSION",
@@ -272,6 +298,38 @@ def is_variant(at: "pa.DataType") -> bool:
) and any(field.name == "value" for field in at)
+def is_geometry(at: "pa.DataType") -> bool:
+ """Check if a PyArrow struct data type represents a geometry"""
+ import pyarrow.types as types
+
+ assert types.is_struct(at)
+
+ return any(
+ (
+ field.name == "wkb"
+ and b"geometry" in field.metadata
+ and field.metadata[b"geometry"] == b"true"
+ )
+ for field in at
+ ) and any(field.name == "srid" for field in at)
+
+
+def is_geography(at: "pa.DataType") -> bool:
+ """Check if a PyArrow struct data type represents a geography"""
+ import pyarrow.types as types
+
+ assert types.is_struct(at)
+
+ return any(
+ (
+ field.name == "wkb"
+ and b"geography" in field.metadata
+ and field.metadata[b"geography"] == b"true"
+ )
+ for field in at
+ ) and any(field.name == "srid" for field in at)
+
+
def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> DataType:
"""Convert pyarrow type to Spark data type."""
import pyarrow.types as types
@@ -337,6 +395,18 @@ def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> Da
elif types.is_struct(at):
if is_variant(at):
return VariantType()
+ elif is_geometry(at):
+ srid = int(at.field("wkb").metadata.get(b"srid"))
+ if srid == GeometryType.MIXED_SRID:
+ return GeometryType("ANY")
+ else:
+ return GeometryType(srid)
+ elif is_geography(at):
+ srid = int(at.field("wkb").metadata.get(b"srid"))
+ if srid == GeographyType.MIXED_SRID:
+ return GeographyType("ANY")
+ else:
+ return GeographyType(srid)
return StructType(
[
StructField(
@@ -1087,6 +1157,8 @@ def convert_udt(value: Any) -> Any:
elif isinstance(dt, VariantType):
def convert_variant(value: Any) -> Any:
+ if isinstance(value, VariantVal):
+ return value
if (
isinstance(value, dict)
and all(key in value for key in ["value", "metadata"])
@@ -1098,6 +1170,40 @@ def convert_variant(value: Any) -> Any:
return convert_variant
+ elif isinstance(dt, GeographyType):
+
+ def convert_geography(value: Any) -> Any:
+ if value is None:
+ return None
+ elif (
+ isinstance(value, dict)
+ and all(key in value for key in ["wkb", "srid"])
+ and isinstance(value["wkb"], bytes)
+ and isinstance(value["srid"], int)
+ ):
+ return Geography.fromWKB(value["wkb"], value["srid"])
+ else:
+ raise PySparkValueError(errorClass="MALFORMED_GEOGRAPHY")
+
+ return convert_geography
+
+ elif isinstance(dt, GeometryType):
+
+ def convert_geometry(value: Any) -> Any:
+ if value is None:
+ return None
+ elif (
+ isinstance(value, dict)
+ and all(key in value for key in ["wkb", "srid"])
+ and isinstance(value["wkb"], bytes)
+ and isinstance(value["srid"], int)
+ ):
+ return Geometry.fromWKB(value["wkb"], value["srid"])
+ else:
+ raise PySparkValueError(errorClass="MALFORMED_GEOMETRY")
+
+ return convert_geometry
+
else:
return None
@@ -1360,6 +1466,22 @@ def convert_variant(variant: Any) -> Any:
return convert_variant
+ elif isinstance(dt, GeographyType):
+
+ def convert_geography(value: Any) -> Any:
+ assert isinstance(value, Geography)
+ return {"srid": value.srid, "wkb": value.wkb}
+
+ return convert_geography
+
+ elif isinstance(dt, GeometryType):
+
+ def convert_geometry(value: Any) -> Any:
+ assert isinstance(value, Geometry)
+ return {"srid": value.srid, "wkb": value.wkb}
+
+ return convert_geometry
+
return None
conv = _converter(data_type)
diff --git a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py
index b88fc2c5ca402..1305a6213c137 100644
--- a/python/pyspark/sql/streaming/proto/StateMessage_pb2.py
+++ b/python/pyspark/sql/streaming/proto/StateMessage_pb2.py
@@ -18,7 +18,7 @@
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: org/apache/spark/sql/execution/streaming/StateMessage.proto
-# Protobuf Python Version: 5.29.5
+# Protobuf Python Version: 6.33.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
@@ -28,9 +28,9 @@
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
- 5,
- 29,
- 5,
+ 6,
+ 33,
+ 0,
"",
"org/apache/spark/sql/execution/streaming/StateMessage.proto",
)
diff --git a/python/pyspark/sql/table_arg.py b/python/pyspark/sql/table_arg.py
index f96b40b2dee1f..483b26eb97abd 100644
--- a/python/pyspark/sql/table_arg.py
+++ b/python/pyspark/sql/table_arg.py
@@ -40,6 +40,59 @@ class TableArg(TableValuedFunctionArgument):
def partitionBy(self, *cols: "ColumnOrName") -> "TableArg":
"""
Partitions the data based on the specified columns.
+
+ This method partitions the table argument data by the specified columns.
+ It must be called before `orderBy()` and cannot be called after
+ `withSinglePartition()` has been called.
+
+ Parameters
+ ----------
+ cols : str, :class:`Column`, or list
+ Column names or :class:`Column` objects to partition by.
+
+ Returns
+ -------
+ :class:`TableArg`
+ A new `TableArg` instance with partitioning applied.
+
+ Examples
+ --------
+ >>> from pyspark.sql.functions import udtf
+ >>>
+ >>> @udtf(returnType="key: int, value: string")
+ ... class ProcessUDTF:
+ ... def eval(self, row):
+ ... yield row["key"], row["value"]
+ ...
+ >>> df = spark.createDataFrame(
+ ... [(1, "a"), (1, "b"), (2, "c"), (2, "d")], ["key", "value"]
+ ... )
+ >>>
+ >>> # Partition by a single column
+ >>> result = ProcessUDTF(df.asTable().partitionBy("key"))
+ >>> result.show()
+ +---+-----+
+ |key|value|
+ +---+-----+
+ | 1| a|
+ | 1| b|
+ | 2| c|
+ | 2| d|
+ +---+-----+
+ >>>
+ >>> # Partition by multiple columns
+ >>> df2 = spark.createDataFrame(
+ ... [(1, "x", 10), (1, "x", 20), (2, "y", 30)], ["key", "category", "value"]
+ ... )
+ >>> result2 = ProcessUDTF(df2.asTable().partitionBy("key", "category"))
+ >>> result2.show()
+ +---+-----+
+ |key|value|
+ +---+-----+
+ | 1| x|
+ | 1| x|
+ | 2| y|
+ +---+-----+
"""
...
@@ -47,6 +100,72 @@ def partitionBy(self, *cols: "ColumnOrName") -> "TableArg":
def orderBy(self, *cols: "ColumnOrName") -> "TableArg":
"""
Orders the data within each partition by the specified columns.
+
+ This method orders the data within partitions. It must be called after
+ `partitionBy()` or `withSinglePartition()` has been called.
+
+ Parameters
+ ----------
+ cols : str, :class:`Column`, or list
+ Column names or :class:`Column` objects to order by. Columns can be
+ ordered in ascending or descending order using :meth:`Column.asc` or
+ :meth:`Column.desc`.
+
+ Returns
+ -------
+ :class:`TableArg`
+ A new `TableArg` instance with ordering applied.
+
+ Examples
+ --------
+ >>> from pyspark.sql.functions import udtf
+ >>>
+ >>> @udtf(returnType="key: int, value: string")
+ ... class ProcessUDTF:
+ ... def eval(self, row):
+ ... yield row["key"], row["value"]
+ ...
+ >>> df = spark.createDataFrame(
+ ... [(1, "b"), (1, "a"), (2, "d"), (2, "c")], ["key", "value"]
+ ... )
+ >>>
+ >>> # Order by a single column within partitions
+ >>> result = ProcessUDTF(df.asTable().partitionBy("key").orderBy("value"))
+ >>> result.show()
+ +---+-----+
+ |key|value|
+ +---+-----+
+ | 1| a|
+ | 1| b|
+ | 2| c|
+ | 2| d|
+ +---+-----+
+ >>>
+ >>> # Order by multiple columns
+ >>> df2 = spark.createDataFrame(
+ ... [(1, "a", 2), (1, "a", 1), (1, "b", 3)], ["key", "value", "num"]
+ ... )
+ >>> result2 = ProcessUDTF(df2.asTable().partitionBy("key").orderBy("value", "num"))
+ >>> result2.show()
+ +---+-----+
+ |key|value|
+ +---+-----+
+ | 1| a|
+ | 1| a|
+ | 1| b|
+ +---+-----+
+ >>>
+ >>> # Order by descending order
+ >>> result3 = ProcessUDTF(df.asTable().partitionBy("key").orderBy(df.value.desc()))
+ >>> result3.show()
+ +---+-----+
+ |key|value|
+ +---+-----+
+ | 1| b|
+ | 1| a|
+ | 2| d|
+ | 2| c|
+ +---+-----+
"""
...
@@ -54,5 +173,52 @@ def orderBy(self, *cols: "ColumnOrName") -> "TableArg":
def withSinglePartition(self) -> "TableArg":
"""
Forces the data to be processed in a single partition.
+
+ This method indicates that all data should be treated as a single partition.
+ It cannot be called after `partitionBy()` has been called. `orderBy()` can
+ be called after this method to order the data within the single partition.
+
+ Returns
+ -------
+ :class:`TableArg`
+ A new `TableArg` instance with single partition constraint applied.
+
+ Examples
+ --------
+ >>> from pyspark.sql.functions import udtf
+ >>>
+ >>> @udtf(returnType="key: int, value: string")
+ ... class ProcessUDTF:
+ ... def eval(self, row):
+ ... yield row["key"], row["value"]
+ ...
+ >>> df = spark.createDataFrame(
+ ... [(1, "a"), (2, "b"), (3, "c")], ["key", "value"]
+ ... )
+ >>>
+ >>> # Process all data in a single partition
+ >>> result = ProcessUDTF(df.asTable().withSinglePartition())
+ >>> result.show()
+ +---+-----+
+ |key|value|
+ +---+-----+
+ | 1| a|
+ | 2| b|
+ | 3| c|
+ +---+-----+
+ >>>
+ >>> # Use withSinglePartition and orderBy together
+ >>> df2 = spark.createDataFrame(
+ ... [(3, "c"), (1, "a"), (2, "b")], ["key", "value"]
+ ... )
+ >>> result2 = ProcessUDTF(df2.asTable().withSinglePartition().orderBy("key"))
+ >>> result2.show()
+ +---+-----+
+ |key|value|
+ +---+-----+
+ | 1| a|
+ | 2| b|
+ | 3| c|
+ +---+-----+
"""
...
diff --git a/python/pyspark/sql/tests/arrow/test_arrow.py b/python/pyspark/sql/tests/arrow/test_arrow.py
index be7dd2febc94a..79d8bf77d9d5b 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow.py
@@ -438,10 +438,12 @@ def check_cached_local_relation_changing_values(self):
assert not df.filter(df["col2"].endswith(suffix)).isEmpty()
def check_large_cached_local_relation_same_values(self):
- data = [("C000000032", "R20", 0.2555)] * 500_000
+ row_count = 500_000
+ data = [("C000000032", "R20", 0.2555)] * row_count
pdf = pd.DataFrame(data=data, columns=["Contrat", "Recommandation", "Distance"])
- df = self.spark.createDataFrame(pdf)
- df.collect()
+ for _ in range(2):
+ df = self.spark.createDataFrame(pdf)
+ assert df.count() == row_count
def test_toArrow_keep_utc_timezone(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
@@ -1810,6 +1812,81 @@ def test_createDataFrame_arrow_fixed_size_list(self):
df = self.spark.createDataFrame(t)
self.assertIsInstance(df.schema["fsl"].dataType, ArrayType)
+ def test_toPandas_with_compression_codec(self):
+ # Test toPandas() with different compression codec settings
+ df = self.spark.createDataFrame(self.data, schema=self.schema)
+ expected = self.create_pandas_data_frame()
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with self.sql_conf({"spark.sql.execution.arrow.compression.codec": codec}):
+ pdf = df.toPandas()
+ assert_frame_equal(expected, pdf)
+
+ def test_toArrow_with_compression_codec(self):
+ # Test toArrow() with different compression codec settings
+ import pyarrow.compute as pc
+
+ t_in = self.create_arrow_table()
+
+ # Convert timezone-naive local timestamp column in input table to UTC
+ # to enable comparison to UTC timestamp column in output table
+ timezone = self.spark.conf.get("spark.sql.session.timeZone")
+ t_in = t_in.set_column(
+ t_in.schema.get_field_index("8_timestamp_t"),
+ "8_timestamp_t",
+ pc.assume_timezone(t_in["8_timestamp_t"], timezone),
+ )
+ t_in = t_in.cast(
+ t_in.schema.set(
+ t_in.schema.get_field_index("8_timestamp_t"),
+ pa.field("8_timestamp_t", pa.timestamp("us", tz="UTC")),
+ )
+ )
+
+ df = self.spark.createDataFrame(self.data, schema=self.schema)
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with self.sql_conf({"spark.sql.execution.arrow.compression.codec": codec}):
+ t_out = df.toArrow()
+ self.assertTrue(t_out.equals(t_in))
+
+ def test_toPandas_with_compression_codec_large_dataset(self):
+ # Test compression with a larger dataset to verify memory savings
+ # Create a dataset with repetitive data that compresses well
+ from pyspark.sql.functions import lit, col
+
+ df = self.spark.range(10000).select(
+ col("id"),
+ lit("test_string_value_" * 10).alias("str_col"),
+ (col("id") % 100).alias("mod_col"),
+ )
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with self.sql_conf({"spark.sql.execution.arrow.compression.codec": codec}):
+ pdf = df.toPandas()
+ self.assertEqual(len(pdf), 10000)
+ self.assertEqual(pdf.columns.tolist(), ["id", "str_col", "mod_col"])
+
+ def test_toArrow_with_compression_codec_large_dataset(self):
+ # Test compression with a larger dataset for toArrow
+ from pyspark.sql.functions import lit, col
+
+ df = self.spark.range(10000).select(
+ col("id"),
+ lit("test_string_value_" * 10).alias("str_col"),
+ (col("id") % 100).alias("mod_col"),
+ )
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with self.sql_conf({"spark.sql.execution.arrow.compression.codec": codec}):
+ t = df.toArrow()
+ self.assertEqual(t.num_rows, 10000)
+ self.assertEqual(t.column_names, ["id", "str_col", "mod_col"])
+
@unittest.skipIf(
not have_pandas or not have_pyarrow,
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py b/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
index 2bdd7bda3bc21..abfd1af7a741e 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
@@ -313,21 +313,19 @@ def arrow_func(key, left, right):
self.assertEqual(df2.join(df2).count(), 1)
def test_arrow_batch_slicing(self):
- df1 = self.spark.range(10000000).select(
- (sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
- )
+ m, n = 100000, 10000
+
+ df1 = self.spark.range(m).select((sf.col("id") % 2).alias("key"), sf.col("id").alias("v"))
cols = {f"col_{i}": sf.col("v") + i for i in range(10)}
df1 = df1.withColumns(cols)
- df2 = self.spark.range(100000).select(
- (sf.col("id") % 4).alias("key"), sf.col("id").alias("v")
- )
+ df2 = self.spark.range(n).select((sf.col("id") % 4).alias("key"), sf.col("id").alias("v"))
cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
df2 = df2.withColumns(cols)
def summarize(key, left, right):
- assert len(left) == 10000000 / 2 or len(left) == 0, len(left)
- assert len(right) == 100000 / 4, len(right)
+ assert len(left) == m / 2 or len(left) == 0, len(left)
+ assert len(right) == n / 4, len(right)
return pa.Table.from_pydict(
{
"key": [key[0].as_py()],
@@ -341,13 +339,13 @@ def summarize(key, left, right):
schema = "key long, left_rows long, left_columns long, right_rows long, right_columns long"
expected = [
- Row(key=0, left_rows=5000000, left_columns=12, right_rows=25000, right_columns=22),
- Row(key=1, left_rows=5000000, left_columns=12, right_rows=25000, right_columns=22),
- Row(key=2, left_rows=0, left_columns=12, right_rows=25000, right_columns=22),
- Row(key=3, left_rows=0, left_columns=12, right_rows=25000, right_columns=22),
+ Row(key=0, left_rows=m / 2, left_columns=12, right_rows=n / 4, right_columns=22),
+ Row(key=1, left_rows=m / 2, left_columns=12, right_rows=n / 4, right_columns=22),
+ Row(key=2, left_rows=0, left_columns=12, right_rows=n / 4, right_columns=22),
+ Row(key=3, left_rows=0, left_columns=12, right_rows=n / 4, right_columns=22),
]
- for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]:
+ for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 4096), (1000, 4096)]:
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
with self.sql_conf(
{
@@ -398,20 +396,20 @@ def func_with_logging(left, right):
+ [Row(id=2, v1=20, v2=200)],
)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg=f"arrow cogrouped map: {dict(v1=v1, v2=v2)}",
- context={"func_name": func_with_logging.__name__},
- logger="test_arrow_cogrouped_map",
- )
- for v1, v2 in [([10, 30], [100, 300]), ([20], [200])]
- ],
- )
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"arrow cogrouped map: {dict(v1=v1, v2=v2)}",
+ context={"func_name": func_with_logging.__name__},
+ logger="test_arrow_cogrouped_map",
+ )
+ for v1, v2 in [([10, 30], [100, 300]), ([20], [200])]
+ ],
+ )
class CogroupedMapInArrowTests(CogroupedMapInArrowTestsMixin, ReusedSQLTestCase):
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
index 829c38385bd0e..cb8d74f724270 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
@@ -356,14 +356,14 @@ def arrow_func(key, table):
self.assertEqual(df2.join(df2).count(), 1)
def test_arrow_batch_slicing(self):
- df = self.spark.range(10000000).select(
- (sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
- )
+ n = 100000
+
+ df = self.spark.range(n).select((sf.col("id") % 2).alias("key"), sf.col("id").alias("v"))
cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
df = df.withColumns(cols)
def min_max_v(table):
- assert len(table) == 10000000 / 2, len(table)
+ assert len(table) == n / 2, len(table)
return pa.Table.from_pydict(
{
"key": [table.column("key")[0].as_py()],
@@ -376,7 +376,7 @@ def min_max_v(table):
df.groupby("key").agg(sf.min("v").alias("min"), sf.max("v").alias("max")).sort("key")
).collect()
- for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]:
+ for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 4096), (1000, 4096)]:
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
with self.sql_conf(
{
@@ -416,20 +416,20 @@ def func_with_logging(group):
df,
)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg=f"arrow grouped map: {dict(id=lst, value=[v*10 for v in lst])}",
- context={"func_name": func_with_logging.__name__},
- logger="test_arrow_grouped_map",
- )
- for lst in [[0, 2, 4, 6, 8], [1, 3, 5, 7]]
- ],
- )
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"arrow grouped map: {dict(id=lst, value=[v*10 for v in lst])}",
+ context={"func_name": func_with_logging.__name__},
+ logger="test_arrow_grouped_map",
+ )
+ for lst in [[0, 2, 4, 6, 8], [1, 3, 5, 7]]
+ ],
+ )
@unittest.skipIf(is_remote_only(), "Requires JVM access")
def test_apply_in_arrow_iter_with_logging(self):
@@ -456,20 +456,20 @@ def func_with_logging(group: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatc
df,
)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg=f"arrow grouped map: {dict(id=lst, value=[v*10 for v in lst])}",
- context={"func_name": func_with_logging.__name__},
- logger="test_arrow_grouped_map",
- )
- for lst in [[0, 2, 4], [6, 8], [1, 3, 5], [7]]
- ],
- )
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"arrow grouped map: {dict(id=lst, value=[v*10 for v in lst])}",
+ context={"func_name": func_with_logging.__name__},
+ logger="test_arrow_grouped_map",
+ )
+ for lst in [[0, 2, 4], [6, 8], [1, 3, 5], [7]]
+ ],
+ )
class ApplyInArrowTests(ApplyInArrowTestsMixin, ReusedSQLTestCase):
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_map.py b/python/pyspark/sql/tests/arrow/test_arrow_map.py
index 4a56a32fbcddb..2921023db4bef 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_map.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_map.py
@@ -55,6 +55,16 @@ def func(iterator):
expected = df.collect()
self.assertEqual(actual, expected)
+ def test_map_in_arrow_with_limit(self):
+ def get_size(iterator):
+ for batch in iterator:
+ assert isinstance(batch, pa.RecordBatch)
+ if batch.num_rows > 0:
+ yield pa.RecordBatch.from_arrays([pa.array([batch.num_rows])], names=["size"])
+
+ df = self.spark.range(100)
+ df.mapInArrow(get_size, "size long").limit(1).collect()
+
def test_multiple_columns(self):
data = [(1, "foo"), (2, None), (3, "bar"), (4, "bar")]
df = self.spark.createDataFrame(data, "a int, b string")
@@ -247,12 +257,12 @@ def func_with_logging(iterator):
[Row(id=i) for i in range(9)],
)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- self._expected_logs_for_test_map_in_arrow_with_logging(func_with_logging.__name__),
- )
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ self._expected_logs_for_test_map_in_arrow_with_logging(func_with_logging.__name__),
+ )
def _expected_logs_for_test_map_in_arrow_with_logging(self, func_name):
return [
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
index 90e05caf21800..c315151d4d759 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_python_udf.py
@@ -508,11 +508,13 @@ class ArrowPythonUDFLegacyTests(ArrowPythonUDFLegacyTestsMixin, ReusedSQLTestCas
@classmethod
def setUpClass(cls):
super(ArrowPythonUDFLegacyTests, cls).setUpClass()
+ cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.concurrency.level", "4")
cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true")
@classmethod
def tearDownClass(cls):
try:
+ cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.concurrency.level")
cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled")
finally:
super(ArrowPythonUDFLegacyTests, cls).tearDownClass()
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
index f719b4fb16bd2..3d8588ffb7af3 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_grouped_agg.py
@@ -1044,20 +1044,20 @@ def my_grouped_agg_arrow_udf(x):
[Row(id=1, result=3.0), Row(id=2, result=18.0)],
)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg=f"grouped agg arrow udf: {n}",
- context={"func_name": my_grouped_agg_arrow_udf.__name__},
- logger="test_grouped_agg_arrow",
- )
- for n in [2, 3]
- ],
- )
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"grouped agg arrow udf: {n}",
+ context={"func_name": my_grouped_agg_arrow_udf.__name__},
+ logger="test_grouped_agg_arrow",
+ )
+ for n in [2, 3]
+ ],
+ )
class GroupedAggArrowUDFTests(GroupedAggArrowUDFTestsMixin, ReusedSQLTestCase):
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
index 05f33a4ae42f7..3bd00d2cc921a 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_scalar.py
@@ -1201,20 +1201,20 @@ def my_scalar_arrow_udf(x):
[Row(result=f"scalar_arrow_{i}") for i in range(3)],
)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg=f"scalar arrow udf: {lst}",
- context={"func_name": my_scalar_arrow_udf.__name__},
- logger="test_scalar_arrow",
- )
- for lst in [[0], [1, 2]]
- ],
- )
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"scalar arrow udf: {lst}",
+ context={"func_name": my_scalar_arrow_udf.__name__},
+ logger="test_scalar_arrow",
+ )
+ for lst in [[0], [1, 2]]
+ ],
+ )
@unittest.skipIf(is_remote_only(), "Requires JVM access")
def test_scalar_iter_arrow_udf_with_logging(self):
@@ -1241,20 +1241,20 @@ def my_scalar_iter_arrow_udf(it):
[Row(result=f"scalar_iter_arrow_{i}") for i in range(9)],
)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg=f"scalar iter arrow udf: {lst}",
- context={"func_name": my_scalar_iter_arrow_udf.__name__},
- logger="test_scalar_iter_arrow",
- )
- for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]]
- ],
- )
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"scalar iter arrow udf: {lst}",
+ context={"func_name": my_scalar_iter_arrow_udf.__name__},
+ logger="test_scalar_iter_arrow",
+ )
+ for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]]
+ ],
+ )
class ScalarArrowUDFTests(ScalarArrowUDFTestsMixin, ReusedSQLTestCase):
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
index 240e34487b006..b0adfbe131864 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udf_window.py
@@ -834,20 +834,20 @@ def my_window_arrow_udf(x):
],
)
- logs = self.spark.table("system.session.python_worker_logs")
-
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg=f"window arrow udf: {lst}",
- context={"func_name": my_window_arrow_udf.__name__},
- logger="test_window_arrow",
- )
- for lst in [[1.0], [1.0, 2.0], [3.0], [3.0, 5.0], [3.0, 5.0, 10.0]]
- ],
- )
+ logs = self.spark.tvf.python_worker_logs()
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"window arrow udf: {lst}",
+ context={"func_name": my_window_arrow_udf.__name__},
+ logger="test_window_arrow",
+ )
+ for lst in [[1.0], [1.0, 2.0], [3.0], [3.0, 5.0], [3.0, 5.0, 10.0]]
+ ],
+ )
class WindowArrowUDFTests(WindowArrowUDFTestsMixin, ReusedSQLTestCase):
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
index 24c5ec0acf91e..cc0edda378abf 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_udtf.py
@@ -1721,20 +1721,20 @@ def eval(self, table_data: "pa.RecordBatch") -> Iterator["pa.Table"]:
[Row(id=i, doubled=i * 2) for i in range(9)],
)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg=f"arrow udtf: {dict(id=lst)}",
- context={"class_name": "TestArrowUDTFWithLogging", "func_name": "eval"},
- logger="test_arrow_udtf",
- )
- for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]]
- ],
- )
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"arrow udtf: {dict(id=lst)}",
+ context={"class_name": "TestArrowUDTFWithLogging", "func_name": "eval"},
+ logger="test_arrow_udtf",
+ )
+ for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]]
+ ],
+ )
class ArrowUDTFTests(ArrowUDTFTestsMixin, ReusedSQLTestCase):
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py
index c189f996cbe43..c876fca241692 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -26,6 +26,7 @@
if should_test_connect:
import grpc
import google.protobuf.any_pb2 as any_pb2
+ import google.protobuf.wrappers_pb2 as wrappers_pb2
from google.rpc import status_pb2
from google.rpc.error_details_pb2 import ErrorInfo
import pandas as pd
@@ -136,9 +137,11 @@ class MockService:
def __init__(self, session_id: str):
self._session_id = session_id
self.req = None
+ self.client_user_context_extensions = []
def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata):
self.req = req
+ self.client_user_context_extensions = list(req.user_context.extensions)
resp = proto.ExecutePlanResponse()
resp.session_id = self._session_id
resp.operation_id = req.operation_id
@@ -159,12 +162,14 @@ def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata):
def Interrupt(self, req: proto.InterruptRequest, metadata):
self.req = req
+ self.client_user_context_extensions = list(req.user_context.extensions)
resp = proto.InterruptResponse()
resp.session_id = self._session_id
return resp
def Config(self, req: proto.ConfigRequest, metadata):
self.req = req
+ self.client_user_context_extensions = list(req.user_context.extensions)
resp = proto.ConfigResponse()
resp.session_id = self._session_id
if req.operation.HasField("get"):
@@ -177,6 +182,15 @@ def Config(self, req: proto.ConfigRequest, metadata):
pair.value = req.operation.get_with_default.pairs[0].value or "true"
return resp
+ def AnalyzePlan(self, req: proto.AnalyzePlanRequest, metadata):
+ self.req = req
+ self.client_user_context_extensions = list(req.user_context.extensions)
+ resp = proto.AnalyzePlanResponse()
+ resp.session_id = self._session_id
+ # Return a minimal response with a semantic hash
+ resp.semantic_hash.result = 12345
+ return resp
+
# The _cleanup_ml_cache invocation will hang in this test (no valid spark cluster)
# and it blocks the test process exiting because it is registered as the atexit handler
# in `SparkConnectClient` constructor. To bypass the issue, patch the method in the test.
@@ -229,6 +243,96 @@ def userId(self) -> Optional[str]:
self.assertEqual(client._user_id, "abc")
+ def test_user_context_extension(self):
+ client = SparkConnectClient("sc://foo/", use_reattachable_execute=False)
+ mock = MockService(client._session_id)
+ client._stub = mock
+
+ try:
+ exlocal = any_pb2.Any()
+ exlocal.Pack(wrappers_pb2.StringValue(value="abc"))
+ exlocal2 = any_pb2.Any()
+ exlocal2.Pack(wrappers_pb2.StringValue(value="def"))
+ exglobal = any_pb2.Any()
+ exglobal.Pack(wrappers_pb2.StringValue(value="ghi"))
+ exglobal2 = any_pb2.Any()
+ exglobal2.Pack(wrappers_pb2.StringValue(value="jkl"))
+
+ exlocal_id = client.add_threadlocal_user_context_extension(exlocal)
+ exglobal_id = client.add_global_user_context_extension(exglobal)
+
+ mock.client_user_context_extensions = []
+ command = proto.Command()
+ client.execute_command(command)
+ self.assertTrue(exlocal in mock.client_user_context_extensions)
+ self.assertTrue(exglobal in mock.client_user_context_extensions)
+ self.assertFalse(exlocal2 in mock.client_user_context_extensions)
+ self.assertFalse(exglobal2 in mock.client_user_context_extensions)
+
+ client.add_threadlocal_user_context_extension(exlocal2)
+
+ mock.client_user_context_extensions = []
+ plan = proto.Plan()
+ client.semantic_hash(plan) # use semantic_hash to test analyze
+ self.assertTrue(exlocal in mock.client_user_context_extensions)
+ self.assertTrue(exglobal in mock.client_user_context_extensions)
+ self.assertTrue(exlocal2 in mock.client_user_context_extensions)
+ self.assertFalse(exglobal2 in mock.client_user_context_extensions)
+
+ client.add_global_user_context_extension(exglobal2)
+
+ mock.client_user_context_extensions = []
+ client.interrupt_all()
+ self.assertTrue(exlocal in mock.client_user_context_extensions)
+ self.assertTrue(exglobal in mock.client_user_context_extensions)
+ self.assertTrue(exlocal2 in mock.client_user_context_extensions)
+ self.assertTrue(exglobal2 in mock.client_user_context_extensions)
+
+ client.remove_user_context_extension(exlocal_id)
+
+ mock.client_user_context_extensions = []
+ client.get_configs("foo", "bar")
+ self.assertFalse(exlocal in mock.client_user_context_extensions)
+ self.assertTrue(exglobal in mock.client_user_context_extensions)
+ self.assertTrue(exlocal2 in mock.client_user_context_extensions)
+ self.assertTrue(exglobal2 in mock.client_user_context_extensions)
+
+ client.remove_user_context_extension(exglobal_id)
+
+ mock.client_user_context_extensions = []
+ command = proto.Command()
+ client.execute_command(command)
+ self.assertFalse(exlocal in mock.client_user_context_extensions)
+ self.assertFalse(exglobal in mock.client_user_context_extensions)
+ self.assertTrue(exlocal2 in mock.client_user_context_extensions)
+ self.assertTrue(exglobal2 in mock.client_user_context_extensions)
+
+ client.clear_user_context_extensions()
+
+ mock.client_user_context_extensions = []
+ plan = proto.Plan()
+ client.semantic_hash(plan) # use semantic_hash to test analyze
+ self.assertFalse(exlocal in mock.client_user_context_extensions)
+ self.assertFalse(exglobal in mock.client_user_context_extensions)
+ self.assertFalse(exlocal2 in mock.client_user_context_extensions)
+ self.assertFalse(exglobal2 in mock.client_user_context_extensions)
+
+ mock.client_user_context_extensions = []
+ client.interrupt_all()
+ self.assertFalse(exlocal in mock.client_user_context_extensions)
+ self.assertFalse(exglobal in mock.client_user_context_extensions)
+ self.assertFalse(exlocal2 in mock.client_user_context_extensions)
+ self.assertFalse(exglobal2 in mock.client_user_context_extensions)
+
+ mock.client_user_context_extensions = []
+ client.get_configs("foo", "bar")
+ self.assertFalse(exlocal in mock.client_user_context_extensions)
+ self.assertFalse(exglobal in mock.client_user_context_extensions)
+ self.assertFalse(exlocal2 in mock.client_user_context_extensions)
+ self.assertFalse(exglobal2 in mock.client_user_context_extensions)
+ finally:
+ client.close()
+
def test_interrupt_all(self):
client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False)
mock = MockService(client._session_id)
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index c1ba9a6fc2d4a..b789d7919c94e 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1447,6 +1447,33 @@ def test_truncate_message(self):
proto_string_truncated_3 = self.connect._client._proto_to_string(plan3, True)
self.assertTrue(len(proto_string_truncated_3) < 64000, len(proto_string_truncated_3))
+ def test_plan_compression(self):
+ self.assertTrue(self.connect._client._zstd_module is not None)
+ self.connect.range(1).count()
+ default_plan_compression_threshold = self.connect._client._plan_compression_threshold
+ self.assertTrue(default_plan_compression_threshold > 0)
+ self.assertTrue(self.connect._client._plan_compression_algorithm == "ZSTD")
+ try:
+ self.connect._client._plan_compression_threshold = 1000
+
+ # Small plan should not be compressed
+ cdf1 = self.connect.range(1).select(CF.lit("Apache Spark"))
+ plan1 = cdf1._plan.to_proto(self.connect._client)
+ self.assertTrue(plan1.root is not None)
+ self.assertTrue(cdf1.count() == 1)
+
+ # Large plan should be compressed
+ cdf2 = self.connect.range(1).select(CF.lit("Apache Spark" * 1000))
+ plan2 = cdf2._plan.to_proto(self.connect._client)
+ self.assertTrue(plan2.compressed_operation is not None)
+ # Test compressed relation
+ self.assertTrue(cdf2.count() == 1)
+ # Test compressed command
+ cdf2.createOrReplaceTempView("temp_view_cdf2")
+ self.assertTrue(self.connect.sql("SELECT * FROM temp_view_cdf2").count() == 1)
+ finally:
+ self.connect._client._plan_compression_threshold = default_plan_compression_threshold
+
class SparkConnectGCTests(SparkConnectSQLTestCase):
@classmethod
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py b/python/pyspark/sql/tests/connect/test_connect_plan.py
index d25799f0c9f26..3f785d4ee7130 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -753,7 +753,7 @@ def test_column_regexp(self):
self.assertIsInstance(col, Column)
self.assertEqual("Column<'UnresolvedRegex(col_name)'>", str(col))
- col_plan = col.to_plan(self.session.client)
+ col_plan = col.to_plan(None)
self.assertIsNotNone(col_plan)
self.assertEqual(col_plan.unresolved_regex.col_name, "col_name")
@@ -864,7 +864,7 @@ def test_uuid_literal(self):
def test_column_literals(self):
df = self.connect.with_plan(Read("table"))
lit_df = df.select(lit(10))
- self.assertIsNotNone(lit_df._plan.to_proto(None))
+ self.assertIsNotNone(lit_df._plan.to_proto(self.connect))
self.assertIsNotNone(lit(10).to_plan(None))
plan = lit(10).to_plan(None)
@@ -937,7 +937,7 @@ def test_column_alias(self) -> None:
self.assertEqual("Column<'a AS martin'>", str(col0))
col0 = col("a").alias("martin", metadata={"pii": True})
- plan = col0.to_plan(self.session.client)
+ plan = col0.to_plan(None)
self.assertIsNotNone(plan)
self.assertEqual(plan.alias.metadata, '{"pii": true}')
diff --git a/python/pyspark/sql/tests/connect/test_parity_geographytype.py b/python/pyspark/sql/tests/connect/test_parity_geographytype.py
new file mode 100644
index 0000000000000..501bbed20ff19
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_geographytype.py
@@ -0,0 +1,38 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.sql.tests.test_geographytype import GeographyTypeTestMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class GeographyTypeParityTest(GeographyTypeTestMixin, ReusedConnectTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.connect.test_parity_geographytype import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_parity_geometrytype.py b/python/pyspark/sql/tests/connect/test_parity_geometrytype.py
new file mode 100644
index 0000000000000..b95321b3c61be
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_geometrytype.py
@@ -0,0 +1,38 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from pyspark.sql.tests.test_geometrytype import GeometryTypeTestMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class GeometryTypeParityTest(GeometryTypeTestMixin, ReusedConnectTestCase):
+ pass
+
+
+if __name__ == "__main__":
+ import unittest
+ from pyspark.sql.tests.connect.test_parity_geometrytype import * # noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py b/python/pyspark/sql/tests/connect/test_parity_types.py
index 6d06611def6af..a39e92233bc0e 100644
--- a/python/pyspark/sql/tests/connect/test_parity_types.py
+++ b/python/pyspark/sql/tests/connect/test_parity_types.py
@@ -34,6 +34,10 @@ def test_apply_schema_to_dict_and_rows(self):
def test_apply_schema_to_row(self):
super().test_apply_schema_to_row()
+ @unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
+ def test_geospatial_create_dataframe_rdd(self):
+ super().test_geospatial_create_dataframe_rdd()
+
@unittest.skip("Spark Connect does not support RDD but the tests depend on them.")
def test_create_dataframe_schema_mismatch(self):
super().test_create_dataframe_schema_mismatch()
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
index ab954dd133f3c..0d91da0354979 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
@@ -660,21 +660,19 @@ def __test_merge_error(
self.__test_merge(left, right, by, fn, output_schema)
def test_arrow_batch_slicing(self):
- df1 = self.spark.range(10000000).select(
- (sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
- )
+ m, n = 100000, 10000
+
+ df1 = self.spark.range(m).select((sf.col("id") % 2).alias("key"), sf.col("id").alias("v"))
cols = {f"col_{i}": sf.col("v") + i for i in range(10)}
df1 = df1.withColumns(cols)
- df2 = self.spark.range(100000).select(
- (sf.col("id") % 4).alias("key"), sf.col("id").alias("v")
- )
+ df2 = self.spark.range(n).select((sf.col("id") % 4).alias("key"), sf.col("id").alias("v"))
cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
df2 = df2.withColumns(cols)
def summarize(key, left, right):
- assert len(left) == 10000000 / 2 or len(left) == 0, len(left)
- assert len(right) == 100000 / 4, len(right)
+ assert len(left) == m / 2 or len(left) == 0, len(left)
+ assert len(right) == n / 4, len(right)
return pd.DataFrame(
{
"key": [key[0]],
@@ -688,13 +686,13 @@ def summarize(key, left, right):
schema = "key long, left_rows long, left_columns long, right_rows long, right_columns long"
expected = [
- Row(key=0, left_rows=5000000, left_columns=12, right_rows=25000, right_columns=22),
- Row(key=1, left_rows=5000000, left_columns=12, right_rows=25000, right_columns=22),
- Row(key=2, left_rows=0, left_columns=12, right_rows=25000, right_columns=22),
- Row(key=3, left_rows=0, left_columns=12, right_rows=25000, right_columns=22),
+ Row(key=0, left_rows=m / 2, left_columns=12, right_rows=n / 4, right_columns=22),
+ Row(key=1, left_rows=m / 2, left_columns=12, right_rows=n / 4, right_columns=22),
+ Row(key=2, left_rows=0, left_columns=12, right_rows=n / 4, right_columns=22),
+ Row(key=3, left_rows=0, left_columns=12, right_rows=n / 4, right_columns=22),
]
- for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]:
+ for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 4096), (1000, 4096)]:
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
with self.sql_conf(
{
@@ -744,20 +742,20 @@ def func_with_logging(left_pdf, right_pdf):
+ [Row(id=2, v1=20, v2=200)],
)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg=f"pandas cogrouped map: {dict(v1=v1, v2=v2)}",
- context={"func_name": func_with_logging.__name__},
- logger="test_pandas_cogrouped_map",
- )
- for v1, v2 in [([10, 30], [100, 300]), ([20], [200])]
- ],
- )
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"pandas cogrouped map: {dict(v1=v1, v2=v2)}",
+ context={"func_name": func_with_logging.__name__},
+ logger="test_pandas_cogrouped_map",
+ )
+ for v1, v2 in [([10, 30], [100, 300]), ([20], [200])]
+ ],
+ )
class CogroupedApplyInPandasTests(CogroupedApplyInPandasTestsMixin, ReusedSQLTestCase):
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index ef84673179dcc..fb18c5f062b80 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -18,6 +18,7 @@
import datetime
import unittest
import logging
+import os
from collections import OrderedDict
from decimal import Decimal
@@ -277,28 +278,20 @@ def check_apply_in_pandas_not_returning_pandas_dataframe(self):
):
self._test_apply_in_pandas(lambda key, pdf: key)
- @staticmethod
- def stats_with_column_names(key, pdf):
- # order of column can be different to applyInPandas schema when column names are given
- return pd.DataFrame([(pdf.v.mean(),) + key], columns=["mean", "id"])
-
- @staticmethod
- def stats_with_no_column_names(key, pdf):
- # columns must be in order of applyInPandas schema when no columns given
- return pd.DataFrame([key + (pdf.v.mean(),)])
-
def test_apply_in_pandas_returning_column_names(self):
- self._test_apply_in_pandas(ApplyInPandasTestsMixin.stats_with_column_names)
+ self._test_apply_in_pandas(
+ lambda key, pdf: pd.DataFrame([(pdf.v.mean(),) + key], columns=["mean", "id"])
+ )
def test_apply_in_pandas_returning_no_column_names(self):
- self._test_apply_in_pandas(ApplyInPandasTestsMixin.stats_with_no_column_names)
+ self._test_apply_in_pandas(lambda key, pdf: pd.DataFrame([key + (pdf.v.mean(),)]))
def test_apply_in_pandas_returning_column_names_sometimes(self):
def stats(key, pdf):
if key[0] % 2:
- return ApplyInPandasTestsMixin.stats_with_column_names(key, pdf)
+ return pd.DataFrame([(pdf.v.mean(),) + key], columns=["mean", "id"])
else:
- return ApplyInPandasTestsMixin.stats_with_no_column_names(key, pdf)
+ return pd.DataFrame([key + (pdf.v.mean(),)])
self._test_apply_in_pandas(stats)
@@ -332,9 +325,15 @@ def check_apply_in_pandas_returning_no_column_names_and_wrong_amount(self):
lambda key, pdf: pd.DataFrame([key + (pdf.v.mean(), pdf.v.std())])
)
+ @unittest.skipIf(
+ os.environ.get("SPARK_SKIP_CONNECT_COMPAT_TESTS") == "1", "SPARK-54482: To be reenabled"
+ )
def test_apply_in_pandas_returning_empty_dataframe(self):
self._test_apply_in_pandas_returning_empty_dataframe(pd.DataFrame())
+ @unittest.skipIf(
+ os.environ.get("SPARK_SKIP_CONNECT_COMPAT_TESTS") == "1", "SPARK-54482: To be reenabled"
+ )
def test_apply_in_pandas_returning_incompatible_type(self):
with self.quiet():
self.check_apply_in_pandas_returning_incompatible_type()
@@ -870,7 +869,7 @@ def _test_apply_in_pandas_returning_empty_dataframe(self, empty_df):
def stats(key, pdf):
if key[0] % 2 == 0:
- return ApplyInPandasTestsMixin.stats_with_no_column_names(key, pdf)
+ return pd.DataFrame([key + (pdf.v.mean(),)])
return empty_df
result = (
@@ -938,14 +937,14 @@ def test(pdf):
self.assertEqual(row[1], 123)
def test_arrow_batch_slicing(self):
- df = self.spark.range(10000000).select(
- (sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
- )
+ n = 100000
+
+ df = self.spark.range(n).select((sf.col("id") % 2).alias("key"), sf.col("id").alias("v"))
cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
df = df.withColumns(cols)
def min_max_v(pdf):
- assert len(pdf) == 10000000 / 2, len(pdf)
+ assert len(pdf) == n / 2, len(pdf)
return pd.DataFrame(
{
"key": [pdf.key.iloc[0]],
@@ -958,7 +957,7 @@ def min_max_v(pdf):
df.groupby("key").agg(sf.min("v").alias("min"), sf.max("v").alias("max")).sort("key")
).collect()
- for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]:
+ for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 4096), (1000, 4096)]:
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
with self.sql_conf(
{
@@ -1000,20 +999,20 @@ def func_with_logging(pdf):
df,
)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg=f"pandas grouped map: {dict(id=lst, value=[v*10 for v in lst])}",
- context={"func_name": func_with_logging.__name__},
- logger="test_pandas_grouped_map",
- )
- for lst in [[0, 2, 4, 6, 8], [1, 3, 5, 7]]
- ],
- )
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"pandas grouped map: {dict(id=lst, value=[v*10 for v in lst])}",
+ context={"func_name": func_with_logging.__name__},
+ logger="test_pandas_grouped_map",
+ )
+ for lst in [[0, 2, 4, 6, 8], [1, 3, 5, 7]]
+ ],
+ )
def test_apply_in_pandas_iterator_basic(self):
df = self.spark.createDataFrame(
@@ -1057,7 +1056,7 @@ def sum_func(
self.assertEqual(result[1][1], 18.0)
def test_apply_in_pandas_iterator_batch_slicing(self):
- df = self.spark.range(10000000).select(
+ df = self.spark.range(100000).select(
(sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
)
cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
@@ -1073,7 +1072,7 @@ def min_max_v(batches: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
key_val = batch.key.iloc[0]
combined = pd.concat(all_data, ignore_index=True)
- assert len(combined) == 10000000 / 2, len(combined)
+ assert len(combined) == 100000 / 2, len(combined)
yield pd.DataFrame(
{
@@ -1092,7 +1091,7 @@ def min_max_v(batches: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
.sort("key")
).collect()
- for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]:
+ for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 4096), (1000, 4096)]:
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
with self.sql_conf(
{
@@ -1109,7 +1108,7 @@ def min_max_v(batches: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
self.assertEqual(expected, result)
def test_apply_in_pandas_iterator_with_keys_batch_slicing(self):
- df = self.spark.range(10000000).select(
+ df = self.spark.range(100000).select(
(sf.col("id") % 2).alias("key"), sf.col("id").alias("v")
)
cols = {f"col_{i}": sf.col("v") + i for i in range(20)}
@@ -1124,7 +1123,7 @@ def min_max_v(
all_data.append(batch)
combined = pd.concat(all_data, ignore_index=True)
- assert len(combined) == 10000000 / 2, len(combined)
+ assert len(combined) == 100000 / 2, len(combined)
yield pd.DataFrame(
{
@@ -1138,7 +1137,7 @@ def min_max_v(
df.groupby("key").agg(sf.min("v").alias("min"), sf.max("v").alias("max")).sort("key")
).collect()
- for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 1048576), (1000, 1048576)]:
+ for maxRecords, maxBytes in [(1000, 2**31 - 1), (0, 4096), (1000, 4096)]:
with self.subTest(maxRecords=maxRecords, maxBytes=maxBytes):
with self.sql_conf(
{
@@ -1406,6 +1405,66 @@ def func(batches: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
actual = grouped_df.applyInPandas(func, "value long").collect()
self.assertEqual(actual, expected)
+ def test_grouped_map_pandas_udf_with_compression_codec(self):
+ # Test grouped map Pandas UDF with different compression codec settings
+ @pandas_udf("id long, v int, v1 double", PandasUDFType.GROUPED_MAP)
+ def foo(pdf):
+ return pdf.assign(v1=pdf.v * pdf.id * 1.0)
+
+ df = self.data
+ pdf = df.toPandas()
+ expected = pdf.groupby("id", as_index=False).apply(foo.func).reset_index(drop=True)
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with self.sql_conf({"spark.sql.execution.arrow.compression.codec": codec}):
+ result = df.groupby("id").apply(foo).sort("id").toPandas()
+ assert_frame_equal(expected, result)
+
+ def test_apply_in_pandas_with_compression_codec(self):
+ # Test applyInPandas with different compression codec settings
+ def stats(key, pdf):
+ return pd.DataFrame([(key[0], pdf.v.mean())], columns=["id", "mean"])
+
+ df = self.data
+ expected = df.select("id").distinct().withColumn("mean", sf.lit(24.5)).toPandas()
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with self.sql_conf({"spark.sql.execution.arrow.compression.codec": codec}):
+ result = (
+ df.groupby("id")
+ .applyInPandas(stats, schema="id long, mean double")
+ .sort("id")
+ .toPandas()
+ )
+ assert_frame_equal(expected, result)
+
+ def test_apply_in_pandas_iterator_with_compression_codec(self):
+ # Test applyInPandas with iterator and compression
+ df = self.spark.createDataFrame(
+ [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
+ )
+
+ def sum_func(batches: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
+ total = 0
+ for batch in batches:
+ total += batch["v"].sum()
+ yield pd.DataFrame({"v": [total]})
+
+ expected = [Row(v=3.0), Row(v=18.0)]
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with self.sql_conf({"spark.sql.execution.arrow.compression.codec": codec}):
+ result = (
+ df.groupby("id")
+ .applyInPandas(sum_func, schema="v double")
+ .orderBy("v")
+ .collect()
+ )
+ self.assertEqual(result, expected)
+
class ApplyInPandasTests(ApplyInPandasTestsMixin, ReusedSQLTestCase):
pass
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py b/python/pyspark/sql/tests/pandas/test_pandas_map.py
index 5e0e33a05b22b..946d56f2fe637 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py
@@ -510,20 +510,20 @@ def func_with_logging(iterator):
[Row(id=i) for i in range(9)],
)
- logs = self.spark.table("system.session.python_worker_logs")
-
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg=f"pandas map: {lst}",
- context={"func_name": func_with_logging.__name__},
- logger="test_pandas_map",
- )
- for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]]
- ],
- )
+ logs = self.spark.tvf.python_worker_logs()
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"pandas map: {lst}",
+ context={"func_name": func_with_logging.__name__},
+ logger="test_pandas_map",
+ )
+ for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]]
+ ],
+ )
class MapInPandasTests(ReusedSQLTestCase, MapInPandasTestsMixin):
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
index 2b3e42312df99..4b66dee5b7af5 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_grouped_agg.py
@@ -846,20 +846,56 @@ def my_grouped_agg_pandas_udf(x):
[Row(id=1, result=3.0), Row(id=2, result=18.0)],
)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg=f"grouped agg pandas udf: {n}",
- context={"func_name": my_grouped_agg_pandas_udf.__name__},
- logger="test_grouped_agg_pandas",
- )
- for n in [2, 3]
- ],
- )
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"grouped agg pandas udf: {n}",
+ context={"func_name": my_grouped_agg_pandas_udf.__name__},
+ logger="test_grouped_agg_pandas",
+ )
+ for n in [2, 3]
+ ],
+ )
+
+ def test_grouped_agg_pandas_udf_with_compression_codec(self):
+ # Test grouped agg Pandas UDF with different compression codec settings
+ @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+ def sum_udf(v):
+ return v.sum()
+
+ df = self.data
+ expected = df.groupby("id").agg(sum_udf(df.v)).sort("id").toPandas()
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with self.sql_conf({"spark.sql.execution.arrow.compression.codec": codec}):
+ result = df.groupby("id").agg(sum_udf(df.v)).sort("id").toPandas()
+ assert_frame_equal(expected, result)
+
+ def test_grouped_agg_pandas_udf_with_compression_codec_complex(self):
+ # Test grouped agg with multiple UDFs and compression
+ @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+ def mean_udf(v):
+ return v.mean()
+
+ @pandas_udf("double", PandasUDFType.GROUPED_AGG)
+ def sum_udf(v):
+ return v.sum()
+
+ df = self.data
+ expected = df.groupby("id").agg(mean_udf(df.v), sum_udf(df.v)).sort("id").toPandas()
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with self.sql_conf({"spark.sql.execution.arrow.compression.codec": codec}):
+ result = (
+ df.groupby("id").agg(mean_udf(df.v), sum_udf(df.v)).sort("id").toPandas()
+ )
+ assert_frame_equal(expected, result)
class GroupedAggPandasUDFTests(GroupedAggPandasUDFTestsMixin, ReusedSQLTestCase):
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
index fbfe1a226b5e2..b936a9240e529 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_scalar.py
@@ -1935,20 +1935,20 @@ def my_scalar_pandas_udf(x):
[Row(result=f"scalar_pandas_{i}") for i in range(3)],
)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg=f"scalar pandas udf: {lst}",
- context={"func_name": my_scalar_pandas_udf.__name__},
- logger="test_scalar_pandas",
- )
- for lst in [[0], [1, 2]]
- ],
- )
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"scalar pandas udf: {lst}",
+ context={"func_name": my_scalar_pandas_udf.__name__},
+ logger="test_scalar_pandas",
+ )
+ for lst in [[0], [1, 2]]
+ ],
+ )
@unittest.skipIf(is_remote_only(), "Requires JVM access")
def test_scalar_iter_pandas_udf_with_logging(self):
@@ -1973,20 +1973,76 @@ def my_scalar_iter_pandas_udf(it):
[Row(result=f"scalar_iter_pandas_{i}") for i in range(9)],
)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg=f"scalar iter pandas udf: {lst}",
- context={"func_name": my_scalar_iter_pandas_udf.__name__},
- logger="test_scalar_iter_pandas",
- )
- for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]]
- ],
- )
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"scalar iter pandas udf: {lst}",
+ context={"func_name": my_scalar_iter_pandas_udf.__name__},
+ logger="test_scalar_iter_pandas",
+ )
+ for lst in [[0, 1, 2], [3], [4, 5, 6], [7, 8]]
+ ],
+ )
+
+ def test_scalar_pandas_udf_with_compression_codec(self):
+ # Test scalar Pandas UDF with different compression codec settings
+ @pandas_udf("long")
+ def plus_one(v):
+ return v + 1
+
+ df = self.spark.range(100)
+ expected = [Row(result=i + 1) for i in range(100)]
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with self.sql_conf({"spark.sql.execution.arrow.compression.codec": codec}):
+ result = df.select(plus_one("id").alias("result")).collect()
+ self.assertEqual(expected, result)
+
+ def test_scalar_pandas_udf_with_compression_codec_complex_types(self):
+ # Test scalar Pandas UDF with compression for complex types (strings, arrays)
+ @pandas_udf("string")
+ def concat_string(v):
+ return v.apply(lambda x: "value_" + str(x))
+
+ @pandas_udf(ArrayType(IntegerType()))
+ def create_array(v):
+ return v.apply(lambda x: [x, x * 2, x * 3])
+
+ df = self.spark.range(50)
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with self.sql_conf({"spark.sql.execution.arrow.compression.codec": codec}):
+ # Test string UDF
+ result = df.select(concat_string("id").alias("result")).collect()
+ expected = [Row(result=f"value_{i}") for i in range(50)]
+ self.assertEqual(expected, result)
+
+ # Test array UDF
+ result = df.select(create_array("id").alias("result")).collect()
+ expected = [Row(result=[i, i * 2, i * 3]) for i in range(50)]
+ self.assertEqual(expected, result)
+
+ def test_scalar_iter_pandas_udf_with_compression_codec(self):
+ # Test scalar iterator Pandas UDF with compression
+ @pandas_udf("long", PandasUDFType.SCALAR_ITER)
+ def plus_two(iterator):
+ for s in iterator:
+ yield s + 2
+
+ df = self.spark.range(100)
+ expected = [Row(result=i + 2) for i in range(100)]
+
+ for codec in ["none", "zstd", "lz4"]:
+ with self.subTest(compressionCodec=codec):
+ with self.sql_conf({"spark.sql.execution.arrow.compression.codec": codec}):
+ result = df.select(plus_two("id").alias("result")).collect()
+ self.assertEqual(expected, result)
class ScalarPandasUDFTests(ScalarPandasUDFTestsMixin, ReusedSQLTestCase):
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
index 6fa7e9063836b..6e1cbdaf73cff 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_udf_window.py
@@ -664,20 +664,20 @@ def my_window_pandas_udf(x):
],
)
- logs = self.spark.table("system.session.python_worker_logs")
-
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg=f"window pandas udf: {lst}",
- context={"func_name": my_window_pandas_udf.__name__},
- logger="test_window_pandas",
- )
- for lst in [[1.0], [1.0, 2.0], [3.0], [3.0, 5.0], [3.0, 5.0, 10.0]]
- ],
- )
+ logs = self.spark.tvf.python_worker_logs()
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"window pandas udf: {lst}",
+ context={"func_name": my_window_pandas_udf.__name__},
+ logger="test_window_pandas",
+ )
+ for lst in [[1.0], [1.0, 2.0], [3.0], [3.0, 5.0], [3.0, 5.0, 10.0]]
+ ],
+ )
class WindowPandasUDFTests(WindowPandasUDFTestsMixin, ReusedSQLTestCase):
diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py
index 75a553b62838e..a726fc85d90a5 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -159,6 +159,13 @@ def test_self_join_IV(self):
self.assertTrue(df3.columns, ["id", "value", "id", "value"])
self.assertTrue(df3.count() == 20)
+ def test_select_join_keys(self):
+ df1 = self.spark.range(10).withColumn("v1", lit(1))
+ df2 = self.spark.range(10).withColumn("v2", lit(2))
+ for how in ["inner", "left", "right", "full", "cross"]:
+ self.assertTrue(df1.join(df2, "id", how).select(df1["id"]).count() >= 0, how)
+ self.assertTrue(df1.join(df2, "id", how).select(df2["id"]).count() >= 0, how)
+
def test_lateral_column_alias(self):
df1 = self.spark.range(10).select(
(col("id") + lit(1)).alias("x"), (col("x") + lit(1)).alias("y")
diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py
index 18f824c463c93..a776fa6e80b74 100644
--- a/python/pyspark/sql/tests/test_functions.py
+++ b/python/pyspark/sql/tests/test_functions.py
@@ -80,10 +80,6 @@ def test_function_parity(self):
missing_in_py = jvm_fn_set.difference(py_fn_set)
- # Temporarily disable Scala/Python parity check for ST geospatial functions, while the
- # feature is under development. Once the geospatial module is stable, remove this.
- missing_in_py = {fn for fn in missing_in_py if not fn.startswith("st_")}
-
# Functions that we expect to be missing in python until they are added to pyspark
expected_missing_in_py = set()
@@ -2118,6 +2114,253 @@ def test_listagg_distinct_functions(self):
None,
)
+ def test_kll_sketch_agg_bigint(self):
+ """Test kll_sketch_agg_bigint function"""
+ df = self.spark.createDataFrame([1, 2, 3, 4, 5], "INT")
+
+ # Test with default k
+ sketch = df.agg(F.kll_sketch_agg_bigint("value")).first()[0]
+ self.assertIsNotNone(sketch)
+ self.assertIsInstance(sketch, (bytes, bytearray))
+
+ # Test with explicit k
+ sketch_k = df.agg(F.kll_sketch_agg_bigint("value", 400)).first()[0]
+ self.assertIsNotNone(sketch_k)
+
+ def test_kll_sketch_agg_float(self):
+ """Test kll_sketch_agg_float function"""
+ df = self.spark.createDataFrame([1.0, 2.0, 3.0, 4.0, 5.0], "FLOAT")
+
+ sketch = df.agg(F.kll_sketch_agg_float("value")).first()[0]
+ self.assertIsNotNone(sketch)
+ self.assertIsInstance(sketch, (bytes, bytearray))
+
+ def test_kll_sketch_agg_double(self):
+ """Test kll_sketch_agg_double function"""
+ df = self.spark.createDataFrame([1.0, 2.0, 3.0, 4.0, 5.0], "DOUBLE")
+
+ sketch = df.agg(F.kll_sketch_agg_double("value")).first()[0]
+ self.assertIsNotNone(sketch)
+ self.assertIsInstance(sketch, (bytes, bytearray))
+
+ def test_kll_sketch_to_string_bigint(self):
+ """Test kll_sketch_to_string_bigint function"""
+ df = self.spark.createDataFrame([1, 2, 3, 4, 5], "INT")
+ sketch_df = df.agg(F.kll_sketch_agg_bigint("value").alias("sketch"))
+
+ result = sketch_df.select(F.kll_sketch_to_string_bigint("sketch")).first()[0]
+ self.assertIsNotNone(result)
+ self.assertIsInstance(result, str)
+ self.assertIn("kll", result.lower())
+
+ def test_kll_sketch_get_n_bigint(self):
+ """Test kll_sketch_get_n_bigint function"""
+ df = self.spark.createDataFrame([1, 2, 3, 4, 5], "INT")
+ sketch_df = df.agg(F.kll_sketch_agg_bigint("value").alias("sketch"))
+
+ n = sketch_df.select(F.kll_sketch_get_n_bigint("sketch")).first()[0]
+ self.assertEqual(n, 5)
+
+ def test_kll_sketch_merge_bigint(self):
+ """Test kll_sketch_merge_bigint function"""
+ df = self.spark.createDataFrame([1, 2, 3], "INT")
+ sketch_df = df.agg(F.kll_sketch_agg_bigint("value").alias("sketch"))
+
+ merged = sketch_df.select(F.kll_sketch_merge_bigint("sketch", "sketch")).first()[0]
+ self.assertIsNotNone(merged)
+ self.assertIsInstance(merged, (bytes, bytearray))
+
+ def test_kll_sketch_get_quantile_bigint(self):
+ """Test kll_sketch_get_quantile_bigint function"""
+ df = self.spark.createDataFrame([1, 2, 3, 4, 5], "INT")
+ sketch_df = df.agg(F.kll_sketch_agg_bigint("value").alias("sketch"))
+
+ quantile = sketch_df.select(F.kll_sketch_get_quantile_bigint("sketch", F.lit(0.5))).first()[
+ 0
+ ]
+ self.assertIsNotNone(quantile)
+ self.assertGreaterEqual(quantile, 1)
+ self.assertLessEqual(quantile, 5)
+
+ def test_kll_sketch_get_quantile_bigint_array(self):
+ """Test kll_sketch_get_quantile_bigint with array of ranks"""
+ df = self.spark.createDataFrame([1, 2, 3, 4, 5], "INT")
+ sketch_df = df.agg(F.kll_sketch_agg_bigint("value").alias("sketch"))
+
+ quantiles = sketch_df.select(
+ F.kll_sketch_get_quantile_bigint(
+ "sketch", F.array(F.lit(0.25), F.lit(0.5), F.lit(0.75))
+ )
+ ).first()[0]
+ self.assertIsNotNone(quantiles)
+ self.assertEqual(len(quantiles), 3)
+
+ def test_kll_sketch_get_rank_bigint(self):
+ """Test kll_sketch_get_rank_bigint function"""
+ df = self.spark.createDataFrame([1, 2, 3, 4, 5], "INT")
+ sketch_df = df.agg(F.kll_sketch_agg_bigint("value").alias("sketch"))
+
+ rank = sketch_df.select(F.kll_sketch_get_rank_bigint("sketch", F.lit(3))).first()[0]
+ self.assertIsNotNone(rank)
+ self.assertGreaterEqual(rank, 0.0)
+ self.assertLessEqual(rank, 1.0)
+
+ def test_kll_sketch_float_variants(self):
+ """Test all float variant functions"""
+ df = self.spark.createDataFrame([1.0, 2.0, 3.0, 4.0, 5.0], "FLOAT")
+ sketch_df = df.agg(F.kll_sketch_agg_float("value").alias("sketch"))
+
+ # Test to_string
+ string_result = sketch_df.select(F.kll_sketch_to_string_float("sketch")).first()[0]
+ self.assertIn("kll", string_result.lower())
+
+ # Test get_n
+ n = sketch_df.select(F.kll_sketch_get_n_float("sketch")).first()[0]
+ self.assertEqual(n, 5)
+
+ # Test merge
+ merged = sketch_df.select(F.kll_sketch_merge_float("sketch", "sketch")).first()[0]
+ self.assertIsNotNone(merged)
+
+ # Test get_quantile
+ quantile = sketch_df.select(F.kll_sketch_get_quantile_float("sketch", F.lit(0.5))).first()[
+ 0
+ ]
+ self.assertIsNotNone(quantile)
+
+ # Test get_rank
+ rank = sketch_df.select(F.kll_sketch_get_rank_float("sketch", F.lit(3.0))).first()[0]
+ self.assertGreaterEqual(rank, 0.0)
+ self.assertLessEqual(rank, 1.0)
+
+ def test_kll_sketch_double_variants(self):
+ """Test all double variant functions"""
+ df = self.spark.createDataFrame([1.0, 2.0, 3.0, 4.0, 5.0], "DOUBLE")
+ sketch_df = df.agg(F.kll_sketch_agg_double("value").alias("sketch"))
+
+ # Test to_string
+ string_result = sketch_df.select(F.kll_sketch_to_string_double("sketch")).first()[0]
+ self.assertIn("kll", string_result.lower())
+
+ # Test get_n
+ n = sketch_df.select(F.kll_sketch_get_n_double("sketch")).first()[0]
+ self.assertEqual(n, 5)
+
+ # Test merge
+ merged = sketch_df.select(F.kll_sketch_merge_double("sketch", "sketch")).first()[0]
+ self.assertIsNotNone(merged)
+
+ # Test get_quantile
+ quantile = sketch_df.select(F.kll_sketch_get_quantile_double("sketch", F.lit(0.5))).first()[
+ 0
+ ]
+ self.assertIsNotNone(quantile)
+
+ # Test get_rank
+ rank = sketch_df.select(F.kll_sketch_get_rank_double("sketch", F.lit(3.0))).first()[0]
+ self.assertGreaterEqual(rank, 0.0)
+ self.assertLessEqual(rank, 1.0)
+
+ def test_kll_sketch_with_nulls(self):
+ """Test KLL sketch with null values"""
+ df = self.spark.createDataFrame([(1,), (None,), (3,), (4,), (None,)], ["value"])
+ sketch_df = df.agg(F.kll_sketch_agg_bigint("value").alias("sketch"))
+
+ n = sketch_df.select(F.kll_sketch_get_n_bigint("sketch")).first()[0]
+ # Should only count non-null values
+ self.assertEqual(n, 3)
+
+ def test_kll_merge_agg_bigint(self):
+ """Test kll_merge_agg_bigint function"""
+ df1 = self.spark.createDataFrame([1, 2, 3], "INT")
+ df2 = self.spark.createDataFrame([4, 5, 6], "INT")
+
+ sketch1 = df1.agg(F.kll_sketch_agg_bigint("value").alias("sketch"))
+ sketch2 = df2.agg(F.kll_sketch_agg_bigint("value").alias("sketch"))
+
+ # Union and merge sketches
+ merged = sketch1.union(sketch2).agg(F.kll_merge_agg_bigint("sketch").alias("merged"))
+
+ # Verify the merged sketch contains all values
+ n = merged.select(F.kll_sketch_get_n_bigint("merged")).first()[0]
+ self.assertEqual(n, 6)
+
+ # Test with explicit k parameter
+ merged_with_k = sketch1.union(sketch2).agg(
+ F.kll_merge_agg_bigint("sketch", 400).alias("merged")
+ )
+ self.assertIsNotNone(merged_with_k.first()[0])
+
+ def test_kll_merge_agg_float(self):
+ """Test kll_merge_agg_float function"""
+ df1 = self.spark.createDataFrame([1.0, 2.0, 3.0], "FLOAT")
+ df2 = self.spark.createDataFrame([4.0, 5.0, 6.0], "FLOAT")
+
+ sketch1 = df1.agg(F.kll_sketch_agg_float("value").alias("sketch"))
+ sketch2 = df2.agg(F.kll_sketch_agg_float("value").alias("sketch"))
+
+ # Union and merge sketches
+ merged = sketch1.union(sketch2).agg(F.kll_merge_agg_float("sketch").alias("merged"))
+
+ # Verify the merged sketch contains all values
+ n = merged.select(F.kll_sketch_get_n_float("merged")).first()[0]
+ self.assertEqual(n, 6)
+
+ # Test with explicit k parameter
+ merged_with_k = sketch1.union(sketch2).agg(
+ F.kll_merge_agg_float("sketch", 300).alias("merged")
+ )
+ self.assertIsNotNone(merged_with_k.first()[0])
+
+ def test_kll_merge_agg_double(self):
+ """Test kll_merge_agg_double function"""
+ df1 = self.spark.createDataFrame([1.0, 2.0, 3.0], "DOUBLE")
+ df2 = self.spark.createDataFrame([4.0, 5.0, 6.0], "DOUBLE")
+
+ sketch1 = df1.agg(F.kll_sketch_agg_double("value").alias("sketch"))
+ sketch2 = df2.agg(F.kll_sketch_agg_double("value").alias("sketch"))
+
+ # Union and merge sketches
+ merged = sketch1.union(sketch2).agg(F.kll_merge_agg_double("sketch").alias("merged"))
+
+ # Verify the merged sketch contains all values
+ n = merged.select(F.kll_sketch_get_n_double("merged")).first()[0]
+ self.assertEqual(n, 6)
+
+ # Test quantile on merged sketch
+ quantile = merged.select(F.kll_sketch_get_quantile_double("merged", F.lit(0.5))).first()[0]
+ self.assertIsNotNone(quantile)
+
+ def test_kll_merge_agg_with_different_k(self):
+ """Test kll_merge_agg with different k values"""
+ df1 = self.spark.createDataFrame([1, 2, 3], "INT")
+ df2 = self.spark.createDataFrame([4, 5, 6], "INT")
+
+ # Create sketches with different k values
+ sketch1 = df1.agg(F.kll_sketch_agg_bigint("value", 200).alias("sketch"))
+ sketch2 = df2.agg(F.kll_sketch_agg_bigint("value", 400).alias("sketch"))
+
+ # Merge sketches with different k values (should adopt from first sketch)
+ merged = sketch1.union(sketch2).agg(F.kll_merge_agg_bigint("sketch").alias("merged"))
+
+ n = merged.select(F.kll_sketch_get_n_bigint("merged")).first()[0]
+ self.assertEqual(n, 6)
+
+ def test_kll_merge_agg_with_nulls(self):
+ """Test kll_merge_agg with null values"""
+ df1 = self.spark.createDataFrame([1, 2, 3], "INT")
+ df2 = self.spark.createDataFrame([4, None, 6], "INT")
+
+ sketch1 = df1.agg(F.kll_sketch_agg_bigint("value").alias("sketch"))
+ sketch2 = df2.agg(F.kll_sketch_agg_bigint("value").alias("sketch"))
+
+ # Merge sketches - null values should be ignored
+ merged = sketch1.union(sketch2).agg(F.kll_merge_agg_bigint("sketch").alias("merged"))
+
+ n = merged.select(F.kll_sketch_get_n_bigint("merged")).first()[0]
+ # Should have 5 values (1,2,3,4,6 - null is ignored)
+ self.assertEqual(n, 5)
+
def test_datetime_functions(self):
df = self.spark.range(1).selectExpr("'2017-01-22' as dateCol")
parse_result = df.select(F.to_date(F.col("dateCol"))).first()
@@ -2805,6 +3048,57 @@ def test_string_validation(self):
result_try_validate_utf8 = df.select(F.try_validate_utf8(df.a).alias("r"))
assertDataFrameEqual([Row(r="abc")], result_try_validate_utf8)
+ # Geospatial ST Functions
+
+ def test_st_asbinary(self):
+ df = self.spark.createDataFrame(
+ [(bytes.fromhex("0101000000000000000000F03F0000000000000040"),)],
+ ["wkb"],
+ )
+ results = df.select(
+ F.hex(F.st_asbinary(F.st_geogfromwkb("wkb"))),
+ F.hex(F.st_asbinary(F.st_geomfromwkb("wkb"))),
+ ).collect()
+ expected = Row(
+ "0101000000000000000000F03F0000000000000040",
+ "0101000000000000000000F03F0000000000000040",
+ )
+ self.assertEqual(results, [expected])
+
+ def test_st_setsrid(self):
+ df = self.spark.createDataFrame(
+ [(bytes.fromhex("0101000000000000000000F03F0000000000000040"), 4326)],
+ ["wkb", "srid"],
+ )
+ results = df.select(
+ F.st_srid(F.st_setsrid(F.st_geogfromwkb("wkb"), "srid")),
+ F.st_srid(F.st_setsrid(F.st_geomfromwkb("wkb"), "srid")),
+ F.st_srid(F.st_setsrid(F.st_geogfromwkb("wkb"), 4326)),
+ F.st_srid(F.st_setsrid(F.st_geomfromwkb("wkb"), 4326)),
+ ).collect()
+ expected = Row(
+ 4326,
+ 4326,
+ 4326,
+ 4326,
+ )
+ self.assertEqual(results, [expected])
+
+ def test_st_srid(self):
+ df = self.spark.createDataFrame(
+ [(bytes.fromhex("0101000000000000000000F03F0000000000000040"),)],
+ ["wkb"],
+ )
+ results = df.select(
+ F.st_srid(F.st_geogfromwkb("wkb")),
+ F.st_srid(F.st_geomfromwkb("wkb")),
+ ).collect()
+ expected = Row(
+ 4326,
+ 0,
+ )
+ self.assertEqual(results, [expected])
+
class FunctionsTests(ReusedSQLTestCase, FunctionsTestsMixin):
pass
diff --git a/python/pyspark/sql/tests/test_group.py b/python/pyspark/sql/tests/test_group.py
index bbc089b00c133..ac868c34a9133 100644
--- a/python/pyspark/sql/tests/test_group.py
+++ b/python/pyspark/sql/tests/test_group.py
@@ -126,6 +126,22 @@ def test_group_by_ordinal(self):
with self.assertRaises(IndexError):
df.groupBy(10).agg(sf.sum("b"))
+ def test_numeric_agg_with_nest_type(self):
+ df = self.spark.createDataFrame(
+ [
+ Row(a="a", b=Row(c=1)),
+ Row(a="a", b=Row(c=2)),
+ Row(a="a", b=Row(c=3)),
+ Row(a="b", b=Row(c=4)),
+ Row(a="b", b=Row(c=5)),
+ ]
+ )
+
+ res = df.groupBy("a").max("b.c").sort("a").collect()
+ # [Row(a='a', max(b.c AS c)=3), Row(a='b', max(b.c AS c)=5)]
+
+ self.assertEqual([["a", 3], ["b", 5]], [list(r) for r in res])
+
@unittest.skipIf(not have_pandas, pandas_requirement_message) # type: ignore
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) # type: ignore
def test_order_by_ordinal(self):
diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py
index cfedf1cf075ba..e5171876656c7 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -16,7 +16,6 @@
#
import os
import platform
-import sys
import tempfile
import unittest
import logging
@@ -258,7 +257,6 @@ def test_read_with_invalid_return_row_type(self):
with self.assertRaisesRegex(PythonException, "DATA_SOURCE_INVALID_RETURN_TYPE"):
df.collect()
- @unittest.skipIf(sys.version_info > (3, 13), "SPARK-54065")
def test_in_memory_data_source(self):
class InMemDataSourceReader(DataSourceReader):
DEFAULT_NUM_PARTITIONS: int = 3
@@ -967,49 +965,49 @@ def reader(self, schema) -> "DataSourceReader":
],
)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg=msg,
- context=context,
- logger="test_data_source_reader",
- )
- for msg, context in [
- (
- "TestJsonDataSource.__init__: ['path']",
- {"class_name": "TestJsonDataSource", "func_name": "__init__"},
- ),
- (
- "TestJsonDataSource.name",
- {"class_name": "TestJsonDataSource", "func_name": "name"},
- ),
- (
- "TestJsonDataSource.schema",
- {"class_name": "TestJsonDataSource", "func_name": "schema"},
- ),
- (
- "TestJsonDataSource.reader: ['name', 'age']",
- {"class_name": "TestJsonDataSource", "func_name": "reader"},
- ),
- (
- "TestJsonReader.__init__: ['path']",
- {"class_name": "TestJsonDataSource", "func_name": "reader"},
- ),
- (
- "TestJsonReader.partitions",
- {"class_name": "TestJsonReader", "func_name": "partitions"},
- ),
- (
- "TestJsonReader.read: None",
- {"class_name": "TestJsonReader", "func_name": "read"},
- ),
- ]
- ],
- )
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=msg,
+ context=context,
+ logger="test_data_source_reader",
+ )
+ for msg, context in [
+ (
+ "TestJsonDataSource.__init__: ['path']",
+ {"class_name": "TestJsonDataSource", "func_name": "__init__"},
+ ),
+ (
+ "TestJsonDataSource.name",
+ {"class_name": "TestJsonDataSource", "func_name": "name"},
+ ),
+ (
+ "TestJsonDataSource.schema",
+ {"class_name": "TestJsonDataSource", "func_name": "schema"},
+ ),
+ (
+ "TestJsonDataSource.reader: ['name', 'age']",
+ {"class_name": "TestJsonDataSource", "func_name": "reader"},
+ ),
+ (
+ "TestJsonReader.__init__: ['path']",
+ {"class_name": "TestJsonDataSource", "func_name": "reader"},
+ ),
+ (
+ "TestJsonReader.partitions",
+ {"class_name": "TestJsonReader", "func_name": "partitions"},
+ ),
+ (
+ "TestJsonReader.read: None",
+ {"class_name": "TestJsonReader", "func_name": "read"},
+ ),
+ ]
+ ],
+ )
@unittest.skipIf(is_remote_only(), "Requires JVM access")
def test_data_source_reader_pushdown_with_logging(self):
@@ -1074,53 +1072,53 @@ def reader(self, schema) -> "DataSourceReader":
],
)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg=msg,
- context=context,
- logger="test_data_source_reader_pushdown",
- )
- for msg, context in [
- (
- "TestJsonDataSource.__init__: ['path']",
- {"class_name": "TestJsonDataSource", "func_name": "__init__"},
- ),
- (
- "TestJsonDataSource.name",
- {"class_name": "TestJsonDataSource", "func_name": "name"},
- ),
- (
- "TestJsonDataSource.schema",
- {"class_name": "TestJsonDataSource", "func_name": "schema"},
- ),
- (
- "TestJsonDataSource.reader: ['name', 'age']",
- {"class_name": "TestJsonDataSource", "func_name": "reader"},
- ),
- (
- "TestJsonReader.pushFilters: [IsNotNull(attribute=('age',))]",
- {"class_name": "TestJsonReader", "func_name": "pushFilters"},
- ),
- (
- "TestJsonReader.__init__: ['path']",
- {"class_name": "TestJsonDataSource", "func_name": "reader"},
- ),
- (
- "TestJsonReader.partitions",
- {"class_name": "TestJsonReader", "func_name": "partitions"},
- ),
- (
- "TestJsonReader.read: None",
- {"class_name": "TestJsonReader", "func_name": "read"},
- ),
- ]
- ],
- )
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=msg,
+ context=context,
+ logger="test_data_source_reader_pushdown",
+ )
+ for msg, context in [
+ (
+ "TestJsonDataSource.__init__: ['path']",
+ {"class_name": "TestJsonDataSource", "func_name": "__init__"},
+ ),
+ (
+ "TestJsonDataSource.name",
+ {"class_name": "TestJsonDataSource", "func_name": "name"},
+ ),
+ (
+ "TestJsonDataSource.schema",
+ {"class_name": "TestJsonDataSource", "func_name": "schema"},
+ ),
+ (
+ "TestJsonDataSource.reader: ['name', 'age']",
+ {"class_name": "TestJsonDataSource", "func_name": "reader"},
+ ),
+ (
+ "TestJsonReader.pushFilters: [IsNotNull(attribute=('age',))]",
+ {"class_name": "TestJsonReader", "func_name": "pushFilters"},
+ ),
+ (
+ "TestJsonReader.__init__: ['path']",
+ {"class_name": "TestJsonDataSource", "func_name": "reader"},
+ ),
+ (
+ "TestJsonReader.partitions",
+ {"class_name": "TestJsonReader", "func_name": "partitions"},
+ ),
+ (
+ "TestJsonReader.read: None",
+ {"class_name": "TestJsonReader", "func_name": "read"},
+ ),
+ ]
+ ],
+ )
@unittest.skipIf(is_remote_only(), "Requires JVM access")
def test_data_source_writer_with_logging(self):
@@ -1199,69 +1197,69 @@ def writer(self, schema, overwrite):
with self.assertRaises(Exception, msg="abort test"):
df.write.format("my-json").mode("append").option("abort", "true").save(d)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg=msg,
- context=context,
- logger="test_datasource_writer",
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=msg,
+ context=context,
+ logger="test_datasource_writer",
+ )
+ for msg, context in [
+ (
+ "TestJsonDataSource.name",
+ {"class_name": "TestJsonDataSource", "func_name": "name"},
+ ),
+ (
+ "TestJsonDataSource.writer: (['name', 'age'], {True})",
+ {"class_name": "TestJsonDataSource", "func_name": "writer"},
+ ),
+ (
+ "TestJsonWriter.__init__: ['path']",
+ {"class_name": "TestJsonDataSource", "func_name": "writer"},
+ ),
+ (
+ "TestJsonWriter.write: 1, [{'name': 'Diana', 'age': 28}]",
+ {"class_name": "TestJsonWriter", "func_name": "write"},
+ ),
+ (
+ "TestJsonWriter.write: 1, [{'name': 'Charlie', 'age': 35}]",
+ {"class_name": "TestJsonWriter", "func_name": "write"},
+ ),
+ (
+ "TestJsonWriter.commit: 2",
+ {"class_name": "TestJsonWriter", "func_name": "commit"},
+ ),
+ (
+ "TestJsonDataSource.name",
+ {"class_name": "TestJsonDataSource", "func_name": "name"},
+ ),
+ (
+ "TestJsonDataSource.writer: (['name', 'age'], {False})",
+ {"class_name": "TestJsonDataSource", "func_name": "writer"},
+ ),
+ (
+ "TestJsonWriter.__init__: ['abort', 'path']",
+ {"class_name": "TestJsonDataSource", "func_name": "writer"},
+ ),
+ (
+ "TestJsonWriter.write: abort test",
+ {"class_name": "TestJsonWriter", "func_name": "write"},
+ ),
+ (
+ "TestJsonWriter.write: abort test",
+ {"class_name": "TestJsonWriter", "func_name": "write"},
+ ),
+ (
+ "TestJsonWriter.abort",
+ {"class_name": "TestJsonWriter", "func_name": "abort"},
+ ),
+ ]
+ ],
)
- for msg, context in [
- (
- "TestJsonDataSource.name",
- {"class_name": "TestJsonDataSource", "func_name": "name"},
- ),
- (
- "TestJsonDataSource.writer: (['name', 'age'], {True})",
- {"class_name": "TestJsonDataSource", "func_name": "writer"},
- ),
- (
- "TestJsonWriter.__init__: ['path']",
- {"class_name": "TestJsonDataSource", "func_name": "writer"},
- ),
- (
- "TestJsonWriter.write: 1, [{'name': 'Diana', 'age': 28}]",
- {"class_name": "TestJsonWriter", "func_name": "write"},
- ),
- (
- "TestJsonWriter.write: 1, [{'name': 'Charlie', 'age': 35}]",
- {"class_name": "TestJsonWriter", "func_name": "write"},
- ),
- (
- "TestJsonWriter.commit: 2",
- {"class_name": "TestJsonWriter", "func_name": "commit"},
- ),
- (
- "TestJsonDataSource.name",
- {"class_name": "TestJsonDataSource", "func_name": "name"},
- ),
- (
- "TestJsonDataSource.writer: (['name', 'age'], {False})",
- {"class_name": "TestJsonDataSource", "func_name": "writer"},
- ),
- (
- "TestJsonWriter.__init__: ['abort', 'path']",
- {"class_name": "TestJsonDataSource", "func_name": "writer"},
- ),
- (
- "TestJsonWriter.write: abort test",
- {"class_name": "TestJsonWriter", "func_name": "write"},
- ),
- (
- "TestJsonWriter.write: abort test",
- {"class_name": "TestJsonWriter", "func_name": "write"},
- ),
- (
- "TestJsonWriter.abort",
- {"class_name": "TestJsonWriter", "func_name": "abort"},
- ),
- ]
- ],
- )
class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
diff --git a/python/pyspark/sql/tests/test_python_streaming_datasource.py b/python/pyspark/sql/tests/test_python_streaming_datasource.py
index 9879231540f1d..ecf28677689b2 100644
--- a/python/pyspark/sql/tests/test_python_streaming_datasource.py
+++ b/python/pyspark/sql/tests/test_python_streaming_datasource.py
@@ -34,6 +34,7 @@
have_pyarrow,
pyarrow_requirement_message,
)
+from pyspark.errors import PySparkException
from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -251,6 +252,61 @@ def check_batch(df, batch_id):
q.awaitTermination()
self.assertIsNone(q.exception(), "No exception has to be propagated.")
+ def test_simple_stream_reader_offset_did_not_advance_raises(self):
+ """Returning end == start with non-empty data raises
+ SIMPLE_STREAM_READER_OFFSET_DID_NOT_ADVANCE."""
+ from pyspark.sql.datasource_internal import _SimpleStreamReaderWrapper
+
+ class BuggySimpleStreamReader(SimpleDataSourceStreamReader):
+ def initialOffset(self):
+ return {"offset": 0}
+
+ def read(self, start: dict):
+ # Bug: return same offset as end despite returning data
+ start_idx = start["offset"]
+ it = iter([(i,) for i in range(start_idx, start_idx + 3)])
+ return (it, start)
+
+ def readBetweenOffsets(self, start: dict, end: dict):
+ return iter([])
+
+ def commit(self, end: dict):
+ pass
+
+ reader = BuggySimpleStreamReader()
+ wrapper = _SimpleStreamReaderWrapper(reader)
+ with self.assertRaises(PySparkException) as cm:
+ wrapper.latestOffset()
+ self.assertEqual(
+ cm.exception.getCondition(),
+ "SIMPLE_STREAM_READER_OFFSET_DID_NOT_ADVANCE",
+ )
+
+ def test_simple_stream_reader_empty_iterator_start_equals_end_allowed(self):
+ """read() with end == start and empty iterator: no exception, no cache entry."""
+ from pyspark.sql.datasource_internal import _SimpleStreamReaderWrapper
+
+ class EmptyBatchReader(SimpleDataSourceStreamReader):
+ def initialOffset(self):
+ return {"offset": 0}
+
+ def read(self, start: dict):
+ # Valid: same offset as end but empty iterator (no data)
+ return (iter([]), start)
+
+ def readBetweenOffsets(self, start: dict, end: dict):
+ return iter([])
+
+ def commit(self, end: dict):
+ pass
+
+ reader = EmptyBatchReader()
+ wrapper = _SimpleStreamReaderWrapper(reader)
+ end = wrapper.latestOffset()
+ start = {"offset": 0}
+ self.assertEqual(end, start)
+ self.assertEqual(len(wrapper.cache), 0)
+
def test_stream_writer(self):
input_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_input")
output_dir = tempfile.TemporaryDirectory(prefix="test_data_stream_write_output")
diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py
index 6979095acca88..0a5219202a3a8 100644
--- a/python/pyspark/sql/tests/test_types.py
+++ b/python/pyspark/sql/tests/test_types.py
@@ -29,6 +29,8 @@
from pyspark.sql import functions as F
from pyspark.errors import (
AnalysisException,
+ IllegalArgumentException,
+ SparkRuntimeException,
ParseException,
PySparkTypeError,
PySparkValueError,
@@ -51,6 +53,8 @@
MapType,
StringType,
CharType,
+ Geography,
+ Geometry,
VarcharType,
StructType,
StructField,
@@ -1365,6 +1369,7 @@ def test_parse_datatype_json_string(self):
NullType(),
GeographyType(4326),
GeographyType("ANY"),
+ GeometryType(0),
GeometryType(4326),
GeometryType("ANY"),
VariantType(),
@@ -2447,6 +2452,340 @@ def test_variant_type(self):
with self.assertRaises(PySparkValueError, msg="Rows cannot be of type VariantVal"):
self.spark.createDataFrame([VariantVal.parseJson("2")], "v variant")
+ def test_variant_to_pandas(self):
+ import pandas as pd
+ import json
+
+ expected_values = [
+ ("str", '"%s"' % ("0123456789" * 10), "0123456789" * 10),
+ ("short_str", '"abc"', "abc"),
+ ]
+ json_str = "{%s}" % ",".join(['"%s": %s' % (t[0], t[1]) for t in expected_values])
+ df = self.spark.createDataFrame([({"json": json_str})])
+ df_variant = df.select(F.parse_json(df.json).alias("v"))
+ pandas = df_variant.toPandas()
+ test_record = json.loads(pandas["v"].iloc[0].toJson())
+ self.assertIsInstance(pandas, pd.DataFrame)
+ self.assertEqual(expected_values[0][2], test_record["str"])
+ self.assertEqual(expected_values[1][2], test_record["short_str"])
+
+ def test_geospatial_encoding(self):
+ df = self.spark.createDataFrame(
+ [
+ (
+ bytes.fromhex("0101000000000000000000F03F0000000000000040"),
+ 4326,
+ )
+ ],
+ ["wkb", "srid"],
+ )
+ row = df.select(
+ F.st_geomfromwkb(df.wkb).alias("geom"),
+ F.st_geogfromwkb(df.wkb).alias("geog"),
+ ).collect()[0]
+
+ self.assertEqual(type(row["geom"]), Geometry)
+ self.assertEqual(type(row["geog"]), Geography)
+ self.assertEqual(
+ row["geom"].getBytes(), bytes.fromhex("0101000000000000000000F03F0000000000000040")
+ )
+ self.assertEqual(row["geom"].getSrid(), 0)
+ self.assertEqual(
+ row["geog"].getBytes(), bytes.fromhex("0101000000000000000000F03F0000000000000040")
+ )
+ self.assertEqual(row["geog"].getSrid(), 4326)
+
+ def test_geospatial_create_dataframe_rdd(self):
+ schema = StructType(
+ [
+ StructField("id", IntegerType(), True),
+ StructField("geom", GeometryType(0), True),
+ StructField("geom4326", GeometryType(4326), True),
+ StructField("geog", GeographyType(4326), True),
+ ]
+ )
+ geospatial_data = [
+ (
+ 1,
+ Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0),
+ Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326),
+ Geography.fromWKB(
+ bytes.fromhex("010100000000000000000031400000000000001c40"), 4326
+ ),
+ ),
+ (
+ 2,
+ Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 0),
+ Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326),
+ Geography.fromWKB(
+ bytes.fromhex("010100000000000000000014400000000000001440"), 4326
+ ),
+ ),
+ ]
+ rdd_data = self.sc.parallelize(geospatial_data)
+ df = self.spark.createDataFrame(rdd_data, schema)
+ rows = df.select(
+ F.st_asbinary(df.geom).alias("geom_wkb"),
+ F.st_srid(df.geom).alias("geom_srid"),
+ F.st_asbinary(df.geom4326).alias("geom4326_wkb"),
+ F.st_srid(df.geom4326).alias("geom4326_srid"),
+ F.st_asbinary(df.geog).alias("geog_wkb"),
+ F.st_srid(df.geog).alias("geog_srid"),
+ ).collect()
+
+ point0_wkb = bytes.fromhex("010100000000000000000031400000000000001c40")
+ point1_wkb = bytes.fromhex("010100000000000000000014400000000000001440")
+ self.assertEqual(rows[0]["geom_wkb"], point0_wkb)
+ self.assertEqual(rows[0]["geom4326_wkb"], point0_wkb)
+ self.assertEqual(rows[0]["geog_wkb"], point0_wkb)
+ self.assertEqual(rows[1]["geom_wkb"], point1_wkb)
+ self.assertEqual(rows[1]["geom4326_wkb"], point1_wkb)
+ self.assertEqual(rows[1]["geog_wkb"], point1_wkb)
+ self.assertEqual(rows[0]["geom_srid"], 0)
+ self.assertEqual(rows[0]["geom4326_srid"], 4326)
+ self.assertEqual(rows[0]["geog_srid"], 4326)
+ self.assertEqual(rows[1]["geom_srid"], 0)
+ self.assertEqual(rows[1]["geom4326_srid"], 4326)
+ self.assertEqual(rows[1]["geog_srid"], 4326)
+ schema_df = self.spark.createDataFrame(rdd_data).select(
+ F.col("_1").alias("id"),
+ F.col("_2").alias("geom"),
+ F.col("_3").alias("geom4326"),
+ F.col("_4").alias("geog"),
+ )
+ self.assertEqual(df.collect(), schema_df.collect())
+
+ def test_geospatial_create_dataframe(self):
+ # Positive Test: Creating DataFrame from a list of tuples with explicit schema
+ geospatial_data = [
+ (
+ 1,
+ Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0),
+ Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326),
+ Geography.fromWKB(
+ bytes.fromhex("010100000000000000000031400000000000001c40"), 4326
+ ),
+ ),
+ (
+ 2,
+ Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 0),
+ Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326),
+ Geography.fromWKB(
+ bytes.fromhex("010100000000000000000014400000000000001440"), 4326
+ ),
+ ),
+ ]
+ named_geospatial_data = [
+ {
+ "id": 1,
+ "geom": Geometry.fromWKB(
+ bytes.fromhex("010100000000000000000031400000000000001c40"), 0
+ ),
+ "geom4326": Geometry.fromWKB(
+ bytes.fromhex("010100000000000000000031400000000000001c40"), 4326
+ ),
+ "geog": Geography.fromWKB(
+ bytes.fromhex("010100000000000000000031400000000000001c40"), 4326
+ ),
+ },
+ {
+ "id": 2,
+ "geom": Geometry.fromWKB(
+ bytes.fromhex("010100000000000000000014400000000000001440"), 0
+ ),
+ "geom4326": Geometry.fromWKB(
+ bytes.fromhex("010100000000000000000014400000000000001440"), 4326
+ ),
+ "geog": Geography.fromWKB(
+ bytes.fromhex("010100000000000000000014400000000000001440"), 4326
+ ),
+ },
+ ]
+ GeospatialRow = Row("id", "geom", "geom4326", "geog")
+ spark_row_data = [
+ GeospatialRow(
+ 1,
+ Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0),
+ Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 4326),
+ Geography.fromWKB(
+ bytes.fromhex("010100000000000000000031400000000000001c40"), 4326
+ ),
+ ),
+ GeospatialRow(
+ 2,
+ Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 0),
+ Geometry.fromWKB(bytes.fromhex("010100000000000000000014400000000000001440"), 4326),
+ Geography.fromWKB(
+ bytes.fromhex("010100000000000000000014400000000000001440"), 4326
+ ),
+ ),
+ ]
+ schema = StructType(
+ [
+ StructField("id", IntegerType(), True),
+ StructField("geom", GeometryType(0), True),
+ StructField("geom4326", GeometryType(4326), True),
+ StructField("geog", GeographyType(4326), True),
+ ]
+ )
+ # Negative Test: Schema mismatch
+ mismatched_schema = StructType(
+ [
+ StructField("id", IntegerType(), True), # Should be GeometryType
+ StructField("geom", GeometryType(4326), True), # Should be GeometryType
+ StructField("geom4326", GeometryType(4326), True), # Should be GeometryType
+ StructField("geog", GeographyType(4326), True), # Should be GeographyType
+ ]
+ )
+
+ # Explicitly validate single test case
+ # rest will be compared with this one.
+ df = self.spark.createDataFrame(geospatial_data, schema)
+ rows = df.select(
+ F.st_asbinary(df.geom).alias("geom_wkb"),
+ F.st_srid(df.geom).alias("geom_srid"),
+ F.st_asbinary(df.geom4326).alias("geom4326_wkb"),
+ F.st_srid(df.geom4326).alias("geom4326_srid"),
+ F.st_asbinary(df.geog).alias("geog_wkb"),
+ F.st_srid(df.geog).alias("geog_srid"),
+ ).collect()
+
+ point0_wkb = bytes.fromhex("010100000000000000000031400000000000001c40")
+ point1_wkb = bytes.fromhex("010100000000000000000014400000000000001440")
+ self.assertEqual(rows[0]["geom_wkb"], point0_wkb)
+ self.assertEqual(rows[0]["geom4326_wkb"], point0_wkb)
+ self.assertEqual(rows[0]["geog_wkb"], point0_wkb)
+ self.assertEqual(rows[1]["geom_wkb"], point1_wkb)
+ self.assertEqual(rows[1]["geom4326_wkb"], point1_wkb)
+ self.assertEqual(rows[1]["geog_wkb"], point1_wkb)
+ self.assertEqual(rows[0]["geom_srid"], 0)
+ self.assertEqual(rows[0]["geom4326_srid"], 4326)
+ self.assertEqual(rows[0]["geog_srid"], 4326)
+ self.assertEqual(rows[1]["geom_srid"], 0)
+ self.assertEqual(rows[1]["geom4326_srid"], 4326)
+ self.assertEqual(rows[1]["geog_srid"], 4326)
+
+ # Just the data set without parameters.
+ self.assertEqual(
+ self.spark.createDataFrame(named_geospatial_data)
+ .select("id", "geom", "geom4326", "geog")
+ .collect(),
+ df.collect(),
+ )
+ self.assertEqual(self.spark.createDataFrame(geospatial_data).collect(), df.collect())
+ self.assertEqual(self.spark.createDataFrame(spark_row_data).collect(), df.collect())
+
+ # Define DataFrame creation methods
+ datasets = [named_geospatial_data, geospatial_data, spark_row_data]
+ schemas = [
+ schema,
+ "id INT, geom GEOMETRY(0), geom4326 GEOMETRY(4326), geog GEOGRAPHY(4326)",
+ ["id", "geom", "geom4326", "geog"],
+ ]
+
+ for dataset_to_check, schema_to_check in zip(datasets, schemas):
+ df_to_check = self.spark.createDataFrame(dataset_to_check, schema_to_check).select(
+ "id", "geom", "geom4326", "geog"
+ )
+ self.assertEqual(df_to_check.collect(), df.collect(), "DataFrame creation with schema")
+
+ # Negative Test: Schema mismatch
+ for dataset_to_check in datasets:
+ with self.assertRaises(SparkRuntimeException) as pe:
+ self.spark.createDataFrame(dataset_to_check, mismatched_schema).collect()
+
+ self.check_error(
+ exception=pe.exception,
+ errorClass="GEO_ENCODER_SRID_MISMATCH_ERROR",
+ messageParameters={"type": "GEOMETRY", "typeSrid": "4326", "valueSrid": "0"},
+ )
+
+ def test_geospatial_schema_inferrence(self):
+ # Mixed data with different SRIDs
+ wkb = bytes.fromhex("010100000000000000000031400000000000001c40")
+ geometry_dataset = [
+ (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326), Geometry.fromWKB(wkb, 4326)),
+ (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326), Geometry.fromWKB(wkb, 0)),
+ (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326), Geometry.fromWKB(wkb, 4326)),
+ (Geometry.fromWKB(wkb, 0), Geometry.fromWKB(wkb, 4326), Geometry.fromWKB(wkb, 0)),
+ ]
+ # Create DataFrame with mixed data types
+ df = self.spark.createDataFrame(geometry_dataset, ["geom0", "geom4326", "geom_any"])
+ expected_schema = StructType(
+ [
+ StructField("geom0", GeometryType(0), True),
+ StructField("geom4326", GeometryType(4326), True),
+ StructField("geom_any", GeometryType("ANY"), True),
+ ]
+ )
+ self.assertEqual(df.schema, expected_schema)
+
+ rows = df.select(
+ F.st_asbinary("geom0").alias("geom0_wkb"),
+ F.st_srid("geom0").alias("geom0_srid"),
+ F.st_asbinary("geom4326").alias("geom4326_wkb"),
+ F.st_srid("geom4326").alias("geom4326_srid"),
+ F.st_asbinary("geom_any").alias("geom_any_wkb"),
+ F.st_srid("geom_any").alias("geom_any_srid"),
+ ).collect()
+
+ point_wkb = bytes.fromhex("010100000000000000000031400000000000001c40")
+ self.assertEqual(rows[0]["geom0_wkb"], point_wkb)
+ self.assertEqual(rows[1]["geom0_wkb"], point_wkb)
+ self.assertEqual(rows[0]["geom4326_wkb"], point_wkb)
+ self.assertEqual(rows[1]["geom4326_wkb"], point_wkb)
+ self.assertEqual(rows[0]["geom_any_wkb"], point_wkb)
+ self.assertEqual(rows[1]["geom_any_wkb"], point_wkb)
+ self.assertEqual(rows[0]["geom0_srid"], 0)
+ self.assertEqual(rows[1]["geom0_srid"], 0)
+ self.assertEqual(rows[0]["geom4326_srid"], 4326)
+ self.assertEqual(rows[1]["geom4326_srid"], 4326)
+ self.assertEqual(rows[0]["geom_any_srid"], 4326)
+ self.assertEqual(rows[1]["geom_any_srid"], 0)
+
+ def test_geospatial_mixed_check_srid_validity(self):
+ geometry_mixed_invalid_data = [
+ (1, Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 0)),
+ (2, Geometry.fromWKB(bytes.fromhex("010100000000000000000031400000000000001c40"), 1)),
+ ]
+
+ with self.assertRaises(IllegalArgumentException) as pe:
+ self.spark.createDataFrame(geometry_mixed_invalid_data).collect()
+ self.check_error(
+ exception=pe.exception,
+ errorClass="ST_INVALID_SRID_VALUE",
+ messageParameters={"srid": "1"},
+ )
+
+ with self.assertRaises(IllegalArgumentException) as pe:
+ self.spark.createDataFrame(
+ geometry_mixed_invalid_data, "id INT, geom GEOMETRY(ANY)"
+ ).collect()
+ self.check_error(
+ exception=pe.exception,
+ errorClass="ST_INVALID_SRID_VALUE",
+ messageParameters={"srid": "1"},
+ )
+
+ def test_geospatial_result_encoding(self):
+ point_wkb = "010100000000000000000031400000000000001c40"
+ point_bytes = bytes.fromhex(point_wkb)
+ df = self.spark.sql(
+ f"""
+ SELECT ST_GeomFromWKB(X'{point_wkb}') AS geom,
+ ST_GeogFromWKB(X'{point_wkb}') AS geog"""
+ )
+ GeospatialRow = Row("geom", "geog")
+ self.assertEqual(
+ df.collect(),
+ [
+ GeospatialRow(
+ Geometry.fromWKB(point_bytes, 0),
+ Geography.fromWKB(point_bytes, 4326),
+ )
+ ],
+ )
+
def test_to_ddl(self):
schema = StructType().add("a", NullType()).add("b", BooleanType()).add("c", BinaryType())
self.assertEqual(schema.toDDL(), "a VOID,b BOOLEAN,c BINARY")
diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py
index d6bc8ad28b330..9c5fa2ad1bba8 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -48,6 +48,7 @@
VariantVal,
)
from pyspark.errors import AnalysisException, PythonException, PySparkTypeError
+from pyspark.logger import PySparkLogger
from pyspark.testing.objects import ExamplePoint, ExamplePointUDT
from pyspark.testing.sqlutils import (
ReusedSQLTestCase,
@@ -1572,50 +1573,66 @@ def my_udf():
logger.exception("exception")
return "x"
+ # The TVF is not available when the feature is disabled.
+ with self.assertRaises(AnalysisException) as pe:
+ self.spark.tvf.python_worker_logs().count()
+
+ self.check_error(
+ exception=pe.exception,
+ errorClass="FEATURE_NOT_ENABLED",
+ messageParameters={
+ "featureName": "Python Worker Logging",
+ "configKey": "spark.sql.pyspark.worker.logging.enabled",
+ "configValue": "true",
+ },
+ )
+
# Logging is disabled by default
assertDataFrameEqual(
self.spark.range(1).select(my_udf().alias("result")), [Row(result="x")]
)
- self.assertEqual(self.spark.table("system.session.python_worker_logs").count(), 0)
with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}):
+ # The logs were not collected when the feature was disabled.
+ self.assertEqual(self.spark.tvf.python_worker_logs().count(), 0)
+
assertDataFrameEqual(
self.spark.range(1).select(my_udf().alias("result")), [Row(result="x")]
)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="INFO",
- msg="print to stdout ❤",
- context={"func_name": my_udf.__name__},
- logger="stdout",
- ),
- Row(
- level="ERROR",
- msg="print to stderr 😀",
- context={"func_name": my_udf.__name__},
- logger="stderr",
- ),
- Row(
- level="WARNING",
- msg="custom context",
- context={"func_name": my_udf.__name__, "abc": "123"},
- logger="test",
- ),
- Row(
- level="ERROR",
- msg="exception",
- context={"func_name": my_udf.__name__},
- logger="test",
- ),
- ],
- )
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="INFO",
+ msg="print to stdout ❤",
+ context={"func_name": my_udf.__name__},
+ logger="stdout",
+ ),
+ Row(
+ level="ERROR",
+ msg="print to stderr 😀",
+ context={"func_name": my_udf.__name__},
+ logger="stderr",
+ ),
+ Row(
+ level="WARNING",
+ msg="custom context",
+ context={"func_name": my_udf.__name__, "abc": "123"},
+ logger="test",
+ ),
+ Row(
+ level="ERROR",
+ msg="exception",
+ context={"func_name": my_udf.__name__},
+ logger="test",
+ ),
+ ],
+ )
- self.assertEqual(logs.where("exception is not null").select("exception").count(), 1)
+ self.assertEqual(logs.where("exception is not null").select("exception").count(), 1)
@unittest.skipIf(is_remote_only(), "Requires JVM access")
def test_multiple_udfs_with_logging(self):
@@ -1637,25 +1654,54 @@ def my_udf2():
[Row(result="x", result2="y")],
)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg="test1",
- context={"func_name": my_udf1.__name__},
- logger="test1",
- ),
- Row(
- level="WARNING",
- msg="test2",
- context={"func_name": my_udf2.__name__},
- logger="test2",
- ),
- ],
- )
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg="test1",
+ context={"func_name": my_udf1.__name__},
+ logger="test1",
+ ),
+ Row(
+ level="WARNING",
+ msg="test2",
+ context={"func_name": my_udf2.__name__},
+ logger="test2",
+ ),
+ ],
+ )
+
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_udf_with_pyspark_logger(self):
+ @udf
+ def my_udf(x):
+ logger = PySparkLogger.getLogger("PySparkLogger")
+ logger.warning("PySparkLogger test", x=x)
+ return str(x)
+
+ with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}):
+ assertDataFrameEqual(
+ self.spark.range(2).select(my_udf("id").alias("result")),
+ [Row(result=str(i)) for i in range(2)],
+ )
+
+ logs = self.spark.tvf.python_worker_logs()
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg="PySparkLogger test",
+ context={"func_name": my_udf.__name__, "x": str(i)},
+ logger="PySparkLogger",
+ )
+ for i in range(2)
+ ],
+ )
class UDFTests(BaseUDFTestsMixin, ReusedSQLTestCase):
diff --git a/python/pyspark/sql/tests/test_udf_profiler.py b/python/pyspark/sql/tests/test_udf_profiler.py
index 4e8f722c22cbd..e6a7bf40b9454 100644
--- a/python/pyspark/sql/tests/test_udf_profiler.py
+++ b/python/pyspark/sql/tests/test_udf_profiler.py
@@ -28,6 +28,7 @@
from pyspark import SparkConf
from pyspark.errors import PySparkValueError
from pyspark.sql import SparkSession
+from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.functions import col, arrow_udf, pandas_udf, udf
from pyspark.sql.window import Window
from pyspark.profiler import UDFBasicProfiler
@@ -325,59 +326,47 @@ def add2(x):
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
- def test_perf_profiler_pandas_udf_iterator_not_supported(self):
+ def test_perf_profiler_pandas_udf_iterator(self):
import pandas as pd
@pandas_udf("long")
- def add1(x):
- return x + 1
-
- @pandas_udf("long")
- def add2(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
+ def add(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
for s in iter:
- yield s + 2
+ yield s + 1
with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
- df = self.spark.range(10, numPartitions=2).select(
- add1("id"), add2("id"), add1("id"), add2(col("id") + 1)
- )
+ df = self.spark.range(10, numPartitions=2).select(add("id"))
df.collect()
self.assertEqual(1, len(self.profile_results), str(self.profile_results.keys()))
for id in self.profile_results:
- self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=2)
+ self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=4)
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
- def test_perf_profiler_arrow_udf_iterator_not_supported(self):
+ def test_perf_profiler_arrow_udf_iterator(self):
import pyarrow as pa
@arrow_udf("long")
- def add1(x):
- return pa.compute.add(x, 1)
-
- @arrow_udf("long")
- def add2(iter: Iterator[pa.Array]) -> Iterator[pa.Array]:
+ def add(iter: Iterator[pa.Array]) -> Iterator[pa.Array]:
for s in iter:
- yield pa.compute.add(s, 2)
+ yield pa.compute.add(s, 1)
with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
- df = self.spark.range(10, numPartitions=2).select(
- add1("id"), add2("id"), add1("id"), add2(col("id") + 1)
- )
+ df = self.spark.range(10, numPartitions=2).select(add("id"))
df.collect()
self.assertEqual(1, len(self.profile_results), str(self.profile_results.keys()))
for id in self.profile_results:
- self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=2)
+ self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=4)
@unittest.skipIf(
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
- def test_perf_profiler_map_in_pandas_not_supported(self):
- df = self.spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
+ def test_perf_profiler_map_in_pandas(self):
+ df = self.spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")).repartition(1)
def filter_func(iterator):
for pdf in iterator:
@@ -386,7 +375,28 @@ def filter_func(iterator):
with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
df.mapInPandas(filter_func, df.schema).show()
- self.assertEqual(0, len(self.profile_results), str(self.profile_results.keys()))
+ self.assertEqual(1, len(self.profile_results), str(self.profile_results.keys()))
+
+ for id in self.profile_results:
+ self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=2)
+
+ @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
+ def test_perf_profiler_map_in_arrow(self):
+ import pyarrow as pa
+
+ df = self.spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")).repartition(1)
+
+ def map_func(iterator: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]:
+ for batch in iterator:
+ yield pa.RecordBatch.from_arrays(
+ [batch.column("id"), pa.compute.add(batch.column("age"), 1)], ["id", "age"]
+ )
+
+ with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
+ df.mapInArrow(map_func, df.schema).show()
+
+ for id in self.profile_results:
+ self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=2)
@unittest.skipIf(
not have_pandas or not have_pyarrow,
@@ -575,6 +585,35 @@ def summarize(left, right):
for id in self.profile_results:
self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=2)
+ @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
+ def test_perf_profiler_data_source(self):
+ class TestDataSourceReader(DataSourceReader):
+ def __init__(self, schema):
+ self.schema = schema
+
+ def partitions(self):
+ raise NotImplementedError
+
+ def read(self, partition):
+ yield from ((1,), (2,), (3,))
+
+ class TestDataSource(DataSource):
+ def schema(self):
+ return "id long"
+
+ def reader(self, schema) -> "DataSourceReader":
+ return TestDataSourceReader(schema)
+
+ self.spark.dataSource.register(TestDataSource)
+
+ with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
+ self.spark.read.format("TestDataSource").load().collect()
+
+ self.assertEqual(1, len(self.profile_results), str(self.profile_results.keys()))
+
+ for id in self.profile_results:
+ self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=4)
+
def test_perf_profiler_render(self):
with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
_do_computation(self.spark)
diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py
index b86a2624acd53..5ded5aa67b4eb 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -65,6 +65,7 @@
StructType,
VariantVal,
)
+from pyspark.logger import PySparkLogger
from pyspark.testing import assertDataFrameEqual, assertSchemaEqual
from pyspark.testing.objects import ExamplePoint, ExamplePointUDT
from pyspark.testing.sqlutils import (
@@ -3078,20 +3079,20 @@ def eval(self, x: int):
[Row(x=x, a=x * 2, b=x + 10) for x in [5, 10]],
)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select("level", "msg", "context", "logger"),
- [
- Row(
- level="WARNING",
- msg=f"udtf with logging: {x}",
- context={"class_name": "TestUDTFWithLogging", "func_name": "eval"},
- logger="test_udtf",
- )
- for x in [5, 10]
- ],
- )
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=f"udtf with logging: {x}",
+ context={"class_name": "TestUDTFWithLogging", "func_name": "eval"},
+ logger="test_udtf",
+ )
+ for x in [5, 10]
+ ],
+ )
@unittest.skipIf(is_remote_only(), "Requires JVM access")
def test_udtf_analyze_with_logging(self):
@@ -3114,26 +3115,70 @@ def eval(self, x: int):
[Row(x=x, a=x * 2, b=x + 10) for x in [5, 10]],
)
- logs = self.spark.table("system.session.python_worker_logs")
+ logs = self.spark.tvf.python_worker_logs()
- assertDataFrameEqual(
- logs.select(
- "level",
- "msg",
- col("context.class_name").alias("context_class_name"),
- col("context.func_name").alias("context_func_name"),
- "logger",
- ).distinct(),
- [
- Row(
- level="WARNING",
- msg='udtf analyze: "long"',
- context_class_name="TestUDTFWithLogging",
- context_func_name="analyze",
- logger="test_udtf",
- )
- ],
- )
+ assertDataFrameEqual(
+ logs.select(
+ "level",
+ "msg",
+ col("context.class_name").alias("context_class_name"),
+ col("context.func_name").alias("context_func_name"),
+ "logger",
+ ).distinct(),
+ [
+ Row(
+ level="WARNING",
+ msg='udtf analyze: "long"',
+ context_class_name="TestUDTFWithLogging",
+ context_func_name="analyze",
+ logger="test_udtf",
+ )
+ ],
+ )
+
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_udtf_analyze_with_pyspark_logger(self):
+ @udtf
+ class TestUDTFWithLogging:
+ @staticmethod
+ def analyze(x: AnalyzeArgument) -> AnalyzeResult:
+ logger = PySparkLogger.getLogger("PySparkLogger")
+ logger.warning(f"udtf analyze: {x.dataType.json()}", dt=x.dataType.json())
+ return AnalyzeResult(StructType().add("a", IntegerType()).add("b", IntegerType()))
+
+ def eval(self, x: int):
+ yield x * 2, x + 10
+
+ with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": "true"}):
+ assertDataFrameEqual(
+ self.spark.createDataFrame([(5,), (10,)], ["x"]).lateralJoin(
+ TestUDTFWithLogging(col("x").outer())
+ ),
+ [Row(x=x, a=x * 2, b=x + 10) for x in [5, 10]],
+ )
+
+ logs = self.spark.tvf.python_worker_logs()
+
+ assertDataFrameEqual(
+ logs.select(
+ "level",
+ "msg",
+ col("context.class_name").alias("context_class_name"),
+ col("context.func_name").alias("context_func_name"),
+ col("context.dt").alias("context_dt"),
+ "logger",
+ ).distinct(),
+ [
+ Row(
+ level="WARNING",
+ msg='udtf analyze: "long"',
+ context_class_name="TestUDTFWithLogging",
+ context_func_name="analyze",
+ context_dt='"long"',
+ logger="PySparkLogger",
+ )
+ ],
+ )
class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
diff --git a/python/pyspark/sql/tvf.py b/python/pyspark/sql/tvf.py
index b34877b03311f..c21e3751655c9 100644
--- a/python/pyspark/sql/tvf.py
+++ b/python/pyspark/sql/tvf.py
@@ -677,6 +677,45 @@ def variant_explode_outer(self, input: Column) -> DataFrame:
"""
return self._fn("variant_explode_outer", input)
+ def python_worker_logs(self) -> DataFrame:
+ """
+ Returns a DataFrame of logs collected from Python workers.
+
+ .. versionadded:: 4.1.0
+
+ Returns
+ -------
+ :class:`DataFrame`
+
+ Examples
+ --------
+ >>> import pyspark.sql.functions as sf
+ >>> import logging
+ >>>
+ >>> @sf.udf("string")
+ ... def my_udf(x):
+ ... logger = logging.getLogger("my_custom_logger")
+ ... logger.warning("This is a warning")
+ ... return str(x)
+ ...
+ >>> spark.conf.set("spark.sql.pyspark.worker.logging.enabled", "true")
+ >>> spark.range(1).select(my_udf("id")).show()
+ +----------+
+ |my_udf(id)|
+ +----------+
+ | 0|
+ +----------+
+ >>> spark.tvf.python_worker_logs().select(
+ ... "level", "msg", "context", "logger"
+ ... ).show(truncate=False) # doctest: +SKIP
+ +-------+-----------------+---------------------+----------------+
+ |level |msg |context |logger |
+ +-------+-----------------+---------------------+----------------+
+ |WARNING|This is a warning|{func_name -> my_udf}|my_custom_logger|
+ +-------+-----------------+---------------------+----------------+
+ """
+ return self._fn("python_worker_logs")
+
def _fn(self, functionName: str, *args: Column) -> DataFrame:
from pyspark.sql.classic.column import _to_java_column
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 440100dba9312..0504999ac65aa 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -90,6 +90,8 @@
"TimestampNTZType",
"DecimalType",
"DoubleType",
+ "Geography",
+ "Geometry",
"FloatType",
"ByteType",
"IntegerType",
@@ -401,7 +403,11 @@ class AnyTimeType(DatetimeType):
class TimeType(AnyTimeType):
- """Time (datetime.time) data type."""
+ """
+ Time (datetime.time) data type.
+
+ .. versionadded:: 4.1.0
+ """
def __init__(self, precision: int = 6):
self.precision = precision
@@ -616,6 +622,20 @@ def jsonValue(self) -> Union[str, Dict[str, Any]]:
# The JSON representation always uses the CRS and algorithm value.
return f"geography({self._crs}, {self._alg})"
+ def needConversion(self) -> bool:
+ return True
+
+ def fromInternal(self, obj: Dict) -> Optional["Geography"]:
+ if obj is None or not all(key in obj for key in ["srid", "bytes"]):
+ return None
+ return Geography(obj["bytes"], obj["srid"])
+
+ def toInternal(self, geography: Any) -> Any:
+ if geography is None:
+ return None
+ assert isinstance(geography, Geography)
+ return {"srid": geography.srid, "wkb": geography.wkb}
+
class GeometryType(SpatialType):
"""
@@ -700,6 +720,20 @@ def jsonValue(self) -> Union[str, Dict[str, Any]]:
# The JSON representation always uses the CRS value.
return f"geometry({self._crs})"
+ def needConversion(self) -> bool:
+ return True
+
+ def fromInternal(self, obj: Dict) -> Optional["Geometry"]:
+ if obj is None or not all(key in obj for key in ["srid", "bytes"]):
+ return None
+ return Geometry(obj["bytes"], obj["srid"])
+
+ def toInternal(self, geometry: Any) -> Any:
+ if geometry is None:
+ return None
+ assert isinstance(geometry, Geometry)
+ return {"srid": geometry.srid, "wkb": geometry.wkb}
+
class ByteType(IntegralType):
"""Byte data type, representing signed 8-bit integers."""
@@ -2039,6 +2073,144 @@ def parseJson(cls, json_str: str) -> "VariantVal":
return VariantVal(value, metadata)
+class Geography:
+ """
+ A class to represent a Geography value in Python.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ wkb : bytes
+ The bytes representing the WKB of Geography.
+
+ srid : integer
+ The integer value representing SRID of Geography.
+
+ Methods
+ -------
+ getBytes()
+ Returns the WKB of Geography.
+
+ getSrid()
+ Returns the SRID of Geography.
+
+ Examples
+ --------
+ >>> g = Geography.fromWKB(bytes.fromhex('010100000000000000000031400000000000001c40'), 4326)
+ >>> g.getBytes().hex()
+ '010100000000000000000031400000000000001c40'
+ >>> g.getSrid()
+ 4326
+ """
+
+ def __init__(self, wkb: bytes, srid: int):
+ self.wkb = wkb
+ self.srid = srid
+
+ def __str__(self) -> str:
+ return "Geography(%r, %d)" % (self.wkb, self.srid)
+
+ def __repr__(self) -> str:
+ return "Geography(%r, %d)" % (self.wkb, self.srid)
+
+ def getSrid(self) -> int:
+ """
+ Returns the SRID of Geography.
+ """
+ return self.srid
+
+ def getBytes(self) -> bytes:
+ """
+ Returns the WKB of Geography.
+ """
+ return self.wkb
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, Geography):
+ # Don't attempt to compare against unrelated types.
+ return NotImplemented
+
+ return self.wkb == other.wkb and self.srid == other.srid
+
+ @classmethod
+ def fromWKB(cls, wkb: bytes, srid: int) -> "Geography":
+ """
+ Construct Python Geography object from WKB.
+ :return: Python representation of the Geography type value.
+ """
+ return Geography(wkb, srid)
+
+
+class Geometry:
+ """
+ A class to represent a Geometry value in Python.
+
+ .. versionadded:: 4.1.0
+
+ Parameters
+ ----------
+ wkb : bytes
+ The bytes representing the WKB of Geometry.
+
+ srid : integer
+ The integer value representing SRID of Geometry.
+
+ Methods
+ -------
+ getBytes()
+ Returns the WKB of Geometry.
+
+ getSrid()
+ Returns the SRID of Geometry.
+
+ Examples
+ --------
+ >>> g = Geometry.fromWKB(bytes.fromhex('010100000000000000000031400000000000001c40'), 0)
+ >>> g.getBytes().hex()
+ '010100000000000000000031400000000000001c40'
+ >>> g.getSrid()
+ 0
+ """
+
+ def __init__(self, wkb: bytes, srid: int):
+ self.wkb = wkb
+ self.srid = srid
+
+ def __str__(self) -> str:
+ return "Geometry(%r, %d)" % (self.wkb, self.srid)
+
+ def __repr__(self) -> str:
+ return "Geometry(%r, %d)" % (self.wkb, self.srid)
+
+ def getSrid(self) -> int:
+ """
+ Returns the SRID of Geometry.
+ """
+ return self.srid
+
+ def getBytes(self) -> bytes:
+ """
+ Returns the WKB of Geometry.
+ """
+ return self.wkb
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, Geometry):
+ # Don't attempt to compare against unrelated types.
+ return NotImplemented
+
+ return self.wkb == other.wkb and self.srid == other.srid
+
+ @classmethod
+ def fromWKB(cls, wkb: bytes, srid: int) -> "Geometry":
+ """
+ Construct Python Geometry object from WKB.
+ :return: Python representation of the Geometry type value.
+ """
+ return Geometry(wkb, srid)
+
+
_atomic_types: List[Type[DataType]] = [
StringType,
CharType,
@@ -2349,6 +2521,8 @@ def _assert_valid_collation_provider(provider: str) -> None:
# Mapping Python types to Spark SQL DataType
_type_mappings = {
type(None): NullType,
+ Geometry: GeometryType,
+ Geography: GeographyType,
bool: BooleanType,
int: LongType,
float: DoubleType,
@@ -2480,6 +2654,12 @@ def _infer_type(
return obj.__UDT__
dataType = _type_mappings.get(type(obj))
+ if dataType is GeographyType:
+ assert isinstance(obj, Geography)
+ return GeographyType(obj.getSrid())
+ if dataType is GeometryType:
+ assert isinstance(obj, Geometry)
+ return GeometryType(obj.getSrid())
if dataType is DecimalType:
# the precision and scale of `obj` may be different from row to row.
return DecimalType(38, 18)
@@ -2747,6 +2927,10 @@ def new_name(n: str) -> str:
return a
elif isinstance(a, TimestampNTZType) and isinstance(b, TimestampType):
return b
+ elif isinstance(a, GeometryType) and isinstance(b, GeometryType) and a.srid != b.srid:
+ return GeometryType("ANY")
+ elif isinstance(a, GeographyType) and isinstance(b, GeographyType) and a.srid != b.srid:
+ return GeographyType("ANY")
elif isinstance(a, AtomicType) and isinstance(b, StringType):
return b
elif isinstance(a, StringType) and isinstance(b, AtomicType):
@@ -2900,6 +3084,8 @@ def convert_struct(obj: Any) -> Optional[Tuple]:
ArrayType: (list, tuple, array),
MapType: (dict,),
StructType: (tuple, list, dict),
+ GeometryType: (Geometry,),
+ GeographyType: (Geography,),
VariantType: (
bool,
int,
@@ -3251,6 +3437,24 @@ def verify_variant(obj: Any) -> None:
verify_value = verify_variant
+ elif isinstance(dataType, GeometryType):
+
+ def verify_geometry(obj: Any) -> None:
+ assert_acceptable_types(obj)
+ verify_acceptable_types(obj)
+ assert isinstance(obj, Geometry)
+
+ verify_value = verify_geometry
+
+ elif isinstance(dataType, GeographyType):
+
+ def verify_geography(obj: Any) -> None:
+ assert_acceptable_types(obj)
+ verify_acceptable_types(obj)
+ assert isinstance(obj, Geography)
+
+ verify_value = verify_geography
+
else:
def verify_default(obj: Any) -> None:
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 37aa30cc279f9..c7471d19f7d6f 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -642,25 +642,24 @@ def register(
>>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP
[Row(random_udf()=82)]
- >>> import pandas as pd # doctest: +SKIP
+ >>> import pandas as pd
>>> from pyspark.sql.functions import pandas_udf
- >>> @pandas_udf("integer") # doctest: +SKIP
+ >>> @pandas_udf("integer")
... def add_one(s: pd.Series) -> pd.Series:
... return s + 1
...
- >>> _ = spark.udf.register("add_one", add_one) # doctest: +SKIP
- >>> spark.sql("SELECT add_one(id) FROM range(3)").collect() # doctest: +SKIP
+ >>> _ = spark.udf.register("add_one", add_one)
+ >>> spark.sql("SELECT add_one(id) FROM range(3)").collect()
[Row(add_one(id)=1), Row(add_one(id)=2), Row(add_one(id)=3)]
- >>> @pandas_udf("integer") # doctest: +SKIP
+ >>> @pandas_udf("integer")
... def sum_udf(v: pd.Series) -> int:
... return v.sum()
...
- >>> _ = spark.udf.register("sum_udf", sum_udf) # doctest: +SKIP
+ >>> _ = spark.udf.register("sum_udf", sum_udf)
>>> q = "SELECT sum_udf(v1) FROM VALUES (3, 0), (2, 0), (1, 1) tbl(v1, v2) GROUP BY v2"
- >>> spark.sql(q).collect() # doctest: +SKIP
+ >>> spark.sql(q).sort("sum_udf(v1)").collect()
[Row(sum_udf(v1)=1), Row(sum_udf(v1)=5)]
-
"""
# This is to check whether the input function is from a user-defined function or
@@ -796,8 +795,13 @@ def _test() -> None:
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.udf
+ from pyspark.testing.utils import have_pandas, have_pyarrow
globs = pyspark.sql.udf.__dict__.copy()
+
+ if not have_pandas or not have_pyarrow:
+ del pyspark.sql.udf.UDFRegistration.register.__doc__
+
spark = SparkSession.builder.master("local[4]").appName("sql.udf tests").getOrCreate()
globs["spark"] = spark
(failure_count, test_count) = doctest.testmod(
diff --git a/python/pyspark/sql/worker/analyze_udtf.py b/python/pyspark/sql/worker/analyze_udtf.py
index 665b1297fbc1f..526cb316862c7 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -15,7 +15,6 @@
# limitations under the License.
#
-import faulthandler
import inspect
import os
import sys
@@ -35,7 +34,12 @@
from pyspark.sql.functions import OrderingColumn, PartitioningColumn, SelectedColumn
from pyspark.sql.types import _parse_datatype_json_string, StructType
from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult
-from pyspark.util import handle_worker_exception, local_connect_and_auth
+from pyspark.util import (
+ handle_worker_exception,
+ local_connect_and_auth,
+ with_faulthandler,
+ start_faulthandler_periodic_traceback,
+)
from pyspark.worker_util import (
check_python_version,
read_command,
@@ -100,6 +104,7 @@ def read_arguments(infile: IO) -> Tuple[List[AnalyzeArgument], Dict[str, Analyze
return args, kwargs
+@with_faulthandler
def main(infile: IO, outfile: IO) -> None:
"""
Runs the Python UDTF's `analyze` static method.
@@ -108,18 +113,10 @@ def main(infile: IO, outfile: IO) -> None:
in JVM and receive the Python UDTF and its arguments for the `analyze` static method,
and call the `analyze` static method, and send back a AnalyzeResult as a result of the method.
"""
- faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
- tracebackDumpIntervalSeconds = os.environ.get("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", None)
try:
- if faulthandler_log_path:
- faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
- faulthandler_log_file = open(faulthandler_log_path, "w")
- faulthandler.enable(file=faulthandler_log_file)
-
check_python_version(infile)
- if tracebackDumpIntervalSeconds is not None and int(tracebackDumpIntervalSeconds) > 0:
- faulthandler.dump_traceback_later(int(tracebackDumpIntervalSeconds), repeat=True)
+ start_faulthandler_periodic_traceback()
memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
setup_memory_limits(memory_limit_mb)
@@ -266,11 +263,6 @@ def invalid_analyze_result_field(field_name: str, expected_field: str) -> PySpar
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
- finally:
- if faulthandler_log_path:
- faulthandler.disable()
- faulthandler_log_file.close()
- os.remove(faulthandler_log_path)
send_accumulator_updates(outfile)
@@ -282,9 +274,6 @@ def invalid_analyze_result_field(field_name: str, expected_field: str) -> PySpar
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
sys.exit(-1)
- # Force to cancel dump_traceback_later
- faulthandler.cancel_dump_traceback_later()
-
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
diff --git a/python/pyspark/sql/worker/commit_data_source_write.py b/python/pyspark/sql/worker/commit_data_source_write.py
index fb82b65f31229..37fee6ad8357e 100644
--- a/python/pyspark/sql/worker/commit_data_source_write.py
+++ b/python/pyspark/sql/worker/commit_data_source_write.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import faulthandler
import os
import sys
from typing import IO
@@ -29,7 +28,12 @@
SpecialLengths,
)
from pyspark.sql.datasource import DataSourceWriter, WriterCommitMessage
-from pyspark.util import handle_worker_exception, local_connect_and_auth
+from pyspark.util import (
+ handle_worker_exception,
+ local_connect_and_auth,
+ with_faulthandler,
+ start_faulthandler_periodic_traceback,
+)
from pyspark.worker_util import (
check_python_version,
pickleSer,
@@ -40,6 +44,7 @@
)
+@with_faulthandler
def main(infile: IO, outfile: IO) -> None:
"""
Main method for committing or aborting a data source write operation.
@@ -49,18 +54,10 @@ def main(infile: IO, outfile: IO) -> None:
responsible for invoking either the `commit` or the `abort` method on a data source
writer instance, given a list of commit messages.
"""
- faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
- tracebackDumpIntervalSeconds = os.environ.get("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", None)
try:
- if faulthandler_log_path:
- faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
- faulthandler_log_file = open(faulthandler_log_path, "w")
- faulthandler.enable(file=faulthandler_log_file)
-
check_python_version(infile)
- if tracebackDumpIntervalSeconds is not None and int(tracebackDumpIntervalSeconds) > 0:
- faulthandler.dump_traceback_later(int(tracebackDumpIntervalSeconds), repeat=True)
+ start_faulthandler_periodic_traceback()
memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
setup_memory_limits(memory_limit_mb)
@@ -106,11 +103,6 @@ def main(infile: IO, outfile: IO) -> None:
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
- finally:
- if faulthandler_log_path:
- faulthandler.disable()
- faulthandler_log_file.close()
- os.remove(faulthandler_log_path)
send_accumulator_updates(outfile)
@@ -122,9 +114,6 @@ def main(infile: IO, outfile: IO) -> None:
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
sys.exit(-1)
- # Force to cancel dump_traceback_later
- faulthandler.cancel_dump_traceback_later()
-
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
diff --git a/python/pyspark/sql/worker/create_data_source.py b/python/pyspark/sql/worker/create_data_source.py
index 15e8fdc618e29..bf6ceda41ffb9 100644
--- a/python/pyspark/sql/worker/create_data_source.py
+++ b/python/pyspark/sql/worker/create_data_source.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import faulthandler
import inspect
import os
import sys
@@ -32,7 +31,12 @@
)
from pyspark.sql.datasource import DataSource, CaseInsensitiveDict
from pyspark.sql.types import _parse_datatype_json_string, StructType
-from pyspark.util import handle_worker_exception, local_connect_and_auth
+from pyspark.util import (
+ handle_worker_exception,
+ local_connect_and_auth,
+ with_faulthandler,
+ start_faulthandler_periodic_traceback,
+)
from pyspark.worker_util import (
check_python_version,
read_command,
@@ -45,6 +49,7 @@
)
+@with_faulthandler
def main(infile: IO, outfile: IO) -> None:
"""
Main method for creating a Python data source instance.
@@ -62,18 +67,10 @@ def main(infile: IO, outfile: IO) -> None:
This process then creates a `DataSource` instance using the above information and
sends the pickled instance as well as the schema back to the JVM.
"""
- faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
- tracebackDumpIntervalSeconds = os.environ.get("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", None)
try:
- if faulthandler_log_path:
- faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
- faulthandler_log_file = open(faulthandler_log_path, "w")
- faulthandler.enable(file=faulthandler_log_file)
-
check_python_version(infile)
- if tracebackDumpIntervalSeconds is not None and int(tracebackDumpIntervalSeconds) > 0:
- faulthandler.dump_traceback_later(int(tracebackDumpIntervalSeconds), repeat=True)
+ start_faulthandler_periodic_traceback()
memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
setup_memory_limits(memory_limit_mb)
@@ -172,11 +169,6 @@ def main(infile: IO, outfile: IO) -> None:
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
- finally:
- if faulthandler_log_path:
- faulthandler.disable()
- faulthandler_log_file.close()
- os.remove(faulthandler_log_path)
send_accumulator_updates(outfile)
@@ -188,9 +180,6 @@ def main(infile: IO, outfile: IO) -> None:
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
sys.exit(-1)
- # Force to cancel dump_traceback_later
- faulthandler.cancel_dump_traceback_later()
-
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
diff --git a/python/pyspark/sql/worker/data_source_pushdown_filters.py b/python/pyspark/sql/worker/data_source_pushdown_filters.py
index 8601521bcfb13..7d255e1dbf778 100644
--- a/python/pyspark/sql/worker/data_source_pushdown_filters.py
+++ b/python/pyspark/sql/worker/data_source_pushdown_filters.py
@@ -16,7 +16,6 @@
#
import base64
-import faulthandler
import json
import os
import sys
@@ -49,7 +48,12 @@
)
from pyspark.sql.types import StructType, VariantVal, _parse_datatype_json_string
from pyspark.sql.worker.plan_data_source_read import write_read_func_and_partitions
-from pyspark.util import handle_worker_exception, local_connect_and_auth
+from pyspark.util import (
+ handle_worker_exception,
+ local_connect_and_auth,
+ with_faulthandler,
+ start_faulthandler_periodic_traceback,
+)
from pyspark.worker_util import (
check_python_version,
pickleSer,
@@ -119,6 +123,7 @@ def deserializeFilter(jsonDict: dict) -> Filter:
return filter
+@with_faulthandler
def main(infile: IO, outfile: IO) -> None:
"""
Main method for planning a data source read with filter pushdown.
@@ -140,18 +145,10 @@ def main(infile: IO, outfile: IO) -> None:
on the reader and determines which filters are supported. The indices of the supported
filters are sent back to the JVM, along with the list of partitions and the read function.
"""
- faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
- tracebackDumpIntervalSeconds = os.environ.get("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", None)
try:
- if faulthandler_log_path:
- faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
- faulthandler_log_file = open(faulthandler_log_path, "w")
- faulthandler.enable(file=faulthandler_log_file)
-
check_python_version(infile)
- if tracebackDumpIntervalSeconds is not None and int(tracebackDumpIntervalSeconds) > 0:
- faulthandler.dump_traceback_later(int(tracebackDumpIntervalSeconds), repeat=True)
+ start_faulthandler_periodic_traceback()
memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
setup_memory_limits(memory_limit_mb)
@@ -258,11 +255,6 @@ def main(infile: IO, outfile: IO) -> None:
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
- finally:
- if faulthandler_log_path:
- faulthandler.disable()
- faulthandler_log_file.close()
- os.remove(faulthandler_log_path)
send_accumulator_updates(outfile)
@@ -274,9 +266,6 @@ def main(infile: IO, outfile: IO) -> None:
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
sys.exit(-1)
- # Force to cancel dump_traceback_later
- faulthandler.cancel_dump_traceback_later()
-
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
diff --git a/python/pyspark/sql/worker/lookup_data_sources.py b/python/pyspark/sql/worker/lookup_data_sources.py
index eeb84263d4452..b23903cac8cb8 100644
--- a/python/pyspark/sql/worker/lookup_data_sources.py
+++ b/python/pyspark/sql/worker/lookup_data_sources.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import faulthandler
from importlib import import_module
from pkgutil import iter_modules
import os
@@ -29,7 +28,12 @@
SpecialLengths,
)
from pyspark.sql.datasource import DataSource
-from pyspark.util import handle_worker_exception, local_connect_and_auth
+from pyspark.util import (
+ handle_worker_exception,
+ local_connect_and_auth,
+ with_faulthandler,
+ start_faulthandler_periodic_traceback,
+)
from pyspark.worker_util import (
check_python_version,
pickleSer,
@@ -40,6 +44,7 @@
)
+@with_faulthandler
def main(infile: IO, outfile: IO) -> None:
"""
Main method for looking up the available Python Data Sources in Python path.
@@ -51,18 +56,10 @@ def main(infile: IO, outfile: IO) -> None:
This is responsible for searching the available Python Data Sources so they can be
statically registered automatically.
"""
- faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
- tracebackDumpIntervalSeconds = os.environ.get("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", None)
try:
- if faulthandler_log_path:
- faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
- faulthandler_log_file = open(faulthandler_log_path, "w")
- faulthandler.enable(file=faulthandler_log_file)
-
check_python_version(infile)
- if tracebackDumpIntervalSeconds is not None and int(tracebackDumpIntervalSeconds) > 0:
- faulthandler.dump_traceback_later(int(tracebackDumpIntervalSeconds), repeat=True)
+ start_faulthandler_periodic_traceback()
memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
setup_memory_limits(memory_limit_mb)
@@ -89,11 +86,6 @@ def main(infile: IO, outfile: IO) -> None:
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
- finally:
- if faulthandler_log_path:
- faulthandler.disable()
- faulthandler_log_file.close()
- os.remove(faulthandler_log_path)
send_accumulator_updates(outfile)
@@ -105,9 +97,6 @@ def main(infile: IO, outfile: IO) -> None:
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
sys.exit(-1)
- # Force to cancel dump_traceback_later
- faulthandler.cancel_dump_traceback_later()
-
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py b/python/pyspark/sql/worker/plan_data_source_read.py
index db79e58f2ec4f..51036f17586f2 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -15,7 +15,6 @@
# limitations under the License.
#
-import faulthandler
import os
import sys
import functools
@@ -47,7 +46,12 @@
BinaryType,
StructType,
)
-from pyspark.util import handle_worker_exception, local_connect_and_auth
+from pyspark.util import (
+ handle_worker_exception,
+ local_connect_and_auth,
+ with_faulthandler,
+ start_faulthandler_periodic_traceback,
+)
from pyspark.worker_util import (
check_python_version,
read_command,
@@ -267,6 +271,7 @@ def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.Rec
write_int(0, outfile)
+@with_faulthandler
def main(infile: IO, outfile: IO) -> None:
"""
Main method for planning a data source read.
@@ -287,18 +292,10 @@ def main(infile: IO, outfile: IO) -> None:
The partition values and the Arrow Batch are then serialized and sent back to the JVM
via the socket.
"""
- faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
- tracebackDumpIntervalSeconds = os.environ.get("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", None)
try:
- if faulthandler_log_path:
- faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
- faulthandler_log_file = open(faulthandler_log_path, "w")
- faulthandler.enable(file=faulthandler_log_file)
-
check_python_version(infile)
- if tracebackDumpIntervalSeconds is not None and int(tracebackDumpIntervalSeconds) > 0:
- faulthandler.dump_traceback_later(int(tracebackDumpIntervalSeconds), repeat=True)
+ start_faulthandler_periodic_traceback()
memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
setup_memory_limits(memory_limit_mb)
@@ -402,11 +399,6 @@ def main(infile: IO, outfile: IO) -> None:
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
- finally:
- if faulthandler_log_path:
- faulthandler.disable()
- faulthandler_log_file.close()
- os.remove(faulthandler_log_path)
send_accumulator_updates(outfile)
@@ -418,9 +410,6 @@ def main(infile: IO, outfile: IO) -> None:
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
sys.exit(-1)
- # Force to cancel dump_traceback_later
- faulthandler.cancel_dump_traceback_later()
-
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
diff --git a/python/pyspark/sql/worker/python_streaming_sink_runner.py b/python/pyspark/sql/worker/python_streaming_sink_runner.py
index ed6907ce5b63d..5ca3307fca33c 100644
--- a/python/pyspark/sql/worker/python_streaming_sink_runner.py
+++ b/python/pyspark/sql/worker/python_streaming_sink_runner.py
@@ -15,7 +15,6 @@
# limitations under the License.
#
-import faulthandler
import os
import sys
from typing import IO
@@ -35,7 +34,12 @@
_parse_datatype_json_string,
StructType,
)
-from pyspark.util import handle_worker_exception, local_connect_and_auth
+from pyspark.util import (
+ handle_worker_exception,
+ local_connect_and_auth,
+ with_faulthandler,
+ start_faulthandler_periodic_traceback,
+)
from pyspark.worker_util import (
check_python_version,
read_command,
@@ -48,6 +52,7 @@
)
+@with_faulthandler
def main(infile: IO, outfile: IO) -> None:
"""
Main method for committing or aborting a data source streaming write operation.
@@ -57,18 +62,10 @@ def main(infile: IO, outfile: IO) -> None:
responsible for invoking either the `commit` or the `abort` method on a data source
writer instance, given a list of commit messages.
"""
- faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
- tracebackDumpIntervalSeconds = os.environ.get("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", None)
try:
- if faulthandler_log_path:
- faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
- faulthandler_log_file = open(faulthandler_log_path, "w")
- faulthandler.enable(file=faulthandler_log_file)
-
check_python_version(infile)
- if tracebackDumpIntervalSeconds is not None and int(tracebackDumpIntervalSeconds) > 0:
- faulthandler.dump_traceback_later(int(tracebackDumpIntervalSeconds), repeat=True)
+ start_faulthandler_periodic_traceback()
setup_spark_files(infile)
setup_broadcasts(infile)
@@ -138,11 +135,6 @@ def main(infile: IO, outfile: IO) -> None:
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
- finally:
- if faulthandler_log_path:
- faulthandler.disable()
- faulthandler_log_file.close()
- os.remove(faulthandler_log_path)
send_accumulator_updates(outfile)
@@ -154,9 +146,6 @@ def main(infile: IO, outfile: IO) -> None:
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
sys.exit(-1)
- # Force to cancel dump_traceback_later
- faulthandler.cancel_dump_traceback_later()
-
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
diff --git a/python/pyspark/sql/worker/write_into_data_source.py b/python/pyspark/sql/worker/write_into_data_source.py
index 917d0ca8e0079..b8a54f8397dc8 100644
--- a/python/pyspark/sql/worker/write_into_data_source.py
+++ b/python/pyspark/sql/worker/write_into_data_source.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import faulthandler
import inspect
import os
import sys
@@ -46,7 +45,12 @@
BinaryType,
_create_row,
)
-from pyspark.util import handle_worker_exception, local_connect_and_auth
+from pyspark.util import (
+ handle_worker_exception,
+ local_connect_and_auth,
+ with_faulthandler,
+ start_faulthandler_periodic_traceback,
+)
from pyspark.worker_util import (
check_python_version,
read_command,
@@ -59,6 +63,7 @@
)
+@with_faulthandler
def main(infile: IO, outfile: IO) -> None:
"""
Main method for saving into a Python data source.
@@ -78,18 +83,10 @@ def main(infile: IO, outfile: IO) -> None:
instance and send a function using the writer instance that can be used
in mapInPandas/mapInArrow back to the JVM.
"""
- faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
- tracebackDumpIntervalSeconds = os.environ.get("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", None)
try:
- if faulthandler_log_path:
- faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
- faulthandler_log_file = open(faulthandler_log_path, "w")
- faulthandler.enable(file=faulthandler_log_file)
-
check_python_version(infile)
- if tracebackDumpIntervalSeconds is not None and int(tracebackDumpIntervalSeconds) > 0:
- faulthandler.dump_traceback_later(int(tracebackDumpIntervalSeconds), repeat=True)
+ start_faulthandler_periodic_traceback()
memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1"))
setup_memory_limits(memory_limit_mb)
@@ -264,11 +261,6 @@ def batch_to_rows() -> Iterator[Row]:
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
- finally:
- if faulthandler_log_path:
- faulthandler.disable()
- faulthandler_log_file.close()
- os.remove(faulthandler_log_path)
send_accumulator_updates(outfile)
@@ -280,9 +272,6 @@ def batch_to_rows() -> Iterator[Row]:
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
sys.exit(-1)
- # Force to cancel dump_traceback_later
- faulthandler.cancel_dump_traceback_later()
-
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py
index bfcb886e1c912..ee86f5c039744 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -16,17 +16,15 @@
#
import shutil
import tempfile
-import typing
import os
import functools
import unittest
import uuid
import contextlib
+from typing import Callable, Optional
from pyspark import Row, SparkConf
-from pyspark.util import is_remote_only
-from pyspark.testing.utils import PySparkErrorTestUtils
-from pyspark import Row, SparkConf
+from pyspark.loose_version import LooseVersion
from pyspark.util import is_remote_only
from pyspark.testing.utils import (
have_pandas,
@@ -52,6 +50,7 @@
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.plan import Read, Range, SQL, LogicalPlan
from pyspark.sql.connect.session import SparkSession
+ import pyspark.sql.connect.proto as pb2
class MockRemoteSession:
@@ -113,6 +112,16 @@ def _session_range(
def _session_sql(cls, query):
return cls._df_mock(SQL(query))
+ @classmethod
+ def _set_relation_in_plan(self, plan: pb2.Plan, relation: pb2.Relation) -> None:
+ # Skip plan compression in plan-only tests.
+ plan.root.CopyFrom(relation)
+
+ @classmethod
+ def _set_command_in_plan(self, plan: pb2.Plan, command: pb2.Command) -> None:
+ # Skip plan compression in plan-only tests.
+ plan.command.CopyFrom(command)
+
if have_pandas:
@classmethod
@@ -122,13 +131,14 @@ def _with_plan(cls, plan):
@classmethod
def setUpClass(cls):
cls.connect = MockRemoteSession()
- cls.session = SparkSession.builder.remote().getOrCreate()
cls.tbl_name = "test_connect_plan_only_table_1"
cls.connect.set_hook("readTable", cls._read_table)
cls.connect.set_hook("range", cls._session_range)
cls.connect.set_hook("sql", cls._session_sql)
cls.connect.set_hook("with_plan", cls._with_plan)
+ cls.connect.set_hook("_set_relation_in_plan", cls._set_relation_in_plan)
+ cls.connect.set_hook("_set_command_in_plan", cls._set_command_in_plan)
@classmethod
def tearDownClass(cls):
@@ -295,3 +305,28 @@ def _both_conf():
yield
return _both_conf()
+
+
+def skip_if_server_version_is(
+ cond: Callable[[LooseVersion], bool], reason: Optional[str] = None
+) -> Callable:
+ def decorator(f: Callable) -> Callable:
+ @functools.wraps(f)
+ def wrapper(self, *args, **kwargs):
+ version = self.spark.version
+ if cond(LooseVersion(version)):
+ raise unittest.SkipTest(
+ f"Skipping test {f.__name__} because server version is {version}"
+ + (f" ({reason})" if reason else "")
+ )
+ return f(self, *args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+def skip_if_server_version_is_greater_than_or_equal_to(
+ version: str, reason: Optional[str] = None
+) -> Callable:
+ return skip_if_server_version_is(lambda v: v >= LooseVersion(version), reason)
diff --git a/python/pyspark/tests/test_install_spark.py b/python/pyspark/tests/test_install_spark.py
index effbab6a90102..769876b46138b 100644
--- a/python/pyspark/tests/test_install_spark.py
+++ b/python/pyspark/tests/test_install_spark.py
@@ -32,7 +32,7 @@ class SparkInstallationTestCase(unittest.TestCase):
def test_install_spark(self):
# Test only one case. Testing this is expensive because it needs to download
# the Spark distribution, ensure it is available at https://dlcdn.apache.org/spark/
- spark_version, hadoop_version, hive_version = checked_versions("3.5.6", "3", "2.3")
+ spark_version, hadoop_version, hive_version = checked_versions("3.5.7", "3", "2.3")
with tempfile.TemporaryDirectory(prefix="test_install_spark") as tmp_dir:
install_spark(
diff --git a/python/pyspark/tests/test_memory_profiler.py b/python/pyspark/tests/test_memory_profiler.py
index df9d63c5260f9..1909358aa2bc2 100644
--- a/python/pyspark/tests/test_memory_profiler.py
+++ b/python/pyspark/tests/test_memory_profiler.py
@@ -341,12 +341,13 @@ def add2(x):
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
- def test_memory_profiler_pandas_udf_iterator_not_supported(self):
+ def test_memory_profiler_pandas_udf_iterator(self):
import pandas as pd
@pandas_udf("long")
- def add1(x):
- return x + 1
+ def add1(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
+ for s in iter:
+ yield s + 1
@pandas_udf("long")
def add2(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
@@ -359,7 +360,7 @@ def add2(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
)
df.collect()
- self.assertEqual(1, len(self.profile_results), str(self.profile_results.keys()))
+ self.assertEqual(3, len(self.profile_results), str(self.profile_results.keys()))
for id in self.profile_results:
self.assert_udf_memory_profile_present(udf_id=id)
@@ -368,7 +369,7 @@ def add2(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
- def test_memory_profiler_map_in_pandas_not_supported(self):
+ def test_memory_profiler_map_in_pandas(self):
df = self.spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
def filter_func(iterator):
@@ -378,7 +379,10 @@ def filter_func(iterator):
with self.sql_conf({"spark.sql.pyspark.udf.profiler": "memory"}):
df.mapInPandas(filter_func, df.schema).show()
- self.assertEqual(0, len(self.profile_results), str(self.profile_results.keys()))
+ self.assertEqual(1, len(self.profile_results), str(self.profile_results.keys()))
+
+ for id in self.profile_results:
+ self.assert_udf_memory_profile_present(udf_id=id)
@unittest.skipIf(
not have_pandas or not have_pyarrow,
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index f633ed699ee2d..9935206df1771 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -18,6 +18,7 @@
import copy
import functools
+import faulthandler
import itertools
import os
import platform
@@ -917,8 +918,73 @@ def default_api_mode() -> str:
return "classic"
+class _FaulthandlerHelper:
+ def __init__(self) -> None:
+ self._log_path: Optional[str] = None
+ self._log_file: Optional[TextIO] = None
+ self._periodic_traceback = False
+
+ def start(self) -> None:
+ if self._log_path:
+ raise Exception("Fault handler is already registered. No second registration allowed")
+ self._log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
+ if self._log_path:
+ self._log_path = os.path.join(self._log_path, str(os.getpid()))
+ self._log_file = open(self._log_path, "w")
+
+ faulthandler.enable(file=self._log_file)
+
+ def stop(self) -> None:
+ if self._log_path:
+ faulthandler.disable()
+ if self._log_file:
+ self._log_file.close()
+ self._log_file = None
+ try:
+ os.remove(self._log_path)
+ finally:
+ self._log_path = None
+ if self._periodic_traceback:
+ faulthandler.cancel_dump_traceback_later()
+ self._periodic_traceback = False
+
+ def start_periodic_traceback(self) -> None:
+ # If the registration is already done - do nothing
+ if self._periodic_traceback:
+ return
+
+ traceback_dump_interval_seconds = os.environ.get(
+ "PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", None
+ )
+ if traceback_dump_interval_seconds is not None and int(traceback_dump_interval_seconds) > 0:
+ self._periodic_traceback = True
+ faulthandler.dump_traceback_later(int(traceback_dump_interval_seconds), repeat=True)
+
+ def with_faulthandler(self, func: Callable) -> Callable:
+ """
+ Registers fault handler for the duration of function execution.
+ After function execution is over the faulthandler registration is cleaned as well,
+ including any files created for the integration.
+ """
+
+ @functools.wraps(func)
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
+ try:
+ self.start()
+ return func(*args, **kwargs)
+ finally:
+ self.stop()
+
+ return wrapper
+
+
+_faulthandler_helper = _FaulthandlerHelper()
+with_faulthandler = _faulthandler_helper.with_faulthandler
+start_faulthandler_periodic_traceback = _faulthandler_helper.start_periodic_traceback
+
+
if __name__ == "__main__":
- if "pypy" not in platform.python_implementation().lower() and sys.version_info[:2] >= (3, 9):
+ if "pypy" not in platform.python_implementation().lower():
import doctest
import pyspark.util
from pyspark.core.context import SparkContext
diff --git a/python/pyspark/version.py b/python/pyspark/version.py
index 374bc8bbd8a47..8ee776a66e67d 100644
--- a/python/pyspark/version.py
+++ b/python/pyspark/version.py
@@ -16,4 +16,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version__: str = "4.1.0.dev0"
+__version__: str = "4.1.2.dev0"
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 09c6a40a33db9..4bae9f6dc48f5 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -27,7 +27,6 @@
import itertools
import json
from typing import Any, Callable, Iterable, Iterator, Optional, Tuple
-import faulthandler
from pyspark.accumulators import (
SpecialAccumulatorIds,
@@ -84,7 +83,12 @@
_create_row,
_parse_datatype_json_string,
)
-from pyspark.util import fail_on_stopiteration, handle_worker_exception
+from pyspark.util import (
+ fail_on_stopiteration,
+ handle_worker_exception,
+ with_faulthandler,
+ start_faulthandler_periodic_traceback,
+)
from pyspark import shuffle
from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError
from pyspark.worker_util import (
@@ -1158,17 +1162,19 @@ def func(*args):
return f, args_offsets
-def _supports_profiler(eval_type: int) -> bool:
- return eval_type not in (
+def _is_iter_based(eval_type: int) -> bool:
+ return eval_type in (
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
+ PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF,
+ PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
)
-def wrap_perf_profiler(f, result_id):
+def wrap_perf_profiler(f, eval_type, result_id):
import cProfile
import pstats
@@ -1178,38 +1184,89 @@ def wrap_perf_profiler(f, result_id):
SpecialAccumulatorIds.SQL_UDF_PROFIER, None, ProfileResultsParam
)
- def profiling_func(*args, **kwargs):
- with cProfile.Profile() as pr:
- ret = f(*args, **kwargs)
- st = pstats.Stats(pr)
- st.stream = None # make it picklable
- st.strip_dirs()
+ if _is_iter_based(eval_type):
+
+ def profiling_func(*args, **kwargs):
+ iterator = iter(f(*args, **kwargs))
+ pr = cProfile.Profile()
+ while True:
+ try:
+ with pr:
+ item = next(iterator)
+ yield item
+ except StopIteration:
+ break
+
+ st = pstats.Stats(pr)
+ st.stream = None # make it picklable
+ st.strip_dirs()
+
+ accumulator.add({result_id: (st, None)})
- accumulator.add({result_id: (st, None)})
+ else:
+
+ def profiling_func(*args, **kwargs):
+ with cProfile.Profile() as pr:
+ ret = f(*args, **kwargs)
+ st = pstats.Stats(pr)
+ st.stream = None # make it picklable
+ st.strip_dirs()
- return ret
+ accumulator.add({result_id: (st, None)})
+
+ return ret
return profiling_func
-def wrap_memory_profiler(f, result_id):
+def wrap_memory_profiler(f, eval_type, result_id):
from pyspark.sql.profiler import ProfileResultsParam
from pyspark.profiler import UDFLineProfilerV2
+ if not has_memory_profiler:
+ return f
+
accumulator = _deserialize_accumulator(
SpecialAccumulatorIds.SQL_UDF_PROFIER, None, ProfileResultsParam
)
- def profiling_func(*args, **kwargs):
- profiler = UDFLineProfilerV2()
+ if _is_iter_based(eval_type):
- wrapped = profiler(f)
- ret = wrapped(*args, **kwargs)
- codemap_dict = {
- filename: list(line_iterator) for filename, line_iterator in profiler.code_map.items()
- }
- accumulator.add({result_id: (None, codemap_dict)})
- return ret
+ def profiling_func(*args, **kwargs):
+ profiler = UDFLineProfilerV2()
+ profiler.add_function(f)
+
+ iterator = iter(f(*args, **kwargs))
+
+ while True:
+ try:
+ with profiler:
+ item = next(iterator)
+ yield item
+ except StopIteration:
+ break
+
+ codemap_dict = {
+ filename: list(line_iterator)
+ for filename, line_iterator in profiler.code_map.items()
+ }
+ accumulator.add({result_id: (None, codemap_dict)})
+
+ else:
+
+ def profiling_func(*args, **kwargs):
+ profiler = UDFLineProfilerV2()
+ profiler.add_function(f)
+
+ with profiler:
+ ret = f(*args, **kwargs)
+
+ codemap_dict = {
+ filename: list(line_iterator)
+ for filename, line_iterator in profiler.code_map.items()
+ }
+ accumulator.add({result_id: (None, codemap_dict)})
+ return ret
return profiling_func
@@ -1254,17 +1311,12 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil
if profiler == "perf":
result_id = read_long(infile)
- if _supports_profiler(eval_type):
- profiling_func = wrap_perf_profiler(chained_func, result_id)
- else:
- profiling_func = chained_func
+ profiling_func = wrap_perf_profiler(chained_func, eval_type, result_id)
elif profiler == "memory":
result_id = read_long(infile)
- if _supports_profiler(eval_type) and has_memory_profiler:
- profiling_func = wrap_memory_profiler(chained_func, result_id)
- else:
- profiling_func = chained_func
+
+ profiling_func = wrap_memory_profiler(chained_func, eval_type, result_id)
else:
profiling_func = chained_func
@@ -3242,23 +3294,14 @@ def func(_, it):
return func, None, ser, ser
+@with_faulthandler
def main(infile, outfile):
- faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
- tracebackDumpIntervalSeconds = os.environ.get("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", None)
try:
- if faulthandler_log_path:
- faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid()))
- faulthandler_log_file = open(faulthandler_log_path, "w")
- faulthandler.enable(file=faulthandler_log_file)
-
boot_time = time.time()
split_index = read_int(infile)
if split_index == -1: # for unit tests
sys.exit(-1)
-
- if tracebackDumpIntervalSeconds is not None and int(tracebackDumpIntervalSeconds) > 0:
- faulthandler.dump_traceback_later(int(tracebackDumpIntervalSeconds), repeat=True)
-
+ start_faulthandler_periodic_traceback()
check_python_version(infile)
# read inputs only for a barrier task
@@ -3349,11 +3392,6 @@ def process():
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
- finally:
- if faulthandler_log_path:
- faulthandler.disable()
- faulthandler_log_file.close()
- os.remove(faulthandler_log_path)
finish_time = time.time()
report_times(outfile, boot_time, init_time, finish_time)
write_long(shuffle.MemoryBytesSpilled, outfile)
@@ -3371,9 +3409,6 @@ def process():
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
sys.exit(-1)
- # Force to cancel dump_traceback_later
- faulthandler.cancel_dump_traceback_later()
-
if __name__ == "__main__":
# Read information about how to connect back to the JVM from the environment.
diff --git a/python/run-with-viztracer b/python/run-with-viztracer
new file mode 100755
index 0000000000000..0f7a4d8494f0e
--- /dev/null
+++ b/python/run-with-viztracer
@@ -0,0 +1,69 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+set -o pipefail
+set -e
+
+FWDIR="$(cd "`dirname $0`"; pwd)"
+
+# Function to display help message
+function usage {
+ cat </dev/null; then
+ echo "Error: viztracer is not installed. Please install it using 'pip install viztracer'."
+ exit 1
+fi
+
+export PYTHONPATH="$FWDIR/conf_viztracer:$PYTHONPATH"
+export SPARK_CONF_DIR="$FWDIR/conf_viztracer"
+export SPARK_VIZTRACER_OUTPUT_DIR="${SPARK_VIZTRACER_OUTPUT_DIR:-"$(pwd)"}"
+
+exec "$@"
diff --git a/repl/pom.xml b/repl/pom.xml
index 7b515eacc5401..d83963811474c 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../pom.xml
diff --git a/repl/src/test/scala/org/apache/spark/repl/SparkShellSuite.scala b/repl/src/test/scala/org/apache/spark/repl/SparkShellSuite.scala
index 067f08cb67528..1aad74c8aefc1 100644
--- a/repl/src/test/scala/org/apache/spark/repl/SparkShellSuite.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/SparkShellSuite.scala
@@ -95,6 +95,25 @@ class SparkShellSuite extends SparkFunSuite {
}
}
+ def handleException(cause: Throwable): Unit = lock.synchronized {
+ val message =
+ s"""
+ |=======================
+ |SparkShellSuite failure output
+ |=======================
+ |Spark Shell command line: ${command.mkString(" ")}
+ |Exception: $cause
+ |Failed to capture next expected output "${expectedAnswers(next)}" within $timeout.
+ |
+ |${buffer.mkString("\n")}
+ |===========================
+ |End SparkShellSuite failure output
+ |===========================
+ """.stripMargin
+ logError(message, cause)
+ fail(message, cause)
+ }
+
val process = new ProcessBuilder(command: _*).start()
val stdinWriter = new OutputStreamWriter(process.getOutputStream, StandardCharsets.UTF_8)
@@ -119,23 +138,8 @@ class SparkShellSuite extends SparkFunSuite {
}
ThreadUtils.awaitResult(foundAllExpectedAnswers.future, timeoutForQuery)
log.info("Found all expected output.")
- } catch { case cause: Throwable =>
- val message =
- s"""
- |=======================
- |SparkShellSuite failure output
- |=======================
- |Spark Shell command line: ${command.mkString(" ")}
- |Exception: $cause
- |Failed to capture next expected output "${expectedAnswers(next)}" within $timeout.
- |
- |${buffer.mkString("\n")}
- |===========================
- |End SparkShellSuite failure output
- |===========================
- """.stripMargin
- logError(message, cause)
- fail(message, cause)
+ } catch {
+ case cause: Throwable => handleException(cause)
} finally {
if (!process.waitFor(1, MINUTES)) {
try {
diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml
index 972b618a5caa2..9f8980b174e48 100644
--- a/resource-managers/kubernetes/core/pom.xml
+++ b/resource-managers/kubernetes/core/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../../pom.xml
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
index f4d708f30b43a..fafff5046b9dc 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala
@@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit
import org.apache.spark.deploy.k8s.Constants._
import org.apache.spark.internal.Logging
-import org.apache.spark.internal.config.{ConfigBuilder, PYSPARK_DRIVER_PYTHON, PYSPARK_PYTHON}
+import org.apache.spark.internal.config.ConfigBuilder
private[spark] object Config extends Logging {
@@ -587,22 +587,6 @@ private[spark] object Config extends Logging {
"Ensure that memory overhead is non-negative")
.createWithDefault(0.1)
- val PYSPARK_MAJOR_PYTHON_VERSION =
- ConfigBuilder("spark.kubernetes.pyspark.pythonVersion")
- .doc(
- s"(Deprecated since Spark 3.1, please set '${PYSPARK_PYTHON.key}' and " +
- s"'${PYSPARK_DRIVER_PYTHON.key}' configurations or $ENV_PYSPARK_PYTHON and " +
- s"$ENV_PYSPARK_DRIVER_PYTHON environment variables instead.)")
- .version("2.4.0")
- .stringConf
- .checkValue("3" == _,
- "Python 2 was dropped from Spark 3.1, and only 3 is allowed in " +
- "this configuration. Note that this configuration was deprecated in Spark 3.1. " +
- s"Please set '${PYSPARK_PYTHON.key}' and '${PYSPARK_DRIVER_PYTHON.key}' " +
- s"configurations or $ENV_PYSPARK_PYTHON and $ENV_PYSPARK_DRIVER_PYTHON environment " +
- "variables instead.")
- .createOptional
-
val KUBERNETES_KERBEROS_KRB5_FILE =
ConfigBuilder("spark.kubernetes.kerberos.krb5.path")
.doc("Specify the local location of the krb5.conf file to be mounted on the driver " +
@@ -716,6 +700,15 @@ private[spark] object Config extends Logging {
.booleanConf
.createWithDefault(true)
+ val KUBERNETES_DELETED_EXECUTORS_CACHE_TIMEOUT =
+ ConfigBuilder("spark.kubernetes.executor.deletedExecutorsCacheTimeout")
+ .internal()
+ .doc("Time-to-live (TTL) value for the cache for deleted executors")
+ .version("4.1.0")
+ .timeConf(TimeUnit.SECONDS)
+ .checkValue(_ >= 0, "deletedExecutorsCacheTimeout must be non-negative")
+ .createWithDefault(180)
+
val KUBERNETES_EXECUTOR_TERMINATION_GRACE_PERIOD_SECONDS =
ConfigBuilder("spark.kubernetes.executor.terminationGracePeriodSeconds")
.doc("Time to wait for graceful termination of executor pods.")
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesDiagnosticsSetter.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesDiagnosticsSetter.scala
index 9a1e79594e483..24f42919ce05a 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesDiagnosticsSetter.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesDiagnosticsSetter.scala
@@ -20,7 +20,7 @@ import io.fabric8.kubernetes.api.model.{Pod, PodBuilder}
import io.fabric8.kubernetes.client.KubernetesClient
import org.apache.hadoop.util.StringUtils
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkMasterRegex}
import org.apache.spark.deploy.SparkDiagnosticsSetter
import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.deploy.k8s.Constants.EXIT_EXCEPTION_ANNOTATION
@@ -75,6 +75,6 @@ private[spark] class SparkKubernetesDiagnosticsSetter(clientProvider: Kubernetes
}
override def supports(clusterManagerUrl: String): Boolean = {
- clusterManagerUrl.startsWith("k8s://")
+ SparkMasterRegex.isK8s(clusterManagerUrl)
}
}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala
index 13d1f1bc98a0e..0cfa842ef3963 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala
@@ -232,8 +232,8 @@ private[spark] class BasicExecutorFeatureStep(
executorLimitCores.map { limitCores =>
val executorCpuLimitQuantity = new Quantity(limitCores)
if (executorCpuLimitQuantity.compareTo(executorCpuQuantity) < 0) {
- throw new SparkException(s"The executor cpu request ($executorCpuQuantity) should be " +
- s"less than or equal to cpu limit ($executorCpuLimitQuantity)")
+ throw new IllegalArgumentException(s"The executor cpu request ($executorCpuQuantity) " +
+ s"should be less than or equal to cpu limit ($executorCpuLimitQuantity)")
}
new ContainerBuilder(executorContainerWithConfVolume)
.editResources()
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStep.scala
index f15f5bc566b4b..0574fa4868f30 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStep.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverCommandFeatureStep.scala
@@ -25,7 +25,6 @@ import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.deploy.k8s.Constants._
import org.apache.spark.deploy.k8s.submit._
import org.apache.spark.internal.Logging
-import org.apache.spark.internal.LogKeys.{CONFIG, CONFIG2, CONFIG3, CONFIG4, CONFIG5}
import org.apache.spark.internal.config.{PYSPARK_DRIVER_PYTHON, PYSPARK_PYTHON}
import org.apache.spark.launcher.SparkLauncher
@@ -77,15 +76,6 @@ private[spark] class DriverCommandFeatureStep(conf: KubernetesDriverConf)
private[spark] def environmentVariables: Map[String, String] = sys.env
private def configureForPython(pod: SparkPod, res: String): SparkPod = {
- if (conf.get(PYSPARK_MAJOR_PYTHON_VERSION).isDefined) {
- logWarning(
- log"${MDC(CONFIG, PYSPARK_MAJOR_PYTHON_VERSION.key)} was deprecated in Spark 3.1. " +
- log"Please set '${MDC(CONFIG2, PYSPARK_PYTHON.key)}' and " +
- log"'${MDC(CONFIG3, PYSPARK_DRIVER_PYTHON.key)}' " +
- log"configurations or ${MDC(CONFIG4, ENV_PYSPARK_PYTHON)} and " +
- log"${MDC(CONFIG5, ENV_PYSPARK_DRIVER_PYTHON)} environment variables instead.")
- }
-
val pythonEnvs = {
KubernetesUtils.buildEnvVars(
Seq(
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/K8sSubmitOps.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/K8sSubmitOps.scala
index 17704b908558e..bd8e0f97132dd 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/K8sSubmitOps.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/K8sSubmitOps.scala
@@ -23,7 +23,7 @@ import io.fabric8.kubernetes.api.model.Pod
import io.fabric8.kubernetes.client.KubernetesClient
import io.fabric8.kubernetes.client.dsl.PodResource
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkMasterRegex}
import org.apache.spark.deploy.SparkSubmitOperation
import org.apache.spark.deploy.k8s.{KubernetesUtils, SparkKubernetesClientFactory}
import org.apache.spark.deploy.k8s.Config.{KUBERNETES_AUTH_SUBMISSION_CONF_PREFIX, KUBERNETES_SUBMIT_GRACE_PERIOD}
@@ -165,7 +165,7 @@ private[spark] class K8SSparkSubmitOperation extends SparkSubmitOperation
}
override def supports(master: String): Boolean = {
- master.startsWith("k8s://")
+ SparkMasterRegex.isK8s(master)
}
}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala
index 35386aff4a80c..3a508add6ccf0 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManager.scala
@@ -53,7 +53,7 @@ private[spark] class ExecutorPodsLifecycleManager(
// bounds.
private lazy val removedExecutorsCache =
CacheBuilder.newBuilder()
- .expireAfterWrite(3, TimeUnit.MINUTES)
+ .expireAfterWrite(conf.get(KUBERNETES_DELETED_EXECUTORS_CACHE_TIMEOUT), TimeUnit.SECONDS)
.build[java.lang.Long, java.lang.Long]()
private var lastFullSnapshotTs: Long = 0
@@ -113,25 +113,23 @@ private[spark] class ExecutorPodsLifecycleManager(
inactivatedPods -= execId
case deleted@PodDeleted(_) =>
+ execIdsRemovedInThisRound += execId
if (removeExecutorFromSpark(schedulerBackend, deleted, execId)) {
- execIdsRemovedInThisRound += execId
logDebug(s"Snapshot reported deleted executor with id $execId," +
s" pod name ${state.pod.getMetadata.getName}")
}
inactivatedPods -= execId
case failed@PodFailed(_) =>
- val deleteFromK8s = !execIdsRemovedInThisRound.contains(execId)
+ val deleteFromK8s = execIdsRemovedInThisRound.add(execId)
if (onFinalNonDeletedState(failed, execId, schedulerBackend, deleteFromK8s)) {
- execIdsRemovedInThisRound += execId
logDebug(s"Snapshot reported failed executor with id $execId," +
s" pod name ${state.pod.getMetadata.getName}")
}
case succeeded@PodSucceeded(_) =>
- val deleteFromK8s = !execIdsRemovedInThisRound.contains(execId)
+ val deleteFromK8s = execIdsRemovedInThisRound.add(execId)
if (onFinalNonDeletedState(succeeded, execId, schedulerBackend, deleteFromK8s)) {
- execIdsRemovedInThisRound += execId
if (schedulerBackend.isExecutorActive(execId.toString)) {
logInfo(log"Snapshot reported succeeded executor with id " +
log"${MDC(LogKeys.EXECUTOR_ID, execId)}, even though the application has not " +
@@ -201,8 +199,16 @@ private[spark] class ExecutorPodsLifecycleManager(
private def removeExecutorFromK8s(execId: Long, updatedPod: Pod): Unit = {
Utils.tryLogNonFatalError {
if (shouldDeleteExecutors) {
- // Get pod before deleting it, we can skip deleting if pod is already deleted so that
- // we do not send too many requests to api server.
+ if (updatedPod.getMetadata.getDeletionTimestamp != null) {
+ // Do not call the Kubernetes API if the deletion timestamp
+ // is already set on the updatedPod object.
+ // This is removing the need for un-necessary API roundtrips
+ // against the Kubernetes API.
+ return
+ }
+ // Get pod before deleting it, we can skip deleting if pod is already deleted
+ // or has already the deletion timestamp set so that we do not send
+ // too many requests to apu server.
// If deletion failed on a previous try, we can try again if resync informs us the pod
// is still around.
// Delete as best attempt - duplicate deletes will throw an exception but the end state
@@ -211,7 +217,9 @@ private[spark] class ExecutorPodsLifecycleManager(
.pods()
.inNamespace(namespace)
.withName(updatedPod.getMetadata.getName)
- if (podToDelete.get() != null) {
+
+ if (podToDelete.get() != null &&
+ podToDelete.get.getMetadata.getDeletionTimestamp == null) {
podToDelete.delete()
}
} else if (!inactivatedPods.contains(execId) && !isPodInactive(updatedPod)) {
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala
index 6953ed789f797..0d9f19ee11b71 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSource.scala
@@ -64,6 +64,7 @@ class ExecutorPodsWatchSnapshotSource(
.inNamespace(namespace)
.withLabel(SPARK_APP_ID_LABEL, applicationId)
.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)
+ .withoutLabel(SPARK_EXECUTOR_INACTIVE_LABEL, "true")
.watch(new ExecutorPodsWatcher())
}
}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala
index 49eac64745b7c..3fb1ed0c9c0fc 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala
@@ -35,7 +35,7 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils}
private[spark] class KubernetesClusterManager extends ExternalClusterManager with Logging {
import SparkMasterRegex._
- override def canCreate(masterURL: String): Boolean = masterURL.startsWith("k8s")
+ override def canCreate(masterURL: String): Boolean = SparkMasterRegex.isK8s(masterURL)
private def isLocal(conf: SparkConf): Boolean =
conf.get(KUBERNETES_DRIVER_MASTER_URL).startsWith("local")
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala
index ced1326e7938c..d264484f4d039 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala
@@ -123,7 +123,7 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
test("SPARK-52933: Verify if the executor cpu request exceeds limit") {
baseConf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "2")
baseConf.set(KUBERNETES_EXECUTOR_LIMIT_CORES, "1")
- val error = intercept[SparkException] {
+ val error = intercept[IllegalArgumentException] {
initDefaultProfile(baseConf)
val step = new BasicExecutorFeatureStep(newExecutorConf(), new SecurityManager(baseConf),
defaultProfile)
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala
index 4c7ffe692b105..cdbcae050ceb9 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsLifecycleManagerSuite.scala
@@ -20,7 +20,7 @@ import java.util.function.UnaryOperator
import scala.collection.mutable
-import io.fabric8.kubernetes.api.model.Pod
+import io.fabric8.kubernetes.api.model.{Pod, PodBuilder}
import io.fabric8.kubernetes.client.KubernetesClient
import io.fabric8.kubernetes.client.dsl.PodResource
import org.mockito.{Mock, MockitoAnnotations}
@@ -219,6 +219,47 @@ class ExecutorPodsLifecycleManagerSuite extends SparkFunSuite with BeforeAndAfte
verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason)
}
+ test("Don't delete pod from K8s if deletionTimestamp is already set.") {
+ // Create a failed pod with deletionTimestamp already in the past
+ val basePod = failedExecutorWithoutDeletion(1)
+ val failedPodWithDeletionTimestamp = new PodBuilder(basePod)
+ .editOrNewMetadata()
+ .withDeletionTimestamp("1970-01-01T00:00:00Z")
+ .endMetadata()
+ .build()
+
+ val mockPodResource = mock(classOf[PodResource])
+ namedExecutorPods.put("spark-executor-1", mockPodResource)
+ when(mockPodResource.get()).thenReturn(failedPodWithDeletionTimestamp)
+
+ snapshotsStore.updatePod(failedPodWithDeletionTimestamp)
+ snapshotsStore.notifySubscribers()
+
+ // Verify executor is removed from Spark
+ val msg = "The executor with id 1 was deleted by a user or the framework."
+ val expectedLossReason = ExecutorExited(1, exitCausedByApp = false, msg)
+ verify(schedulerBackend).doRemoveExecutor("1", expectedLossReason)
+
+ // Verify delete() is NOT called since deletionTimestamp is already set
+ verify(mockPodResource, never()).delete()
+ }
+
+ test("SPARK-54198: Delete Kubernetes executor pods only once per event processing interval") {
+ val failedPod = failedExecutorWithoutDeletion(1)
+ val mockPodResource = mock(classOf[PodResource])
+ namedExecutorPods.put("spark-executor-1", mockPodResource)
+ when(mockPodResource.get()).thenReturn(failedPod)
+ snapshotsStore.updatePod(failedPod)
+ snapshotsStore.notifySubscribers()
+ snapshotsStore.updatePod(failedPod)
+ snapshotsStore.updatePod(failedPod)
+ snapshotsStore.notifySubscribers()
+ val msg = exitReasonMessage(1, failedPod, 1)
+ val expectedLossReason = ExecutorExited(1, exitCausedByApp = true, msg)
+ verify(schedulerBackend, times(1)).doRemoveExecutor("1", expectedLossReason)
+ verify(namedExecutorPods(failedPod.getMetadata.getName), times(2)).delete()
+ }
+
private def exitReasonMessage(execId: Int, failedPod: Pod, exitCode: Int): String = {
val reason = Option(failedPod.getStatus.getReason)
val message = Option(failedPod.getStatus.getMessage)
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala
index 61080268cde60..f830abc0d1298 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsWatchSnapshotSourceSuite.scala
@@ -50,6 +50,9 @@ class ExecutorPodsWatchSnapshotSourceSuite extends SparkFunSuite with BeforeAndA
@Mock
private var executorRoleLabeledPods: LABELED_PODS = _
+ @Mock
+ private var executorRoleLabeledActivePods: LABELED_PODS = _
+
@Mock
private var watchConnection: Watch = _
@@ -66,7 +69,9 @@ class ExecutorPodsWatchSnapshotSourceSuite extends SparkFunSuite with BeforeAndA
.thenReturn(appIdLabeledPods)
when(appIdLabeledPods.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE))
.thenReturn(executorRoleLabeledPods)
- when(executorRoleLabeledPods.watch(watch.capture())).thenReturn(watchConnection)
+ when(executorRoleLabeledPods.withoutLabel(SPARK_EXECUTOR_INACTIVE_LABEL, "true"))
+ .thenReturn(executorRoleLabeledActivePods)
+ when(executorRoleLabeledActivePods.watch(watch.capture())).thenReturn(watchConnection)
}
test("Watch events should be pushed to the snapshots store as snapshot updates.") {
diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml
index d470356f39e66..f69f81259355f 100644
--- a/resource-managers/kubernetes/integration-tests/pom.xml
+++ b/resource-managers/kubernetes/integration-tests/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../../pom.xml
diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala
index d710add45eb96..22295dcaa2af7 100644
--- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala
+++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/BasicTestsSuite.scala
@@ -53,6 +53,14 @@ private[spark] trait BasicTestsSuite { k8sSuite: KubernetesSuite =>
runSparkPiAndVerifyCompletion()
}
+ test("SPARK-53944: Run SparkPi without driver service", k8sTestTag) {
+ sparkAppConf.set(
+ "spark.kubernetes.driver.pod.excludedFeatureSteps",
+ "org.apache.spark.deploy.k8s.features.DriverServiceFeatureStep")
+ sparkAppConf.set("spark.kubernetes.executor.useDriverPodIP", "true")
+ runSparkPiAndVerifyCompletion()
+ }
+
test("Run SparkPi with no resources & statefulset allocation", k8sTestTag) {
sparkAppConf.set("spark.kubernetes.allocation.pods.allocator", "statefulset")
runSparkPiAndVerifyCompletion()
diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml
index 0629c24c56dc4..f46cf6d31392a 100644
--- a/resource-managers/yarn/pom.xml
+++ b/resource-managers/yarn/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index cea4cd3ac4fc1..9bcdecbbf4fdc 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -846,4 +846,9 @@ This file is divided into 3 sections:
\bInts\.checkedCast\b
Use JavaUtils.checkedCast instead.
+
+
+ \bstartsWith\("k8s\b
+ Use SparkMasterRegex.isK8s instead.
+
diff --git a/sql/api/pom.xml b/sql/api/pom.xml
index 184d39c4b8ea1..b43610317e1c1 100644
--- a/sql/api/pom.xml
+++ b/sql/api/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index 8ccac6a39d2ce..02b7ec195a933 100644
--- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -46,6 +46,13 @@ options { tokenVocab = SqlBaseLexer; }
* When true, parameter markers are allowed everywhere a literal is supported.
*/
public boolean parameter_substitution_enabled = true;
+
+ /**
+ * When false (default), IDENTIFIER('literal') is resolved to an identifier at parse time (identifier-lite).
+ * When true, only the legacy IDENTIFIER(expression) function syntax is allowed.
+ * Controlled by spark.sql.legacy.identifierClause configuration.
+ */
+ public boolean legacy_identifier_clause_only = false;
}
compoundOrSingleStatement
@@ -92,7 +99,7 @@ sqlStateValue
;
declareConditionStatement
- : DECLARE multipartIdentifier CONDITION (FOR SQLSTATE VALUE? sqlStateValue)?
+ : DECLARE strictIdentifier CONDITION (FOR SQLSTATE VALUE? sqlStateValue)?
;
conditionValue
@@ -125,11 +132,11 @@ repeatStatement
;
leaveStatement
- : LEAVE multipartIdentifier
+ : LEAVE strictIdentifier
;
iterateStatement
- : ITERATE multipartIdentifier
+ : ITERATE strictIdentifier
;
caseStatement
@@ -144,7 +151,7 @@ loopStatement
;
forStatement
- : beginLabel? FOR (multipartIdentifier AS)? query DO compoundBody END FOR endLabel?
+ : beginLabel? FOR (strictIdentifier AS)? query DO compoundBody END FOR endLabel?
;
singleStatement
@@ -152,11 +159,11 @@ singleStatement
;
beginLabel
- : multipartIdentifier COLON
+ : strictIdentifier COLON
;
endLabel
- : multipartIdentifier
+ : strictIdentifier
;
singleExpression
@@ -225,9 +232,9 @@ statement
createTableClauses
(AS? query)? #replaceTable
| ANALYZE TABLE identifierReference partitionSpec? COMPUTE STATISTICS
- (identifier | FOR COLUMNS identifierSeq | FOR ALL COLUMNS)? #analyze
+ (simpleIdentifier | FOR COLUMNS identifierSeq | FOR ALL COLUMNS)? #analyze
| ANALYZE TABLES ((FROM | IN) identifierReference)? COMPUTE STATISTICS
- (identifier)? #analyzeTables
+ (simpleIdentifier)? #analyzeTables
| ALTER TABLE identifierReference
ADD (COLUMN | COLUMNS)
columns=qualifiedColTypeWithPositionList #addTableColumns
@@ -321,7 +328,7 @@ statement
| SHOW VIEWS ((FROM | IN) identifierReference)?
(LIKE? pattern=stringLit)? #showViews
| SHOW PARTITIONS identifierReference partitionSpec? #showPartitions
- | SHOW identifier? FUNCTIONS ((FROM | IN) ns=identifierReference)?
+ | SHOW functionScope=simpleIdentifier? FUNCTIONS ((FROM | IN) ns=identifierReference)?
(LIKE? (legacy=multipartIdentifier | pattern=stringLit))? #showFunctions
| SHOW PROCEDURES ((FROM | IN) identifierReference)? #showProcedures
| SHOW CREATE TABLE identifierReference (AS SERDE)? #showCreateTable
@@ -349,7 +356,7 @@ statement
| TRUNCATE TABLE identifierReference partitionSpec? #truncateTable
| (MSCK)? REPAIR TABLE identifierReference
(option=(ADD|DROP|SYNC) PARTITIONS)? #repairTable
- | op=(ADD | LIST) identifier .*? #manageResource
+ | op=(ADD | LIST) simpleIdentifier .*? #manageResource
| CREATE INDEX (IF errorCapturingNot EXISTS)? identifier ON TABLE?
identifierReference (USING indexType=identifier)?
LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN
@@ -653,7 +660,7 @@ createFileFormat
fileFormat
: INPUTFORMAT inFmt=stringLit OUTPUTFORMAT outFmt=stringLit #tableFileFormat
- | identifier #genericFileFormat
+ | simpleIdentifier #genericFileFormat
;
storageHandler
@@ -661,7 +668,7 @@ storageHandler
;
resource
- : identifier stringLit
+ : simpleIdentifier stringLit
;
dmlStatementNoWith
@@ -833,8 +840,8 @@ hint
;
hintStatement
- : hintName=identifier
- | hintName=identifier LEFT_PAREN parameters+=primaryExpression (COMMA parameters+=primaryExpression)* RIGHT_PAREN
+ : hintName=simpleIdentifier
+ | hintName=simpleIdentifier LEFT_PAREN parameters+=primaryExpression (COMMA parameters+=primaryExpression)* RIGHT_PAREN
;
fromClause
@@ -1241,7 +1248,7 @@ primaryExpression
| identifier #columnReference
| base=primaryExpression DOT fieldName=identifier #dereference
| LEFT_PAREN expression RIGHT_PAREN #parenthesizedExpression
- | EXTRACT LEFT_PAREN field=identifier FROM source=valueExpression RIGHT_PAREN #extract
+ | EXTRACT LEFT_PAREN field=simpleIdentifier FROM source=valueExpression RIGHT_PAREN #extract
| (SUBSTR | SUBSTRING) LEFT_PAREN str=valueExpression (FROM | COMMA) pos=valueExpression
((FOR | COMMA) len=valueExpression)? RIGHT_PAREN #substring
| TRIM LEFT_PAREN trimOption=(BOTH | LEADING | TRAILING)? (trimStr=valueExpression)?
@@ -1297,7 +1304,7 @@ constant
;
namedParameterMarker
- : COLON identifier
+ : COLON simpleIdentifier
;
comparisonOperator
: EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ
@@ -1599,13 +1606,32 @@ identifier
| {!SQL_standard_keyword_behavior}? strictNonReserved
;
+// simpleIdentifier: like identifier but without IDENTIFIER('literal') support
+// Use this for contexts where IDENTIFIER() syntax is not appropriate:
+// - Named parameters (:param_name)
+// - Extract field names (EXTRACT(field FROM ...))
+// - Other keyword-like or string-like uses
+simpleIdentifier
+ : simpleStrictIdentifier
+ | {!SQL_standard_keyword_behavior}? strictNonReserved
+ ;
+
strictIdentifier
: IDENTIFIER #unquotedIdentifier
| quotedIdentifier #quotedIdentifierAlternative
+ | {!legacy_identifier_clause_only}? IDENTIFIER_KW LEFT_PAREN stringLit RIGHT_PAREN #identifierLiteral
| {SQL_standard_keyword_behavior}? ansiNonReserved #unquotedIdentifier
| {!SQL_standard_keyword_behavior}? nonReserved #unquotedIdentifier
;
+// simpleStrictIdentifier: like strictIdentifier but without IDENTIFIER('literal') support
+simpleStrictIdentifier
+ : IDENTIFIER #simpleUnquotedIdentifier
+ | quotedIdentifier #simpleQuotedIdentifierAlternative
+ | {SQL_standard_keyword_behavior}? ansiNonReserved #simpleUnquotedIdentifier
+ | {!SQL_standard_keyword_behavior}? nonReserved #simpleUnquotedIdentifier
+ ;
+
quotedIdentifier
: BACKQUOTED_IDENTIFIER
| {double_quoted_identifiers}? DOUBLEQUOTED_STRING
diff --git a/sql/api/src/main/java/org/apache/spark/sql/types/Geography.java b/sql/api/src/main/java/org/apache/spark/sql/types/Geography.java
index e99902336ffe9..4a52288ba3f2a 100644
--- a/sql/api/src/main/java/org/apache/spark/sql/types/Geography.java
+++ b/sql/api/src/main/java/org/apache/spark/sql/types/Geography.java
@@ -17,10 +17,13 @@
package org.apache.spark.sql.types;
+import org.apache.spark.annotation.Unstable;
+
import java.io.Serializable;
import java.util.Arrays;
// This class represents the Geography data for clients.
+@Unstable
public final class Geography implements Serializable {
// The GEOGRAPHY type is implemented as WKB bytes + SRID integer stored in class itself.
protected final byte[] value;
diff --git a/sql/api/src/main/java/org/apache/spark/sql/types/Geometry.java b/sql/api/src/main/java/org/apache/spark/sql/types/Geometry.java
index 3974aec131d00..fdd64f482130e 100644
--- a/sql/api/src/main/java/org/apache/spark/sql/types/Geometry.java
+++ b/sql/api/src/main/java/org/apache/spark/sql/types/Geometry.java
@@ -17,10 +17,13 @@
package org.apache.spark.sql.types;
+import org.apache.spark.annotation.Unstable;
+
import java.io.Serializable;
import java.util.Arrays;
// This class represents the Geometry data for clients.
+@Unstable
public final class Geometry implements Serializable {
// The GEOMETRY type is implemented as WKB bytes + SRID integer stored in class itself.
protected final byte[] value;
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
index cb1402e1b0f4a..7e698e58321ee 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -162,6 +162,20 @@ object Encoders {
*/
def BINARY: Encoder[Array[Byte]] = BinaryEncoder
+ /**
+ * An encoder for Geometry data type.
+ *
+ * @since 4.1.0
+ */
+ def GEOMETRY(dt: GeometryType): Encoder[Geometry] = GeometryEncoder(dt)
+
+ /**
+ * An encoder for Geography data type.
+ *
+ * @since 4.1.0
+ */
+ def GEOGRAPHY(dt: GeographyType): Encoder[Geography] = GeographyEncoder(dt)
+
/**
* Creates an encoder that serializes instances of the `java.time.Duration` class to the
* internal representation of nullable Catalyst's DayTimeIntervalType.
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
index 59c27d1e56304..2a25b4bd4430b 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Observation.scala
@@ -70,6 +70,9 @@ class Observation(val name: String) {
* first action. Only the result of the first action is available. Subsequent actions do not
* modify the result.
*
+ * Note that if no metrics were recorded, an empty map is probably returned. It possibly happens
+ * when the operators used for observation are optimized away.
+ *
* @return
* the observed metrics as a `Map[String, Any]`
* @throws InterruptedException
@@ -78,7 +81,11 @@ class Observation(val name: String) {
@throws[InterruptedException]
def get: Map[String, Any] = {
val row = getRow
- row.getValuesMap(row.schema.map(_.name))
+ if (row == null || row.schema == null) {
+ Map.empty
+ } else {
+ row.getValuesMap(row.schema.map(_.name))
+ }
}
/**
@@ -86,6 +93,9 @@ class Observation(val name: String) {
* first action. Only the result of the first action is available. Subsequent actions do not
* modify the result.
*
+ * Note that if no metrics were recorded, an empty map is probably returned. It possibly happens
+ * when the operators used for observation are optimized away.
+ *
* @return
* the observed metrics as a `java.util.Map[String, Object]`
* @throws InterruptedException
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Row.scala b/sql/api/src/main/scala/org/apache/spark/sql/Row.scala
index 764bdb17b37e2..1019d4c9a2276 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Row.scala
@@ -302,6 +302,24 @@ trait Row extends Serializable {
*/
def getDecimal(i: Int): java.math.BigDecimal = getAs[java.math.BigDecimal](i)
+ /**
+ * Returns the value at position i of date type as org.apache.spark.sql.types.Geometry.
+ *
+ * @throws ClassCastException
+ * when data type does not match.
+ */
+ def getGeometry(i: Int): org.apache.spark.sql.types.Geometry =
+ getAs[org.apache.spark.sql.types.Geometry](i)
+
+ /**
+ * Returns the value at position i of date type as org.apache.spark.sql.types.Geography.
+ *
+ * @throws ClassCastException
+ * when data type does not match.
+ */
+ def getGeography(i: Int): org.apache.spark.sql.types.Geography =
+ getAs[org.apache.spark.sql.types.Geography](i)
+
/**
* Returns the value at position i of date type as java.sql.Date.
*
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
index a5b1060ca03db..9d64225b96633 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -104,6 +104,14 @@ trait EncoderImplicits extends LowPrioritySQLImplicits with Serializable {
implicit def newScalaDecimalEncoder: Encoder[scala.math.BigDecimal] =
DEFAULT_SCALA_DECIMAL_ENCODER
+ /** @since 4.1.0 */
+ implicit def newGeometryEncoder: Encoder[org.apache.spark.sql.types.Geometry] =
+ DEFAULT_GEOMETRY_ENCODER
+
+ /** @since 4.1.0 */
+ implicit def newGeographyEncoder: Encoder[org.apache.spark.sql.types.Geography] =
+ DEFAULT_GEOGRAPHY_ENCODER
+
/** @since 2.2.0 */
implicit def newDateEncoder: Encoder[java.sql.Date] = Encoders.DATE
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala b/sql/api/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala
index d2479ce9879c5..56f2f75c2c190 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/TableValuedFunction.scala
@@ -181,4 +181,12 @@ abstract class TableValuedFunction {
* @since 4.0.0
*/
def variant_explode_outer(input: Column): Dataset[Row]
+
+ /**
+ * Returns a `DataFrame` of logs collected from Python workers.
+ *
+ * @group table_funcs
+ * @since 4.1.0
+ */
+ def python_worker_logs(): Dataset[Row]
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 906e6419b3607..91947cf416fb6 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -27,7 +27,7 @@ import scala.reflect.ClassTag
import org.apache.commons.lang3.reflect.{TypeUtils => JavaTypeUtils}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, DayTimeIntervalEncoder, DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaEnumEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, STRICT_DATE_ENCODER, STRICT_INSTANT_ENCODER, STRICT_LOCAL_DATE_ENCODER, STRICT_TIMESTAMP_ENCODER, StringEncoder, UDTEncoder, YearMonthIntervalEncoder}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, DayTimeIntervalEncoder, DEFAULT_GEOGRAPHY_ENCODER, DEFAULT_GEOMETRY_ENCODER, DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaEnumEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, STRICT_DATE_ENCODER, STRICT_INSTANT_ENCODER, STRICT_LOCAL_DATE_ENCODER, STRICT_TIMESTAMP_ENCODER, StringEncoder, UDTEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.util.ArrayImplicits._
@@ -86,6 +86,10 @@ object JavaTypeInference {
case c: Class[_] if c == classOf[java.lang.String] => StringEncoder
case c: Class[_] if c == classOf[Array[Byte]] => BinaryEncoder
+ case c: Class[_] if c == classOf[org.apache.spark.sql.types.Geometry] =>
+ DEFAULT_GEOMETRY_ENCODER
+ case c: Class[_] if c == classOf[org.apache.spark.sql.types.Geography] =>
+ DEFAULT_GEOGRAPHY_ENCODER
case c: Class[_] if c == classOf[java.math.BigDecimal] => DEFAULT_JAVA_DECIMAL_ENCODER
case c: Class[_] if c == classOf[java.math.BigInteger] => JavaBigIntEncoder
case c: Class[_] if c == classOf[java.time.LocalDate] => STRICT_LOCAL_DATE_ENCODER
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index d2e0053597e4f..6f5c4be42bbd4 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -332,6 +332,10 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) => LocalDateTimeEncoder
case t if isSubtype(t, localTypeOf[java.time.LocalTime]) => LocalTimeEncoder
case t if isSubtype(t, localTypeOf[VariantVal]) => VariantEncoder
+ case t if isSubtype(t, localTypeOf[Geography]) =>
+ DEFAULT_GEOGRAPHY_ENCODER
+ case t if isSubtype(t, localTypeOf[Geometry]) =>
+ DEFAULT_GEOMETRY_ENCODER
case t if isSubtype(t, localTypeOf[Row]) => UnboundRowEncoder
// UDT encoders
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
index 0c5295176608f..20949c188cb81 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
@@ -246,6 +246,8 @@ object AgnosticEncoders {
case object DayTimeIntervalEncoder extends LeafEncoder[Duration](DayTimeIntervalType())
case object YearMonthIntervalEncoder extends LeafEncoder[Period](YearMonthIntervalType())
case object VariantEncoder extends LeafEncoder[VariantVal](VariantType)
+ case class GeographyEncoder(dt: GeographyType) extends LeafEncoder[Geography](dt)
+ case class GeometryEncoder(dt: GeometryType) extends LeafEncoder[Geometry](dt)
case class DateEncoder(override val lenientSerialization: Boolean)
extends LeafEncoder[jsql.Date](DateType)
case class LocalDateEncoder(override val lenientSerialization: Boolean)
@@ -277,6 +279,10 @@ object AgnosticEncoders {
ScalaDecimalEncoder(DecimalType.SYSTEM_DEFAULT)
val DEFAULT_JAVA_DECIMAL_ENCODER: JavaDecimalEncoder =
JavaDecimalEncoder(DecimalType.SYSTEM_DEFAULT, lenientSerialization = false)
+ val DEFAULT_GEOMETRY_ENCODER: GeometryEncoder =
+ GeometryEncoder(GeometryType(Geometry.DEFAULT_SRID))
+ val DEFAULT_GEOGRAPHY_ENCODER: GeographyEncoder =
+ GeographyEncoder(GeographyType(Geography.DEFAULT_SRID))
/**
* Encoder that transforms external data into a representation that can be further processed by
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 620278c66d21d..73152017cf225 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import scala.reflect.classTag
import org.apache.spark.sql.{AnalysisException, Row}
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, VarcharEncoder, VariantEncoder, YearMonthIntervalEncoder}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, GeographyEncoder, GeometryEncoder, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, VarcharEncoder, VariantEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.errors.DataTypeErrorsBase
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types._
@@ -120,6 +120,8 @@ object RowEncoder extends DataTypeErrorsBase {
field.nullable,
field.metadata)
}.toImmutableArraySeq)
+ case g: GeographyType => GeographyEncoder(g)
+ case g: GeometryType => GeometryEncoder(g)
case _ =>
throw new AnalysisException(
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala
index 0f21972552339..b90d9f8013d6f 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/codecs.scala
@@ -55,7 +55,7 @@ object JavaSerializationCodec extends (() => Codec[Any, Array[Byte]]) {
* server (driver & executors) very tricky. As a workaround a user can define their own Codec
* which internalizes the Kryo configuration.
*/
-object KryoSerializationCodec extends (() => Codec[Any, Array[Byte]]) {
+object KryoSerializationCodec extends (() => Codec[Any, Array[Byte]]) with Serializable {
private lazy val kryoCodecConstructor: MethodHandle = {
val cls = SparkClassUtils.classForName(
"org.apache.spark.sql.catalyst.encoders.KryoSerializationCodecImpl")
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
index 73767990bd3a3..51c846f93c1ec 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
@@ -20,7 +20,7 @@ import java.util.Locale
import scala.jdk.CollectionConverters._
-import org.antlr.v4.runtime.Token
+import org.antlr.v4.runtime.{ParserRuleContext, Token}
import org.antlr.v4.runtime.tree.ParseTree
import org.apache.spark.SparkException
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.catalyst.util.SparkParserUtils.{string, withOrigin}
import org.apache.spark.sql.connector.catalog.IdentityColumnSpec
-import org.apache.spark.sql.errors.QueryParsingErrors
+import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryParsingErrors}
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, GeographyType, GeometryType, IntegerType, LongType, MapType, MetadataBuilder, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType, TimeType, VarcharType, VariantType, YearMonthIntervalType}
@@ -60,12 +60,52 @@ import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType,
*
* @see
* [[org.apache.spark.sql.catalyst.parser.AstBuilder]] for the full SQL statement parser
+ *
+ * ==CRITICAL: Extracting Identifier Names==
+ *
+ * When extracting identifier names from parser contexts, you MUST use the helper methods provided
+ * by this class instead of calling ctx.getText() directly:
+ *
+ * - '''getIdentifierText(ctx)''': For single identifiers (column names, aliases, window names)
+ * - '''getIdentifierParts(ctx)''': For qualified identifiers (table names, schema.table)
+ *
+ * '''DO NOT use ctx.getText() or ctx.identifier.getText()''' directly! These methods do not
+ * handle the IDENTIFIER('literal') syntax and will cause incorrect behavior.
+ *
+ * The IDENTIFIER('literal') syntax allows string literals to be used as identifiers at parse time
+ * (e.g., IDENTIFIER('my_col') resolves to the identifier my_col). If you use getText(), you'll
+ * get the raw text "IDENTIFIER('my_col')" instead of "my_col", breaking the feature.
+ *
+ * Example:
+ * {{{
+ * // WRONG - does not handle IDENTIFIER('literal'):
+ * val name = ctx.identifier.getText
+ * SubqueryAlias(ctx.name.getText, plan)
+ *
+ * // CORRECT - handles both regular identifiers and IDENTIFIER('literal'):
+ * val name = getIdentifierText(ctx.identifier)
+ * SubqueryAlias(getIdentifierText(ctx.name), plan)
+ * }}}
*/
-class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
+class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with DataTypeErrorsBase {
protected def typedVisit[T](ctx: ParseTree): T = {
ctx.accept(this).asInstanceOf[T]
}
+ /**
+ * Public helper to extract identifier parts from a context. This is exposed as public to allow
+ * utility classes like ParserUtils to reuse the identifier resolution logic without duplicating
+ * code.
+ *
+ * @param ctx
+ * The parser context containing the identifier.
+ * @return
+ * Sequence of identifier parts.
+ */
+ def extractIdentifierParts(ctx: ParserRuleContext): Seq[String] = {
+ getIdentifierParts(ctx)
+ }
+
override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) {
typedVisit[DataType](ctx.dataType)
}
@@ -161,11 +201,89 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
}
/**
- * Create a multi-part identifier.
+ * Parse a string into a multi-part identifier. Subclasses should override this method to
+ * provide proper multi-part identifier parsing with access to a full SQL parser.
+ *
+ * For example, in AstBuilder, this would parse "`catalog`.`schema`.`table`" into Seq("catalog",
+ * "schema", "table").
+ *
+ * This method is only called when parsing IDENTIFIER('literal') where the literal contains a
+ * qualified identifier (e.g., IDENTIFIER('schema.table')). Since DataTypeAstBuilder only parses
+ * data types (not full SQL with qualified table names), this should never be called in
+ * practice. The base implementation throws an error to catch unexpected usage.
+ *
+ * @param identifier
+ * The identifier string to parse, potentially containing dots and backticks.
+ * @return
+ * Sequence of identifier parts.
+ */
+ protected def parseMultipartIdentifier(identifier: String): Seq[String] = {
+ throw SparkException.internalError(
+ "parseMultipartIdentifier must be overridden by subclasses. " +
+ s"Attempted to parse: $identifier")
+ }
+
+ /**
+ * Get the identifier parts from a context, handling both regular identifiers and
+ * IDENTIFIER('literal'). This method is used to support identifier-lite syntax where
+ * IDENTIFIER('string') is folded at parse time. For qualified identifiers like
+ * IDENTIFIER('`catalog`.`schema`'), this will parse the string and return multiple parts.
+ *
+ * Subclasses should override this method to provide actual parsing logic.
+ */
+ protected def getIdentifierParts(ctx: ParserRuleContext): Seq[String] = {
+ ctx match {
+ case idCtx: IdentifierContext =>
+ // identifier can be either strictIdentifier or strictNonReserved.
+ // Recursively process the strictIdentifier.
+ Option(idCtx.strictIdentifier()).map(getIdentifierParts).getOrElse(Seq(ctx.getText))
+
+ case idLitCtx: IdentifierLiteralContext =>
+ // For IDENTIFIER('literal') in strictIdentifier.
+ val literalValue = string(visitStringLit(idLitCtx.stringLit()))
+ // Parse the string to handle qualified identifiers like "`cat`.`schema`".
+ parseMultipartIdentifier(literalValue)
+
+ case errCapture: ErrorCapturingIdentifierContext =>
+ // Regular identifier with errorCapturingIdentifierExtra.
+ // Need to recursively handle identifier which might itself be IDENTIFIER('literal').
+ Option(errCapture.identifier())
+ .flatMap(id => Option(id.strictIdentifier()).map(getIdentifierParts))
+ .getOrElse(Seq(ctx.getText))
+
+ case _ =>
+ // For regular identifiers, just return the text as a single part.
+ Seq(ctx.getText)
+ }
+ }
+
+ /**
+ * Get the text of a SINGLE identifier, handling both regular identifiers and
+ * IDENTIFIER('literal'). This method REQUIRES that the identifier be unqualified (single part
+ * only). If IDENTIFIER('qualified.name') is used where a single identifier is required, this
+ * will error.
+ */
+ protected def getIdentifierText(ctx: ParserRuleContext): String = {
+ val parts = getIdentifierParts(ctx)
+ if (parts.size > 1) {
+ throw new ParseException(
+ errorClass = "IDENTIFIER_TOO_MANY_NAME_PARTS",
+ messageParameters = Map("identifier" -> toSQLId(parts), "limit" -> "1"),
+ ctx)
+ }
+ parts.head
+ }
+
+ /**
+ * Create a multi-part identifier. Handles identifier-lite with qualified identifiers like
+ * IDENTIFIER('`cat`.`schema`').table
*/
override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] =
withOrigin(ctx) {
- ctx.parts.asScala.map(_.getText).toSeq
+ // Each part is an errorCapturingIdentifier (which wraps identifier).
+ // getIdentifierParts recursively handles IDENTIFIER('literal') syntax through
+ // identifier -> strictIdentifier -> identifierLiteral.
+ ctx.parts.asScala.flatMap(getIdentifierParts).toSeq
}
/**
@@ -351,7 +469,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
}
StructField(
- name = colName.getText,
+ name = getIdentifierText(colName),
dataType = typedVisit[DataType](ctx.dataType),
nullable = NULL == null,
metadata = builder.build())
@@ -379,7 +497,7 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
override def visitComplexColType(ctx: ComplexColTypeContext): StructField = withOrigin(ctx) {
import ctx._
val structField = StructField(
- name = errorCapturingIdentifier.getText,
+ name = getIdentifierText(errorCapturingIdentifier),
dataType = typedVisit(dataType()),
nullable = NULL == null)
Option(commentSpec).map(visitCommentSpec).map(structField.withComment).getOrElse(structField)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/SubstituteParmsAstBuilder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/SubstituteParmsAstBuilder.scala
index 8beeb9b17d4c8..f32c1d6f3836d 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/SubstituteParmsAstBuilder.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/SubstituteParmsAstBuilder.scala
@@ -81,7 +81,8 @@ class SubstituteParmsAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
*/
override def visitNamedParameterLiteral(ctx: NamedParameterLiteralContext): AnyRef =
withOrigin(ctx) {
- val paramName = ctx.namedParameterMarker().identifier().getText
+ // Named parameters use simpleIdentifier, so .getText() is correct.
+ val paramName = ctx.namedParameterMarker().simpleIdentifier().getText
namedParams += paramName
// Calculate the location of the entire parameter (including the colon)
@@ -117,7 +118,8 @@ class SubstituteParmsAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
*/
override def visitNamedParameterMarkerRule(ctx: NamedParameterMarkerRuleContext): AnyRef =
withOrigin(ctx) {
- val paramName = ctx.namedParameterMarker().identifier().getText
+ // Named parameters use simpleIdentifier, so .getText() is correct.
+ val paramName = ctx.namedParameterMarker().simpleIdentifier().getText
namedParams += paramName
// Calculate the location of the entire parameter (including the colon)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala
index e2e320be36546..32270df0a9885 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala
@@ -430,7 +430,15 @@ case class UnclosedCommentProcessor(command: String, tokenStream: CommonTokenStr
}
object DataTypeParser extends AbstractParser {
- override protected def astBuilder: DataTypeAstBuilder = new DataTypeAstBuilder
+ override protected def astBuilder: DataTypeAstBuilder = new DataTypeAstBuilder {
+ // DataTypeParser only parses data types, not full SQL.
+ // Multi-part identifiers should not appear in IDENTIFIER() within type definitions.
+ override protected def parseMultipartIdentifier(identifier: String): Seq[String] = {
+ throw SparkException.internalError(
+ "DataTypeParser does not support multi-part identifiers in IDENTIFIER(). " +
+ s"Attempted to parse: $identifier")
+ }
+ }
}
object AbstractParser extends Logging {
@@ -476,6 +484,7 @@ object AbstractParser extends Logging {
parser.SQL_standard_keyword_behavior = conf.enforceReservedKeywords
parser.double_quoted_identifiers = conf.doubleQuotedIdentifiers
parser.parameter_substitution_enabled = !conf.legacyParameterSubstitutionConstantsOnly
+ parser.legacy_identifier_clause_only = conf.legacyIdentifierClauseOnly
}
/**
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala
index 8dff1ceccfcfe..ca8e73a517d56 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/RebaseDateTime.scala
@@ -288,7 +288,7 @@ object RebaseDateTime {
// `JsonRebaseRecord`. Mutable HashMap is used here instead of AnyRefMap due to SPARK-49491.
private[sql] def loadRebaseRecords(fileName: String): HashMap[String, RebaseInfo] = {
val file = SparkClassUtils.getSparkClassLoader.getResource(fileName)
- val jsonRebaseRecords = mapper.readValue[Seq[JsonRebaseRecord]](file)
+ val jsonRebaseRecords = mapper.readValue[Seq[JsonRebaseRecord]](file.openStream())
val hashMap = new HashMap[String, RebaseInfo]
hashMap.sizeHint(jsonRebaseRecords.size)
jsonRebaseRecords.foreach { jsonRecord =>
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
index 630f274a621e9..ea565aeb7febf 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/QueryParsingErrors.scala
@@ -461,7 +461,7 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
ctx)
}
- def computeStatisticsNotExpectedError(ctx: IdentifierContext): Throwable = {
+ def computeStatisticsNotExpectedError(ctx: ParserRuleContext): Throwable = {
new ParseException(
errorClass = "INVALID_SQL_SYNTAX.ANALYZE_TABLE_UNEXPECTED_NOSCAN",
messageParameters = Map("ctx" -> toSQLStmt(ctx.getText)),
@@ -477,7 +477,7 @@ private[sql] object QueryParsingErrors extends DataTypeErrorsBase {
ctx)
}
- def showFunctionsUnsupportedError(identifier: String, ctx: IdentifierContext): Throwable = {
+ def showFunctionsUnsupportedError(identifier: String, ctx: ParserRuleContext): Throwable = {
new ParseException(
errorClass = "INVALID_SQL_SYNTAX.SHOW_FUNCTIONS_INVALID_SCOPE",
messageParameters = Map("scope" -> toSQLId(identifier)),
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/api/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index 6a22cbfaf351e..d6da8ce17e832 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql.expressions
-import scala.reflect.runtime.universe.TypeTag
+import scala.reflect.runtime.universe.{typeOf, TypeTag}
import scala.util.Try
import org.apache.spark.annotation.Stable
-import org.apache.spark.sql.{Column, Encoder}
+import org.apache.spark.sql.{Column, Encoder, Row}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.internal.{InvokeInlineUserDefinedFunction, UserDefinedFunctionLike}
@@ -130,8 +130,17 @@ object SparkUserDefinedFunction {
returnTypeTag: TypeTag[_],
inputTypeTags: TypeTag[_]*): SparkUserDefinedFunction = {
val outputEncoder = ScalaReflection.encoderFor(returnTypeTag)
+ // Return None for any input that is a Seq[Row].
+ // This replicates the Spark 3 behavior of passing None as the encoder for such inputs.
+ val seqRowType = typeOf[Seq[_]]
+ val rowType = typeOf[Row]
val inputEncoders = inputTypeTags.map { tag =>
- Try(ScalaReflection.encoderFor(tag)).toOption
+ val tpe = tag.tpe
+ if (tpe <:< seqRowType && tpe.typeArgs.nonEmpty && tpe.typeArgs.head =:= rowType) {
+ None
+ } else {
+ Try(ScalaReflection.encoderFor(tag)).toOption
+ }
}
SparkUserDefinedFunction(
f = function,
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
index ef78c842c5440..d5d6481073908 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1303,6 +1303,329 @@ object functions {
def theta_union_agg(columnName: String): Column =
theta_union_agg(Column(columnName))
+ /**
+ * Aggregate function: returns the compact binary representation of the Datasketches
+ * KllLongsSketch built with the values in the input column. The optional k parameter controls
+ * the size and accuracy of the sketch (default 200, range 8-65535).
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_agg_bigint(e: Column, k: Column): Column =
+ Column.fn("kll_sketch_agg_bigint", e, k)
+
+ /**
+ * Aggregate function: returns the compact binary representation of the Datasketches
+ * KllLongsSketch built with the values in the input column. The optional k parameter controls
+ * the size and accuracy of the sketch (default 200, range 8-65535).
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_agg_bigint(e: Column, k: Int): Column =
+ Column.fn("kll_sketch_agg_bigint", e, lit(k))
+
+ /**
+ * Aggregate function: returns the compact binary representation of the Datasketches
+ * KllLongsSketch built with the values in the input column. The optional k parameter controls
+ * the size and accuracy of the sketch (default 200, range 8-65535).
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_agg_bigint(columnName: String, k: Int): Column =
+ kll_sketch_agg_bigint(Column(columnName), k)
+
+ /**
+ * Aggregate function: returns the compact binary representation of the Datasketches
+ * KllLongsSketch built with the values in the input column with default k value of 200.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_agg_bigint(e: Column): Column =
+ Column.fn("kll_sketch_agg_bigint", e)
+
+ /**
+ * Aggregate function: returns the compact binary representation of the Datasketches
+ * KllLongsSketch built with the values in the input column with default k value of 200.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_agg_bigint(columnName: String): Column =
+ kll_sketch_agg_bigint(Column(columnName))
+
+ /**
+ * Aggregate function: returns the compact binary representation of the Datasketches
+ * KllFloatsSketch built with the values in the input column. The optional k parameter controls
+ * the size and accuracy of the sketch (default 200, range 8-65535).
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_agg_float(e: Column, k: Column): Column =
+ Column.fn("kll_sketch_agg_float", e, k)
+
+ /**
+ * Aggregate function: returns the compact binary representation of the Datasketches
+ * KllFloatsSketch built with the values in the input column. The optional k parameter controls
+ * the size and accuracy of the sketch (default 200, range 8-65535).
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_agg_float(e: Column, k: Int): Column =
+ Column.fn("kll_sketch_agg_float", e, lit(k))
+
+ /**
+ * Aggregate function: returns the compact binary representation of the Datasketches
+ * KllFloatsSketch built with the values in the input column. The optional k parameter controls
+ * the size and accuracy of the sketch (default 200, range 8-65535).
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_agg_float(columnName: String, k: Int): Column =
+ kll_sketch_agg_float(Column(columnName), k)
+
+ /**
+ * Aggregate function: returns the compact binary representation of the Datasketches
+ * KllFloatsSketch built with the values in the input column with default k value of 200.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_agg_float(e: Column): Column =
+ Column.fn("kll_sketch_agg_float", e)
+
+ /**
+ * Aggregate function: returns the compact binary representation of the Datasketches
+ * KllFloatsSketch built with the values in the input column with default k value of 200.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_agg_float(columnName: String): Column =
+ kll_sketch_agg_float(Column(columnName))
+
+ /**
+ * Aggregate function: returns the compact binary representation of the Datasketches
+ * KllDoublesSketch built with the values in the input column. The optional k parameter controls
+ * the size and accuracy of the sketch (default 200, range 8-65535).
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_agg_double(e: Column, k: Column): Column =
+ Column.fn("kll_sketch_agg_double", e, k)
+
+ /**
+ * Aggregate function: returns the compact binary representation of the Datasketches
+ * KllDoublesSketch built with the values in the input column. The optional k parameter controls
+ * the size and accuracy of the sketch (default 200, range 8-65535).
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_agg_double(e: Column, k: Int): Column =
+ Column.fn("kll_sketch_agg_double", e, lit(k))
+
+ /**
+ * Aggregate function: returns the compact binary representation of the Datasketches
+ * KllDoublesSketch built with the values in the input column. The optional k parameter controls
+ * the size and accuracy of the sketch (default 200, range 8-65535).
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_agg_double(columnName: String, k: Int): Column =
+ kll_sketch_agg_double(Column(columnName), k)
+
+ /**
+ * Aggregate function: returns the compact binary representation of the Datasketches
+ * KllDoublesSketch built with the values in the input column with default k value of 200.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_agg_double(e: Column): Column =
+ Column.fn("kll_sketch_agg_double", e)
+
+ /**
+ * Aggregate function: returns the compact binary representation of the Datasketches
+ * KllDoublesSketch built with the values in the input column with default k value of 200.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_agg_double(columnName: String): Column =
+ kll_sketch_agg_double(Column(columnName))
+
+ /**
+ * Aggregate function: merges binary KllLongsSketch representations and returns the merged
+ * sketch. The optional k parameter controls the size and accuracy of the merged sketch (range
+ * 8-65535). If k is not specified, the merged sketch adopts the k value from the first input
+ * sketch.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_merge_agg_bigint(e: Column, k: Column): Column =
+ Column.fn("kll_merge_agg_bigint", e, k)
+
+ /**
+ * Aggregate function: merges binary KllLongsSketch representations and returns the merged
+ * sketch. The optional k parameter controls the size and accuracy of the merged sketch (range
+ * 8-65535). If k is not specified, the merged sketch adopts the k value from the first input
+ * sketch.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_merge_agg_bigint(e: Column, k: Int): Column =
+ Column.fn("kll_merge_agg_bigint", e, lit(k))
+
+ /**
+ * Aggregate function: merges binary KllLongsSketch representations and returns the merged
+ * sketch. The optional k parameter controls the size and accuracy of the merged sketch (range
+ * 8-65535). If k is not specified, the merged sketch adopts the k value from the first input
+ * sketch.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_merge_agg_bigint(columnName: String, k: Int): Column =
+ kll_merge_agg_bigint(Column(columnName), k)
+
+ /**
+ * Aggregate function: merges binary KllLongsSketch representations and returns the merged
+ * sketch. If k is not specified, the merged sketch adopts the k value from the first input
+ * sketch.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_merge_agg_bigint(e: Column): Column =
+ Column.fn("kll_merge_agg_bigint", e)
+
+ /**
+ * Aggregate function: merges binary KllLongsSketch representations and returns the merged
+ * sketch. If k is not specified, the merged sketch adopts the k value from the first input
+ * sketch.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_merge_agg_bigint(columnName: String): Column =
+ kll_merge_agg_bigint(Column(columnName))
+
+ /**
+ * Aggregate function: merges binary KllFloatsSketch representations and returns merged sketch.
+ * The optional k parameter controls the size and accuracy of the merged sketch (range 8-65535).
+ * If k is not specified, the merged sketch adopts the k value from the first input sketch.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_merge_agg_float(e: Column, k: Column): Column =
+ Column.fn("kll_merge_agg_float", e, k)
+
+ /**
+ * Aggregate function: merges binary KllFloatsSketch representations and returns merged sketch.
+ * The optional k parameter controls the size and accuracy of the merged sketch (range 8-65535).
+ * If k is not specified, the merged sketch adopts the k value from the first input sketch.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_merge_agg_float(e: Column, k: Int): Column =
+ Column.fn("kll_merge_agg_float", e, lit(k))
+
+ /**
+ * Aggregate function: merges binary KllFloatsSketch representations and returns merged sketch.
+ * The optional k parameter controls the size and accuracy of the merged sketch (range 8-65535).
+ * If k is not specified, the merged sketch adopts the k value from the first input sketch.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_merge_agg_float(columnName: String, k: Int): Column =
+ kll_merge_agg_float(Column(columnName), k)
+
+ /**
+ * Aggregate function: merges binary KllFloatsSketch representations and returns merged sketch.
+ * If k is not specified, the merged sketch adopts the k value from the first input sketch.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_merge_agg_float(e: Column): Column =
+ Column.fn("kll_merge_agg_float", e)
+
+ /**
+ * Aggregate function: merges binary KllFloatsSketch representations and returns merged sketch.
+ * If k is not specified, the merged sketch adopts the k value from the first input sketch.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_merge_agg_float(columnName: String): Column =
+ kll_merge_agg_float(Column(columnName))
+
+ /**
+ * Aggregate function: merges binary KllDoublesSketch representations and returns merged sketch.
+ * The optional k parameter controls the size and accuracy of the merged sketch (range 8-65535).
+ * If k is not specified, the merged sketch adopts the k value from the first input sketch.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_merge_agg_double(e: Column, k: Column): Column =
+ Column.fn("kll_merge_agg_double", e, k)
+
+ /**
+ * Aggregate function: merges binary KllDoublesSketch representations and returns merged sketch.
+ * The optional k parameter controls the size and accuracy of the merged sketch (range 8-65535).
+ * If k is not specified, the merged sketch adopts the k value from the first input sketch.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_merge_agg_double(e: Column, k: Int): Column =
+ Column.fn("kll_merge_agg_double", e, lit(k))
+
+ /**
+ * Aggregate function: merges binary KllDoublesSketch representations and returns merged sketch.
+ * The optional k parameter controls the size and accuracy of the merged sketch (range 8-65535).
+ * If k is not specified, the merged sketch adopts the k value from the first input sketch.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_merge_agg_double(columnName: String, k: Int): Column =
+ kll_merge_agg_double(Column(columnName), k)
+
+ /**
+ * Aggregate function: merges binary KllDoublesSketch representations and returns merged sketch.
+ * If k is not specified, the merged sketch adopts the k value from the first input sketch.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_merge_agg_double(e: Column): Column =
+ Column.fn("kll_merge_agg_double", e)
+
+ /**
+ * Aggregate function: merges binary KllDoublesSketch representations and returns merged sketch.
+ * If k is not specified, the merged sketch adopts the k value from the first input sketch.
+ *
+ * @group agg_funcs
+ * @since 4.1.0
+ */
+ def kll_merge_agg_double(columnName: String): Column =
+ kll_merge_agg_double(Column(columnName))
+
/**
* Aggregate function: returns the concatenation of non-null input values.
*
@@ -3809,6 +4132,147 @@ object functions {
def theta_union(c1: Column, c2: Column, lgNomEntries: Column): Column =
Column.fn("theta_union", c1, c2, lgNomEntries)
+ /**
+ * Returns a string with human readable summary information about the KLL bigint sketch.
+ *
+ * @group misc_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_to_string_bigint(e: Column): Column =
+ Column.fn("kll_sketch_to_string_bigint", e)
+
+ /**
+ * Returns a string with human readable summary information about the KLL float sketch.
+ *
+ * @group misc_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_to_string_float(e: Column): Column =
+ Column.fn("kll_sketch_to_string_float", e)
+
+ /**
+ * Returns a string with human readable summary information about the KLL double sketch.
+ *
+ * @group misc_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_to_string_double(e: Column): Column =
+ Column.fn("kll_sketch_to_string_double", e)
+
+ /**
+ * Returns the number of items collected in the KLL bigint sketch.
+ *
+ * @group misc_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_get_n_bigint(e: Column): Column =
+ Column.fn("kll_sketch_get_n_bigint", e)
+
+ /**
+ * Returns the number of items collected in the KLL float sketch.
+ *
+ * @group misc_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_get_n_float(e: Column): Column =
+ Column.fn("kll_sketch_get_n_float", e)
+
+ /**
+ * Returns the number of items collected in the KLL double sketch.
+ *
+ * @group misc_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_get_n_double(e: Column): Column =
+ Column.fn("kll_sketch_get_n_double", e)
+
+ /**
+ * Merges two KLL bigint sketch buffers together into one.
+ *
+ * @group misc_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_merge_bigint(left: Column, right: Column): Column =
+ Column.fn("kll_sketch_merge_bigint", left, right)
+
+ /**
+ * Merges two KLL float sketch buffers together into one.
+ *
+ * @group misc_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_merge_float(left: Column, right: Column): Column =
+ Column.fn("kll_sketch_merge_float", left, right)
+
+ /**
+ * Merges two KLL double sketch buffers together into one.
+ *
+ * @group misc_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_merge_double(left: Column, right: Column): Column =
+ Column.fn("kll_sketch_merge_double", left, right)
+
+ /**
+ * Extracts a quantile value from a KLL bigint sketch given an input rank value. The rank can be
+ * a single value or an array.
+ *
+ * @group misc_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_get_quantile_bigint(sketch: Column, rank: Column): Column =
+ Column.fn("kll_sketch_get_quantile_bigint", sketch, rank)
+
+ /**
+ * Extracts a quantile value from a KLL float sketch given an input rank value. The rank can be
+ * a single value or an array.
+ *
+ * @group misc_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_get_quantile_float(sketch: Column, rank: Column): Column =
+ Column.fn("kll_sketch_get_quantile_float", sketch, rank)
+
+ /**
+ * Extracts a quantile value from a KLL double sketch given an input rank value. The rank can be
+ * a single value or an array.
+ *
+ * @group misc_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_get_quantile_double(sketch: Column, rank: Column): Column =
+ Column.fn("kll_sketch_get_quantile_double", sketch, rank)
+
+ /**
+ * Extracts a rank value from a KLL bigint sketch given an input quantile value. The quantile
+ * can be a single value or an array.
+ *
+ * @group misc_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_get_rank_bigint(sketch: Column, quantile: Column): Column =
+ Column.fn("kll_sketch_get_rank_bigint", sketch, quantile)
+
+ /**
+ * Extracts a rank value from a KLL float sketch given an input quantile value. The quantile can
+ * be a single value or an array.
+ *
+ * @group misc_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_get_rank_float(sketch: Column, quantile: Column): Column =
+ Column.fn("kll_sketch_get_rank_float", sketch, quantile)
+
+ /**
+ * Extracts a rank value from a KLL double sketch given an input quantile value. The quantile
+ * can be a single value or an array.
+ *
+ * @group misc_funcs
+ * @since 4.1.0
+ */
+ def kll_sketch_get_rank_double(sketch: Column, quantile: Column): Column =
+ Column.fn("kll_sketch_get_rank_double", sketch, quantile)
+
/**
* Returns the user name of current execution context.
*
@@ -9171,6 +9635,33 @@ object functions {
def st_geomfromwkb(wkb: Column): Column =
Column.fn("st_geomfromwkb", wkb)
+ /**
+ * Returns a new GEOGRAPHY or GEOMETRY value whose SRID is the specified SRID value.
+ *
+ * @group st_funcs
+ * @since 4.1.0
+ */
+ def st_setsrid(geo: Column, srid: Column): Column =
+ Column.fn("st_setsrid", geo, srid)
+
+ /**
+ * Returns a new GEOGRAPHY or GEOMETRY value whose SRID is the specified SRID value.
+ *
+ * @group st_funcs
+ * @since 4.1.0
+ */
+ def st_setsrid(geo: Column, srid: Int): Column =
+ Column.fn("st_setsrid", geo, lit(srid))
+
+ /**
+ * Returns the SRID of the input GEOGRAPHY or GEOMETRY value.
+ *
+ * @group st_funcs
+ * @since 4.1.0
+ */
+ def st_srid(geo: Column): Column =
+ Column.fn("st_srid", geo)
+
//////////////////////////////////////////////////////////////////////////////////////////////
// Scala UDF functions
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
index f715f8f9ed8cd..1be00d75acbce 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
@@ -51,6 +51,7 @@ private[sql] trait SqlApiConf {
def parserDfaCacheFlushThreshold: Int
def parserDfaCacheFlushRatio: Double
def legacyParameterSubstitutionConstantsOnly: Boolean
+ def legacyIdentifierClauseOnly: Boolean
}
private[sql] object SqlApiConf {
@@ -67,6 +68,8 @@ private[sql] object SqlApiConf {
SqlApiConfHelper.LOCAL_RELATION_CHUNK_SIZE_ROWS_KEY
val LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY: String =
SqlApiConfHelper.LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY
+ val LOCAL_RELATION_BATCH_OF_CHUNKS_SIZE_BYTES_KEY: String =
+ SqlApiConfHelper.LOCAL_RELATION_BATCH_OF_CHUNKS_SIZE_BYTES_KEY
val PARSER_DFA_CACHE_FLUSH_THRESHOLD_KEY: String =
SqlApiConfHelper.PARSER_DFA_CACHE_FLUSH_THRESHOLD_KEY
val PARSER_DFA_CACHE_FLUSH_RATIO_KEY: String =
@@ -104,4 +107,5 @@ private[sql] object DefaultSqlApiConf extends SqlApiConf {
override def parserDfaCacheFlushThreshold: Int = -1
override def parserDfaCacheFlushRatio: Double = -1.0
override def legacyParameterSubstitutionConstantsOnly: Boolean = false
+ override def legacyIdentifierClauseOnly: Boolean = false
}
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
index b839caba3f547..4fcc2f4e150d1 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
@@ -35,6 +35,8 @@ private[sql] object SqlApiConfHelper {
val LOCAL_RELATION_CHUNK_SIZE_ROWS_KEY: String = "spark.sql.session.localRelationChunkSizeRows"
val LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY: String =
"spark.sql.session.localRelationChunkSizeBytes"
+ val LOCAL_RELATION_BATCH_OF_CHUNKS_SIZE_BYTES_KEY: String =
+ "spark.sql.session.localRelationBatchOfChunksSizeBytes"
val ARROW_EXECUTION_USE_LARGE_VAR_TYPES = "spark.sql.execution.arrow.useLargeVarTypes"
val PARSER_DFA_CACHE_FLUSH_THRESHOLD_KEY: String =
"spark.sql.parser.parserDfaCacheFlushThreshold"
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/CartesianSpatialReferenceSystemMapper.java b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/CartesianSpatialReferenceSystemMapper.java
index d7729e88a3331..7384bb331d44f 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/CartesianSpatialReferenceSystemMapper.java
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/CartesianSpatialReferenceSystemMapper.java
@@ -17,9 +17,12 @@
package org.apache.spark.sql.internal.types;
+import org.apache.spark.annotation.Unstable;
+
/**
* Class for providing SRS mappings for cartesian spatial reference systems.
*/
+@Unstable
public class CartesianSpatialReferenceSystemMapper extends SpatialReferenceSystemMapper {
// Returns the string ID corresponding to the input SRID. If not supported, returns `null`.
public static String getStringId(int srid) {
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/GeographicSpatialReferenceSystemMapper.java b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/GeographicSpatialReferenceSystemMapper.java
index 85e0cc53658c6..f409041c564ee 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/GeographicSpatialReferenceSystemMapper.java
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/GeographicSpatialReferenceSystemMapper.java
@@ -17,9 +17,12 @@
package org.apache.spark.sql.internal.types;
+import org.apache.spark.annotation.Unstable;
+
/**
* Class for providing SRS mappings for geographic spatial reference systems.
*/
+@Unstable
public class GeographicSpatialReferenceSystemMapper extends SpatialReferenceSystemMapper {
// Returns the string ID corresponding to the input SRID. If not supported, returns `null`.
public static String getStringId(int srid) {
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/SpatialReferenceSystemCache.java b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/SpatialReferenceSystemCache.java
index 40094785600fb..34ea42271352f 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/SpatialReferenceSystemCache.java
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/SpatialReferenceSystemCache.java
@@ -17,12 +17,15 @@
package org.apache.spark.sql.internal.types;
+import org.apache.spark.annotation.Unstable;
+
import java.util.HashMap;
import java.util.List;
/**
* Class for maintaining the mappings between supported SRID/CRS values and the corresponding SRS.
*/
+@Unstable
public class SpatialReferenceSystemCache {
// Private constructor to prevent external instantiation of this singleton class.
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/SpatialReferenceSystemInformation.java b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/SpatialReferenceSystemInformation.java
index 46fbfd9e9ac17..ba3526d84f079 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/SpatialReferenceSystemInformation.java
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/SpatialReferenceSystemInformation.java
@@ -17,9 +17,12 @@
package org.apache.spark.sql.internal.types;
+import org.apache.spark.annotation.Unstable;
+
/**
* Class for maintaining information about a spatial reference system (SRS).
*/
+@Unstable
public record SpatialReferenceSystemInformation(
// Field storing the spatial reference identifier (SRID) value of this SRS.
int srid,
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/SpatialReferenceSystemMapper.java b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/SpatialReferenceSystemMapper.java
index 2993ba05b76f8..24b82f540b082 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/types/SpatialReferenceSystemMapper.java
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/types/SpatialReferenceSystemMapper.java
@@ -17,9 +17,13 @@
package org.apache.spark.sql.internal.types;
+import org.apache.spark.annotation.Unstable;
+
/**
* Abstract class for providing SRS mappings for spatial reference systems.
*/
+
+@Unstable
public abstract class SpatialReferenceSystemMapper {
protected static final SpatialReferenceSystemCache srsCache =
SpatialReferenceSystemCache.getInstance();
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala
index 638ae79351846..f74cddc02ac4a 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.types
import org.json4s.JsonAST.{JString, JValue}
-import org.apache.spark.SparkIllegalArgumentException
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException}
+import org.apache.spark.annotation.Unstable
import org.apache.spark.sql.internal.types.GeographicSpatialReferenceSystemMapper
/**
@@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.types.GeographicSpatialReferenceSystemMappe
* Geospatial Consortium (OGC) Simple Feature Access specification
* (https://portal.ogc.org/files/?artifact_id=25355), with a geographic coordinate system.
*/
-@Experimental
+@Unstable
class GeographyType private (val crs: String, val algorithm: EdgeInterpolationAlgorithm)
extends AtomicType
with Serializable {
@@ -133,9 +133,30 @@ class GeographyType private (val crs: String, val algorithm: EdgeInterpolationAl
// If the SRID is not mixed, we can only accept the same SRID.
isMixedSrid || gt.srid == srid
}
+
+ private[sql] def assertSridAllowedForType(otherSrid: Int): Unit = {
+ // If SRID is not mixed, SRIDs must match.
+ if (!isMixedSrid && otherSrid != srid) {
+ throw new SparkRuntimeException(
+ errorClass = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ messageParameters = Map(
+ "type" -> "GEOGRAPHY",
+ "valueSrid" -> otherSrid.toString,
+ "typeSrid" -> srid.toString))
+ } else if (isMixedSrid) {
+ // For fixed SRID geom types, we have a check that value matches the type srid.
+ // For mixed SRID we need to do that check explicitly, as MIXED SRID can accept any SRID.
+ // However it should accept only valid SRIDs.
+ if (!GeographyType.isSridSupported(otherSrid)) {
+ throw new SparkIllegalArgumentException(
+ errorClass = "ST_INVALID_SRID_VALUE",
+ messageParameters = Map("srid" -> otherSrid.toString))
+ }
+ }
+ }
}
-@Experimental
+@Unstable
object GeographyType extends SpatialType {
/**
@@ -154,9 +175,14 @@ object GeographyType extends SpatialType {
/**
* The default concrete GeographyType in SQL.
*/
- private final val GEOGRAPHY_MIXED_TYPE: GeographyType =
+ private final lazy val GEOGRAPHY_MIXED_TYPE: GeographyType =
GeographyType(MIXED_CRS, GEOGRAPHY_DEFAULT_ALGORITHM)
+ /** Returns whether the given SRID is supported. */
+ def isSridSupported(srid: Int): Boolean = {
+ GeographicSpatialReferenceSystemMapper.getStringId(srid) != null
+ }
+
/**
* Constructors for GeographyType.
*/
@@ -228,8 +254,10 @@ object GeographyType extends SpatialType {
* Edge interpolation algorithm for Geography logical type. Currently, Spark only supports
* spherical algorithm.
*/
+@Unstable
sealed abstract class EdgeInterpolationAlgorithm
+@Unstable
object EdgeInterpolationAlgorithm {
case object SPHERICAL extends EdgeInterpolationAlgorithm
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala
index 77a6b365c042a..c8b475dae2bab 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.types
import org.json4s.JsonAST.{JString, JValue}
-import org.apache.spark.SparkIllegalArgumentException
-import org.apache.spark.annotation.Experimental
+import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException}
+import org.apache.spark.annotation.Unstable
import org.apache.spark.sql.internal.types.CartesianSpatialReferenceSystemMapper
/**
@@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.types.CartesianSpatialReferenceSystemMapper
* Geospatial Consortium (OGC) Simple Feature Access specification
* (https://portal.ogc.org/files/?artifact_id=25355), with a Cartesian coordinate system.
*/
-@Experimental
+@Unstable
class GeometryType private (val crs: String) extends AtomicType with Serializable {
/**
@@ -130,9 +130,30 @@ class GeometryType private (val crs: String) extends AtomicType with Serializabl
// If the SRID is not mixed, we can only accept the same SRID.
isMixedSrid || gt.srid == srid
}
+
+ private[sql] def assertSridAllowedForType(otherSrid: Int): Unit = {
+ // If SRID is not mixed, SRIDs must match.
+ if (!isMixedSrid && otherSrid != srid) {
+ throw new SparkRuntimeException(
+ errorClass = "GEO_ENCODER_SRID_MISMATCH_ERROR",
+ messageParameters = Map(
+ "type" -> "GEOMETRY",
+ "valueSrid" -> otherSrid.toString,
+ "typeSrid" -> srid.toString))
+ } else if (isMixedSrid) {
+ // For fixed SRID geom types, we have a check that value matches the type srid.
+ // For mixed SRID we need to do that check explicitly, as MIXED SRID can accept any SRID.
+ // However it should accept only valid SRIDs.
+ if (!GeometryType.isSridSupported(otherSrid)) {
+ throw new SparkIllegalArgumentException(
+ errorClass = "ST_INVALID_SRID_VALUE",
+ messageParameters = Map("srid" -> otherSrid.toString))
+ }
+ }
+ }
}
-@Experimental
+@Unstable
object GeometryType extends SpatialType {
/**
@@ -146,9 +167,14 @@ object GeometryType extends SpatialType {
/**
* The default concrete GeometryType in SQL.
*/
- private final val GEOMETRY_MIXED_TYPE: GeometryType =
+ private final lazy val GEOMETRY_MIXED_TYPE: GeometryType =
GeometryType(MIXED_CRS)
+ /** Returns whether the given SRID is supported. */
+ def isSridSupported(srid: Int): Boolean = {
+ CartesianSpatialReferenceSystemMapper.getStringId(srid) != null
+ }
+
/**
* Constructors for GeometryType.
*/
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/SpatialType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/SpatialType.scala
index a61158e36c850..92fa6dfd38fb1 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/SpatialType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/SpatialType.scala
@@ -17,8 +17,9 @@
package org.apache.spark.sql.types
-import org.apache.spark.sql.types.AbstractDataType
+import org.apache.spark.annotation.Unstable
+@Unstable
trait SpatialType extends AbstractDataType {
/**
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
index 6caabf20f8f6b..23d8a0bbb65b5 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala
@@ -143,6 +143,43 @@ private[sql] object ArrowUtils {
largeVarTypes)).asJava)
case udt: UserDefinedType[_] =>
toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes)
+ case g: GeometryType =>
+ val fieldType =
+ new FieldType(nullable, ArrowType.Struct.INSTANCE, null)
+
+ // WKB field is tagged with additional metadata so we can identify that the arrow
+ // struct actually represents a geometry schema.
+ val wkbFieldType = new FieldType(
+ false,
+ toArrowType(BinaryType, timeZoneId, largeVarTypes),
+ null,
+ Map("geometry" -> "true", "srid" -> g.srid.toString).asJava)
+
+ new Field(
+ name,
+ fieldType,
+ Seq(
+ toArrowField("srid", IntegerType, false, timeZoneId, largeVarTypes),
+ new Field("wkb", wkbFieldType, Seq.empty[Field].asJava)).asJava)
+
+ case g: GeographyType =>
+ val fieldType =
+ new FieldType(nullable, ArrowType.Struct.INSTANCE, null, null)
+
+ // WKB field is tagged with additional metadata so we can identify that the arrow
+ // struct actually represents a geography schema.
+ val wkbFieldType = new FieldType(
+ false,
+ toArrowType(BinaryType, timeZoneId, largeVarTypes),
+ null,
+ Map("geography" -> "true", "srid" -> g.srid.toString).asJava)
+
+ new Field(
+ name,
+ fieldType,
+ Seq(
+ toArrowField("srid", IntegerType, false, timeZoneId, largeVarTypes),
+ new Field("wkb", wkbFieldType, Seq.empty[Field].asJava)).asJava)
case _: VariantType =>
val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null)
// The metadata field is tagged with additional metadata so we can identify that the arrow
@@ -175,6 +212,26 @@ private[sql] object ArrowUtils {
}
}
+ def isGeometryField(field: Field): Boolean = {
+ assert(field.getType.isInstanceOf[ArrowType.Struct])
+ field.getChildren.asScala
+ .map(_.getName)
+ .asJava
+ .containsAll(Seq("wkb", "srid").asJava) && field.getChildren.asScala.exists { child =>
+ child.getName == "wkb" && child.getMetadata.getOrDefault("geometry", "false") == "true"
+ }
+ }
+
+ def isGeographyField(field: Field): Boolean = {
+ assert(field.getType.isInstanceOf[ArrowType.Struct])
+ field.getChildren.asScala
+ .map(_.getName)
+ .asJava
+ .containsAll(Seq("wkb", "srid").asJava) && field.getChildren.asScala.exists { child =>
+ child.getName == "wkb" && child.getMetadata.getOrDefault("geography", "false") == "true"
+ }
+ }
+
def fromArrowField(field: Field): DataType = {
field.getType match {
case _: ArrowType.Map =>
@@ -188,6 +245,26 @@ private[sql] object ArrowUtils {
ArrayType(elementType, containsNull = elementField.isNullable)
case ArrowType.Struct.INSTANCE if isVariantField(field) =>
VariantType
+ case ArrowType.Struct.INSTANCE if isGeometryField(field) =>
+ // We expect that type metadata is associated with wkb field.
+ val metadataField =
+ field.getChildren.asScala.filter { child => child.getName == "wkb" }.head
+ val srid = metadataField.getMetadata.get("srid").toInt
+ if (srid == GeometryType.MIXED_SRID) {
+ GeometryType("ANY")
+ } else {
+ GeometryType(srid)
+ }
+ case ArrowType.Struct.INSTANCE if isGeographyField(field) =>
+ // We expect that type metadata is associated with wkb field.
+ val metadataField =
+ field.getChildren.asScala.filter { child => child.getName == "wkb" }.head
+ val srid = metadataField.getMetadata.get("srid").toInt
+ if (srid == GeographyType.MIXED_SRID) {
+ GeographyType("ANY")
+ } else {
+ GeographyType(srid)
+ }
case ArrowType.Struct.INSTANCE =>
val fields = field.getChildren().asScala.map { child =>
val dt = fromArrowField(child)
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/CloseableIterator.scala
similarity index 97%
rename from sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala
rename to sql/api/src/main/scala/org/apache/spark/sql/util/CloseableIterator.scala
index 9de585503a500..dc38c75d3ce73 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CloseableIterator.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/util/CloseableIterator.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.connect.client
+package org.apache.spark.sql.util
private[sql] trait CloseableIterator[E] extends Iterator[E] with AutoCloseable { self =>
def asJava: java.util.Iterator[E] = new java.util.Iterator[E] with AutoCloseable {
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ConcatenatingArrowStreamReader.scala
similarity index 94%
rename from sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala
rename to sql/api/src/main/scala/org/apache/spark/sql/util/ConcatenatingArrowStreamReader.scala
index 90963c831c252..2e5706fe4dcca 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ConcatenatingArrowStreamReader.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ConcatenatingArrowStreamReader.scala
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.connect.client.arrow
+package org.apache.spark.sql.util
import java.io.{InputStream, IOException}
import java.nio.channels.Channels
@@ -25,6 +25,8 @@ import org.apache.arrow.vector.ipc.{ArrowReader, ReadChannel}
import org.apache.arrow.vector.ipc.message.{ArrowDictionaryBatch, ArrowMessage, ArrowRecordBatch, MessageChannelReader, MessageResult, MessageSerializer}
import org.apache.arrow.vector.types.pojo.Schema
+import org.apache.spark.SparkException
+
/**
* An [[ArrowReader]] that concatenates multiple [[MessageIterator]]s into a single stream. Each
* iterator represents a single IPC stream. The concatenated streams all must have the same
@@ -34,7 +36,7 @@ import org.apache.arrow.vector.types.pojo.Schema
* closes its messages when it consumes them. In order to prevent that from happening in
* non-destructive mode we clone the messages before passing them to the reading logic.
*/
-class ConcatenatingArrowStreamReader(
+private[sql] class ConcatenatingArrowStreamReader(
allocator: BufferAllocator,
input: Iterator[AbstractMessageIterator],
destructive: Boolean)
@@ -62,7 +64,7 @@ class ConcatenatingArrowStreamReader(
totalBytesRead += current.bytesRead
current = input.next()
if (current.schema != getVectorSchemaRoot.getSchema) {
- throw new IllegalStateException()
+ throw SparkException.internalError("IPC Streams have different schemas.")
}
}
if (current.hasNext) {
@@ -128,7 +130,7 @@ class ConcatenatingArrowStreamReader(
override def closeReadSource(): Unit = ()
}
-trait AbstractMessageIterator extends Iterator[ArrowMessage] {
+private[sql] trait AbstractMessageIterator extends Iterator[ArrowMessage] {
def schema: Schema
def bytesRead: Long
}
@@ -137,7 +139,7 @@ trait AbstractMessageIterator extends Iterator[ArrowMessage] {
* Decode an Arrow IPC stream into individual messages. Please note that this iterator MUST have a
* valid IPC stream as its input, otherwise construction will fail.
*/
-class MessageIterator(input: InputStream, allocator: BufferAllocator)
+private[sql] class MessageIterator(input: InputStream, allocator: BufferAllocator)
extends AbstractMessageIterator {
private[this] val in = new ReadChannel(Channels.newChannel(input))
private[this] val reader = new MessageChannelReader(in, allocator)
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 6dcedfa6408e3..5fec0441d49f2 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/Geo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/Geo.java
index 2299f35988638..bf723a8efef91 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/Geo.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/Geo.java
@@ -78,4 +78,7 @@ interface Geo {
// Returns the Spatial Reference Identifier (SRID) value of the geo object.
int srid();
+ // Sets the Spatial Reference Identifier (SRID) value of the geo object.
+ void setSrid(int srid);
+
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/Geography.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/Geography.java
index c46c2368832fe..da513d399f8b0 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/Geography.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/Geography.java
@@ -162,7 +162,20 @@ public byte[] toEwkt() {
@Override
public int srid() {
// This method gets the SRID value from the in-memory Geography representation header.
- return ByteBuffer.wrap(getBytes()).order(DEFAULT_ENDIANNESS).getInt(SRID_OFFSET);
+ return getWrapper().getInt(SRID_OFFSET);
+ }
+
+ @Override
+ public void setSrid(int srid) {
+ // This method sets the SRID value in the in-memory Geography representation header.
+ getWrapper().putInt(SRID_OFFSET, srid);
+ }
+
+ /** Other private helper/utility methods used for implementation. */
+
+ // Returns a byte buffer wrapper over the byte buffer of this geography value.
+ private ByteBuffer getWrapper() {
+ return ByteBuffer.wrap(getBytes()).order(DEFAULT_ENDIANNESS);
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/Geometry.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/Geometry.java
index c4b6e5d0e4bd1..36fffef2abbd1 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/Geometry.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/Geometry.java
@@ -162,7 +162,20 @@ public byte[] toEwkt() {
@Override
public int srid() {
// This method gets the SRID value from the in-memory Geometry representation header.
- return ByteBuffer.wrap(getBytes()).order(DEFAULT_ENDIANNESS).getInt(SRID_OFFSET);
+ return getWrapper().getInt(SRID_OFFSET);
+ }
+
+ @Override
+ public void setSrid(int srid) {
+ // This method sets the SRID value in the in-memory Geometry representation header.
+ getWrapper().putInt(SRID_OFFSET, srid);
+ }
+
+ /** Other private helper/utility methods used for implementation. */
+
+ // Returns a byte buffer wrapper over the byte buffer of this geometry value.
+ private ByteBuffer getWrapper() {
+ return ByteBuffer.wrap(getBytes()).order(DEFAULT_ENDIANNESS);
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java
index aca3fdf1f1000..0a9942c4cf557 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java
@@ -16,6 +16,9 @@
*/
package org.apache.spark.sql.catalyst.util;
+import org.apache.spark.sql.errors.QueryExecutionErrors;
+import org.apache.spark.sql.types.GeographyType;
+import org.apache.spark.sql.types.GeometryType;
import org.apache.spark.unsafe.types.GeographyVal;
import org.apache.spark.unsafe.types.GeometryVal;
@@ -46,6 +49,47 @@ static GeometryVal toPhysVal(Geometry g) {
return g.getValue();
}
+ /** Geospatial type casting utility methods. */
+
+ // Cast geography to geometry.
+ public static GeometryVal geographyToGeometry(GeographyVal geographyVal) {
+ // Geographic SRID is always a valid SRID for geometry, so we don't need to check it.
+ // Also, all geographic coordinates are valid for geometry, so no need to check bounds.
+ return toPhysVal(Geometry.fromBytes(geographyVal.getBytes()));
+ }
+
+ /** Geospatial type encoder/decoder utilities. */
+
+ public static GeometryVal serializeGeomFromWKB(org.apache.spark.sql.types.Geometry geometry,
+ GeometryType gt) {
+ int geometrySrid = geometry.getSrid();
+ gt.assertSridAllowedForType(geometrySrid);
+ return toPhysVal(Geometry.fromWkb(geometry.getBytes(), geometrySrid));
+ }
+
+ public static GeographyVal serializeGeogFromWKB(org.apache.spark.sql.types.Geography geography,
+ GeographyType gt) {
+ int geographySrid = geography.getSrid();
+ gt.assertSridAllowedForType(geographySrid);
+ return toPhysVal(Geography.fromWkb(geography.getBytes(), geographySrid));
+ }
+
+ public static org.apache.spark.sql.types.Geometry deserializeGeom(
+ GeometryVal geometry, GeometryType gt) {
+ int geometrySrid = stSrid(geometry);
+ gt.assertSridAllowedForType(geometrySrid);
+ byte[] wkb = stAsBinary(geometry);
+ return org.apache.spark.sql.types.Geometry.fromWKB(wkb, geometrySrid);
+ }
+
+ public static org.apache.spark.sql.types.Geography deserializeGeog(
+ GeographyVal geography, GeographyType gt) {
+ int geographySrid = stSrid(geography);
+ gt.assertSridAllowedForType(geographySrid);
+ byte[] wkb = stAsBinary(geography);
+ return org.apache.spark.sql.types.Geography.fromWKB(wkb, geographySrid);
+ }
+
/** Methods for implementing ST expressions. */
// ST_AsBinary
@@ -67,6 +111,35 @@ public static GeometryVal stGeomFromWKB(byte[] wkb) {
return toPhysVal(Geometry.fromWkb(wkb));
}
+ public static GeometryVal stGeomFromWKB(byte[] wkb, int srid) {
+ return toPhysVal(Geometry.fromWkb(wkb, srid));
+ }
+
+ // ST_SetSrid
+ public static GeographyVal stSetSrid(GeographyVal geo, int srid) {
+ // We only allow setting the SRID to geographic values.
+ if(!GeographyType.isSridSupported(srid)) {
+ throw QueryExecutionErrors.stInvalidSridValueError(srid);
+ }
+ // Create a copy of the input geography.
+ Geography copy = fromPhysVal(geo).copy();
+ // Set the SRID of the copy to the specified value.
+ copy.setSrid(srid);
+ return toPhysVal(copy);
+ }
+
+ public static GeometryVal stSetSrid(GeometryVal geo, int srid) {
+ // We only allow setting the SRID to valid values.
+ if(!GeometryType.isSridSupported(srid)) {
+ throw QueryExecutionErrors.stInvalidSridValueError(srid);
+ }
+ // Create a copy of the input geometry.
+ Geometry copy = fromPhysVal(geo).copy();
+ // Set the SRID of the copy to the specified value.
+ copy.setSrid(srid);
+ return toPhysVal(copy);
+ }
+
// ST_Srid
public static int stSrid(GeographyVal geog) {
return fromPhysVal(geog).srid();
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsV1OverwriteWithSaveAsTable.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsV1OverwriteWithSaveAsTable.java
new file mode 100644
index 0000000000000..63ee7493cb85b
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsV1OverwriteWithSaveAsTable.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog;
+
+import org.apache.spark.annotation.Evolving;
+
+/**
+ * A marker interface that can be mixed into a {@link TableProvider} to indicate that the data
+ * source needs to distinguish between DataFrameWriter V1 {@code saveAsTable} operations and
+ * DataFrameWriter V2 {@code createOrReplace}/{@code replace} operations.
+ *
+ * Background: DataFrameWriter V1's {@code saveAsTable} with {@code SaveMode.Overwrite} creates
+ * a {@code ReplaceTableAsSelect} logical plan, which is identical to the plan created by
+ * DataFrameWriter V2's {@code createOrReplace}. However, the documented semantics can have
+ * different interpretations:
+ *
+ * - V1 saveAsTable with Overwrite: "if data/table already exists, existing data is expected
+ * to be overwritten by the contents of the DataFrame" - does not define behavior for
+ * metadata (schema) overwriting
+ * - V2 createOrReplace: "The output table's schema, partition layout, properties, and other
+ * configuration will be based on the contents of the data frame... If the table exists,
+ * its configuration and data will be replaced"
+ *
+ *
+ * Data sources that migrated from V1 to V2 may have adopted different behaviors based on these
+ * documented semantics. For example, Delta Lake interprets V1 saveAsTable to not replace table
+ * schema unless the {@code overwriteSchema} option is explicitly set.
+ *
+ * When a {@link TableProvider} implements this interface and
+ * {@link #addV1OverwriteWithSaveAsTableOption()} returns true, DataFrameWriter V1 will add an
+ * internal write option to indicate that the command originated from saveAsTable API.
+ * The option key used is defined by {@link #OPTION_NAME} and the value will be set to "true".
+ * This allows the data source to distinguish between the two APIs and apply appropriate
+ * semantics.
+ *
+ * @since 4.1.0
+ */
+@Evolving
+public interface SupportsV1OverwriteWithSaveAsTable extends TableProvider {
+ /**
+ * The name of the internal write option that indicates the command originated from
+ * DataFrameWriter V1 saveAsTable API.
+ */
+ String OPTION_NAME = "__v1_save_as_table_overwrite";
+
+ /**
+ * Returns whether to add the "__v1_save_as_table_overwrite" to write operations originating
+ * from DataFrameWriter V1 saveAsTable with mode Overwrite.
+ * Implementations can override this method to control when the option is added.
+ *
+ * @return true if the option should be added (default), false otherwise
+ */
+ default boolean addV1OverwriteWithSaveAsTableOption() {
+ return true;
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java
index 3a1e0d9f7011e..a298520760bc0 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java
@@ -50,6 +50,16 @@ public interface Table {
*/
String name();
+ /**
+ * An ID of the table that can be used to reliably check if two table objects refer to the same
+ * metastore entity. If a table is dropped and recreated again with the same name, the new table
+ * ID must be different. This method must return null if connectors don't support the notion of
+ * table ID.
+ */
+ default String id() {
+ return null;
+ }
+
/**
* Returns the schema of this table. If the table is not readable and doesn't have a schema, an
* empty schema can be returned here.
@@ -94,8 +104,10 @@ default Map properties() {
default Constraint[] constraints() { return new Constraint[0]; }
/**
- * Returns the current table version if implementation supports versioning.
- * If the table is not versioned, null can be returned here.
+ * Returns the version of this table if versioning is supported, null otherwise.
+ *
+ * This method must not trigger a refresh of the table metadata. It should return
+ * the version that corresponds to the current state of this table instance.
*/
- default String currentVersion() { return null; }
+ default String version() { return null; }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Unique.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Unique.java
index d983ef656297e..d4837932863fd 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Unique.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Unique.java
@@ -28,6 +28,15 @@
*
* A UNIQUE constraint specifies one or more columns as unique columns. Such constraint is satisfied
* if and only if no two rows in a table have the same non-null values in the unique columns.
+ * Unlike PRIMARY KEY, UNIQUE allows nullable columns.
+ *
+ * NULL values are treated as distinct from each other (NULLS DISTINCT semantics). Two rows
+ * are considered duplicates only when every column in the unique key has a non-null value and
+ * every value matches. If any column in the unique key is NULL, the row is always considered
+ * unique regardless of other values. In other words, multiple rows with NULL in one or more
+ * unique columns are allowed and do not violate the constraint definition.
+ *
+ * The {@code NULLS NOT DISTINCT} modifier is not currently supported.
*
* Spark doesn't enforce UNIQUE constraints but leverages them for query optimization. Each
* constraint is either valid (the existing data is guaranteed to satisfy the constraint), invalid
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GetArrayItem.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GetArrayItem.java
new file mode 100644
index 0000000000000..5d7e05f598adc
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GetArrayItem.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.expressions;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.internal.connector.ExpressionWithToString;
+
+/**
+ * Get array item expression.
+ *
+ * @since 4.1.0
+ */
+
+@Evolving
+public class GetArrayItem extends ExpressionWithToString {
+
+ private final Expression childArray;
+ private final Expression ordinal;
+ private final boolean failOnError;
+
+ /**
+ * Creates GetArrayItem expression.
+ * @param childArray Array that is source to get element from. Child of this expression.
+ * @param ordinal Ordinal of element. Zero-based indexing.
+ * @param failOnError Whether expression should throw exception for index out of bound or to
+ * return null.
+ */
+ public GetArrayItem(Expression childArray, Expression ordinal, boolean failOnError) {
+ this.childArray = childArray;
+ this.ordinal = ordinal;
+ this.failOnError = failOnError;
+ }
+
+ public Expression childArray() { return this.childArray; }
+ public Expression ordinal() { return this.ordinal; }
+ public boolean failOnError() { return this.failOnError; }
+
+ @Override
+ public Expression[] children() { return new Expression[]{ childArray, ordinal }; }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
index 5286bbf9f85a1..c12bc14a49c44 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/PartitionReader.java
@@ -58,4 +58,13 @@ default CustomTaskMetric[] currentMetricsValues() {
CustomTaskMetric[] NO_METRICS = {};
return NO_METRICS;
}
+
+ /**
+ * Sets the initial value of metrics before fetching any data from the reader. This is called
+ * when multiple {@link PartitionReader}s are grouped into one partition in case of
+ * {@link org.apache.spark.sql.connector.read.partitioning.KeyGroupedPartitioning} and the reader
+ * is initialized with the metrics returned by the previous reader that belongs to the same
+ * partition. By default, this method does nothing.
+ */
+ default void initMetricsValues(CustomTaskMetric[] metrics) {}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariantExtractions.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariantExtractions.java
new file mode 100644
index 0000000000000..750e0479e542d
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariantExtractions.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.read;
+
+import org.apache.spark.annotation.Experimental;
+
+/**
+ * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
+ * support pushing down variant field extraction operations to the data source.
+ *
+ * When variant columns are accessed with specific field extractions (e.g., variant_get,
+ * try_variant_get), the optimizer can push these extractions down to the data source.
+ * The data source can then read only the required fields from variant columns, reducing
+ * I/O and improving performance.
+ *
+ * Each {@link VariantExtraction} in the input array represents one field extraction operation.
+ * Data sources should examine each extraction and determine which ones can be handled efficiently.
+ * The return value is a boolean array of the same length, where each element indicates whether
+ * the corresponding extraction was accepted.
+ *
+ * @since 4.1.0
+ */
+@Experimental
+public interface SupportsPushDownVariantExtractions extends ScanBuilder {
+
+ /**
+ * Pushes down variant field extractions to the data source.
+ *
+ * Each element in the input array represents one field extraction operation from a variant
+ * column. Data sources should examine each extraction and determine whether it can be
+ * pushed down based on the data source's capabilities (e.g., supported data types,
+ * path complexity, etc.).
+ *
+ * The return value is a boolean array of the same length as the input array, where each
+ * element indicates whether the corresponding extraction was accepted:
+ *
+ * - true: The extraction will be handled by the data source
+ * - false: The extraction will be handled by Spark after reading
+ *
+ *
+ * Data sources can choose to accept all, some, or none of the extractions. Spark will
+ * handle any extractions that are not pushed down.
+ *
+ * @param extractions Array of variant extractions, one per field extraction operation
+ * @return Boolean array indicating which extractions were accepted (same length as input)
+ */
+ boolean[] pushVariantExtractions(VariantExtraction[] extractions);
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariants.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariants.java
deleted file mode 100644
index ff82e71bfd586..0000000000000
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownVariants.java
+++ /dev/null
@@ -1,77 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.connector.read;
-
-import org.apache.spark.annotation.Evolving;
-
-/**
- * A mix-in interface for {@link Scan}. Data sources can implement this interface to
- * support pushing down variant field access operations to the data source.
- *
- * When variant columns are accessed with specific field extractions (e.g., variant_get),
- * the optimizer can push these accesses down to the data source. The data source can then
- * read only the required fields from variant columns, reducing I/O and improving performance.
- *
- * The typical workflow is:
- *
- * - Optimizer analyzes the query plan and identifies variant field accesses
- * - Optimizer calls {@link #pushVariantAccess} with the access information
- * - Data source validates and stores the variant access information
- * - Optimizer retrieves pushed information via {@link #pushedVariantAccess}
- * - Data source uses the information to optimize reading in {@link #readSchema()}
- * and readers
- *
- *
- * @since 4.1.0
- */
-@Evolving
-public interface SupportsPushDownVariants extends Scan {
-
- /**
- * Pushes down variant field access information to the data source.
- *
- * Implementations should validate if the variant accesses can be pushed down based on
- * the data source's capabilities. If some accesses cannot be pushed down, the implementation
- * can choose to:
- *
- * - Push down only the supported accesses and return true
- * - Reject all pushdown and return false
- *
- *
- * The implementation should store the variant access information that can be pushed down.
- * The stored information will be retrieved later via {@link #pushedVariantAccess()}.
- *
- * @param variantAccessInfo Array of variant access information, one per variant column
- * @return true if at least some variant accesses were pushed down, false if none were pushed
- */
- boolean pushVariantAccess(VariantAccessInfo[] variantAccessInfo);
-
- /**
- * Returns the variant access information that has been pushed down to this scan.
- *
- * This method is called by the optimizer after {@link #pushVariantAccess} to retrieve
- * what variant accesses were actually accepted by the data source. The optimizer uses
- * this information to rewrite the query plan.
- *
- * If {@link #pushVariantAccess} was not called or returned false, this should return
- * an empty array.
- *
- * @return Array of pushed down variant access information
- */
- VariantAccessInfo[] pushedVariantAccess();
-}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java
index 0921a90ac22a7..927d4a53e22fc 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeFiltering.java
@@ -49,10 +49,11 @@ public interface SupportsRuntimeFiltering extends SupportsRuntimeV2Filtering {
*
* If the scan also implements {@link SupportsReportPartitioning}, it must preserve
* the originally reported partitioning during runtime filtering. While applying runtime filters,
- * the scan may detect that some {@link InputPartition}s have no matching data. It can omit
- * such partitions entirely only if it does not report a specific partitioning. Otherwise,
- * the scan can replace the initially planned {@link InputPartition}s that have no matching
- * data with empty {@link InputPartition}s but must preserve the overall number of partitions.
+ * the scan may detect that some {@link InputPartition}s have no matching data, in which case
+ * it can either replace the initially planned {@link InputPartition}s that have no matching data
+ * with empty {@link InputPartition}s, or report only a subset of the original partition values
+ * (omitting those with no data) via {@link Batch#planInputPartitions()}. The scan must not report
+ * new partition values that were not present in the original partitioning.
*
* Note that Spark will call {@link Scan#toBatch()} again after filtering the scan at runtime.
*
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java
index 7c238bde969b2..f5acdf885bf5c 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsRuntimeV2Filtering.java
@@ -53,11 +53,11 @@ public interface SupportsRuntimeV2Filtering extends Scan {
*
* If the scan also implements {@link SupportsReportPartitioning}, it must preserve
* the originally reported partitioning during runtime filtering. While applying runtime
- * predicates, the scan may detect that some {@link InputPartition}s have no matching data. It
- * can omit such partitions entirely only if it does not report a specific partitioning.
- * Otherwise, the scan can replace the initially planned {@link InputPartition}s that have no
- * matching data with empty {@link InputPartition}s but must preserve the overall number of
- * partitions.
+ * predicates, the scan may detect that some {@link InputPartition}s have no matching data, in
+ * which case it can either replace the initially planned {@link InputPartition}s that have no
+ * matching data with empty {@link InputPartition}s, or report only a subset of the original
+ * partition values (omitting those with no data) via {@link Batch#planInputPartitions()}. The
+ * scan must not report new partition values that were not present in the original partitioning.
*
* Note that Spark will call {@link Scan#toBatch()} again after filtering the scan at runtime.
*
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/VariantAccessInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/VariantAccessInfo.java
deleted file mode 100644
index 4f61a42d05196..0000000000000
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/VariantAccessInfo.java
+++ /dev/null
@@ -1,105 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.connector.read;
-
-import java.io.Serializable;
-import java.util.Objects;
-
-import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.types.StructType;
-
-/**
- * Variant access information that describes how variant fields are accessed in a query.
- *
- * This class captures the information needed by data sources to optimize reading variant columns.
- * Instead of reading the entire variant value, the data source can read only the fields that
- * are actually accessed, represented as a structured schema.
- *
- * For example, if a query accesses `variant_get(v, '$.a', 'int')` and
- * `variant_get(v, '$.b', 'string')`, the extracted schema would be
- * `struct<0:int, 1:string>` where field ordinals correspond to the access order.
- *
- * @since 4.1.0
- */
-@Evolving
-public final class VariantAccessInfo implements Serializable {
- private final String columnName;
- private final StructType extractedSchema;
-
- /**
- * Creates variant access information for a variant column.
- *
- * @param columnName The name of the variant column
- * @param extractedSchema The schema representing extracted fields from the variant.
- * Each field represents one variant field access, with field names
- * typically being ordinals (e.g., "0", "1", "2") and metadata
- * containing variant-specific information like JSON path.
- */
- public VariantAccessInfo(String columnName, StructType extractedSchema) {
- this.columnName = Objects.requireNonNull(columnName, "columnName cannot be null");
- this.extractedSchema =
- Objects.requireNonNull(extractedSchema, "extractedSchema cannot be null");
- }
-
- /**
- * Returns the name of the variant column.
- */
- public String columnName() {
- return columnName;
- }
-
- /**
- * Returns the schema representing fields extracted from the variant column.
- *
- * The schema structure is:
- *
- * - Field names: Typically ordinals ("0", "1", "2", ...) representing access order
- * - Field types: The target data type for each field extraction
- * - Field metadata: Contains variant-specific information such as JSON path,
- * timezone, and error handling mode
- *
- *
- * Data sources should use this schema to determine what fields to extract from the variant
- * and what types they should be converted to.
- */
- public StructType extractedSchema() {
- return extractedSchema;
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || getClass() != o.getClass()) return false;
- VariantAccessInfo that = (VariantAccessInfo) o;
- return columnName.equals(that.columnName) &&
- extractedSchema.equals(that.extractedSchema);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(columnName, extractedSchema);
- }
-
- @Override
- public String toString() {
- return "VariantAccessInfo{" +
- "columnName='" + columnName + '\'' +
- ", extractedSchema=" + extractedSchema +
- '}';
- }
-}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/VariantExtraction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/VariantExtraction.java
new file mode 100644
index 0000000000000..64987299934a2
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/VariantExtraction.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.read;
+
+import java.io.Serializable;
+
+import org.apache.spark.annotation.Experimental;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.Metadata;
+
+/**
+ * Variant extraction information that describes a single field extraction from a variant column.
+ *
+ * This interface captures the information needed by data sources to optimize reading variant
+ * columns. Each instance represents one field extraction operation (e.g., from variant_get or
+ * try_variant_get).
+ *
+ * For example, if a query contains `variant_get(v, '$.a', 'int')`, this would be represented
+ * as a VariantExtraction with columnName=["v"], path="$.a", and expectedDataType=IntegerType.
+ *
+ * @since 4.1.0
+ */
+@Experimental
+public interface VariantExtraction extends Serializable {
+ /**
+ * Returns the path to the variant column. For top-level variant columns, this is a single
+ * element array containing the column name. For nested variant columns within structs,
+ * this is an array representing the path (e.g., ["structCol", "innerStruct", "variantCol"]).
+ */
+ String[] columnName();
+
+ /**
+ * Returns the expected data type for the extracted value.
+ * This is the target type specified in variant_get (e.g., IntegerType, StringType).
+ */
+ DataType expectedDataType();
+
+ /**
+ * Returns the metadata associated with this variant extraction.
+ * This may include additional information needed by the data source:
+ * - "path": the extraction path from variant_get or try_variant_get.
+ * This follows JSON path syntax (e.g., "$.a", "$.b.c", "$[0]").
+ * - "failOnError": whether the extraction to expected data type should throw an exception
+ * or return null if the cast fails.
+ * - "timeZoneId": a string identifier of a time zone. It is required by timestamp-related casts.
+ *
+ */
+ Metadata metadata();
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
index e3b8754691693..50921f3de0b40 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -29,6 +29,7 @@
import org.apache.spark.sql.connector.expressions.Extract;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.GeneralScalarExpression;
+import org.apache.spark.sql.connector.expressions.GetArrayItem;
import org.apache.spark.sql.connector.expressions.Literal;
import org.apache.spark.sql.connector.expressions.NullOrdering;
import org.apache.spark.sql.connector.expressions.SortDirection;
@@ -84,6 +85,8 @@ public String build(Expression expr) {
} else if (expr instanceof SortOrder sortOrder) {
return visitSortOrder(
build(sortOrder.expression()), sortOrder.direction(), sortOrder.nullOrdering());
+ } else if (expr instanceof GetArrayItem getArrayItem) {
+ return visitGetArrayItem(getArrayItem);
} else if (expr instanceof GeneralScalarExpression e) {
String name = e.name();
return switch (name) {
@@ -348,6 +351,13 @@ protected String visitTrim(String direction, String[] inputs) {
}
}
+ protected String visitGetArrayItem(GetArrayItem getArrayItem) {
+ throw new SparkUnsupportedOperationException(
+ "EXPRESSION_TRANSLATION_TO_V2_IS_NOT_SUPPORTED",
+ Map.of("expr", getArrayItem.toString())
+ );
+ }
+
protected String visitExtract(Extract extract) {
return visitExtract(extract.field(), build(extract.source()));
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
index 66116d7c952fd..019bc258579a8 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java
@@ -25,9 +25,12 @@
import org.apache.spark.SparkUnsupportedOperationException;
import org.apache.spark.annotation.DeveloperApi;
+import org.apache.spark.sql.catalyst.util.STUtils;
import org.apache.spark.sql.util.ArrowUtils;
import org.apache.spark.sql.types.*;
import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.GeographyVal;
+import org.apache.spark.unsafe.types.GeometryVal;
import org.apache.spark.unsafe.types.UTF8String;
/**
@@ -146,6 +149,30 @@ public ColumnarMap getMap(int rowId) {
super(type);
}
+ @Override
+ public GeographyVal getGeography(int rowId) {
+ if (isNullAt(rowId)) return null;
+
+ GeographyType gt = (GeographyType) this.type;
+ int srid = getChild(0).getInt(rowId);
+ byte[] bytes = getChild(1).getBinary(rowId);
+ gt.assertSridAllowedForType(srid);
+ // TODO(GEO-602): Geog still does not support different SRIDs, once it does,
+ // we need to update this.
+ return (bytes == null) ? null : STUtils.stGeogFromWKB(bytes);
+ }
+
+ @Override
+ public GeometryVal getGeometry(int rowId) {
+ if (isNullAt(rowId)) return null;
+
+ GeometryType gt = (GeometryType) this.type;
+ int srid = getChild(0).getInt(rowId);
+ byte[] bytes = getChild(1).getBinary(rowId);
+ gt.assertSridAllowedForType(srid);
+ return (bytes == null) ? null : STUtils.stGeomFromWKB(bytes, srid);
+ }
+
public ArrowColumnVector(ValueVector vector) {
this(ArrowUtils.fromArrowField(vector.getField()));
initAccessor(vector);
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java
index 4be45dc5d399d..3d1e780f6e057 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRow.java
@@ -86,6 +86,8 @@ public InternalRow copy() {
row.update(i, getArray(i).copy());
} else if (pdt instanceof PhysicalMapType) {
row.update(i, getMap(i).copy());
+ } else if (pdt instanceof PhysicalVariantType) {
+ row.update(i, getVariant(i));
} else {
throw new RuntimeException("Not implemented. " + dt);
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java
index d9e65afe1cb00..656c5f8a8f30e 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java
@@ -91,6 +91,8 @@ public InternalRow copy() {
row.update(i, getArray(i).copy());
} else if (pdt instanceof PhysicalMapType) {
row.update(i, getMap(i).copy());
+ } else if (pdt instanceof PhysicalVariantType) {
+ row.update(i, getVariant(i));
} else {
throw new RuntimeException("Not implemented. " + dt);
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index c1e0674d391d2..cdbd2d49e8b74 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DayTimeIntervalType._
import org.apache.spark.sql.types.YearMonthIntervalType._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.types.{GeographyVal, GeometryVal, UTF8String}
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.collection.Utils
@@ -61,6 +61,7 @@ object CatalystTypeConverters {
}
private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = {
+ TypeUtils.failUnsupportedDataType(dataType, SQLConf.get)
val converter = dataType match {
case udt: UserDefinedType[_] => UDTConverter(udt)
case arrayType: ArrayType => ArrayConverter(arrayType.elementType)
@@ -69,6 +70,10 @@ object CatalystTypeConverters {
case CharType(length) => new CharConverter(length)
case VarcharType(length) => new VarcharConverter(length)
case _: StringType => StringConverter
+ case g: GeographyType =>
+ new GeographyConverter(g)
+ case g: GeometryType =>
+ new GeometryConverter(g)
case DateType if SQLConf.get.datetimeJava8ApiEnabled => LocalDateConverter
case DateType => DateConverter
case _: TimeType => TimeConverter
@@ -345,6 +350,64 @@ object CatalystTypeConverters {
row.getUTF8String(column).toString
}
+ private def assertGeospatialEnabled(): Unit = {
+ if (!SQLConf.get.geospatialEnabled) {
+ throw new org.apache.spark.sql.AnalysisException(
+ errorClass = "UNSUPPORTED_FEATURE.GEOSPATIAL_DISABLED",
+ messageParameters = scala.collection.immutable.Map.empty)
+ }
+ }
+
+ private class GeometryConverter(dataType: GeometryType)
+ extends CatalystTypeConverter[Any, org.apache.spark.sql.types.Geometry, GeometryVal] {
+ override def toCatalystImpl(scalaValue: Any): GeometryVal = scalaValue match {
+ case g: org.apache.spark.sql.types.Geometry if SQLConf.get.geospatialEnabled =>
+ STUtils.serializeGeomFromWKB(g, dataType)
+ case other => throw new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_3219",
+ messageParameters = scala.collection.immutable.Map(
+ "other" -> other.toString,
+ "otherClass" -> other.getClass.getCanonicalName,
+ "dataType" -> StringType.sql))
+ }
+ override def toScala(catalystValue: GeometryVal): org.apache.spark.sql.types.Geometry = {
+ assertGeospatialEnabled()
+ if (catalystValue == null) null
+ else STUtils.deserializeGeom(catalystValue, dataType)
+ }
+
+ override def toScalaImpl(row: InternalRow, column: Int):
+ org.apache.spark.sql.types.Geometry = {
+ assertGeospatialEnabled()
+ STUtils.deserializeGeom(row.getGeometry(0), dataType)
+ }
+ }
+
+ private class GeographyConverter(dataType: GeographyType)
+ extends CatalystTypeConverter[Any, org.apache.spark.sql.types.Geography, GeographyVal] {
+ override def toCatalystImpl(scalaValue: Any): GeographyVal = scalaValue match {
+ case g: org.apache.spark.sql.types.Geography if SQLConf.get.geospatialEnabled =>
+ STUtils.serializeGeogFromWKB(g, dataType)
+ case other => throw new SparkIllegalArgumentException(
+ errorClass = "_LEGACY_ERROR_TEMP_3219",
+ messageParameters = scala.collection.immutable.Map(
+ "other" -> other.toString,
+ "otherClass" -> other.getClass.getCanonicalName,
+ "dataType" -> StringType.sql))
+ }
+ override def toScala(catalystValue: GeographyVal): org.apache.spark.sql.types.Geography = {
+ assertGeospatialEnabled()
+ if (catalystValue == null) null
+ else STUtils.deserializeGeog(catalystValue, dataType)
+ }
+
+ override def toScalaImpl(row: InternalRow, column: Int):
+ org.apache.spark.sql.types.Geography = {
+ assertGeospatialEnabled()
+ STUtils.deserializeGeog(row.getGeography(0), dataType)
+ }
+ }
+
private object DateConverter extends CatalystTypeConverter[Any, Date, Any] {
override def toCatalystImpl(scalaValue: Any): Int = scalaValue match {
case d: Date => DateTimeUtils.fromJavaDate(d)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
index a051205829a11..080794643fa0e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
@@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec}
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, GeographyEncoder, GeometryEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{dataTypeForClass, externalDataTypeFor, isNativeEncoder}
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, DecodeUsingSerializer, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption}
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharCodegenUtils, DateTimeUtils, IntervalUtils}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharCodegenUtils, DateTimeUtils, IntervalUtils, STUtils}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
object DeserializerBuildHelper {
@@ -80,6 +81,24 @@ object DeserializerBuildHelper {
returnNullable = false)
}
+ def createDeserializerForGeometryType(inputObject: Expression, gt: GeometryType): Expression = {
+ StaticInvoke(
+ classOf[STUtils],
+ ObjectType(classOf[Geometry]),
+ "deserializeGeom",
+ inputObject :: Literal.fromObject(gt) :: Nil,
+ returnNullable = false)
+ }
+
+ def createDeserializerForGeographyType(inputObject: Expression, gt: GeographyType): Expression = {
+ StaticInvoke(
+ classOf[STUtils],
+ ObjectType(classOf[Geography]),
+ "deserializeGeog",
+ inputObject :: Literal.fromObject(gt) :: Nil,
+ returnNullable = false)
+ }
+
def createDeserializerForChar(
path: Expression,
returnNullable: Boolean,
@@ -290,6 +309,14 @@ object DeserializerBuildHelper {
"withName",
createDeserializerForString(path, returnNullable = false) :: Nil,
returnNullable = false)
+ case _ @ (_: GeographyEncoder | _: GeometryEncoder) if !SQLConf.get.geospatialEnabled =>
+ throw new org.apache.spark.sql.AnalysisException(
+ errorClass = "UNSUPPORTED_FEATURE.GEOSPATIAL_DISABLED",
+ messageParameters = scala.collection.immutable.Map.empty)
+ case g: GeographyEncoder =>
+ createDeserializerForGeographyType(path, g.dt)
+ case g: GeometryEncoder =>
+ createDeserializerForGeometryType(path, g.dt)
case CharEncoder(length) =>
createDeserializerForChar(path, returnNullable = false, length)
case VarcharEncoder(length) =>
@@ -318,6 +345,8 @@ object DeserializerBuildHelper {
createDeserializerForInstant(path)
case LocalDateTimeEncoder =>
createDeserializerForLocalDateTime(path)
+ case LocalTimeEncoder if !SQLConf.get.isTimeTypeEnabled =>
+ throw org.apache.spark.sql.errors.QueryCompilationErrors.unsupportedTimeTypeError()
case LocalTimeEncoder =>
createDeserializerForLocalTime(path)
case UDTEncoder(udt, udtClass) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
index 82b3cdc508bf9..b8b2406a58130 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
@@ -22,11 +22,11 @@ import scala.language.existentials
import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec}
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, GeographyEncoder, GeometryEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor}
import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.{ArrayData, CharVarcharCodegenUtils, DateTimeUtils, GenericArrayData, IntervalUtils}
+import org.apache.spark.sql.catalyst.util.{ArrayData, CharVarcharCodegenUtils, DateTimeUtils, GenericArrayData, IntervalUtils, STUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -63,6 +63,24 @@ object SerializerBuildHelper {
Invoke(inputObject, "doubleValue", DoubleType)
}
+ def createSerializerForGeographyType(inputObject: Expression, gt: GeographyType): Expression = {
+ StaticInvoke(
+ classOf[STUtils],
+ gt,
+ "serializeGeogFromWKB",
+ inputObject :: Literal.fromObject(gt) :: Nil,
+ returnNullable = false)
+ }
+
+ def createSerializerForGeometryType(inputObject: Expression, gt: GeometryType): Expression = {
+ StaticInvoke(
+ classOf[STUtils],
+ gt,
+ "serializeGeomFromWKB",
+ inputObject :: Literal.fromObject(gt) :: Nil,
+ returnNullable = false)
+ }
+
def createSerializerForChar(inputObject: Expression, length: Int): Expression = {
StaticInvoke(
classOf[CharVarcharCodegenUtils],
@@ -326,6 +344,12 @@ object SerializerBuildHelper {
case BoxedDoubleEncoder => createSerializerForDouble(input)
case JavaEnumEncoder(_) => createSerializerForJavaEnum(input)
case ScalaEnumEncoder(_, _) => createSerializerForScalaEnum(input)
+ case _ @ (_: GeographyEncoder | _: GeometryEncoder) if !SQLConf.get.geospatialEnabled =>
+ throw new org.apache.spark.sql.AnalysisException(
+ errorClass = "UNSUPPORTED_FEATURE.GEOSPATIAL_DISABLED",
+ messageParameters = scala.collection.immutable.Map.empty)
+ case g: GeographyEncoder => createSerializerForGeographyType(input, g.dt)
+ case g: GeometryEncoder => createSerializerForGeometryType(input, g.dt)
case CharEncoder(length) => createSerializerForChar(input, length)
case VarcharEncoder(length) => createSerializerForVarchar(input, length)
case StringEncoder => createSerializerForString(input)
@@ -343,6 +367,8 @@ object SerializerBuildHelper {
case TimestampEncoder(false) => createSerializerForSqlTimestamp(input)
case InstantEncoder(false) => createSerializerForJavaInstant(input)
case LocalDateTimeEncoder => createSerializerForLocalDateTime(input)
+ case LocalTimeEncoder if !SQLConf.get.isTimeTypeEnabled =>
+ throw org.apache.spark.sql.errors.QueryCompilationErrors.unsupportedTimeTypeError()
case LocalTimeEncoder => createSerializerForLocalTime(input)
case UDTEncoder(udt, udtClass) => createSerializerForUserDefinedType(input, udt, udtClass)
case OptionEncoder(valueEnc) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 98c514925fa04..23d32dd87db12 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -80,7 +80,8 @@ object SimpleAnalyzer extends Analyzer(
FunctionRegistry.builtin,
TableFunctionRegistry.builtin) {
override def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = {}
- })) {
+ }),
+ RelationCache.empty) {
override def resolver: Resolver = caseSensitiveResolution
}
@@ -285,11 +286,14 @@ object Analyzer {
* Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and
* [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]].
*/
-class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor[LogicalPlan]
+class Analyzer(
+ override val catalogManager: CatalogManager,
+ private[sql] val sharedRelationCache: RelationCache = RelationCache.empty)
+ extends RuleExecutor[LogicalPlan]
with CheckAnalysis with AliasHelper with SQLConfHelper with ColumnResolutionHelper {
private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog
- private val relationResolution = new RelationResolution(catalogManager)
+ private val relationResolution = new RelationResolution(catalogManager, sharedRelationCache)
private val functionResolution = new FunctionResolution(catalogManager, relationResolution)
override protected def validatePlanChanges(
@@ -1185,12 +1189,26 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
}
+ private def resolveAsV2Relation(plan: LogicalPlan): Option[DataSourceV2Relation] = {
+ plan match {
+ case ref: V2TableReference =>
+ EliminateSubqueryAliases(relationResolution.resolveReference(ref)) match {
+ case r: DataSourceV2Relation => Some(r)
+ case _ => None
+ }
+ case r: DataSourceV2Relation => Some(r)
+ case _ => None
+ }
+ }
+
def apply(plan: LogicalPlan)
: LogicalPlan = plan.resolveOperatorsUpWithPruning(AlwaysProcess.fn, ruleId) {
case i @ InsertIntoStatement(table, _, _, _, _, _, _) =>
val relation = table match {
case u: UnresolvedRelation if !u.isStreaming =>
resolveRelation(u).getOrElse(u)
+ case r: V2TableReference =>
+ relationResolution.resolveReference(r)
case other => other
}
@@ -1210,13 +1228,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
resolveRelation(u).map(unwrapRelationPlan).map {
case v: View => throw QueryCompilationErrors.writeIntoViewNotAllowedError(
v.desc.identifier, write)
- case r: DataSourceV2Relation => write.withNewTable(r)
case u: UnresolvedCatalogRelation =>
throw QueryCompilationErrors.writeIntoV1TableNotAllowedError(
u.tableMeta.identifier, write)
- case other =>
- throw QueryCompilationErrors.writeIntoTempViewNotAllowedError(
- u.multipartIdentifier.quoted)
+ case plan =>
+ resolveAsV2Relation(plan).map(write.withNewTable).getOrElse {
+ throw QueryCompilationErrors.writeIntoTempViewNotAllowedError(
+ u.multipartIdentifier.quoted)
+ }
}.getOrElse(write)
case _ => write
}
@@ -1224,6 +1243,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case u: UnresolvedRelation =>
resolveRelation(u).map(resolveViews(_, u.options)).getOrElse(u)
+ case r: V2TableReference =>
+ relationResolution.resolveReference(r)
+
case r @ RelationTimeTravel(u: UnresolvedRelation, timestamp, version)
if timestamp.forall(ts => ts.resolved && !SubqueryExpression.hasSubquery(ts)) =>
val timeTravelSpec = TimeTravelSpec.create(timestamp, version, conf.sessionLocalTimeZone)
@@ -1670,7 +1692,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
case u: UpdateTable => resolveReferencesInUpdate(u)
case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _, _)
- if !m.resolved && targetTable.resolved && sourceTable.resolved && !m.needSchemaEvolution =>
+ if !m.resolved && targetTable.resolved && sourceTable.resolved =>
+
+ // Do not throw exception for schema evolution case.
+ // This allows unresolved assignment keys a chance to be resolved by a second pass
+ // by newly column/nested fields added by schema evolution.
+ val throws = !m.schemaEvolutionEnabled
EliminateSubqueryAliases(targetTable) match {
case r: NamedRelation if r.skipSchemaResolution =>
@@ -1680,29 +1707,45 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
m
case _ =>
+ def findAttrInTarget(name: String): Option[Attribute] = {
+ targetTable.output.find(targetAttr => conf.resolver(name, targetAttr.name))
+ }
val newMatchedActions = m.matchedActions.map {
case DeleteAction(deleteCondition) =>
val resolvedDeleteCondition = deleteCondition.map(
resolveExpressionByPlanChildren(_, m))
DeleteAction(resolvedDeleteCondition)
- case UpdateAction(updateCondition, assignments) =>
+ case UpdateAction(updateCondition, assignments, fromStar) =>
val resolvedUpdateCondition = updateCondition.map(
resolveExpressionByPlanChildren(_, m))
UpdateAction(
resolvedUpdateCondition,
// The update value can access columns from both target and source tables.
- resolveAssignments(assignments, m, MergeResolvePolicy.BOTH))
+ resolveAssignments(assignments, m, MergeResolvePolicy.BOTH,
+ throws = throws),
+ fromStar)
case UpdateStarAction(updateCondition) =>
- // Use only source columns. Missing columns in target will be handled in
- // ResolveRowLevelCommandAssignments.
- val assignments = targetTable.output.flatMap{ targetAttr =>
- sourceTable.output.find(
- sourceCol => conf.resolver(sourceCol.name, targetAttr.name))
- .map(Assignment(targetAttr, _))}
+ // Expand star to top level source columns. If source has less columns than target,
+ // assignments will be added by ResolveRowLevelCommandAssignments later.
+ val assignments = if (m.schemaEvolutionEnabled) {
+ // For schema evolution case, generate assignments for missing target columns.
+ // These columns will be added by ResolveMergeIntoTableSchemaEvolution later.
+ sourceTable.output.map { sourceAttr =>
+ val key = findAttrInTarget(sourceAttr.name).getOrElse(
+ UnresolvedAttribute(sourceAttr.name))
+ Assignment(key, sourceAttr)
+ }
+ } else {
+ targetTable.output.map { attr =>
+ Assignment(attr, UnresolvedAttribute(Seq(attr.name)))
+ }
+ }
UpdateAction(
updateCondition.map(resolveExpressionByPlanChildren(_, m)),
// For UPDATE *, the value must be from source table.
- resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE))
+ resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE,
+ throws = throws),
+ fromStar = true)
case o => o
}
val newNotMatchedActions = m.notMatchedActions.map {
@@ -1713,21 +1756,32 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
resolveExpressionByPlanOutput(_, m.sourceTable))
InsertAction(
resolvedInsertCondition,
- resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE))
+ resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE,
+ throws = throws))
case InsertStarAction(insertCondition) =>
// The insert action is used when not matched, so its condition and value can only
// access columns from the source table.
val resolvedInsertCondition = insertCondition.map(
resolveExpressionByPlanOutput(_, m.sourceTable))
- // Use only source columns. Missing columns in target will be handled in
- // ResolveRowLevelCommandAssignments.
- val assignments = targetTable.output.flatMap{ targetAttr =>
- sourceTable.output.find(
- sourceCol => conf.resolver(sourceCol.name, targetAttr.name))
- .map(Assignment(targetAttr, _))}
+ // Expand star to top level source columns. If source has less columns than target,
+ // assignments will be added by ResolveRowLevelCommandAssignments later.
+ val assignments = if (m.schemaEvolutionEnabled) {
+ // For schema evolution case, generate assignments for missing target columns.
+ // These columns will be added by ResolveMergeIntoTableSchemaEvolution later.
+ sourceTable.output.map { sourceAttr =>
+ val key = findAttrInTarget(sourceAttr.name).getOrElse(
+ UnresolvedAttribute(sourceAttr.name))
+ Assignment(key, sourceAttr)
+ }
+ } else {
+ targetTable.output.map { attr =>
+ Assignment(attr, UnresolvedAttribute(Seq(attr.name)))
+ }
+ }
InsertAction(
resolvedInsertCondition,
- resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE))
+ resolveAssignments(assignments, m, MergeResolvePolicy.SOURCE,
+ throws = throws))
case o => o
}
val newNotMatchedBySourceActions = m.notMatchedBySourceActions.map {
@@ -1735,13 +1789,15 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
val resolvedDeleteCondition = deleteCondition.map(
resolveExpressionByPlanOutput(_, targetTable))
DeleteAction(resolvedDeleteCondition)
- case UpdateAction(updateCondition, assignments) =>
+ case UpdateAction(updateCondition, assignments, fromStar) =>
val resolvedUpdateCondition = updateCondition.map(
resolveExpressionByPlanOutput(_, targetTable))
UpdateAction(
resolvedUpdateCondition,
// The update value can access columns from the target table only.
- resolveAssignments(assignments, m, MergeResolvePolicy.TARGET))
+ resolveAssignments(assignments, m, MergeResolvePolicy.TARGET,
+ throws = throws),
+ fromStar)
case o => o
}
@@ -1818,11 +1874,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
def resolveAssignments(
assignments: Seq[Assignment],
mergeInto: MergeIntoTable,
- resolvePolicy: MergeResolvePolicy.Value): Seq[Assignment] = {
+ resolvePolicy: MergeResolvePolicy.Value,
+ throws: Boolean): Seq[Assignment] = {
assignments.map { assign =>
val resolvedKey = assign.key match {
case c if !c.resolved =>
- resolveMergeExprOrFail(c, Project(Nil, mergeInto.targetTable))
+ resolveMergeExpr(c, Project(Nil, mergeInto.targetTable), throws)
case o => o
}
val resolvedValue = assign.value match {
@@ -1842,7 +1899,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
} else {
resolvedExpr
}
- checkResolvedMergeExpr(withDefaultResolved, resolvePlan)
+ if (throws) {
+ checkResolvedMergeExpr(withDefaultResolved, resolvePlan)
+ }
withDefaultResolved
case o => o
}
@@ -1850,9 +1909,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
}
}
- private def resolveMergeExprOrFail(e: Expression, p: LogicalPlan): Expression = {
- val resolved = resolveExprInAssignment(e, p)
- checkResolvedMergeExpr(resolved, p)
+ private def resolveMergeExpr(e: Expression, p: LogicalPlan, throws: Boolean): Expression = {
+ val resolved = resolveExprInAssignment(e, p, throws)
+ if (throws) {
+ checkResolvedMergeExpr(resolved, p)
+ }
resolved
}
@@ -2172,8 +2233,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
v1SessionCatalog.resolvePersistentTableFunction(
ident.asFunctionIdentifier, u.functionArgs)
} else {
- throw QueryCompilationErrors.missingCatalogAbilityError(
- catalog, "table-valued functions")
+ throw QueryCompilationErrors.missingCatalogTableValuedFunctionsAbilityError(
+ catalog)
}
}
resolvedFunc.transformAllExpressionsWithPruning(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
index 13b554eb53d4d..e23e7561f0e36 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
@@ -130,6 +130,13 @@ object AnsiTypeCoercion extends TypeCoercionBase {
case (t1: YearMonthIntervalType, t2: YearMonthIntervalType) =>
Some(YearMonthIntervalType(t1.startField.min(t2.startField), t1.endField.max(t2.endField)))
+ // We allow coercion from GEOGRAPHY() types (i.e. fixed SRID types) to the
+ // GEOGRAPHY(ANY) type (i.e. mixed SRID type). This coercion is always safe to do.
+ case (t1: GeographyType, t2: GeographyType) if t1 != t2 => Some(GeographyType("ANY"))
+ // We allow coercion from GEOMETRY() types (i.e. fixed SRID types) to the
+ // GEOMETRY(ANY) type (i.e. mixed SRID type). This coercion is always safe to do.
+ case (t1: GeometryType, t2: GeometryType) if t1 != t2 => Some(GeometryType("ANY"))
+
case (t1, t2) => findTypeForComplex(t1, t2, findTightestCommonType)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala
index 145c9077a4c2b..df4b0646ed42f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala
@@ -19,15 +19,17 @@ package org.apache.spark.sql.catalyst.analysis
import scala.collection.mutable
+import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.analysis.TableOutputResolver.DefaultValueFillMode.{NONE, RECURSE}
-import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Expression, GetStructField, Literal}
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, CreateNamedStruct, Expression, GetStructField, If, IsNull, Literal}
import org.apache.spark.sql.catalyst.plans.logical.Assignment
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getDefaultValueExprOrNullLit
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.ArrayImplicits._
@@ -50,6 +52,9 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
*
* @param attrs table attributes
* @param assignments assignments to align
+ * @param fromStar whether the assignments were resolved from an UPDATE SET * clause.
+ * These updates may assign struct fields individually
+ * (preserving existing fields).
* @param coerceNestedTypes whether to coerce nested types to match the target type
* for complex types
* @return aligned update assignments that match table attributes
@@ -57,6 +62,7 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
def alignUpdateAssignments(
attrs: Seq[Attribute],
assignments: Seq[Assignment],
+ fromStar: Boolean,
coerceNestedTypes: Boolean): Seq[Assignment] = {
val errors = new mutable.ArrayBuffer[String]()
@@ -68,7 +74,8 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
assignments,
addError = err => errors += err,
colPath = Seq(attr.name),
- coerceNestedTypes)
+ coerceNestedTypes,
+ fromStar)
}
if (errors.nonEmpty) {
@@ -152,7 +159,8 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
assignments: Seq[Assignment],
addError: String => Unit,
colPath: Seq[String],
- coerceNestedTypes: Boolean = false): Expression = {
+ coerceNestedTypes: Boolean = false,
+ updateStar: Boolean = false): Expression = {
val (exactAssignments, otherAssignments) = assignments.partition { assignment =>
assignment.key.semanticEquals(colExpr)
@@ -174,9 +182,25 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
} else if (exactAssignments.isEmpty && fieldAssignments.isEmpty) {
TableOutputResolver.checkNullability(colExpr, col, conf, colPath)
} else if (exactAssignments.nonEmpty) {
- val value = exactAssignments.head.value
- val coerceMode = if (coerceNestedTypes) RECURSE else NONE
- TableOutputResolver.resolveUpdate("", value, col, conf, addError, colPath, coerceMode)
+ if (updateStar && SQLConf.get.coerceMergeNestedTypes) {
+ val value = exactAssignments.head.value
+ col.dataType match {
+ case _: StructType =>
+ // Expand assignments to leaf fields (fixNullExpansion is applied inside)
+ applyNestedFieldAssignments(col, colExpr, value, addError, colPath,
+ coerceNestedTypes)
+ case _ =>
+ // For non-struct types, resolve directly
+ val coerceMode = if (coerceNestedTypes) RECURSE else NONE
+ TableOutputResolver.resolveUpdate("", value, col, conf, addError, colPath,
+ coerceMode)
+ }
+ } else {
+ val value = exactAssignments.head.value
+ val coerceMode = if (coerceNestedTypes) RECURSE else NONE
+ TableOutputResolver.resolveUpdate("", value, col, conf, addError,
+ colPath, coerceMode)
+ }
} else {
applyFieldAssignments(col, colExpr, fieldAssignments, addError, colPath, coerceNestedTypes)
}
@@ -188,7 +212,7 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
assignments: Seq[Assignment],
addError: String => Unit,
colPath: Seq[String],
- coerceNestedTyptes: Boolean): Expression = {
+ coerceNestedTypes: Boolean): Expression = {
col.dataType match {
case structType: StructType =>
@@ -198,14 +222,74 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
}
val updatedFieldExprs = fieldAttrs.zip(fieldExprs).map { case (fieldAttr, fieldExpr) =>
applyAssignments(fieldAttr, fieldExpr, assignments, addError, colPath :+ fieldAttr.name,
- coerceNestedTyptes)
+ coerceNestedTypes)
}
toNamedStruct(structType, updatedFieldExprs)
case otherType =>
addError(
"Updating nested fields is only supported for StructType but " +
- s"'${colPath.quoted}' is of type $otherType")
+ s"'${colPath.quoted}' is of type $otherType")
+ colExpr
+ }
+ }
+
+ private def applyNestedFieldAssignments(
+ col: Attribute,
+ colExpr: Expression,
+ value: Expression,
+ addError: String => Unit,
+ colPath: Seq[String],
+ coerceNestedTypes: Boolean): Expression = {
+
+ col.dataType match {
+ case structType: StructType =>
+ val fieldAttrs = DataTypeUtils.toAttributes(structType)
+
+ val updatedFieldExprs = fieldAttrs.zipWithIndex.map { case (fieldAttr, ordinal) =>
+ val fieldPath = colPath :+ fieldAttr.name
+ val targetFieldExpr = GetStructField(colExpr, ordinal, Some(fieldAttr.name))
+
+ // Try to find a corresponding field in the source value by name
+ val sourceFieldValue: Expression = value.dataType match {
+ case valueStructType: StructType =>
+ valueStructType.fields.find(f => conf.resolver(f.name, fieldAttr.name)) match {
+ case Some(matchingField) =>
+ // Found matching field in source, extract it
+ val fieldIndex = valueStructType.fieldIndex(matchingField.name)
+ GetStructField(value, fieldIndex, Some(matchingField.name))
+ case None =>
+ // Field doesn't exist in source, use target's current value with null check
+ TableOutputResolver.checkNullability(targetFieldExpr, fieldAttr, conf, fieldPath)
+ }
+ case _ =>
+ // Value is not a struct, cannot extract field
+ addError(s"Cannot assign non-struct value to struct field '${fieldPath.quoted}'")
+ Literal(null, fieldAttr.dataType)
+ }
+
+ // Recurse or resolve based on field type
+ fieldAttr.dataType match {
+ case _: StructType =>
+ // Field is a struct, recurse
+ applyNestedFieldAssignments(fieldAttr, targetFieldExpr,
+ sourceFieldValue, addError, fieldPath, coerceNestedTypes)
+ case _ =>
+ // Field is not a struct, resolve with TableOutputResolver
+ val coerceMode = if (coerceNestedTypes) RECURSE else NONE
+ TableOutputResolver.resolveUpdate("", sourceFieldValue, fieldAttr, conf, addError,
+ fieldPath, coerceMode)
+ }
+ }
+ val namedStruct = toNamedStruct(structType, updatedFieldExprs)
+
+ // Prevent unnecessary null struct expansion
+ fixNullExpansion(colExpr, value, structType, namedStruct, colPath)
+
+ case otherType =>
+ addError(
+ "Updating nested fields is only supported for StructType but " +
+ s"'${colPath.quoted}' is of type $otherType")
colExpr
}
}
@@ -217,6 +301,73 @@ object AssignmentUtils extends SQLConfHelper with CastSupport {
CreateNamedStruct(namedStructExprs)
}
+ /**
+ * Checks if target struct has extra fields compared to source struct, recursively.
+ */
+ private def hasExtraTargetFields(targetType: StructType, sourceType: DataType): Boolean = {
+ sourceType match {
+ case sourceStructType: StructType =>
+ targetType.fields.exists { targetField =>
+ sourceStructType.fields.find(f => conf.resolver(f.name, targetField.name)) match {
+ case Some(sourceField) =>
+ // Check nested structs recursively
+ (targetField.dataType, sourceField.dataType) match {
+ case (targetNested: StructType, sourceNested) =>
+ hasExtraTargetFields(targetNested, sourceNested)
+ case _ => false
+ }
+ case None => true // target has extra field not in source
+ }
+ }
+ case _ =>
+ // Should be caught earlier
+ throw SparkException.internalError(
+ s"Source type must be StructType but found: $sourceType")
+ }
+ }
+
+ /**
+ * As UPDATE SET * assigns struct fields individually (preserving existing fields),
+ * this will lead to indiscriminate null expansion, ie, a struct is created where all
+ * fields are null. Wraps a struct assignment with a condition to return null
+ * if both conditions are true:
+ *
+ * - source struct is null
+ * - target struct is null OR target struct is same as source struct
+ *
+ * If the condition is not true, we preserve the original structure.
+ * This includes cases where the source was a struct of nulls,
+ * or there were any extra target fields (including null ones),
+ * both cases retain the assignment to a struct of nulls.
+ *
+ * @param key the original assignment key (target struct) expression
+ * @param value the original assignment value (source struct) expression
+ * @param structType the target struct type
+ * @param structExpression the result create struct expression result to wrap
+ * @param colPath the column path for error reporting
+ * @return the wrapped expression with null checks
+ */
+ private def fixNullExpansion(
+ key: Expression,
+ value: Expression,
+ structType: StructType,
+ structExpression: Expression,
+ colPath: Seq[String]): Expression = {
+ if (key.nullable) {
+ val condition = if (hasExtraTargetFields(structType, value.dataType)) {
+ // extra target fields: return null iff source struct is null and target struct is null
+ And(IsNull(value), IsNull(key))
+ } else {
+ // schemas match: return null iff source struct is null
+ IsNull(value)
+ }
+
+ If(condition, Literal(null, structExpression.dataType), structExpression)
+ } else {
+ structExpression
+ }
+ }
+
/**
* Checks whether assignments are aligned and compatible with table columns.
*
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 2ff842553bee6..3b8a363e704a5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -340,6 +340,12 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
errorClass = "UNSUPPORTED_FEATURE.OVERWRITE_BY_SUBQUERY",
messageParameters = Map.empty)
+ case p: PlanWithUnresolvedIdentifier if !p.identifierExpr.resolved =>
+ p.identifierExpr.failAnalysis(
+ errorClass = "NOT_A_CONSTANT_STRING.NOT_CONSTANT",
+ messageParameters = Map("name" -> "IDENTIFIER", "expr" -> p.identifierExpr.sql)
+ )
+
case operator: LogicalPlan =>
operator transformExpressionsDown {
case hof: HigherOrderFunction if hof.arguments.exists {
@@ -706,6 +712,7 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
}
create.tableSchema.foreach(f => TypeUtils.failWithIntervalType(f.dataType))
+ TypeUtils.failUnsupportedDataType(create.tableSchema, SQLConf.get)
SchemaUtils.checkIndeterminateCollationInSchema(create.tableSchema)
case write: V2WriteCommand if write.resolved =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
index 53c92ca5425df..1172ecee72236 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
@@ -140,7 +140,9 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
}
matched(ordinal)
- case u @ UnresolvedAttribute(nameParts) =>
+ case u @ UnresolvedAttribute(nameParts)
+ if u.getTagValue(LogicalPlan.PLAN_ID_TAG).isEmpty =>
+ // UnresolvedAttribute with PLAN_ID_TAG should be resolved in resolveDataFrameColumn
val result = withPosition(u) {
resolveColumnByName(nameParts)
.orElse(LiteralFunctionResolution.resolve(nameParts))
@@ -425,7 +427,8 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
def resolveExpressionByPlanChildren(
e: Expression,
q: LogicalPlan,
- includeLastResort: Boolean = false): Expression = {
+ includeLastResort: Boolean = false,
+ throws: Boolean = true): Expression = {
resolveExpression(
tryResolveDataFrameColumns(e, q.children),
resolveColumnByName = nameParts => {
@@ -435,7 +438,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
assert(q.children.length == 1)
q.children.head.output
},
- throws = true,
+ throws,
includeLastResort = includeLastResort)
}
@@ -475,8 +478,14 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
resolveVariables(resolveOuterRef(e))
}
- def resolveExprInAssignment(expr: Expression, hostPlan: LogicalPlan): Expression = {
- resolveExpressionByPlanChildren(expr, hostPlan) match {
+ def resolveExprInAssignment(
+ expr: Expression,
+ hostPlan: LogicalPlan,
+ throws: Boolean = true): Expression = {
+ resolveExpressionByPlanChildren(expr,
+ hostPlan,
+ includeLastResort = false,
+ throws = throws) match {
// Assignment key and value does not need the alias when resolving nested columns.
case Alias(child: ExtractValue, _) => child
case other => other
@@ -488,8 +497,7 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
// 1. extract the attached plan id from UnresolvedAttribute;
// 2. top-down traverse the query plan to find the plan node that matches the plan id;
// 3. if can not find the matching node, fails with 'CANNOT_RESOLVE_DATAFRAME_COLUMN';
- // 4, if the matching node is found, but can not resolve the column, also fails with
- // 'CANNOT_RESOLVE_DATAFRAME_COLUMN';
+ // 4, if the matching node is found, but can not resolve the column, return the original one;
// 5, resolve the expression against the target node, the resolved attribute will be
// filtered by the output attributes of nodes in the path (from matching to root node);
// 6. if more than one resolved attributes are found in the above recursive process,
@@ -564,10 +572,9 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
} else {
None
}
- if (resolved.isEmpty) {
- // The targe plan node is found, but the column cannot be resolved.
- throw QueryCompilationErrors.cannotResolveDataFrameColumn(u)
- }
+ // The targe plan node is found, but might still fail to resolve.
+ // In this case, return None to delay the failure, so it is possible to be
+ // resolved in the next iteration.
(resolved.map(r => (r, currentDepth)), true)
} else {
val children = p match {
@@ -610,11 +617,8 @@ trait ColumnResolutionHelper extends Logging with DataTypeErrorsBase {
// the dataframe column 'df.id' will remain unresolved, and the analyzer
// will try to resolve 'id' without plan id later.
val filtered = resolved.filter { r =>
- if (isMetadataAccess) {
- r._1.references.subsetOf(AttributeSet(p.output ++ p.metadataOutput))
- } else {
- r._1.references.subsetOf(p.outputSet)
- }
+ // A DataFrame column can be resolved as a metadata column, we should keep it.
+ r._1.references.subsetOf(AttributeSet(p.output ++ p.metadataOutput))
}
(filtered, matched)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index bb98e5fa02ed2..e1ca5ad918476 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -30,9 +30,10 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.st._
import org.apache.spark.sql.catalyst.expressions.variant._
import org.apache.spark.sql.catalyst.expressions.xml._
-import org.apache.spark.sql.catalyst.plans.logical.{FunctionBuilderBase, Generate, LogicalPlan, OneRowRelation, Range}
+import org.apache.spark.sql.catalyst.plans.logical.{FunctionBuilderBase, Generate, LogicalPlan, OneRowRelation, PythonWorkerLogs, Range}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
@@ -535,6 +536,12 @@ object FunctionRegistry {
expression[ThetaIntersectionAgg]("theta_intersection_agg"),
expression[ApproxTopKAccumulate]("approx_top_k_accumulate"),
expression[ApproxTopKCombine]("approx_top_k_combine"),
+ expression[KllSketchAggBigint]("kll_sketch_agg_bigint"),
+ expression[KllSketchAggFloat]("kll_sketch_agg_float"),
+ expression[KllSketchAggDouble]("kll_sketch_agg_double"),
+ expression[KllMergeAggBigint]("kll_merge_agg_bigint"),
+ expression[KllMergeAggFloat]("kll_merge_agg_float"),
+ expression[KllMergeAggDouble]("kll_merge_agg_double"),
// string functions
expression[Ascii]("ascii"),
@@ -800,6 +807,21 @@ object FunctionRegistry {
expression[ThetaDifference]("theta_difference"),
expression[ThetaIntersection]("theta_intersection"),
expression[ApproxTopKEstimate]("approx_top_k_estimate"),
+ expression[KllSketchToStringBigint]("kll_sketch_to_string_bigint"),
+ expression[KllSketchToStringFloat]("kll_sketch_to_string_float"),
+ expression[KllSketchToStringDouble]("kll_sketch_to_string_double"),
+ expression[KllSketchGetNBigint]("kll_sketch_get_n_bigint"),
+ expression[KllSketchGetNFloat]("kll_sketch_get_n_float"),
+ expression[KllSketchGetNDouble]("kll_sketch_get_n_double"),
+ expression[KllSketchMergeBigint]("kll_sketch_merge_bigint"),
+ expression[KllSketchMergeFloat]("kll_sketch_merge_float"),
+ expression[KllSketchMergeDouble]("kll_sketch_merge_double"),
+ expression[KllSketchGetQuantileBigint]("kll_sketch_get_quantile_bigint"),
+ expression[KllSketchGetQuantileFloat]("kll_sketch_get_quantile_float"),
+ expression[KllSketchGetQuantileDouble]("kll_sketch_get_quantile_double"),
+ expression[KllSketchGetRankBigint]("kll_sketch_get_rank_bigint"),
+ expression[KllSketchGetRankFloat]("kll_sketch_get_rank_float"),
+ expression[KllSketchGetRankDouble]("kll_sketch_get_rank_double"),
// grouping sets
expression[Grouping]("grouping"),
@@ -878,6 +900,7 @@ object FunctionRegistry {
expression[ST_GeogFromWKB]("st_geogfromwkb"),
expression[ST_GeomFromWKB]("st_geomfromwkb"),
expression[ST_Srid]("st_srid"),
+ expression[ST_SetSrid]("st_setsrid"),
// cast
expression[Cast]("cast"),
@@ -1213,7 +1236,8 @@ object TableFunctionRegistry {
generator[Collations]("collations"),
generator[SQLKeywords]("sql_keywords"),
generatorBuilder("variant_explode", VariantExplodeGeneratorBuilder),
- generatorBuilder("variant_explode_outer", VariantExplodeOuterGeneratorBuilder)
+ generatorBuilder("variant_explode_outer", VariantExplodeOuterGeneratorBuilder),
+ PythonWorkerLogs.functionBuilder
)
val builtin: SimpleTableFunctionRegistry = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationCache.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationCache.scala
new file mode 100644
index 0000000000000..770a5e780b24a
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationCache.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+
+private[sql] trait RelationCache {
+ def lookup(nameParts: Seq[String], resolver: Resolver): Option[LogicalPlan]
+}
+
+private[sql] object RelationCache {
+ val empty: RelationCache = (_, _) => None
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala
index ea5836b8ec2df..15d5e4874dbb5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.catalog.{
TemporaryViewRelation,
UnresolvedCatalogRelation
}
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PythonWorkerLogs, SubqueryAlias}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
import org.apache.spark.sql.connector.catalog.{
CatalogManager,
@@ -46,7 +46,9 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.ArrayImplicits._
-class RelationResolution(override val catalogManager: CatalogManager)
+class RelationResolution(
+ override val catalogManager: CatalogManager,
+ sharedRelationCache: RelationCache)
extends DataTypeErrorsBase
with Logging
with LookupCatalog
@@ -112,44 +114,68 @@ class RelationResolution(override val catalogManager: CatalogManager)
u.isStreaming,
finalTimeTravelSpec.isDefined
).orElse {
- resolveSystemSessionView(u.multipartIdentifier)
- }.orElse {
expandIdentifier(u.multipartIdentifier) match {
case CatalogAndIdentifier(catalog, ident) =>
val key = toCacheKey(catalog, ident, finalTimeTravelSpec)
val planId = u.getTagValue(LogicalPlan.PLAN_ID_TAG)
relationCache
.get(key)
- .map { cache =>
- val cachedRelation = cache.transform {
- case multi: MultiInstanceRelation =>
- val newRelation = multi.newInstance()
- newRelation.copyTagsFrom(multi)
- newRelation
- }
- cloneWithPlanId(cachedRelation, planId)
- }
+ .map(adaptCachedRelation(_, planId))
.orElse {
- val writePrivilegesString =
- Option(u.options.get(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES))
- val table =
- CatalogV2Util.loadTable(catalog, ident, finalTimeTravelSpec, writePrivilegesString)
- val loaded = createRelation(
+ val writePrivileges = u.options.get(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)
+ val finalOptions = u.clearWritePrivileges.options
+ val table = CatalogV2Util.loadTable(
catalog,
ident,
- table,
- u.clearWritePrivileges.options,
- u.isStreaming,
- finalTimeTravelSpec
- )
- loaded.foreach(relationCache.update(key, _))
- loaded.map(cloneWithPlanId(_, planId))
- }
+ finalTimeTravelSpec,
+ Option(writePrivileges))
+
+ val sharedRelationCacheMatch = for {
+ t <- table
+ if finalTimeTravelSpec.isEmpty && writePrivileges == null && !u.isStreaming
+ cached <- lookupSharedRelationCache(catalog, ident, t)
+ } yield {
+ val updatedRelation = cached.copy(options = finalOptions)
+ val nameParts = ident.toQualifiedNameParts(catalog)
+ val aliasedRelation = SubqueryAlias(nameParts, updatedRelation)
+ relationCache.update(key, aliasedRelation)
+ adaptCachedRelation(aliasedRelation, planId)
+ }
+
+ sharedRelationCacheMatch.orElse {
+ val loaded = createRelation(
+ catalog,
+ ident,
+ table,
+ finalOptions,
+ u.isStreaming,
+ finalTimeTravelSpec)
+ loaded.foreach(relationCache.update(key, _))
+ loaded.map(cloneWithPlanId(_, planId))
+ }
+ }
case _ => None
}
}
}
+ private def lookupSharedRelationCache(
+ catalog: CatalogPlugin,
+ ident: Identifier,
+ table: Table): Option[DataSourceV2Relation] = {
+ CatalogV2Util.lookupCachedRelation(sharedRelationCache, catalog, ident, table, conf)
+ }
+
+ private def adaptCachedRelation(cached: LogicalPlan, planId: Option[Long]): LogicalPlan = {
+ val plan = cached transform {
+ case multi: MultiInstanceRelation =>
+ val newRelation = multi.newInstance()
+ newRelation.copyTagsFrom(multi)
+ newRelation
+ }
+ cloneWithPlanId(plan, planId)
+ }
+
private def createRelation(
catalog: CatalogPlugin,
ident: Identifier,
@@ -227,6 +253,45 @@ class RelationResolution(override val catalogManager: CatalogManager)
}
}
+ def resolveReference(ref: V2TableReference): LogicalPlan = {
+ val relation = getOrLoadRelation(ref)
+ val planId = ref.getTagValue(LogicalPlan.PLAN_ID_TAG)
+ cloneWithPlanId(relation, planId)
+ }
+
+ private def getOrLoadRelation(ref: V2TableReference): LogicalPlan = {
+ val key = toCacheKey(ref.catalog, ref.identifier)
+ relationCache.get(key) match {
+ case Some(cached) =>
+ adaptCachedRelation(cached, ref)
+ case None =>
+ val relation = loadRelation(ref)
+ relationCache.update(key, relation)
+ relation
+ }
+ }
+
+ private def loadRelation(ref: V2TableReference): LogicalPlan = {
+ val table = ref.catalog.loadTable(ref.identifier)
+ V2TableReferenceUtils.validateLoadedTable(table, ref)
+ val tableName = ref.identifier.toQualifiedNameParts(ref.catalog)
+ SubqueryAlias(tableName, ref.toRelation(table))
+ }
+
+ private def adaptCachedRelation(cached: LogicalPlan, ref: V2TableReference): LogicalPlan = {
+ cached transform {
+ case r: DataSourceV2Relation if matchesReference(r, ref) =>
+ V2TableReferenceUtils.validateLoadedTable(r.table, ref)
+ r.copy(output = ref.output, options = ref.options)
+ }
+ }
+
+ private def matchesReference(
+ relation: DataSourceV2Relation,
+ ref: V2TableReference): Boolean = {
+ relation.catalog.contains(ref.catalog) && relation.identifier.contains(ref.identifier)
+ }
+
private def isResolvingView: Boolean = AnalysisContext.get.catalogAndNamespace.nonEmpty
private def isReferredTempViewName(nameParts: Seq[String]): Boolean = {
@@ -238,22 +303,6 @@ class RelationResolution(override val catalogManager: CatalogManager)
}
}
- private def isSystemSessionIdentifier(identifier: Seq[String]): Boolean = {
- identifier.length > 2 &&
- identifier(0).equalsIgnoreCase(CatalogManager.SYSTEM_CATALOG_NAME) &&
- identifier(1).equalsIgnoreCase(CatalogManager.SESSION_NAMESPACE)
- }
-
- private def resolveSystemSessionView(
- identifier: Seq[String]): Option[LogicalPlan] = {
- if (isSystemSessionIdentifier(identifier)) {
- Option(identifier.drop(2)).collect {
- case Seq(viewName) if viewName.equalsIgnoreCase(PythonWorkerLogs.ViewName) =>
- PythonWorkerLogs.viewDefinition()
- }
- } else None
- }
-
private def toCacheKey(
catalog: CatalogPlugin,
ident: Identifier,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala
index 7e7776098a04a..bbb8e7852b2c7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveMergeIntoSchemaEvolution.scala
@@ -17,13 +17,17 @@
package org.apache.spark.sql.catalyst.analysis
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.connector.catalog.{CatalogV2Util, TableCatalog}
+import org.apache.spark.sql.connector.catalog.{CatalogV2Util, TableCatalog, TableChange}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.types.StructType
/**
@@ -34,24 +38,40 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
object ResolveMergeIntoSchemaEvolution extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- case m @ MergeIntoTable(_, _, _, _, _, _, _)
- if m.needSchemaEvolution =>
+ // This rule should run only if all assignments are resolved, except those
+ // that will be satisfied by schema evolution
+ case m@MergeIntoTable(_, _, _, _, _, _, _) if m.evaluateSchemaEvolution =>
+ val changes = m.changesForSchemaEvolution
+ if (changes.isEmpty) {
+ m
+ } else {
+ val finalAttrMapping = ArrayBuffer.empty[(Attribute, Attribute)]
val newTarget = m.targetTable.transform {
- case r : DataSourceV2Relation => performSchemaEvolution(r, m.sourceTable)
+ case r: DataSourceV2Relation =>
+ val referencedSourceSchema = MergeIntoTable.sourceSchemaForSchemaEvolution(m)
+ val newTarget = performSchemaEvolution(r, referencedSourceSchema, changes)
+ val oldTargetOutput = m.targetTable.output
+ val newTargetOutput = newTarget.output
+ val attributeMapping = oldTargetOutput.zip(newTargetOutput)
+ finalAttrMapping ++= attributeMapping
+ newTarget
}
- m.copy(targetTable = newTarget)
+ val res = m.copy(targetTable = newTarget)
+ res.rewriteAttrs(AttributeMap(finalAttrMapping.toSeq))
+ }
}
- private def performSchemaEvolution(relation: DataSourceV2Relation, source: LogicalPlan)
- : DataSourceV2Relation = {
+ private def performSchemaEvolution(
+ relation: DataSourceV2Relation,
+ referencedSourceSchema: StructType,
+ changes: Array[TableChange]): DataSourceV2Relation = {
(relation.catalog, relation.identifier) match {
case (Some(c: TableCatalog), Some(i)) =>
- val changes = MergeIntoTable.schemaChanges(relation.schema, source.schema)
c.alterTable(i, changes: _*)
val newTable = c.loadTable(i)
val newSchema = CatalogV2Util.v2ColumnsToStructType(newTable.columns())
// Check if there are any remaining changes not applied.
- val remainingChanges = MergeIntoTable.schemaChanges(newSchema, source.schema)
+ val remainingChanges = MergeIntoTable.schemaChanges(newSchema, referencedSourceSchema)
if (remainingChanges.nonEmpty) {
throw QueryCompilationErrors.unsupportedTableChangesInAutoSchemaEvolutionError(
remainingChanges, i.toQualifiedNameParts(c))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala
index 3eb528954b352..bf1016ba82684 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveRowLevelCommandAssignments.scala
@@ -44,7 +44,7 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
validateStoreAssignmentPolicy()
val newTable = cleanAttrMetadata(u.table)
val newAssignments = AssignmentUtils.alignUpdateAssignments(u.table.output, u.assignments,
- coerceNestedTypes = false)
+ fromStar = false, coerceNestedTypes = false)
u.copy(table = newTable, assignments = newAssignments)
case u: UpdateTable if !u.skipSchemaResolution && u.resolved && !u.aligned =>
@@ -53,10 +53,11 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
case m: MergeIntoTable if !m.skipSchemaResolution && m.resolved && m.rewritable && !m.aligned &&
!m.needSchemaEvolution =>
validateStoreAssignmentPolicy()
- val coerceNestedTypes = SQLConf.get.coerceMergeNestedTypes
+ val coerceNestedTypes = SQLConf.get.coerceMergeNestedTypes && m.withSchemaEvolution
m.copy(
targetTable = cleanAttrMetadata(m.targetTable),
- matchedActions = alignActions(m.targetTable.output, m.matchedActions, coerceNestedTypes),
+ matchedActions = alignActions(m.targetTable.output, m.matchedActions,
+ coerceNestedTypes),
notMatchedActions = alignActions(m.targetTable.output, m.notMatchedActions,
coerceNestedTypes),
notMatchedBySourceActions = alignActions(m.targetTable.output, m.notMatchedBySourceActions,
@@ -117,9 +118,9 @@ object ResolveRowLevelCommandAssignments extends Rule[LogicalPlan] {
actions: Seq[MergeAction],
coerceNestedTypes: Boolean): Seq[MergeAction] = {
actions.map {
- case u @ UpdateAction(_, assignments) =>
+ case u @ UpdateAction(_, assignments, fromStar) =>
u.copy(assignments = AssignmentUtils.alignUpdateAssignments(attrs, assignments,
- coerceNestedTypes))
+ fromStar, coerceNestedTypes))
case d: DeleteAction =>
d
case i @ InsertAction(_, assignments) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
index 8b5b690aa7403..1d2e2fef20965 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
@@ -334,7 +334,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
// original row ID values must be preserved and passed back to the table to encode updates
// if there are any assignments to row ID attributes, add extra columns for original values
val updateAssignments = (matchedActions ++ notMatchedBySourceActions).flatMap {
- case UpdateAction(_, assignments) => assignments
+ case UpdateAction(_, assignments, _) => assignments
case _ => Nil
}
buildOriginalRowIdValues(rowIdAttrs, updateAssignments)
@@ -434,7 +434,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
// converts a MERGE action into an instruction on top of the joined plan for group-based plans
private def toInstruction(action: MergeAction, metadataAttrs: Seq[Attribute]): Instruction = {
action match {
- case UpdateAction(cond, assignments) =>
+ case UpdateAction(cond, assignments, _) =>
val rowValues = assignments.map(_.value)
val metadataValues = nullifyMetadataOnUpdate(metadataAttrs)
val output = Seq(Literal(WRITE_WITH_METADATA_OPERATION)) ++ rowValues ++ metadataValues
@@ -466,12 +466,12 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
splitUpdates: Boolean): Instruction = {
action match {
- case UpdateAction(cond, assignments) if splitUpdates =>
+ case UpdateAction(cond, assignments, _) if splitUpdates =>
val output = deltaDeleteOutput(rowAttrs, rowIdAttrs, metadataAttrs, originalRowIdValues)
val otherOutput = deltaReinsertOutput(assignments, metadataAttrs, originalRowIdValues)
Split(cond.getOrElse(TrueLiteral), output, otherOutput)
- case UpdateAction(cond, assignments) =>
+ case UpdateAction(cond, assignments, _) =>
val output = deltaUpdateOutput(assignments, metadataAttrs, originalRowIdValues)
Keep(Update, cond.getOrElse(TrueLiteral), output)
@@ -495,7 +495,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelCommand with PredicateHelper
val actions = merge.matchedActions ++ merge.notMatchedActions ++ merge.notMatchedBySourceActions
actions.foreach {
case DeleteAction(Some(cond)) => checkMergeIntoCondition("DELETE", cond)
- case UpdateAction(Some(cond), _) => checkMergeIntoCondition("UPDATE", cond)
+ case UpdateAction(Some(cond), _, _) => checkMergeIntoCondition("UPDATE", cond)
case InsertAction(Some(cond), _) => checkMergeIntoCondition("INSERT", cond)
case _ => // OK
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 3e5f14810935b..ce387ef397aca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -98,6 +98,13 @@ object TypeCoercion extends TypeCoercionBase {
case (t1: YearMonthIntervalType, t2: YearMonthIntervalType) =>
Some(YearMonthIntervalType(t1.startField.min(t2.startField), t1.endField.max(t2.endField)))
+ // We allow coercion from GEOGRAPHY() types (i.e. fixed SRID types) to the
+ // GEOGRAPHY(ANY) type (i.e. mixed SRID type). This coercion is always safe to do.
+ case (t1: GeographyType, t2: GeographyType) if t1 != t2 => Some(GeographyType("ANY"))
+ // We allow coercion from GEOMETRY() types (i.e. fixed SRID types) to the
+ // GEOMETRY(ANY) type (i.e. mixed SRID type). This coercion is always safe to do.
+ case (t1: GeometryType, t2: GeometryType) if t1 != t2 => Some(GeometryType("ANY"))
+
case (t1, t2) => findTypeForComplex(t1, t2, findTightestCommonType)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index fd4e081c91b52..d658d83f066f3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -584,6 +584,25 @@ object UnsupportedOperationChecker extends Logging {
}
}
+ // Verifies that a query using real-time mode is valid. It is meant to be used in addition to
+ // the checkForStreaming method: for this reason, we call this method check *additional*
+ // real-time mode constraints.
+ //
+ // It should be called during resolution of the WriteToStreamStatement if and only if
+ // the query is using the real-time trigger.
+ def checkAdditionalRealTimeModeConstraints(plan: LogicalPlan, outputMode: OutputMode): Unit = {
+ if (outputMode != InternalOutputModes.Update) {
+ throwRealTimeError("OUTPUT_MODE_NOT_SUPPORTED", Map("outputMode" -> outputMode.toString))
+ }
+ }
+
+ private def throwRealTimeError(subClass: String, args: Map[String, String]): Unit = {
+ throw new AnalysisException(
+ errorClass = s"STREAMING_REAL_TIME_MODE.$subClass",
+ messageParameters = args
+ )
+ }
+
private def throwErrorIf(
condition: Boolean,
msg: String)(implicit operator: LogicalPlan): Unit = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala
new file mode 100644
index 0000000000000..85c36d452b309
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.SQLConfHelper
+import org.apache.spark.sql.catalyst.analysis.V2TableReference.Context
+import org.apache.spark.sql.catalyst.analysis.V2TableReference.TableInfo
+import org.apache.spark.sql.catalyst.analysis.V2TableReference.TemporaryViewContext
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.plans.logical.LeafNode
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
+import org.apache.spark.sql.catalyst.util.truncatedString
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper
+import org.apache.spark.sql.connector.catalog.Column
+import org.apache.spark.sql.connector.catalog.Identifier
+import org.apache.spark.sql.connector.catalog.MetadataColumn
+import org.apache.spark.sql.connector.catalog.Table
+import org.apache.spark.sql.connector.catalog.TableCatalog
+import org.apache.spark.sql.connector.catalog.V2TableUtil
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.sql.util.SchemaValidationMode.ALLOW_NEW_TOP_LEVEL_FIELDS
+import org.apache.spark.util.ArrayImplicits._
+
+/**
+ * A reference to a V2 table.
+ *
+ * References are placeholders for the latest table metadata and are replaced with actual table
+ * versions during analysis, allowing Spark to reload tables with up-to-date metadata. The newly
+ * loaded table metadata is validated against the original metadata depending on the context.
+ * For instance, temporary views with fully resolved logical plans don't allow schema changes
+ * in underlying tables.
+ */
+private[sql] case class V2TableReference private(
+ catalog: TableCatalog,
+ identifier: Identifier,
+ options: CaseInsensitiveStringMap,
+ info: TableInfo,
+ output: Seq[AttributeReference],
+ context: Context)
+ extends LeafNode with MultiInstanceRelation with NamedRelation {
+
+ override def name: String = V2TableUtil.toQualifiedName(catalog, identifier)
+
+ override def newInstance(): V2TableReference = {
+ copy(output = output.map(_.newInstance()))
+ }
+
+ override def computeStats(): Statistics = Statistics.DUMMY
+
+ override def simpleString(maxFields: Int): String = {
+ val outputString = truncatedString(output, "[", ", ", "]", maxFields)
+ s"TableReference$outputString $name"
+ }
+
+ def toRelation(table: Table): DataSourceV2Relation = {
+ DataSourceV2Relation(table, output, Some(catalog), Some(identifier), options)
+ }
+}
+
+private[sql] object V2TableReference {
+
+ case class TableInfo(
+ columns: Seq[Column],
+ metadataColumns: Seq[MetadataColumn])
+
+ sealed trait Context
+ case class TemporaryViewContext(viewName: Seq[String]) extends Context
+
+ def createForTempView(relation: DataSourceV2Relation, viewName: Seq[String]): V2TableReference = {
+ create(relation, TemporaryViewContext(viewName))
+ }
+
+ private def create(relation: DataSourceV2Relation, context: Context): V2TableReference = {
+ val ref = V2TableReference(
+ relation.catalog.get.asTableCatalog,
+ relation.identifier.get,
+ relation.options,
+ TableInfo(
+ columns = relation.table.columns.toImmutableArraySeq,
+ metadataColumns = V2TableUtil.extractMetadataColumns(relation)),
+ relation.output,
+ context)
+ ref.copyTagsFrom(relation)
+ ref
+ }
+}
+
+private[sql] object V2TableReferenceUtils extends SQLConfHelper {
+
+ def validateLoadedTable(table: Table, ref: V2TableReference): Unit = {
+ ref.context match {
+ case ctx: TemporaryViewContext =>
+ validateLoadedTableInTempView(table, ref, ctx)
+ case ctx =>
+ throw SparkException.internalError(s"Unknown table ref context: ${ctx.getClass.getName}")
+ }
+ }
+
+ private def validateLoadedTableInTempView(
+ table: Table,
+ ref: V2TableReference,
+ ctx: TemporaryViewContext): Unit = {
+ val tableName = ref.identifier.toQualifiedNameParts(ref.catalog)
+
+ val dataErrors = V2TableUtil.validateCapturedColumns(
+ table,
+ ref.info.columns,
+ mode = ALLOW_NEW_TOP_LEVEL_FIELDS)
+ if (dataErrors.nonEmpty) {
+ throw QueryCompilationErrors.columnsChangedAfterViewWithPlanCreation(
+ ctx.viewName,
+ tableName,
+ dataErrors)
+ }
+
+ val metaErrors = V2TableUtil.validateCapturedMetadataColumns(table, ref.info.metadataColumns)
+ if (metaErrors.nonEmpty) {
+ throw QueryCompilationErrors.metadataColumnsChangedAfterViewWithPlanCreation(
+ ctx.viewName,
+ tableName,
+ metaErrors)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala
index 0117b3fc2fb55..d346969be8eff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/HybridAnalyzer.scala
@@ -302,6 +302,7 @@ object HybridAnalyzer {
resolverGuard = new ResolverGuard(legacyAnalyzer.catalogManager),
resolver = new Resolver(
catalogManager = legacyAnalyzer.catalogManager,
+ sharedRelationCache = legacyAnalyzer.sharedRelationCache,
extensions = legacyAnalyzer.singlePassResolverExtensions,
metadataResolverExtensions = legacyAnalyzer.singlePassMetadataResolverExtensions,
externalRelationResolution = Some(legacyAnalyzer.getRelationResolution)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala
index 75d23f29ecfc5..78029d593df13 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/Resolver.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{
AnalysisErrorAt,
FunctionResolution,
MultiInstanceRelation,
+ RelationCache,
RelationResolution,
ResolvedInlineTable,
UnresolvedHaving,
@@ -71,6 +72,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
*/
class Resolver(
catalogManager: CatalogManager,
+ sharedRelationCache: RelationCache = RelationCache.empty,
override val extensions: Seq[ResolverExtension] = Seq.empty,
metadataResolverExtensions: Seq[ResolverExtension] = Seq.empty,
externalRelationResolution: Option[RelationResolution] = None)
@@ -81,8 +83,9 @@ class Resolver(
private val cteRegistry = new CteRegistry
private val subqueryRegistry = new SubqueryRegistry
private val identifierAndCteSubstitutor = new IdentifierAndCteSubstitutor
- private val relationResolution =
- externalRelationResolution.getOrElse(Resolver.createRelationResolution(catalogManager))
+ private val relationResolution = externalRelationResolution.getOrElse {
+ Resolver.createRelationResolution(catalogManager, sharedRelationCache)
+ }
private val functionResolution = new FunctionResolution(catalogManager, relationResolution)
private val expressionResolver = new ExpressionResolver(this, functionResolution, planLogger)
private val aggregateResolver = new AggregateResolver(this, expressionResolver)
@@ -788,7 +791,9 @@ object Resolver {
/**
* Create a new instance of the [[RelationResolution]].
*/
- def createRelationResolution(catalogManager: CatalogManager): RelationResolution = {
- new RelationResolution(catalogManager)
+ def createRelationResolution(
+ catalogManager: CatalogManager,
+ sharedRelationCache: RelationCache = RelationCache.empty): RelationResolution = {
+ new RelationResolution(catalogManager, sharedRelationCache)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
index 5d0184579faac..f5c732ee1412c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
@@ -192,7 +192,10 @@ class InMemoryCatalog(
override def createTable(
tableDefinition: CatalogTable,
ignoreIfExists: Boolean): Unit = synchronized {
- assert(tableDefinition.identifier.database.isDefined)
+ assert(tableDefinition.identifier.database.isDefined,
+ "Table identifier " + tableDefinition.identifier.quotedString +
+ " is missing database name. " +
+ "Cannot create table without a database specified.")
val db = tableDefinition.identifier.database.get
requireDbExists(db)
val table = tableDefinition.identifier.table
@@ -313,7 +316,10 @@ class InMemoryCatalog(
}
override def alterTable(tableDefinition: CatalogTable): Unit = synchronized {
- assert(tableDefinition.identifier.database.isDefined)
+ assert(tableDefinition.identifier.database.isDefined,
+ "Table identifier " + tableDefinition.identifier.quotedString +
+ " is missing database name. " +
+ "Cannot alter table without a database specified.")
val db = tableDefinition.identifier.database.get
requireTableExists(db, tableDefinition.identifier.table)
val updatedProperties = tableDefinition.properties.filter(kv => kv._1 != "comment")
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
index f2fd3b90f6468..84d87fab8b060 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SQLFunction.scala
@@ -61,8 +61,15 @@ case class SQLFunction(
owner: Option[String] = None,
createTimeMs: Long = System.currentTimeMillis) extends UserDefinedFunction {
- assert(exprText.nonEmpty || queryText.nonEmpty)
- assert((isTableFunc && returnType.isRight) || (!isTableFunc && returnType.isLeft))
+ assert(exprText.nonEmpty || queryText.nonEmpty,
+ "SQL function '" + name + "' is missing function body. " +
+ "Either exprText or queryText must be defined. " +
+ "Found: exprText=" + exprText + ", queryText=" + queryText + ".")
+ assert((isTableFunc && returnType.isRight) || (!isTableFunc && returnType.isLeft),
+ "SQL function '" + name + "' has mismatched function type and return type. " +
+ "isTableFunc=" + isTableFunc + ", returnType.isRight=" + returnType.isRight + ", " +
+ "returnType.isLeft=" + returnType.isLeft + ". " +
+ "Table functions require Right[StructType] and scalar functions require Left[DataType].")
import SQLFunction._
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index c351aacd45a4e..be90c7ad3656c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -1663,7 +1663,9 @@ class SessionCatalog(
.putString("__funcInputAlias", "true")
.build()
}
- assert(!function.isTableFunc)
+ assert(!function.isTableFunc,
+ "Function '" + function.name + "' is a table function. " +
+ "Use makeSQLTableFunctionPlan() instead of makeSQLFunctionPlan().")
val funcName = function.name.funcName
// Use captured SQL configs when parsing a SQL function.
@@ -1674,7 +1676,10 @@ class SessionCatalog(
val inputParam = function.inputParam
val returnType = function.getScalarFuncReturnType
val (expression, query) = function.getExpressionAndQuery(parser, isTableFunc = false)
- assert(expression.isDefined || query.isDefined)
+ assert(expression.isDefined || query.isDefined,
+ "SQL function '" + function.name + "' could not be parsed. " +
+ "Neither expression nor query could be extracted from function body. " +
+ "exprText=" + function.exprText + ", queryText=" + function.queryText + ".")
// Check function arguments
val paramSize = inputParam.map(_.size).getOrElse(0)
@@ -1763,12 +1768,17 @@ class SessionCatalog(
function: SQLFunction,
input: Seq[Expression],
outputAttrs: Seq[Attribute]): LogicalPlan = {
- assert(function.isTableFunc)
+ assert(function.isTableFunc,
+ "Function '" + function.name + "' is a scalar function. " +
+ "Use makeSQLFunctionPlan() instead of makeSQLTableFunctionPlan().")
val funcName = function.name.funcName
val inputParam = function.inputParam
val returnParam = function.getTableFuncReturnCols
val (_, query) = function.getExpressionAndQuery(parser, isTableFunc = true)
- assert(query.isDefined)
+ assert(query.isDefined,
+ "SQL table function '" + function.name + "' could not be parsed. " +
+ "Query could not be extracted from function body. " +
+ "queryText=" + function.queryText + ".")
// Check function arguments
val paramSize = inputParam.map(_.size).getOrElse(0)
@@ -1807,7 +1817,12 @@ class SessionCatalog(
query.get
}
- assert(returnParam.length == outputAttrs.length)
+ assert(returnParam.length == outputAttrs.length,
+ "SQL table function '" + function.name + "' has mismatched return columns. " +
+ "Expected " + outputAttrs.length + " output attributes but found " +
+ returnParam.length + " return parameters. " +
+ "Return parameters: [" + returnParam.fieldNames.mkString(", ") + "], " +
+ "Output attributes: [" + outputAttrs.map(_.name).mkString(", ") + "].")
val output = returnParam.fields.zipWithIndex.map { case (param, i) =>
// Since we cannot get the output of a unresolved logical plan, we need
// to reference the output column of the lateral join by its position.
@@ -2390,7 +2405,9 @@ class SessionCatalog(
requireTableNotExists(newName)
val oldTable = getTableMetadata(oldName)
if (oldTable.tableType == CatalogTableType.MANAGED) {
- assert(oldName.database.nonEmpty)
+ assert(oldName.database.nonEmpty,
+ "Table identifier " + oldName.quotedString + " is missing database name. " +
+ "Managed tables must have a database defined.")
val databaseLocation =
externalCatalog.getDatabase(oldName.database.get).locationUri
val newTableLocation = new Path(new Path(databaseLocation), format(newName.table))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
index 8ed2414683522..3365b11b07424 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/UserDefinedFunction.scala
@@ -132,7 +132,7 @@ object UserDefinedFunction {
*/
private def getObjectMapper: ObjectMapper = {
val mapper = new ObjectMapper with ClassTagExtensions
- mapper.setSerializationInclusion(Include.NON_ABSENT)
+ mapper.setDefaultPropertyInclusion(Include.NON_ABSENT)
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
mapper.registerModule(DefaultScalaModule)
mapper
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
index 64d816587f4de..eab99a96f4c3e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -278,7 +278,7 @@ case class ClusterBySpec(columnNames: Seq[NamedReference]) {
object ClusterBySpec {
private val mapper = {
val ret = new ObjectMapper() with ClassTagExtensions
- ret.setSerializationInclusion(Include.NON_ABSENT)
+ ret.setDefaultPropertyInclusion(Include.NON_ABSENT)
ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
ret.registerModule(DefaultScalaModule)
ret
@@ -1070,7 +1070,9 @@ case class UnresolvedCatalogRelation(
tableMeta: CatalogTable,
options: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty(),
override val isStreaming: Boolean = false) extends UnresolvedLeafNode {
- assert(tableMeta.identifier.database.isDefined)
+ assert(tableMeta.identifier.database.isDefined,
+ "Table identifier " + tableMeta.identifier.quotedString + " is missing database name. " +
+ "UnresolvedCatalogRelation requires a fully qualified table identifier with database.")
}
/**
@@ -1097,7 +1099,9 @@ case class HiveTableRelation(
tableStats: Option[Statistics] = None,
@transient prunedPartitions: Option[Seq[CatalogTablePartition]] = None)
extends LeafNode with MultiInstanceRelation with NormalizeableRelation {
- assert(tableMeta.identifier.database.isDefined)
+ assert(tableMeta.identifier.database.isDefined,
+ "Table identifier " + tableMeta.identifier.quotedString + " is missing database name. " +
+ "HiveTableRelation requires a fully qualified table identifier with database.")
assert(DataTypeUtils.sameType(tableMeta.partitionSchema, partitionCols.toStructType))
assert(DataTypeUtils.sameType(tableMeta.dataSchema, dataCols.toStructType))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index b317cacc061b7..9b6430c9ff0f5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -48,14 +48,14 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute, A)])
override def contains(k: Attribute): Boolean = get(k).isDefined
override def + [B1 >: A](kv: (Attribute, B1)): AttributeMap[B1] =
- AttributeMap(baseMap.values.toMap + kv)
+ new AttributeMap(baseMap + (kv._1.exprId -> kv))
override def updated[B1 >: A](key: Attribute, value: B1): Map[Attribute, B1] =
- baseMap.values.toMap + (key -> value)
+ this + (key -> value)
override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator
- override def removed(key: Attribute): Map[Attribute, A] = baseMap.values.toMap - key
+ override def removed(key: Attribute): Map[Attribute, A] = new AttributeMap(baseMap - key.exprId)
def ++(other: AttributeMap[A]): AttributeMap[A] = new AttributeMap(baseMap ++ other.baseMap)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala
index 784bea899c4c8..e3ff7c5f05f0c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BloomFilterMightContain.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLId, toSQLType}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, JavaCode, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
+import org.apache.spark.sql.catalyst.optimizer.ScalarSubqueryReference
import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE
import org.apache.spark.sql.types._
import org.apache.spark.util.sketch.BloomFilter
@@ -58,6 +59,7 @@ case class BloomFilterMightContain(
case GetStructField(subquery: PlanExpression[_], _, _)
if !subquery.containsPattern(OUTER_REFERENCE) =>
TypeCheckResult.TypeCheckSuccess
+ case _: ScalarSubqueryReference => TypeCheckResult.TypeCheckSuccess
case _ =>
DataTypeMismatch(
errorSubClass = "BLOOM_FILTER_BINARY_OP_WRONG_TYPE",
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 5b76c7d225e11..00b0a83f6d533 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util.IntervalUtils.{dayTimeIntervalToByte,
import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
+import org.apache.spark.unsafe.types.{GeographyVal, UTF8String, VariantVal}
import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper}
import org.apache.spark.util.ArrayImplicits._
@@ -164,6 +164,16 @@ object Cast extends QueryErrorsBase {
case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt2.acceptsType(udt1) => true
+ // Casts from concrete GEOGRAPHY(srid) to mixed GEOGRAPHY(ANY) is allowed.
+ case (gt1: GeographyType, gt2: GeographyType) if !gt1.isMixedSrid && gt2.isMixedSrid =>
+ true
+ // Casting from GEOGRAPHY to GEOMETRY with the same SRID is allowed.
+ case (geog: GeographyType, geom: GeometryType) if geog.srid == geom.srid =>
+ true
+ // Casts from concrete GEOMETRY(srid) to mixed GEOMETRY(ANY) is allowed.
+ case (gt1: GeometryType, gt2: GeometryType) if !gt1.isMixedSrid && gt2.isMixedSrid =>
+ true
+
case _ => false
}
@@ -290,6 +300,16 @@ object Cast extends QueryErrorsBase {
case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt2.acceptsType(udt1) => true
+ // Casts from concrete GEOGRAPHY(srid) to mixed GEOGRAPHY(ANY) is allowed.
+ case (gt1: GeographyType, gt2: GeographyType) if !gt1.isMixedSrid && gt2.isMixedSrid =>
+ true
+ // Casting from GEOGRAPHY to GEOMETRY with the same SRID is allowed.
+ case (geog: GeographyType, geom: GeometryType) if geog.srid == geom.srid =>
+ true
+ // Casts from concrete GEOMETRY(srid) to mixed GEOMETRY(ANY) is allowed.
+ case (gt1: GeometryType, gt2: GeometryType) if !gt1.isMixedSrid && gt2.isMixedSrid =>
+ true
+
case _ => false
}
@@ -328,6 +348,31 @@ object Cast extends QueryErrorsBase {
*/
def canUpCast(from: DataType, to: DataType): Boolean = UpCastRule.canUpCast(from, to)
+ /**
+ * Returns true iff it is safe to provide a default value of `from` type typically defined in the
+ * data source metadata to the `to` type typically in the read schema of a query.
+ */
+ def canAssignDefaultValue(from: DataType, to: DataType): Boolean = {
+ def isVariantStruct(st: StructType): Boolean = {
+ st.fields.length > 0 && st.fields.forall(_.metadata.contains("__VARIANT_METADATA_KEY"))
+ }
+ (from, to) match {
+ case (s1: StructType, s2: StructType) =>
+ s1.length == s2.length && s1.fields.zip(s2.fields).forall {
+ case (f1, f2) => resolvableNullability(f1.nullable, f2.nullable) &&
+ canAssignDefaultValue(f1.dataType, f2.dataType)
+ }
+ case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
+ resolvableNullability(fn, tn) && canAssignDefaultValue(fromType, toType)
+ case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
+ resolvableNullability(fn, tn) && canAssignDefaultValue(fromKey, toKey) &&
+ canAssignDefaultValue(fromValue, toValue)
+ // A VARIANT field can be read as StructType due to shredding.
+ case (VariantType, s: StructType) => isVariantStruct(s)
+ case _ => canUpCast(from, to)
+ }
+ }
+
/**
* Returns true iff we can cast the `from` type to `to` type as per the ANSI SQL.
* In practice, the behavior is mostly the same as PostgreSQL. It disallows certain unreasonable
@@ -545,6 +590,7 @@ case class Cast(
}
override def checkInputDataTypes(): TypeCheckResult = {
+ TypeUtils.failUnsupportedDataType(dataType, SQLConf.get)
val canCast = evalMode match {
case EvalMode.LEGACY => Cast.canCast(child.dataType, dataType)
case EvalMode.ANSI => Cast.canAnsiCast(child.dataType, dataType)
@@ -1139,6 +1185,14 @@ case class Cast(
b => numeric.toFloat(b)
}
+ // GeometryConverter
+ private[this] def castToGeometry(from: DataType): Any => Any = from match {
+ case _: GeographyType =>
+ buildCast[GeographyVal](_, STUtils.geographyToGeometry)
+ case _: GeometryType =>
+ identity
+ }
+
private[this] def castArray(fromType: DataType, toType: DataType): Any => Any = {
val elementCast = cast(fromType, toType)
// TODO: Could be faster?
@@ -1218,6 +1272,8 @@ case class Cast(
case FloatType => castToFloat(from)
case LongType => castToLong(from)
case DoubleType => castToDouble(from)
+ case _: GeographyType => identity
+ case _: GeometryType => castToGeometry(from)
case array: ArrayType =>
castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
case map: MapType => castMap(from.asInstanceOf[MapType], map)
@@ -1326,6 +1382,8 @@ case class Cast(
case FloatType => castToFloatCode(from, ctx)
case LongType => castToLongCode(from, ctx)
case DoubleType => castToDoubleCode(from, ctx)
+ case _: GeographyType => (c, evPrim, _) => code"$evPrim = $c;"
+ case _: GeometryType => castToGeometryCode(from)
case array: ArrayType =>
castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx)
@@ -2172,6 +2230,17 @@ case class Cast(
}
}
+ private[this] def castToGeometryCode(from: DataType): CastFunction = {
+ from match {
+ case _: GeographyType =>
+ (c, evPrim, _) =>
+ code"$evPrim = org.apache.spark.sql.catalyst.util.STUtils.geographyToGeometry($c);"
+ case _: GeometryType =>
+ (c, evPrim, _) =>
+ code"$evPrim = $c;"
+ }
+ }
+
private[this] def castArrayCode(
fromType: DataType, toType: DataType, ctx: CodegenContext): CastFunction = {
val elementCast = nullSafeCastFunction(fromType, toType, ctx)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala
index 5d2fd14eee298..9a0aaea75f810 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala
@@ -18,10 +18,11 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
-import org.apache.spark.sql.catalyst.util.CollationFactory
+import org.apache.spark.sql.catalyst.util.{CollationFactory, UnsafeRowUtils}
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.ArrayImplicits.SparkArrayOps
case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] =
@@ -46,3 +47,61 @@ case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsIn
override def child: Expression = expr
}
+
+object CollationKey {
+ /**
+ * Recursively process the expression in order to recursively replace non-binary collated strings
+ * with their associated collation key.
+ */
+ def injectCollationKey(expr: Expression): Expression = {
+ injectCollationKey(expr, expr.dataType)
+ }
+
+ private def injectCollationKey(expr: Expression, dt: DataType): Expression = {
+ dt match {
+ // For binary stable expressions, no special handling is needed.
+ case _ if UnsafeRowUtils.isBinaryStable(dt) =>
+ expr
+
+ // Inject CollationKey for non-binary collated strings.
+ case _: StringType =>
+ CollationKey(expr)
+
+ // Recursively process struct fields for non-binary structs.
+ case StructType(fields) =>
+ val transformed = fields.zipWithIndex.map { case (f, i) =>
+ val originalField = GetStructField(expr, i, Some(f.name))
+ val injected = injectCollationKey(originalField, f.dataType)
+ (f, injected, injected.fastEquals(originalField))
+ }
+ val anyChanged = transformed.exists { case (_, _, same) => !same }
+ if (!anyChanged) {
+ expr
+ } else {
+ val struct = CreateNamedStruct(
+ transformed.flatMap { case (f, injected, _) =>
+ Seq(Literal(f.name), injected)
+ }.toImmutableArraySeq)
+ if (expr.nullable) {
+ If(IsNull(expr), Literal(null, struct.dataType), struct)
+ } else {
+ struct
+ }
+ }
+
+ // Recursively process array elements for non-binary arrays.
+ case ArrayType(et, containsNull) =>
+ val param: NamedExpression = NamedLambdaVariable("a", et, containsNull)
+ val funcBody: Expression = injectCollationKey(param, et)
+ if (!funcBody.fastEquals(param)) {
+ ArrayTransform(expr, LambdaFunction(funcBody, Seq(param)))
+ } else {
+ expr
+ }
+
+ // Joins are not supported on maps, so there's no special handling for MapType.
+ case _ =>
+ expr
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala
index cfcb53769e253..8fb1bf51319cc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/datasketchesAggregates.scala
@@ -66,6 +66,9 @@ case class HllSketchAgg(
// Hllsketch config - mark as lazy so that they're not evaluated during tree transformation.
lazy val lgConfigK: Int = {
+ if (!right.foldable) {
+ throw QueryExecutionErrors.hllKMustBeConstantError(prettyName)
+ }
val lgConfigK = right.eval().asInstanceOf[Int]
HllSketchAgg.checkLgK(lgConfigK)
lgConfigK
@@ -335,7 +338,8 @@ case class HllUnionAgg(
union.update(sketch)
Some(union)
} catch {
- case _: SketchesArgumentException | _: java.lang.Error =>
+ case _: SketchesArgumentException | _: java.lang.Error
+ | _: ArrayIndexOutOfBoundsException =>
throw QueryExecutionErrors.hllInvalidInputSketchBuffer(prettyName)
}
case _ =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/kllAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/kllAggregates.scala
new file mode 100644
index 0000000000000..6e3ea19425d9c
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/kllAggregates.scala
@@ -0,0 +1,856 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.datasketches.kll.{KllDoublesSketch, KllFloatsSketch, KllLongsSketch, KllSketch}
+import org.apache.datasketches.memory.Memory
+
+import org.apache.spark.SparkUnsupportedOperationException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription}
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.types.{AbstractDataType, BinaryType, ByteType, DataType, DoubleType, FloatType, IntegerType, LongType, ShortType, TypeCollection}
+
+/**
+ * The KllSketchAggBigint function utilizes an Apache DataSketches KllLongsSketch instance to
+ * compute quantiles of the values of an input expression (such as an input column in a table).
+ * It outputs the binary representation of the KllLongsSketch.
+ *
+ * See [[https://datasketches.apache.org/docs/KLL/KLLSketch.html]] for more information.
+ *
+ * @param child
+ * child expression against which quantile computation will occur
+ * @param kExpr
+ * optional expression for the k parameter from the Apache DataSketches library that controls
+ * the size and accuracy of the sketch. Must be a constant integer between 8 and 65535.
+ * Default is 200 (normalized rank error ~1.65%). Larger k values provide more accurate
+ * estimates but result in larger, slower sketches.
+ * @param mutableAggBufferOffset
+ * offset for mutable aggregation buffer
+ * @param inputAggBufferOffset
+ * offset for input aggregation buffer
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(expr[, k]) - Returns the KllLongsSketch compact binary representation.
+ The optional k parameter controls the size and accuracy of the sketch (default 200, range 8-65535).
+ Larger k values provide more accurate quantile estimates but result in larger, slower sketches.
+ """,
+ examples = """
+ Examples:
+ > SELECT LENGTH(kll_sketch_to_string_bigint(_FUNC_(col))) > 0 FROM VALUES (1), (2), (3), (4), (5) tab(col);
+ true
+ > SELECT LENGTH(kll_sketch_to_string_bigint(_FUNC_(col, 400))) > 0 FROM VALUES (1), (2), (3), (4), (5) tab(col);
+ true
+ """,
+ group = "agg_funcs",
+ since = "4.1.0")
+// scalastyle:on line.size.limit
+case class KllSketchAggBigint(
+ child: Expression,
+ kExpr: Option[Expression] = None,
+ override val mutableAggBufferOffset: Int = 0,
+ override val inputAggBufferOffset: Int = 0)
+ extends TypedImperativeAggregate[KllLongsSketch]
+ with KllSketchAggBase
+ with ExpectsInputTypes {
+ def this(child: Expression) = this(child, None, 0, 0)
+ def this(child: Expression, kExpr: Expression) = this(child, Some(kExpr), 0, 0)
+
+ override def children: Seq[Expression] = child +: kExpr.toSeq
+
+ override def withNewMutableAggBufferOffset(
+ newMutableAggBufferOffset: Int): KllSketchAggBigint =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+ override def withNewInputAggBufferOffset(
+ newInputAggBufferOffset: Int): KllSketchAggBigint =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): KllSketchAggBigint = {
+ if (newChildren.length == 1) {
+ copy(child = newChildren(0), kExpr = None)
+ } else {
+ copy(child = newChildren(0), kExpr = Some(newChildren(1)))
+ }
+ }
+
+ override def dataType: DataType = BinaryType
+ override def inputTypes: Seq[AbstractDataType] = {
+ val baseTypes = Seq(
+ TypeCollection(
+ ByteType,
+ IntegerType,
+ LongType,
+ ShortType))
+ if (kExpr.isDefined) baseTypes :+ IntegerType else baseTypes
+ }
+ override def nullable: Boolean = false
+ override def prettyName: String = "kll_sketch_agg_bigint"
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else {
+ checkKInputDataTypes()
+ }
+ }
+
+ override def createAggregationBuffer(): KllLongsSketch =
+ KllLongsSketch.newHeapInstance(kValue)
+
+ /**
+ * Evaluate the input row and update the KllLongsSketch instance with the row's value. The update
+ * function only supports a subset of Spark SQL types, and an exception will be thrown for
+ * unsupported types.
+ * Note, null values are ignored.
+ */
+ override def update(sketch: KllLongsSketch, input: InternalRow): KllLongsSketch = {
+ val v = child.eval(input)
+ if (v == null) {
+ sketch
+ } else {
+ // Handle the different data types for sketch updates.
+ child.dataType match {
+ case ByteType =>
+ sketch.update(v.asInstanceOf[Byte].toLong)
+ case IntegerType =>
+ sketch.update(v.asInstanceOf[Int].toLong)
+ case LongType =>
+ sketch.update(v.asInstanceOf[Long])
+ case ShortType =>
+ sketch.update(v.asInstanceOf[Short].toLong)
+ case _ =>
+ throw unexpectedInputDataTypeError(child)
+ }
+ sketch
+ }
+ }
+
+ /** Merges an input sketch into the current aggregation buffer. */
+ override def merge(updateBuffer: KllLongsSketch, input: KllLongsSketch): KllLongsSketch = {
+ try {
+ updateBuffer.merge(input)
+ updateBuffer
+ } catch {
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ }
+
+ /** Returns a sketch derived from the input column or expression. */
+ override def eval(sketch: KllLongsSketch): Any = sketch.toByteArray
+
+ /** Converts the underlying sketch state into a byte array. */
+ override def serialize(sketch: KllLongsSketch): Array[Byte] = sketch.toByteArray
+
+ /** Wraps the byte array into a sketch instance. */
+ override def deserialize(buffer: Array[Byte]): KllLongsSketch = if (buffer.nonEmpty) {
+ try {
+ KllLongsSketch.heapify(Memory.wrap(buffer))
+ } catch {
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ } else {
+ this.createAggregationBuffer()
+ }
+}
+
+/**
+ * The KllSketchAggFloat function utilizes an Apache DataSketches KllFloatsSketch instance to
+ * compute quantiles of the values of an input expression (such as an input column in a table).
+ * It outputs the binary representation of the KllFloatsSketch.
+ *
+ * See [[https://datasketches.apache.org/docs/KLL/KLLSketch.html]] for more information.
+ *
+ * @param child
+ * child expression against which quantile computation will occur
+ * @param kExpr
+ * optional expression for the k parameter from the Apache DataSketches library that controls
+ * the size and accuracy of the sketch. Must be a constant integer between 8 and 65535.
+ * Default is 200 (normalized rank error ~1.65%). Larger k values provide more accurate
+ * estimates but result in larger, slower sketches.
+ * @param mutableAggBufferOffset
+ * offset for mutable aggregation buffer
+ * @param inputAggBufferOffset
+ * offset for input aggregation buffer
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(expr[, k]) - Returns the KllFloatsSketch compact binary representation.
+ The optional k parameter controls the size and accuracy of the sketch (default 200, range 8-65535).
+ Larger k values provide more accurate quantile estimates but result in larger, slower sketches.
+ """,
+ examples = """
+ Examples:
+ > SELECT LENGTH(kll_sketch_to_string_float(_FUNC_(col))) > 0 FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) tab(col);
+ true
+ > SELECT LENGTH(kll_sketch_to_string_float(_FUNC_(col, 400))) > 0 FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) tab(col);
+ true
+ """,
+ group = "agg_funcs",
+ since = "4.1.0")
+// scalastyle:on line.size.limit
+case class KllSketchAggFloat(
+ child: Expression,
+ kExpr: Option[Expression] = None,
+ override val mutableAggBufferOffset: Int = 0,
+ override val inputAggBufferOffset: Int = 0)
+ extends TypedImperativeAggregate[KllFloatsSketch]
+ with KllSketchAggBase
+ with ExpectsInputTypes {
+ def this(child: Expression) = this(child, None, 0, 0)
+ def this(child: Expression, kExpr: Expression) = this(child, Some(kExpr), 0, 0)
+
+ override def children: Seq[Expression] = child +: kExpr.toSeq
+
+ override def withNewMutableAggBufferOffset(
+ newMutableAggBufferOffset: Int): KllSketchAggFloat =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+ override def withNewInputAggBufferOffset(
+ newInputAggBufferOffset: Int): KllSketchAggFloat =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): KllSketchAggFloat = {
+ if (newChildren.length == 1) {
+ copy(child = newChildren(0), kExpr = None)
+ } else {
+ copy(child = newChildren(0), kExpr = Some(newChildren(1)))
+ }
+ }
+
+ override def dataType: DataType = BinaryType
+ override def inputTypes: Seq[AbstractDataType] = {
+ val baseTypes = Seq(FloatType)
+ if (kExpr.isDefined) baseTypes :+ IntegerType else baseTypes
+ }
+ override def nullable: Boolean = false
+ override def prettyName: String = "kll_sketch_agg_float"
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else {
+ checkKInputDataTypes()
+ }
+ }
+
+ override def createAggregationBuffer(): KllFloatsSketch =
+ KllFloatsSketch.newHeapInstance(kValue)
+
+ /**
+ * Evaluate the input row and update the KllFloatsSketch instance with the row's value. The update
+ * function only supports FloatType to avoid precision loss from integer-to-float conversion.
+ * Users should use kll_sketch_agg_bigint for integer types.
+ * Note, Null values are ignored.
+ */
+ override def update(sketch: KllFloatsSketch, input: InternalRow): KllFloatsSketch = {
+ val v = child.eval(input)
+ if (v == null) {
+ sketch
+ } else {
+ // Handle the different data types for sketch updates.
+ child.dataType match {
+ case FloatType =>
+ sketch.update(v.asInstanceOf[Float])
+ case _ =>
+ throw unexpectedInputDataTypeError(child)
+ }
+ sketch
+ }
+ }
+
+ /** Merges an input sketch into the current aggregation buffer. */
+ override def merge(updateBuffer: KllFloatsSketch, input: KllFloatsSketch): KllFloatsSketch = {
+ try {
+ updateBuffer.merge(input)
+ updateBuffer
+ } catch {
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ }
+
+ /** Returns a sketch derived from the input column or expression. */
+ override def eval(sketch: KllFloatsSketch): Any = sketch.toByteArray
+
+ /** Converts the underlying sketch state into a byte array. */
+ override def serialize(sketch: KllFloatsSketch): Array[Byte] = sketch.toByteArray
+
+ /** Wraps the byte array into a sketch instance. */
+ override def deserialize(buffer: Array[Byte]): KllFloatsSketch = if (buffer.nonEmpty) {
+ try {
+ KllFloatsSketch.heapify(Memory.wrap(buffer))
+ } catch {
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ } else {
+ this.createAggregationBuffer()
+ }
+}
+
+/**
+ * The KllSketchAggDouble function utilizes an Apache DataSketches KllDoublesSketch instance to
+ * compute quantiles of the values of an input expression (such as an input column in a table).
+ * It outputs the binary representation of the KllDoublesSketch.
+ *
+ * See [[https://datasketches.apache.org/docs/KLL/KLLSketch.html]] for more information.
+ *
+ * @param child
+ * child expression against which quantile computation will occur
+ * @param kExpr
+ * optional expression for the k parameter from the Apache DataSketches library that controls
+ * the size and accuracy of the sketch. Must be a constant integer between 8 and 65535.
+ * Default is 200 (normalized rank error ~1.65%). Larger k values provide more accurate
+ * estimates but result in larger, slower sketches.
+ * @param mutableAggBufferOffset
+ * offset for mutable aggregation buffer
+ * @param inputAggBufferOffset
+ * offset for input aggregation buffer
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(expr[, k]) - Returns the KllDoublesSketch compact binary representation.
+ The optional k parameter controls the size and accuracy of the sketch (default 200, range 8-65535).
+ Larger k values provide more accurate quantile estimates but result in larger, slower sketches.
+ """,
+ examples = """
+ Examples:
+ > SELECT LENGTH(kll_sketch_to_string_double(_FUNC_(col))) > 0 FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col);
+ true
+ > SELECT LENGTH(kll_sketch_to_string_double(_FUNC_(col, 400))) > 0 FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col);
+ true
+ """,
+ group = "agg_funcs",
+ since = "4.1.0")
+// scalastyle:on line.size.limit
+case class KllSketchAggDouble(
+ child: Expression,
+ kExpr: Option[Expression] = None,
+ override val mutableAggBufferOffset: Int = 0,
+ override val inputAggBufferOffset: Int = 0)
+ extends TypedImperativeAggregate[KllDoublesSketch]
+ with KllSketchAggBase
+ with ExpectsInputTypes {
+ def this(child: Expression) = this(child, None, 0, 0)
+ def this(child: Expression, kExpr: Expression) = this(child, Some(kExpr), 0, 0)
+
+ override def children: Seq[Expression] = child +: kExpr.toSeq
+
+ override def withNewMutableAggBufferOffset(
+ newMutableAggBufferOffset: Int): KllSketchAggDouble =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+ override def withNewInputAggBufferOffset(
+ newInputAggBufferOffset: Int): KllSketchAggDouble =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): KllSketchAggDouble = {
+ if (newChildren.length == 1) {
+ copy(child = newChildren(0), kExpr = None)
+ } else {
+ copy(child = newChildren(0), kExpr = Some(newChildren(1)))
+ }
+ }
+
+ override def dataType: DataType = BinaryType
+ override def inputTypes: Seq[AbstractDataType] = {
+ val baseTypes = Seq(TypeCollection(FloatType, DoubleType))
+ if (kExpr.isDefined) baseTypes :+ IntegerType else baseTypes
+ }
+ override def nullable: Boolean = false
+ override def prettyName: String = "kll_sketch_agg_double"
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else {
+ checkKInputDataTypes()
+ }
+ }
+
+ override def createAggregationBuffer(): KllDoublesSketch =
+ KllDoublesSketch.newHeapInstance(kValue)
+
+ /**
+ * Evaluate the input row and update the KllDoublesSketch instance with the row's value.
+ * The update function only supports FloatType and DoubleType to avoid precision loss from
+ * integer-to-double conversion. Users should use kll_sketch_agg_bigint for integer types.
+ * Note, Null values are ignored.
+ */
+ override def update(sketch: KllDoublesSketch, input: InternalRow): KllDoublesSketch = {
+ val v = child.eval(input)
+ if (v == null) {
+ sketch
+ } else {
+ // Handle the different data types for sketch updates.
+ child.dataType match {
+ case DoubleType =>
+ sketch.update(v.asInstanceOf[Double])
+ case FloatType =>
+ sketch.update(v.asInstanceOf[Float].toDouble)
+ case _ =>
+ throw unexpectedInputDataTypeError(child)
+ }
+ sketch
+ }
+ }
+
+ /** Merges an input sketch into the current aggregation buffer. */
+ override def merge(updateBuffer: KllDoublesSketch, input: KllDoublesSketch): KllDoublesSketch = {
+ try {
+ updateBuffer.merge(input)
+ updateBuffer
+ } catch {
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ }
+
+ /** Returns a sketch derived from the input column or expression. */
+ override def eval(sketch: KllDoublesSketch): Any = sketch.toByteArray
+
+ /** Converts the underlying sketch state into a byte array. */
+ override def serialize(sketch: KllDoublesSketch): Array[Byte] = sketch.toByteArray
+
+ /** Wraps the byte array into a sketch instance. */
+ override def deserialize(buffer: Array[Byte]): KllDoublesSketch = if (buffer.nonEmpty) {
+ try {
+ KllDoublesSketch.heapify(Memory.wrap(buffer))
+ } catch {
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ } else {
+ this.createAggregationBuffer()
+ }
+}
+
+/**
+ * The KllMergeAggBigint function merges multiple Apache DataSketches KllLongsSketch instances
+ * that have been serialized to binary format. This is useful for combining sketches created
+ * in separate aggregations (e.g., from different partitions or time windows).
+ * It outputs the merged binary representation of the KllLongsSketch.
+ *
+ * See [[https://datasketches.apache.org/docs/KLL/KLLSketch.html]] for more information.
+ *
+ * @param child
+ * child expression containing binary KllLongsSketch representations to merge
+ * @param kExpr
+ * optional expression for the k parameter from the Apache DataSketches library that controls
+ * the size and accuracy of the sketch. Must be a constant integer between 8 and 65535.
+ * If not specified, the merged sketch adopts the k value from the first input sketch.
+ * If specified, the value is used to initialize the aggregation buffer. The merge operation
+ * can handle input sketches with different k values. Larger k values provide more accurate
+ * estimates but result in larger, slower sketches.
+ * @param mutableAggBufferOffset
+ * offset for mutable aggregation buffer
+ * @param inputAggBufferOffset
+ * offset for input aggregation buffer
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(expr[, k]) - Merges binary KllLongsSketch representations and returns the merged sketch.
+ The input expression should contain binary sketch representations (e.g., from kll_sketch_agg_bigint).
+ The optional k parameter controls the size and accuracy of the merged sketch (range 8-65535).
+ If k is not specified, the merged sketch adopts the k value from the first input sketch.
+ """,
+ examples = """
+ Examples:
+ > SELECT kll_sketch_get_n_bigint(_FUNC_(sketch)) FROM (SELECT kll_sketch_agg_bigint(col) as sketch FROM VALUES (1), (2), (3) tab(col) UNION ALL SELECT kll_sketch_agg_bigint(col) as sketch FROM VALUES (4), (5), (6) tab(col)) t;
+ 6
+ """,
+ group = "agg_funcs",
+ since = "4.1.0")
+// scalastyle:on line.size.limit
+case class KllMergeAggBigint(
+ child: Expression,
+ kExpr: Option[Expression] = None,
+ override val mutableAggBufferOffset: Int = 0,
+ override val inputAggBufferOffset: Int = 0)
+ extends KllMergeAggBase[KllLongsSketch] {
+ def this(child: Expression) = this(child, None, 0, 0)
+ def this(child: Expression, kExpr: Expression) = this(child, Some(kExpr), 0, 0)
+
+ override def withNewMutableAggBufferOffset(
+ newMutableAggBufferOffset: Int): KllMergeAggBigint =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+ override def withNewInputAggBufferOffset(
+ newInputAggBufferOffset: Int): KllMergeAggBigint =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): KllMergeAggBigint = {
+ if (newChildren.length == 1) {
+ copy(child = newChildren(0), kExpr = None)
+ } else {
+ copy(child = newChildren(0), kExpr = Some(newChildren(1)))
+ }
+ }
+
+ override def prettyName: String = "kll_merge_agg_bigint"
+
+ // Factory method implementations
+ protected def newHeapInstance(k: Int): KllLongsSketch = KllLongsSketch.newHeapInstance(k)
+ protected def wrapSketch(bytes: Array[Byte]): KllLongsSketch =
+ KllLongsSketch.wrap(Memory.wrap(bytes))
+ protected def heapifySketch(bytes: Array[Byte]): KllLongsSketch =
+ KllLongsSketch.heapify(Memory.wrap(bytes))
+ protected def toByteArray(sketch: KllLongsSketch): Array[Byte] = sketch.toByteArray
+}
+
+/**
+ * The KllMergeAggFloat function merges multiple Apache DataSketches KllFloatsSketch instances
+ * that have been serialized to binary format. This is useful for combining sketches created
+ * in separate aggregations (e.g., from different partitions or time windows).
+ * It outputs the merged binary representation of the KllFloatsSketch.
+ *
+ * See [[https://datasketches.apache.org/docs/KLL/KLLSketch.html]] for more information.
+ *
+ * @param child
+ * child expression containing binary KllFloatsSketch representations to merge
+ * @param kExpr
+ * optional expression for the k parameter from the Apache DataSketches library that controls
+ * the size and accuracy of the sketch. Must be a constant integer between 8 and 65535.
+ * If not specified, the merged sketch adopts the k value from the first input sketch.
+ * If specified, the value is used to initialize the aggregation buffer. The merge operation
+ * can handle input sketches with different k values. Larger k values provide more accurate
+ * estimates but result in larger, slower sketches.
+ * @param mutableAggBufferOffset
+ * offset for mutable aggregation buffer
+ * @param inputAggBufferOffset
+ * offset for input aggregation buffer
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(expr[, k]) - Merges binary KllFloatsSketch representations and returns the merged sketch.
+ The input expression should contain binary sketch representations (e.g., from kll_sketch_agg_float).
+ The optional k parameter controls the size and accuracy of the merged sketch (range 8-65535).
+ If k is not specified, the merged sketch adopts the k value from the first input sketch.
+ """,
+ examples = """
+ Examples:
+ > SELECT kll_sketch_get_n_float(_FUNC_(sketch)) FROM (SELECT kll_sketch_agg_float(col) as sketch FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)) tab(col) UNION ALL SELECT kll_sketch_agg_float(col) as sketch FROM VALUES (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)), (CAST(6.0 AS FLOAT)) tab(col)) t;
+ 6
+ """,
+ group = "agg_funcs",
+ since = "4.1.0")
+// scalastyle:on line.size.limit
+case class KllMergeAggFloat(
+ child: Expression,
+ kExpr: Option[Expression] = None,
+ override val mutableAggBufferOffset: Int = 0,
+ override val inputAggBufferOffset: Int = 0)
+ extends KllMergeAggBase[KllFloatsSketch] {
+ def this(child: Expression) = this(child, None, 0, 0)
+ def this(child: Expression, kExpr: Expression) = this(child, Some(kExpr), 0, 0)
+
+ override def withNewMutableAggBufferOffset(
+ newMutableAggBufferOffset: Int): KllMergeAggFloat =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+ override def withNewInputAggBufferOffset(
+ newInputAggBufferOffset: Int): KllMergeAggFloat =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): KllMergeAggFloat = {
+ if (newChildren.length == 1) {
+ copy(child = newChildren(0), kExpr = None)
+ } else {
+ copy(child = newChildren(0), kExpr = Some(newChildren(1)))
+ }
+ }
+
+ override def prettyName: String = "kll_merge_agg_float"
+
+ // Factory method implementations
+ protected def newHeapInstance(k: Int): KllFloatsSketch = KllFloatsSketch.newHeapInstance(k)
+ protected def wrapSketch(bytes: Array[Byte]): KllFloatsSketch =
+ KllFloatsSketch.wrap(Memory.wrap(bytes))
+ protected def heapifySketch(bytes: Array[Byte]): KllFloatsSketch =
+ KllFloatsSketch.heapify(Memory.wrap(bytes))
+ protected def toByteArray(sketch: KllFloatsSketch): Array[Byte] = sketch.toByteArray
+}
+
+/**
+ * The KllMergeAggDouble function merges multiple Apache DataSketches KllDoublesSketch instances
+ * that have been serialized to binary format. This is useful for combining sketches created
+ * in separate aggregations (e.g., from different partitions or time windows).
+ * It outputs the merged binary representation of the KllDoublesSketch.
+ *
+ * See [[https://datasketches.apache.org/docs/KLL/KLLSketch.html]] for more information.
+ *
+ * @param child
+ * child expression containing binary KllDoublesSketch representations to merge
+ * @param kExpr
+ * optional expression for the k parameter from the Apache DataSketches library that controls
+ * the size and accuracy of the sketch. Must be a constant integer between 8 and 65535.
+ * If not specified, the merged sketch adopts the k value from the first input sketch.
+ * If specified, the value is used to initialize the aggregation buffer. The merge operation
+ * can handle input sketches with different k values. Larger k values provide more accurate
+ * estimates but result in larger, slower sketches.
+ * @param mutableAggBufferOffset
+ * offset for mutable aggregation buffer
+ * @param inputAggBufferOffset
+ * offset for input aggregation buffer
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(expr[, k]) - Merges binary KllDoublesSketch representations and returns the merged sketch.
+ The input expression should contain binary sketch representations (e.g., from kll_sketch_agg_double).
+ The optional k parameter controls the size and accuracy of the merged sketch (range 8-65535).
+ If k is not specified, the merged sketch adopts the k value from the first input sketch.
+ """,
+ examples = """
+ Examples:
+ > SELECT kll_sketch_get_n_double(_FUNC_(sketch)) FROM (SELECT kll_sketch_agg_double(col) as sketch FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)) tab(col) UNION ALL SELECT kll_sketch_agg_double(col) as sketch FROM VALUES (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)), (CAST(6.0 AS DOUBLE)) tab(col)) t;
+ 6
+ """,
+ group = "agg_funcs",
+ since = "4.1.0")
+// scalastyle:on line.size.limit
+case class KllMergeAggDouble(
+ child: Expression,
+ kExpr: Option[Expression] = None,
+ override val mutableAggBufferOffset: Int = 0,
+ override val inputAggBufferOffset: Int = 0)
+ extends KllMergeAggBase[KllDoublesSketch] {
+ def this(child: Expression) = this(child, None, 0, 0)
+ def this(child: Expression, kExpr: Expression) = this(child, Some(kExpr), 0, 0)
+
+ override def withNewMutableAggBufferOffset(
+ newMutableAggBufferOffset: Int): KllMergeAggDouble =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+ override def withNewInputAggBufferOffset(
+ newInputAggBufferOffset: Int): KllMergeAggDouble =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[Expression]): KllMergeAggDouble = {
+ if (newChildren.length == 1) {
+ copy(child = newChildren(0), kExpr = None)
+ } else {
+ copy(child = newChildren(0), kExpr = Some(newChildren(1)))
+ }
+ }
+
+ override def prettyName: String = "kll_merge_agg_double"
+
+ // Factory method implementations
+ protected def newHeapInstance(k: Int): KllDoublesSketch = KllDoublesSketch.newHeapInstance(k)
+ protected def wrapSketch(bytes: Array[Byte]): KllDoublesSketch =
+ KllDoublesSketch.wrap(Memory.wrap(bytes))
+ protected def heapifySketch(bytes: Array[Byte]): KllDoublesSketch =
+ KllDoublesSketch.heapify(Memory.wrap(bytes))
+ protected def toByteArray(sketch: KllDoublesSketch): Array[Byte] = sketch.toByteArray
+}
+
+/**
+ * Base abstract class for KLL merge aggregate functions that provides common implementation
+ * for merging serialized KLL sketches with optional k parameter.
+ *
+ * @tparam T The KLL sketch type (KllLongsSketch, KllFloatsSketch, or KllDoublesSketch)
+ */
+abstract class KllMergeAggBase[T <: KllSketch]
+ extends TypedImperativeAggregate[Option[T]]
+ with KllSketchAggBase
+ with ExpectsInputTypes {
+
+ def child: Expression
+
+ // Abstract factory methods for sketch-specific instantiation
+ protected def newHeapInstance(k: Int): T
+ protected def wrapSketch(bytes: Array[Byte]): T
+ protected def heapifySketch(bytes: Array[Byte]): T
+ protected def toByteArray(sketch: T): Array[Byte]
+
+ // Common implementations for all merge aggregates
+ override def children: Seq[Expression] = child +: kExpr.toSeq
+
+ override def dataType: DataType = BinaryType
+
+ override def inputTypes: Seq[AbstractDataType] = {
+ val baseTypes = Seq(BinaryType)
+ if (kExpr.isDefined) baseTypes :+ IntegerType else baseTypes
+ }
+
+ override def nullable: Boolean = false
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else {
+ checkKInputDataTypes()
+ }
+ }
+
+ /**
+ * Defer instantiation of the sketch instance until we've deserialized
+ * our first sketch (if kExpr was not provided), and use that sketch's k value.
+ *
+ * @return None if kExpr was not provided, otherwise Some(sketch with specified k)
+ */
+ override def createAggregationBuffer(): Option[T] = {
+ if (kExpr.isDefined) {
+ Some(newHeapInstance(kValue))
+ } else {
+ None
+ }
+ }
+
+ /**
+ * Evaluate the input row and wrap the binary sketch, then merge it into
+ * the current aggregation buffer.
+ * Note, null values are ignored.
+ */
+ override def update(sketchOption: Option[T], input: InternalRow): Option[T] = {
+ val v = child.eval(input)
+ if (v == null) {
+ sketchOption
+ } else {
+ try {
+ val sketchBytes = v.asInstanceOf[Array[Byte]]
+ val inputSketch = wrapSketch(sketchBytes)
+ val sketch = sketchOption.getOrElse(newHeapInstance(inputSketch.getK()))
+ sketch.merge(inputSketch)
+ Some(sketch)
+ } catch {
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ }
+ }
+
+ /** Merges an input sketch into the current aggregation buffer. */
+ override def merge(updateBufferOption: Option[T], inputOption: Option[T]): Option[T] = {
+ (updateBufferOption, inputOption) match {
+ case (Some(updateBuffer), Some(input)) =>
+ try {
+ updateBuffer.merge(input)
+ Some(updateBuffer)
+ } catch {
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ case (Some(_), None) => updateBufferOption
+ case (None, Some(_)) => inputOption
+ case (None, None) => None
+ }
+ }
+
+ /** Returns a sketch derived from the input column or expression. */
+ override def eval(sketchOption: Option[T]): Any = {
+ sketchOption match {
+ case Some(sketch) => toByteArray(sketch)
+ case None => toByteArray(newHeapInstance(kValue))
+ }
+ }
+
+ /** Converts the underlying sketch state into a byte array. */
+ override def serialize(sketchOption: Option[T]): Array[Byte] = {
+ sketchOption match {
+ case Some(sketch) => toByteArray(sketch)
+ case None => toByteArray(newHeapInstance(kValue))
+ }
+ }
+
+ /** Wraps the byte array into a sketch instance. */
+ override def deserialize(buffer: Array[Byte]): Option[T] = {
+ if (buffer.nonEmpty) {
+ try {
+ Some(heapifySketch(buffer))
+ } catch {
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ } else {
+ createAggregationBuffer()
+ }
+ }
+}
+
+/**
+ * Common trait for KLL sketch aggregate functions that support an optional k parameter.
+ */
+trait KllSketchAggBase {
+ def kExpr: Option[Expression]
+ def prettyName: String
+
+ // Constants from the Apache DataSketches library.
+ private val MIN_K = 8
+ private val MAX_K = 65535
+ private val DEFAULT_K = 200
+
+ // Validate and extract k value
+ protected lazy val kValue: Int = {
+ kExpr match {
+ case Some(expr) =>
+ if (!expr.foldable) {
+ throw QueryExecutionErrors.kllSketchKMustBeConstantError(prettyName)
+ }
+ val k = expr.eval().asInstanceOf[Int]
+ if (k < MIN_K || k > MAX_K) {
+ throw QueryExecutionErrors.kllSketchKOutOfRangeError(prettyName, k)
+ }
+ k
+ case None => DEFAULT_K
+ }
+ }
+
+ protected def checkKInputDataTypes(): TypeCheckResult = {
+ kExpr match {
+ case Some(expr) =>
+ if (!expr.foldable) {
+ DataTypeMismatch(
+ errorSubClass = "NON_FOLDABLE_INPUT",
+ messageParameters = Map(
+ "inputName" -> "k",
+ "inputType" -> "int",
+ "inputExpr" -> expr.sql))
+ } else if (expr.eval() == null) {
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_NULL",
+ messageParameters = Map("exprName" -> "k"))
+ } else {
+ // Trigger validation
+ try {
+ kValue
+ TypeCheckResult.TypeCheckSuccess
+ } catch {
+ case e: Exception => TypeCheckResult.TypeCheckFailure(e.getMessage)
+ }
+ }
+ case None => TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
+ protected def unexpectedInputDataTypeError(
+ child: Expression): SparkUnsupportedOperationException =
+ new SparkUnsupportedOperationException(
+ errorClass = "_LEGACY_ERROR_TEMP_3121",
+ messageParameters = Map("dataType" -> child.dataType.toString))
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/thetasketchesAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/thetasketchesAggregates.scala
index 7e55c006782cf..0f148d03cd70b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/thetasketchesAggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/thetasketchesAggregates.scala
@@ -94,6 +94,9 @@ case class ThetaSketchAgg(
// ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation.
lazy val lgNomEntries: Int = {
+ if (!right.foldable) {
+ throw QueryExecutionErrors.thetaLgNomEntriesMustBeConstantError(prettyName)
+ }
val lgNomEntriesInput = right.eval().asInstanceOf[Int]
ThetaSketchUtils.checkLgNomLongs(lgNomEntriesInput, prettyName)
lgNomEntriesInput
@@ -216,7 +219,7 @@ case class ThetaSketchAgg(
messageParameters = Map("dataType" -> left.dataType.toString))
}
- UpdatableSketchBuffer(sketch)
+ updateBuffer
}
/**
@@ -243,13 +246,13 @@ case class ThetaSketchAgg(
// Reuse the existing union in the next iteration. This is the most efficient path.
case (UnionAggregationBuffer(existingUnion), UpdatableSketchBuffer(sketch)) =>
existingUnion.union(sketch.compact)
- UnionAggregationBuffer(existingUnion)
+ updateBuffer
case (UnionAggregationBuffer(existingUnion), FinalizedSketch(sketch)) =>
existingUnion.union(sketch)
- UnionAggregationBuffer(existingUnion)
+ updateBuffer
case (UnionAggregationBuffer(union1), UnionAggregationBuffer(union2)) =>
union1.union(union2.getResult)
- UnionAggregationBuffer(union1)
+ updateBuffer
// Create a new union only when necessary.
case (UpdatableSketchBuffer(sketch1), UpdatableSketchBuffer(sketch2)) =>
createUnionWith(sketch1.compact, sketch2.compact)
@@ -332,6 +335,9 @@ case class ThetaUnionAgg(
// ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation.
lazy val lgNomEntries: Int = {
+ if (!right.foldable) {
+ throw QueryExecutionErrors.thetaLgNomEntriesMustBeConstantError(prettyName)
+ }
val lgNomEntriesInput = right.eval().asInstanceOf[Int]
ThetaSketchUtils.checkLgNomLongs(lgNomEntriesInput, prettyName)
lgNomEntriesInput
@@ -414,7 +420,7 @@ case class ThetaUnionAgg(
case _ => throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName)
}
union.union(inputSketch)
- UnionAggregationBuffer(union)
+ unionBuffer
}
/**
@@ -430,11 +436,11 @@ case class ThetaUnionAgg(
// If both arguments are union objects, merge them directly.
case (UnionAggregationBuffer(unionLeft), UnionAggregationBuffer(unionRight)) =>
unionLeft.union(unionRight.getResult)
- UnionAggregationBuffer(unionLeft)
+ unionBuffer
// The input was serialized then deserialized.
case (UnionAggregationBuffer(union), FinalizedSketch(sketch)) =>
union.union(sketch)
- UnionAggregationBuffer(union)
+ unionBuffer
// The program should never make it here, the cases are for defensive programming.
case (FinalizedSketch(sketch1), FinalizedSketch(sketch2)) =>
val union = SetOperation.builder.setLogNominalEntries(lgNomEntries).buildUnion
@@ -443,7 +449,7 @@ case class ThetaUnionAgg(
UnionAggregationBuffer(union)
case (FinalizedSketch(sketch), UnionAggregationBuffer(union)) =>
union.union(sketch)
- UnionAggregationBuffer(union)
+ input
case _ => throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName)
}
}
@@ -491,7 +497,7 @@ case class ThetaUnionAgg(
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
- _FUNC_(expr, lgNomEntries) - Returns the ThetaSketch's Compact binary representation
+ _FUNC_(expr) - Returns the ThetaSketch's Compact binary representation
by intersecting all the Theta sketches in the input column.""",
examples = """
Examples:
@@ -576,7 +582,7 @@ case class ThetaIntersectionAgg(
case _ => throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName)
}
intersection.intersect(inputSketch)
- IntersectionAggregationBuffer(intersection)
+ intersectionBuffer
}
/**
@@ -597,11 +603,11 @@ case class ThetaIntersectionAgg(
IntersectionAggregationBuffer(intersectLeft),
IntersectionAggregationBuffer(intersectRight)) =>
intersectLeft.intersect(intersectRight.getResult)
- IntersectionAggregationBuffer(intersectLeft)
+ intersectionBuffer
// The input was serialized then deserialized.
case (IntersectionAggregationBuffer(intersection), FinalizedSketch(sketch)) =>
intersection.intersect(sketch)
- IntersectionAggregationBuffer(intersection)
+ intersectionBuffer
// The program should never make it here, the cases are for defensive programming.
case (FinalizedSketch(sketch1), FinalizedSketch(sketch2)) =>
val intersection =
@@ -611,7 +617,7 @@ case class ThetaIntersectionAgg(
IntersectionAggregationBuffer(intersection)
case (FinalizedSketch(sketch), IntersectionAggregationBuffer(intersection)) =>
intersection.intersect(sketch)
- IntersectionAggregationBuffer(intersection)
+ input
case _ => throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
index dba061eeb870d..f40077c53311b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala
@@ -431,7 +431,7 @@ trait GetArrayItemUtil {
true
}
} else {
- if (failOnError) arrayElementNullable else true
+ if (failOnError) arrayElementNullable || child.nullable else true
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala
index a27460e2be1cd..1a7e3b03c0e6a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraints.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder
import org.apache.spark.sql.connector.catalog.constraints.Constraint
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.types.{BooleanType, DataType}
+import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType}
trait TableConstraint extends Expression with Unevaluable {
/** Convert to a data source v2 constraint */
@@ -122,9 +122,12 @@ case class CheckConstraint(
override val tableName: String = null,
override val userProvidedCharacteristic: ConstraintCharacteristic = ConstraintCharacteristic.empty)
extends UnaryExpression
- with TableConstraint {
+ with TableConstraint
+ with ImplicitCastInputTypes {
// scalastyle:on line.size.limit
+ override def inputTypes: Seq[AbstractDataType] = Seq(BooleanType)
+
def toV2Constraint: Constraint = {
val predicate = new V2ExpressionBuilder(child, true).buildPredicate().orNull
val enforced = userProvidedCharacteristic.enforced.getOrElse(true)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datasketchesExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datasketchesExpressions.scala
index a4ac0bdbb11d3..1880d71e7d542 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datasketchesExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datasketchesExpressions.scala
@@ -56,7 +56,8 @@ case class HllSketchEstimate(child: Expression)
try {
Math.round(HllSketch.heapify(Memory.wrap(buffer)).getEstimate)
} catch {
- case _: SketchesArgumentException | _: java.lang.Error =>
+ case _: SketchesArgumentException | _: java.lang.Error
+ | _: ArrayIndexOutOfBoundsException =>
throw QueryExecutionErrors.hllInvalidInputSketchBuffer(prettyName)
}
}
@@ -108,13 +109,15 @@ case class HllUnion(first: Expression, second: Expression, third: Expression)
val sketch1 = try {
HllSketch.heapify(Memory.wrap(value1.asInstanceOf[Array[Byte]]))
} catch {
- case _: SketchesArgumentException | _: java.lang.Error =>
+ case _: SketchesArgumentException | _: java.lang.Error
+ | _: ArrayIndexOutOfBoundsException =>
throw QueryExecutionErrors.hllInvalidInputSketchBuffer(prettyName)
}
val sketch2 = try {
HllSketch.heapify(Memory.wrap(value2.asInstanceOf[Array[Byte]]))
} catch {
- case _: SketchesArgumentException | _: java.lang.Error =>
+ case _: SketchesArgumentException | _: java.lang.Error
+ | _: ArrayIndexOutOfBoundsException =>
throw QueryExecutionErrors.hllInvalidInputSketchBuffer(prettyName)
}
val allowDifferentLgConfigK = value3.asInstanceOf[Boolean]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/kllExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/kllExpressions.scala
new file mode 100644
index 0000000000000..af6c1a32e229f
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/kllExpressions.scala
@@ -0,0 +1,701 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.datasketches.kll.{KllDoublesSketch, KllFloatsSketch, KllLongsSketch}
+import org.apache.datasketches.memory.Memory
+
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLExpr, toSQLId, toSQLType}
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.types.{AbstractDataType, ArrayType, BinaryType, DataType, DoubleType, FloatType, LongType, StringType, TypeCollection}
+import org.apache.spark.unsafe.types.UTF8String
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(expr) - Returns human readable summary information about this sketch.
+ """,
+ examples = """
+ Examples:
+ > SELECT LENGTH(_FUNC_(kll_sketch_agg_bigint(col))) > 0 FROM VALUES (1), (2), (3), (4), (5) tab(col);
+ true
+ """,
+ group = "misc_funcs",
+ since = "4.1.0")
+case class KllSketchToStringBigint(child: Expression) extends KllSketchToStringBase {
+ override protected def withNewChildInternal(newChild: Expression): KllSketchToStringBigint =
+ copy(child = newChild)
+ override def prettyName: String = "kll_sketch_to_string_bigint"
+ override def nullSafeEval(input: Any): Any = {
+ try {
+ val buffer = input.asInstanceOf[Array[Byte]]
+ val sketch = KllLongsSketch.heapify(Memory.wrap(buffer))
+ UTF8String.fromString(sketch.toString())
+ } catch {
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(expr) - Returns human readable summary information about this sketch.
+ """,
+ examples = """
+ Examples:
+ > SELECT LENGTH(_FUNC_(kll_sketch_agg_float(col))) > 0 FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) tab(col);
+ true
+ """,
+ group = "misc_funcs",
+ since = "4.1.0")
+case class KllSketchToStringFloat(child: Expression) extends KllSketchToStringBase {
+ override protected def withNewChildInternal(newChild: Expression): KllSketchToStringFloat =
+ copy(child = newChild)
+ override def prettyName: String = "kll_sketch_to_string_float"
+ override def nullSafeEval(input: Any): Any = {
+ try {
+ val buffer = input.asInstanceOf[Array[Byte]]
+ val sketch = KllFloatsSketch.heapify(Memory.wrap(buffer))
+ UTF8String.fromString(sketch.toString())
+ } catch {
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(expr) - Returns human readable summary information about this sketch.
+ """,
+ examples = """
+ Examples:
+ > SELECT LENGTH(_FUNC_(kll_sketch_agg_double(col))) > 0 FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col);
+ true
+ """,
+ group = "misc_funcs",
+ since = "4.1.0")
+case class KllSketchToStringDouble(child: Expression) extends KllSketchToStringBase {
+ override protected def withNewChildInternal(newChild: Expression): KllSketchToStringDouble =
+ copy(child = newChild)
+ override def prettyName: String = "kll_sketch_to_string_double"
+ override def nullSafeEval(input: Any): Any = {
+ try {
+ val buffer = input.asInstanceOf[Array[Byte]]
+ val sketch = KllDoublesSketch.heapify(Memory.wrap(buffer))
+ UTF8String.fromString(sketch.toString())
+ } catch {
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ }
+}
+
+/** This is a base class for the above expressions to reduce boilerplate. */
+abstract class KllSketchToStringBase
+ extends UnaryExpression
+ with CodegenFallback
+ with ImplicitCastInputTypes {
+ override def dataType: DataType = StringType
+ override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
+ override def nullIntolerant: Boolean = true
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(expr) - Returns the number of items collected in the sketch.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(kll_sketch_agg_bigint(col)) FROM VALUES (1), (2), (3), (4), (5) tab(col);
+ 5
+ """,
+ group = "misc_funcs",
+ since = "4.1.0")
+case class KllSketchGetNBigint(child: Expression) extends KllSketchGetNBase {
+ override protected def withNewChildInternal(newChild: Expression): KllSketchGetNBigint =
+ copy(child = newChild)
+ override def prettyName: String = "kll_sketch_get_n_bigint"
+ override def nullSafeEval(input: Any): Any = {
+ try {
+ val buffer = input.asInstanceOf[Array[Byte]]
+ val sketch = KllLongsSketch.heapify(Memory.wrap(buffer))
+ sketch.getN()
+ } catch {
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(expr) - Returns the number of items collected in the sketch.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(kll_sketch_agg_float(col)) FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) tab(col);
+ 5
+ """,
+ group = "misc_funcs",
+ since = "4.1.0")
+case class KllSketchGetNFloat(child: Expression) extends KllSketchGetNBase {
+ override protected def withNewChildInternal(newChild: Expression): KllSketchGetNFloat =
+ copy(child = newChild)
+ override def prettyName: String = "kll_sketch_get_n_float"
+ override def nullSafeEval(input: Any): Any = {
+ try {
+ val buffer = input.asInstanceOf[Array[Byte]]
+ val sketch = KllFloatsSketch.heapify(Memory.wrap(buffer))
+ sketch.getN()
+ } catch {
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(expr) - Returns the number of items collected in the sketch.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(kll_sketch_agg_double(col)) FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col);
+ 5
+ """,
+ group = "misc_funcs",
+ since = "4.1.0")
+case class KllSketchGetNDouble(child: Expression) extends KllSketchGetNBase {
+ override protected def withNewChildInternal(newChild: Expression): KllSketchGetNDouble =
+ copy(child = newChild)
+ override def prettyName: String = "kll_sketch_get_n_double"
+ override def nullSafeEval(input: Any): Any = {
+ try {
+ val buffer = input.asInstanceOf[Array[Byte]]
+ val sketch = KllDoublesSketch.heapify(Memory.wrap(buffer))
+ sketch.getN()
+ } catch {
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ }
+}
+
+/** This is a base class for the above expressions to reduce boilerplate. */
+abstract class KllSketchGetNBase
+ extends UnaryExpression
+ with CodegenFallback
+ with ImplicitCastInputTypes {
+ override def dataType: DataType = LongType
+ override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
+ override def nullIntolerant: Boolean = true
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(left, right) - Merges two sketch buffers together into one.
+ """,
+ examples = """
+ Examples:
+ > SELECT LENGTH(kll_sketch_to_string_bigint(_FUNC_(kll_sketch_agg_bigint(col), kll_sketch_agg_bigint(col)))) > 0 FROM VALUES (1), (2), (3), (4), (5) tab(col);
+ true
+ """,
+ group = "misc_funcs",
+ since = "4.1.0")
+case class KllSketchMergeBigint(left: Expression, right: Expression) extends KllSketchMergeBase {
+ override def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression =
+ copy(left = newLeft, right = newRight)
+ override def prettyName: String = "kll_sketch_merge_bigint"
+ override def nullSafeEval(left: Any, right: Any): Any = {
+ try {
+ val leftBuffer = left.asInstanceOf[Array[Byte]]
+ val rightBuffer = right.asInstanceOf[Array[Byte]]
+ val leftSketch = KllLongsSketch.heapify(Memory.wrap(leftBuffer))
+ val rightSketch = KllLongsSketch.wrap(Memory.wrap(rightBuffer))
+ leftSketch.merge(rightSketch)
+ leftSketch.toByteArray
+ } catch {
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(left, right) - Merges two sketch buffers together into one.
+ """,
+ examples = """
+ Examples:
+ > SELECT LENGTH(kll_sketch_to_string_float(_FUNC_(kll_sketch_agg_float(col), kll_sketch_agg_float(col)))) > 0 FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) tab(col);
+ true
+ """,
+ group = "misc_funcs",
+ since = "4.1.0")
+case class KllSketchMergeFloat(left: Expression, right: Expression) extends KllSketchMergeBase {
+ override def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression =
+ copy(left = newLeft, right = newRight)
+ override def prettyName: String = "kll_sketch_merge_float"
+ override def nullSafeEval(left: Any, right: Any): Any = {
+ try {
+ val leftBuffer = left.asInstanceOf[Array[Byte]]
+ val rightBuffer = right.asInstanceOf[Array[Byte]]
+ val leftSketch = KllFloatsSketch.heapify(Memory.wrap(leftBuffer))
+ val rightSketch = KllFloatsSketch.wrap(Memory.wrap(rightBuffer))
+ leftSketch.merge(rightSketch)
+ leftSketch.toByteArray
+ } catch {
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(left, right) - Merges two sketch buffers together into one.
+ """,
+ examples = """
+ Examples:
+ > SELECT LENGTH(kll_sketch_to_string_double(_FUNC_(kll_sketch_agg_double(col), kll_sketch_agg_double(col)))) > 0 FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col);
+ true
+ """,
+ group = "misc_funcs",
+ since = "4.1.0")
+case class KllSketchMergeDouble(left: Expression, right: Expression) extends KllSketchMergeBase {
+ override def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression =
+ copy(left = newLeft, right = newRight)
+ override def prettyName: String = "kll_sketch_merge_double"
+ override def nullSafeEval(left: Any, right: Any): Any = {
+ try {
+ val leftBuffer = left.asInstanceOf[Array[Byte]]
+ val rightBuffer = right.asInstanceOf[Array[Byte]]
+ val leftSketch = KllDoublesSketch.heapify(Memory.wrap(leftBuffer))
+ val rightSketch = KllDoublesSketch.wrap(Memory.wrap(rightBuffer))
+ leftSketch.merge(rightSketch)
+ leftSketch.toByteArray
+ } catch {
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ }
+}
+
+/** This is a base class for the above expressions to reduce boilerplate. */
+abstract class KllSketchMergeBase
+ extends BinaryExpression
+ with CodegenFallback
+ with ImplicitCastInputTypes {
+ override def dataType: DataType = BinaryType
+ override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType, BinaryType)
+ override def nullIntolerant: Boolean = true
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(left, right) - Extracts a single value from the quantiles sketch representing the
+ desired quantile given the input rank. The desired quantile can either be a single value
+ or an array. In the latter case, the function will return an array of results of equal
+ length to the input array.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(kll_sketch_agg_bigint(col), 0.5) > 1 FROM VALUES (1), (2), (3), (4), (5) tab(col);
+ true
+ """,
+ group = "misc_funcs",
+ since = "4.1.0")
+case class KllSketchGetQuantileBigint(left: Expression, right: Expression)
+ extends KllSketchGetQuantileBase {
+ override def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression =
+ copy(left = newLeft, right = newRight)
+ override def prettyName: String = "kll_sketch_get_quantile_bigint"
+ override def outputDataType: DataType = LongType
+ override def kllSketchGetQuantile(memory: Memory, rank: Double): Any = {
+ withQuantileErrorHandling(rank) {
+ KllLongsSketch.wrap(memory).getQuantile(rank)
+ }
+ }
+ override def kllSketchGetQuantiles(memory: Memory, ranks: Array[Double]): Array[Any] = {
+ withQuantileErrorHandling(if (ranks.length > 0) ranks(0) else 0.0) {
+ KllLongsSketch.wrap(memory).getQuantiles(ranks).map(_.asInstanceOf[Any])
+ }
+ }
+}
+
+@ExpressionDescription(
+ usage = """
+ _FUNC_(left, right) - Extracts a single value from the quantiles sketch representing the
+ desired quantile given the input rank. The desired quantile can either be a single value
+ or an array. In the latter case, the function will return an array of results of equal
+ length to the input array.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(kll_sketch_agg_float(col), 0.5) > 1 FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) tab(col);
+ true
+ """,
+ group = "misc_funcs",
+ since = "4.1.0")
+case class KllSketchGetQuantileFloat(left: Expression, right: Expression)
+ extends KllSketchGetQuantileBase {
+ override def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression =
+ copy(left = newLeft, right = newRight)
+ override def prettyName: String = "kll_sketch_get_quantile_float"
+ override def outputDataType: DataType = FloatType
+ override def kllSketchGetQuantile(memory: Memory, rank: Double): Any = {
+ withQuantileErrorHandling(rank) {
+ KllFloatsSketch.wrap(memory).getQuantile(rank)
+ }
+ }
+ override def kllSketchGetQuantiles(memory: Memory, ranks: Array[Double]): Array[Any] = {
+ withQuantileErrorHandling(if (ranks.length > 0) ranks(0) else 0.0) {
+ KllFloatsSketch.wrap(memory).getQuantiles(ranks).map(_.asInstanceOf[Any])
+ }
+ }
+}
+
+@ExpressionDescription(
+ usage = """
+ _FUNC_(left, right) - Extracts a single value from the quantiles sketch representing the
+ desired quantile given the input rank. The desired quantile can either be a single value
+ or an array. In the latter case, the function will return an array of results of equal
+ length to the input array.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(kll_sketch_agg_double(col), 0.5) > 1 FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col);
+ true
+ """,
+ group = "misc_funcs",
+ since = "4.1.0")
+case class KllSketchGetQuantileDouble(left: Expression, right: Expression)
+ extends KllSketchGetQuantileBase {
+ override def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression =
+ copy(left = newLeft, right = newRight)
+ override def prettyName: String = "kll_sketch_get_quantile_double"
+ override def outputDataType: DataType = DoubleType
+ override def kllSketchGetQuantile(memory: Memory, rank: Double): Any = {
+ withQuantileErrorHandling(rank) {
+ KllDoublesSketch.wrap(memory).getQuantile(rank)
+ }
+ }
+ override def kllSketchGetQuantiles(memory: Memory, ranks: Array[Double]): Array[Any] = {
+ withQuantileErrorHandling(if (ranks.length > 0) ranks(0) else 0.0) {
+ KllDoublesSketch.wrap(memory).getQuantiles(ranks).map(_.asInstanceOf[Any])
+ }
+ }
+}
+
+/**
+ * This is a base class for the above expressions to reduce boilerplate.
+ * Each implementor is expected to define three methods: one to specify the output data type,
+ * one to compute the quantile of an input sketch buffer given a single input rank,
+ * and one to compute multiple quantiles given an array of ranks (batch API for performance).
+ */
+abstract class KllSketchGetQuantileBase
+ extends BinaryExpression
+ with CodegenFallback
+ with ImplicitCastInputTypes {
+ /**
+ * This method accepts a KLL quantiles Memory segment, wraps it with the corresponding
+ * Kll*Sketch.wrap method, and then calls getQuantile on the result.
+ * @param memory The input KLL quantiles sketch buffer to extract the quantile from
+ * @param rank The input rank to use to compute the quantile
+ * @return The result quantile
+ */
+ protected def kllSketchGetQuantile(memory: Memory, rank: Double): Any
+
+ /**
+ * This method accepts a KLL quantiles Memory segment, wraps it with the corresponding
+ * Kll*Sketch.wrap method, and then calls getQuantiles on the result (batch API).
+ * @param memory The input KLL quantiles sketch buffer to extract the quantiles from
+ * @param ranks The input ranks array to use to compute the quantiles
+ * @return The result quantiles as an array
+ */
+ protected def kllSketchGetQuantiles(memory: Memory, ranks: Array[Double]): Array[Any]
+
+ /**
+ * Helper method to wrap quantile operations with consistent error handling.
+ * @param rankForError The rank value to include in error messages
+ * @param operation The operation to execute
+ * @return The result of the operation
+ */
+ protected def withQuantileErrorHandling[T](rankForError: Double)(operation: => T): T = {
+ try {
+ operation
+ } catch {
+ case e: org.apache.datasketches.common.SketchesArgumentException =>
+ if (e.getMessage.contains("normalized rank")) {
+ throw QueryExecutionErrors.kllSketchInvalidQuantileRangeError(prettyName)
+ } else {
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ }
+
+ /** The output data type for a single value (not array) */
+ protected def outputDataType: DataType
+
+ // The rank argument must be foldable (compile-time constant).
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (!right.foldable) {
+ TypeCheckResult.DataTypeMismatch(
+ errorSubClass = "NON_FOLDABLE_INPUT",
+ messageParameters = Map(
+ "inputName" -> toSQLId("rank"),
+ "inputType" -> toSQLType(right.dataType),
+ "inputExpr" -> toSQLExpr(right)))
+ } else {
+ super.checkInputDataTypes()
+ }
+ }
+
+ override def nullIntolerant: Boolean = true
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(
+ BinaryType,
+ TypeCollection(
+ DoubleType,
+ ArrayType(DoubleType, containsNull = false)))
+
+ override def dataType: DataType = {
+ right.dataType match {
+ case ArrayType(_, _) => ArrayType(outputDataType, false)
+ case _ => outputDataType
+ }
+ }
+
+ override def nullSafeEval(leftInput: Any, rightInput: Any): Any = {
+ val buffer = leftInput.asInstanceOf[Array[Byte]]
+ val memory = Memory.wrap(buffer)
+
+ rightInput match {
+ case null => null
+ case num: Double =>
+ // Single value case
+ kllSketchGetQuantile(memory, num)
+ case arrayData: ArrayData =>
+ // Array case - use batch API for better performance
+ val ranks = arrayData.toDoubleArray()
+ val results = kllSketchGetQuantiles(memory, ranks)
+ new GenericArrayData(results)
+ }
+ }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(left, right) - Extracts a single value from the quantiles sketch representing the
+ desired rank given the input quantile. The desired rank can either be a single value
+ or an array. In the latter case, the function will return an array of results of equal
+ length to the input array.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(kll_sketch_agg_bigint(col), 3) > 0.3 FROM VALUES (1), (2), (3), (4), (5) tab(col);
+ true
+ """,
+ group = "misc_funcs",
+ since = "4.1.0")
+case class KllSketchGetRankBigint(left: Expression, right: Expression)
+ extends KllSketchGetRankBase {
+ override def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression =
+ copy(left = newLeft, right = newRight)
+ override def prettyName: String = "kll_sketch_get_rank_bigint"
+ override def inputDataType: DataType = LongType
+ override def kllSketchGetRank(memory: Memory, quantile: Any): Double = {
+ withRankErrorHandling {
+ KllLongsSketch.wrap(memory).getRank(quantile.asInstanceOf[Long])
+ }
+ }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(left, right) - Extracts a single value from the quantiles sketch representing the
+ desired rank given the input quantile. The desired rank can either be a single value
+ or an array. In the latter case, the function will return an array of results of equal
+ length to the input array.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(kll_sketch_agg_float(col), 3.0) > 0.3 FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) tab(col);
+ true
+ """,
+ group = "misc_funcs",
+ since = "4.1.0")
+case class KllSketchGetRankFloat(left: Expression, right: Expression)
+ extends KllSketchGetRankBase {
+ override def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression =
+ copy(left = newLeft, right = newRight)
+ override def prettyName: String = "kll_sketch_get_rank_float"
+ override def inputDataType: DataType = FloatType
+ override def kllSketchGetRank(memory: Memory, quantile: Any): Double = {
+ withRankErrorHandling {
+ KllFloatsSketch.wrap(memory).getRank(quantile.asInstanceOf[Float])
+ }
+ }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_(left, right) - Extracts a single value from the quantiles sketch representing the
+ desired rank given the input quantile. The desired rank can either be a single value
+ or an array. In the latter case, the function will return an array of results of equal
+ length to the input array.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(kll_sketch_agg_double(col), 3.0) > 0.3 FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col);
+ true
+ """,
+ group = "misc_funcs",
+ since = "4.1.0")
+case class KllSketchGetRankDouble(left: Expression, right: Expression)
+ extends KllSketchGetRankBase {
+ override def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression =
+ copy(left = newLeft, right = newRight)
+ override def prettyName: String = "kll_sketch_get_rank_double"
+ override def inputDataType: DataType = DoubleType
+ override def kllSketchGetRank(memory: Memory, quantile: Any): Double = {
+ withRankErrorHandling {
+ KllDoublesSketch.wrap(memory).getRank(quantile.asInstanceOf[Double])
+ }
+ }
+}
+
+/**
+ * This is a base class for the above expressions to reduce boilerplate.
+ * Each implementor is expected to define two methods, one to specify the input argument data type,
+ * and another to compute the rank of an input sketch buffer given the input quantile.
+ */
+abstract class KllSketchGetRankBase
+ extends BinaryExpression
+ with CodegenFallback
+ with ImplicitCastInputTypes {
+ /**
+ * Helper method to wrap rank operations with consistent error handling.
+ * @param operation The operation to execute
+ * @return The result of the operation
+ */
+ protected def withRankErrorHandling[T](operation: => T): T = {
+ try {
+ operation
+ } catch {
+ case _: Exception =>
+ throw QueryExecutionErrors.kllInvalidInputSketchBuffer(prettyName)
+ }
+ }
+
+ protected def inputDataType: DataType
+
+ /**
+ * This method accepts a KLL quantiles Memory segment, wraps it with the corresponding
+ * Kll*Sketch.wrap method, and then calls getRank on the result.
+ * @param memory The input KLL quantiles sketch buffer to extract the rank from
+ * @param quantile The input quantile to use to compute the rank
+ * @return The result rank
+ */
+ protected def kllSketchGetRank(memory: Memory, quantile: Any): Double
+
+ // The quantile argument must be foldable (compile-time constant).
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (!right.foldable) {
+ TypeCheckResult.DataTypeMismatch(
+ errorSubClass = "NON_FOLDABLE_INPUT",
+ messageParameters = Map(
+ "inputName" -> toSQLId("quantile"),
+ "inputType" -> toSQLType(right.dataType),
+ "inputExpr" -> toSQLExpr(right)))
+ } else {
+ super.checkInputDataTypes()
+ }
+ }
+
+ override def nullIntolerant: Boolean = true
+ override def inputTypes: Seq[AbstractDataType] = {
+ Seq(
+ BinaryType,
+ TypeCollection(
+ inputDataType,
+ ArrayType(inputDataType, containsNull = false)))
+ }
+ override def dataType: DataType = {
+ right.dataType match {
+ case ArrayType(_, _) => ArrayType(DoubleType, false)
+ case _ => DoubleType
+ }
+ }
+
+ override def nullSafeEval(leftInput: Any, rightInput: Any): Any = {
+ val buffer: Array[Byte] = leftInput.asInstanceOf[Array[Byte]]
+ val memory: Memory = Memory.wrap(buffer)
+
+ rightInput match {
+ case null => null
+ case value if !value.isInstanceOf[ArrayData] =>
+ // Single value case
+ kllSketchGetRank(memory, value)
+ case arrayData: ArrayData =>
+ // Array case - use direct iteration to avoid multiple array allocations
+ val numElements = arrayData.numElements()
+ val results = new Array[Double](numElements)
+ var i = 0
+ inputDataType match {
+ case LongType =>
+ while (i < numElements) {
+ results(i) = kllSketchGetRank(memory, arrayData.getLong(i))
+ i += 1
+ }
+ case FloatType =>
+ while (i < numElements) {
+ results(i) = kllSketchGetRank(memory, arrayData.getFloat(i))
+ i += 1
+ }
+ case DoubleType =>
+ while (i < numElements) {
+ results(i) = kllSketchGetRank(memory, arrayData.getDouble(i))
+ i += 1
+ }
+ }
+ new GenericArrayData(results)
+ }
+ }
+}
+
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
index ee3e3e0272767..0643e5fba2f32 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
@@ -1572,7 +1572,7 @@ abstract class RoundBase(child: Expression, scale: Expression,
val decimal = input1.asInstanceOf[Decimal]
if (_scale >= 0) {
// Overflow cannot happen, so no need to control nullOnOverflow
- decimal.toPrecision(decimal.precision, s, mode)
+ decimal.toPrecision(p, s, mode)
} else {
Decimal(decimal.toBigDecimal.setScale(_scale, mode), p, s)
}
@@ -1644,10 +1644,9 @@ abstract class RoundBase(child: Expression, scale: Expression,
case DecimalType.Fixed(p, s) =>
if (_scale >= 0) {
s"""
- ${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s,
- Decimal.$modeStr(), true, null);
+ ${ev.value} = ${ce.value}.toPrecision($p, $s, Decimal.$modeStr(), true, null);
${ev.isNull} = ${ev.value} == null;"""
- } else {
+ } else {
s"""
${ev.value} = new Decimal().set(${ce.value}.toBigDecimal()
.setScale(${_scale}, Decimal.$modeStr()), $p, $s);
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/st/STExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/st/STExpressionUtils.scala
index 055173ec39ade..bfecf8c28ef45 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/st/STExpressionUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/st/STExpressionUtils.scala
@@ -17,6 +17,9 @@
package org.apache.spark.sql.catalyst.expressions.st
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
private[sql] object STExpressionUtils {
@@ -29,4 +32,54 @@ private[sql] object STExpressionUtils {
case _ => false
}
+ /**
+ * Returns the input GEOMETRY or GEOGRAPHY value with the specified SRID. Only geospatial types
+ * are allowed as the source type, and calls are delegated to the corresponding helper methods.
+ */
+ def geospatialTypeWithSrid(sourceType: DataType, srid: Expression): DataType = {
+ sourceType match {
+ case _ if !SQLConf.get.geospatialEnabled =>
+ throw new AnalysisException(
+ errorClass = "UNSUPPORTED_FEATURE.GEOSPATIAL_DISABLED",
+ messageParameters = Map.empty
+ )
+ case _: GeometryType =>
+ geometryTypeWithSrid(srid)
+ case _: GeographyType =>
+ geographyTypeWithSrid(srid)
+ case _ =>
+ throw new IllegalArgumentException(s"Unexpected data type: $sourceType.")
+ }
+ }
+
+ /**
+ * Returns the input GEOMETRY value with the specified SRID. If the SRID expression is a literal,
+ * the SRID value can be directly extracted. Otherwise, only the mixed SRID value can be used.
+ */
+ private def geometryTypeWithSrid(srid: Expression): GeometryType = {
+ srid match {
+ case Literal(sridValue: Int, IntegerType) =>
+ // If the SRID expression is a literal, the SRID value can be directly extracted.
+ GeometryType(sridValue)
+ case _ =>
+ // Otherwise, only the mixed SRID value can be used for the output GEOMETRY value.
+ GeometryType("ANY")
+ }
+ }
+
+ /**
+ * Returns the input GEOGRAPHY value with the specified SRID. If the SRID expression is a literal,
+ * the SRID value can be directly extracted. Otherwise, only the mixed SRID value can be used.
+ */
+ private def geographyTypeWithSrid(srid: Expression): GeographyType = {
+ srid match {
+ case Literal(sridValue: Int, IntegerType) =>
+ // If the SRID expression is a literal, the SRID value can be directly extracted.
+ GeographyType(sridValue)
+ case _ =>
+ // Otherwise, only the mixed SRID value can be used for the output GEOMETRY value.
+ GeographyType("ANY")
+ }
+ }
+
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/st/stExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/st/stExpressions.scala
index a08f88b679525..0a032d191a26f 100755
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/st/stExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/st/stExpressions.scala
@@ -15,13 +15,32 @@
* limitations under the License.
*/
-package org.apache.spark.sql.catalyst.expressions
+package org.apache.spark.sql.catalyst.expressions.st
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.trees._
import org.apache.spark.sql.catalyst.util.{Geography, Geometry, STUtils}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
+/**
+ * ST expressions are behind a feature flag while the geospatial module is under development.
+ */
+
+sealed trait GeospatialInputTypes extends ImplicitCastInputTypes {
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (!SQLConf.get.geospatialEnabled) {
+ throw new AnalysisException(
+ errorClass = "UNSUPPORTED_FEATURE.GEOSPATIAL_DISABLED",
+ messageParameters = Map.empty
+ )
+ }
+ super.checkInputDataTypes()
+ }
+}
////////////////////////////////////////////////////////////////////////////////////////////////////
// This file defines expressions for geospatial operations.
@@ -60,7 +79,7 @@ private[sql] object ExpressionDefaults {
)
case class ST_AsBinary(geo: Expression)
extends RuntimeReplaceable
- with ImplicitCastInputTypes
+ with GeospatialInputTypes
with UnaryLike[Expression] {
override def inputTypes: Seq[AbstractDataType] = Seq(
@@ -108,7 +127,7 @@ case class ST_AsBinary(geo: Expression)
)
case class ST_GeogFromWKB(wkb: Expression)
extends RuntimeReplaceable
- with ImplicitCastInputTypes
+ with GeospatialInputTypes
with UnaryLike[Expression] {
override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
@@ -152,7 +171,7 @@ case class ST_GeogFromWKB(wkb: Expression)
)
case class ST_GeomFromWKB(wkb: Expression)
extends RuntimeReplaceable
- with ImplicitCastInputTypes
+ with GeospatialInputTypes
with UnaryLike[Expression] {
override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
@@ -200,7 +219,7 @@ case class ST_GeomFromWKB(wkb: Expression)
)
case class ST_Srid(geo: Expression)
extends RuntimeReplaceable
- with ImplicitCastInputTypes
+ with GeospatialInputTypes
with UnaryLike[Expression] {
override def inputTypes: Seq[AbstractDataType] = Seq(
@@ -222,3 +241,56 @@ case class ST_Srid(geo: Expression)
override protected def withNewChildInternal(newChild: Expression): ST_Srid =
copy(geo = newChild)
}
+
+/** ST modifier expressions. */
+
+/**
+ * Returns a new GEOGRAPHY or GEOMETRY value whose SRID is the specified SRID value.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(geo, srid) - Returns a new GEOGRAPHY or GEOMETRY value whose SRID is " +
+ "the specified SRID value.",
+ arguments = """
+ Arguments:
+ * geo - A GEOGRAPHY or GEOMETRY value.
+ * srid - The new SRID value of the geography or geometry.
+ """,
+ examples = """
+ Examples:
+ > SELECT st_srid(_FUNC_(ST_GeogFromWKB(X'0101000000000000000000F03F0000000000000040'), 4326));
+ 4326
+ > SELECT st_srid(_FUNC_(ST_GeomFromWKB(X'0101000000000000000000F03F0000000000000040'), 3857));
+ 3857
+ """,
+ since = "4.1.0",
+ group = "st_funcs"
+)
+case class ST_SetSrid(geo: Expression, srid: Expression)
+ extends RuntimeReplaceable
+ with GeospatialInputTypes
+ with BinaryLike[Expression] {
+
+ override def inputTypes: Seq[AbstractDataType] =
+ Seq(
+ TypeCollection(GeographyType, GeometryType),
+ IntegerType
+ )
+
+ override lazy val replacement: Expression = StaticInvoke(
+ classOf[STUtils],
+ STExpressionUtils.geospatialTypeWithSrid(geo.dataType, srid),
+ "stSetSrid",
+ Seq(geo, srid),
+ returnNullable = false
+ )
+
+ override def prettyName: String = "st_setsrid"
+
+ override def left: Expression = geo
+
+ override def right: Expression = srid
+
+ override protected def withNewChildrenInternal(
+ newLeft: Expression,
+ newRight: Expression): ST_SetSrid = copy(geo = newLeft, srid = newRight)
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/timeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/timeExpressions.scala
index ff088876969bd..692dd5b1f3987 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/timeExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/timeExpressions.scala
@@ -32,11 +32,22 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.TimeFormatter
import org.apache.spark.sql.catalyst.util.TypeUtils.ordinalNumber
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.StringTypeWithCollation
import org.apache.spark.sql.types.{AbstractDataType, AnyTimeType, ByteType, DataType, DayTimeIntervalType, DecimalType, IntegerType, LongType, ObjectType, TimeType}
import org.apache.spark.sql.types.DayTimeIntervalType.{HOUR, SECOND}
import org.apache.spark.unsafe.types.UTF8String
+trait TimeExpression extends Expression {
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (SQLConf.get.isTimeTypeEnabled) {
+ super.checkInputDataTypes()
+ } else {
+ throw QueryCompilationErrors.unsupportedTimeTypeError()
+ }
+ }
+}
+
/**
* Parses a column to a time based on the given format.
*/
@@ -64,7 +75,7 @@ import org.apache.spark.unsafe.types.UTF8String
since = "4.1.0")
// scalastyle:on line.size.limit
case class ToTime(str: Expression, format: Option[Expression])
- extends RuntimeReplaceable with ExpectsInputTypes {
+ extends RuntimeReplaceable with ExpectsInputTypes with TimeExpression {
def this(str: Expression, format: Expression) = this(str, Option(format))
def this(str: Expression) = this(str, None)
@@ -200,7 +211,7 @@ object TryToTimeExpressionBuilder extends ExpressionBuilder {
// scalastyle:on line.size.limit
case class MinutesOfTime(child: Expression)
extends RuntimeReplaceable
- with ExpectsInputTypes {
+ with ExpectsInputTypes with TimeExpression {
override def replacement: Expression = StaticInvoke(
classOf[DateTimeUtils.type],
@@ -259,7 +270,7 @@ object MinuteExpressionBuilder extends ExpressionBuilder {
case class HoursOfTime(child: Expression)
extends RuntimeReplaceable
- with ExpectsInputTypes {
+ with ExpectsInputTypes with TimeExpression {
override def replacement: Expression = StaticInvoke(
classOf[DateTimeUtils.type],
@@ -316,7 +327,7 @@ object HourExpressionBuilder extends ExpressionBuilder {
case class SecondsOfTimeWithFraction(child: Expression)
extends RuntimeReplaceable
- with ExpectsInputTypes {
+ with ExpectsInputTypes with TimeExpression {
override def replacement: Expression = {
val precision = child.dataType match {
case TimeType(p) => p
@@ -342,7 +353,7 @@ case class SecondsOfTimeWithFraction(child: Expression)
case class SecondsOfTime(child: Expression)
extends RuntimeReplaceable
- with ExpectsInputTypes {
+ with ExpectsInputTypes with TimeExpression {
override def replacement: Expression = StaticInvoke(
classOf[DateTimeUtils.type],
@@ -433,7 +444,8 @@ object SecondExpressionBuilder extends ExpressionBuilder {
case class CurrentTime(
child: Expression = Literal(TimeType.MICROS_PRECISION),
timeZoneId: Option[String] = None) extends UnaryExpression
- with TimeZoneAwareExpression with ImplicitCastInputTypes with CodegenFallback {
+ with TimeZoneAwareExpression with ImplicitCastInputTypes with CodegenFallback
+ with TimeExpression {
def this() = {
this(Literal(TimeType.MICROS_PRECISION), None)
@@ -545,7 +557,7 @@ case class MakeTime(
secsAndMicros: Expression)
extends RuntimeReplaceable
with ImplicitCastInputTypes
- with ExpectsInputTypes {
+ with ExpectsInputTypes with TimeExpression {
// Accept `sec` as DecimalType to avoid loosing precision of microseconds while converting
// it to the fractional part of `sec`. If `sec` is an IntegerType, it can be cast into decimal
@@ -570,7 +582,8 @@ case class MakeTime(
* Adds day-time interval to time.
*/
case class TimeAddInterval(time: Expression, interval: Expression)
- extends BinaryExpression with RuntimeReplaceable with ExpectsInputTypes {
+ extends BinaryExpression with RuntimeReplaceable with ExpectsInputTypes
+ with TimeExpression {
override def nullIntolerant: Boolean = true
override def left: Expression = time
@@ -611,7 +624,8 @@ case class TimeAddInterval(time: Expression, interval: Expression)
* Returns a day-time interval between time values.
*/
case class SubtractTimes(left: Expression, right: Expression)
- extends BinaryExpression with RuntimeReplaceable with ExpectsInputTypes {
+ extends BinaryExpression with RuntimeReplaceable with ExpectsInputTypes
+ with TimeExpression {
override def nullIntolerant: Boolean = true
override def inputTypes: Seq[AbstractDataType] = Seq(AnyTimeType, AnyTimeType)
@@ -668,7 +682,8 @@ case class TimeDiff(
end: Expression)
extends TernaryExpression
with RuntimeReplaceable
- with ImplicitCastInputTypes {
+ with ImplicitCastInputTypes
+ with TimeExpression {
override def first: Expression = unit
override def second: Expression = start
@@ -723,7 +738,8 @@ case class TimeDiff(
since = "4.1.0")
// scalastyle:on line.size.limit
case class TimeTrunc(unit: Expression, time: Expression)
- extends BinaryExpression with RuntimeReplaceable with ImplicitCastInputTypes {
+ extends BinaryExpression with RuntimeReplaceable with ImplicitCastInputTypes
+ with TimeExpression {
override def left: Expression = unit
override def right: Expression = time
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala
new file mode 100644
index 0000000000000..6c0bca0e1104f
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/normalizer/NormalizeCTEIds.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.normalizer
+
+import java.util.concurrent.atomic.AtomicLong
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.plans.logical.{CacheTableAsSelect, CTERelationRef, LogicalPlan, UnionLoop, UnionLoopRef, WithCTE}
+import org.apache.spark.sql.catalyst.rules.Rule
+
+object NormalizeCTEIds extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ val curId = new java.util.concurrent.atomic.AtomicLong()
+ val cteIdToNewId = mutable.Map.empty[Long, Long]
+ applyInternal(plan, curId, cteIdToNewId)
+ }
+
+ private def applyInternal(
+ plan: LogicalPlan,
+ curId: AtomicLong,
+ cteIdToNewId: mutable.Map[Long, Long]): LogicalPlan = {
+ plan transformDownWithSubqueries {
+ case ctas @ CacheTableAsSelect(_, plan, _, _, _, _, _) =>
+ ctas.copy(plan = applyInternal(plan, curId, cteIdToNewId))
+
+ case withCTE @ WithCTE(plan, cteDefs) =>
+ val newCteDefs = cteDefs.map { cteDef =>
+ cteIdToNewId.getOrElseUpdate(cteDef.id, curId.getAndIncrement())
+ val normalizedCteDef = canonicalizeCTE(cteDef.child, cteIdToNewId)
+ cteDef.copy(child = normalizedCteDef, id = cteIdToNewId(cteDef.id))
+ }
+ val normalizedPlan = canonicalizeCTE(plan, cteIdToNewId)
+ withCTE.copy(plan = normalizedPlan, cteDefs = newCteDefs)
+ }
+ }
+
+ private def canonicalizeCTE(
+ plan: LogicalPlan,
+ defIdToNewId: mutable.Map[Long, Long]): LogicalPlan = {
+ plan.transformDownWithSubqueries {
+ // For nested WithCTE, if defIndex didn't contain the cteId,
+ // means it's not current WithCTE's ref.
+ case ref: CTERelationRef if defIdToNewId.contains(ref.cteId) =>
+ ref.copy(cteId = defIdToNewId(ref.cteId))
+ case unionLoop: UnionLoop if defIdToNewId.contains(unionLoop.id) =>
+ unionLoop.copy(id = defIdToNewId(unionLoop.id))
+ case unionLoopRef: UnionLoopRef if defIdToNewId.contains(unionLoopRef.loopId) =>
+ unionLoopRef.copy(loopId = defIdToNewId(unionLoopRef.loopId))
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala
index 46815969e7ece..d36a71b043901 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/InferWindowGroupLimit.scala
@@ -26,12 +26,29 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{FILTER, WINDOW}
* Inserts a `WindowGroupLimit` below `Window` if the `Window` has rank-like functions
* and the function results are further filtered by limit-like predicates. Example query:
* {{{
- * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 WHERE rn = 5
- * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 WHERE 5 = rn
- * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 WHERE rn < 5
- * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 WHERE 5 > rn
- * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 WHERE rn <= 5
- * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1 WHERE 5 >= rn
+ * SELECT * FROM (
+ * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1
+ * ) WHERE rn = 5;
+ *
+ * SELECT * FROM (
+ * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1
+ * ) WHERE 5 = rn;
+ *
+ * SELECT * FROM (
+ * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1
+ * ) WHERE rn < 5;
+ *
+ * SELECT * FROM (
+ * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1
+ * ) WHERE 5 > rn;
+ *
+ * SELECT * FROM (
+ * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1
+ * ) WHERE rn <= 5;
+ *
+ * SELECT * FROM (
+ * SELECT *, ROW_NUMBER() OVER(PARTITION BY k ORDER BY a) AS rn FROM Tab1
+ * ) WHERE 5 >= rn;
* }}}
*/
object InferWindowGroupLimit extends Rule[LogicalPlan] with PredicateHelper {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
index 7134c3daf3baa..9a676571d1071 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicate.scala
@@ -149,7 +149,8 @@ object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] {
private def replaceNullWithFalse(mergeActions: Seq[MergeAction]): Seq[MergeAction] = {
mergeActions.map {
- case u @ UpdateAction(Some(cond), _) => u.copy(condition = Some(replaceNullWithFalse(cond)))
+ case u @ UpdateAction(Some(cond), _, _) =>
+ u.copy(condition = Some(replaceNullWithFalse(cond)))
case u @ UpdateStarAction(Some(cond)) => u.copy(condition = Some(replaceNullWithFalse(cond)))
case d @ DeleteAction(Some(cond)) => d.copy(condition = Some(replaceNullWithFalse(cond)))
case i @ InsertAction(Some(cond), _) => i.copy(condition = Some(replaceNullWithFalse(cond)))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 71eb3e5ea2bd7..c0fc8d0bae42b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, TreeNodeTag}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.CharVarcharUtils.CHAR_VARCHAR_TYPE_STRING_METADATA_KEY
+import org.apache.spark.sql.catalyst.util.UnsafeRowUtils.isBinaryStable
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -217,8 +218,17 @@ object ConstantPropagation extends Rule[LogicalPlan] {
// substituted into `1 + 1 = 1` if 'c' isn't nullable. If 'c' is nullable then the enclosing
// NOT prevents us to do the substitution as NOT flips the context (`nullIsFalse`) of what a
// null result of the enclosed expression means.
+ //
+ // Also, we shouldn't replace attributes with non-binary-stable data types, since this can lead
+ // to incorrect results. For example:
+ // `CREATE TABLE t (c STRING COLLATE UTF8_LCASE);`
+ // `INSERT INTO t VALUES ('HELLO'), ('hello');`
+ // `SELECT * FROM t WHERE c = 'hello' AND c = 'HELLO' COLLATE UNICODE;`
+ // If we replace `c` with `'hello'`, we get `'hello' = 'HELLO' COLLATE UNICODE` for the right
+ // condition, which is false, while the original `c = 'HELLO' COLLATE UNICODE` is true for
+ // 'HELLO' and false for 'hello'.
private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) =
- !ar.nullable || nullIsFalse
+ (!ar.nullable || nullIsFalse) && isBinaryStable(ar.dataType)
private def replaceConstants(
condition: Expression,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 971633b9a46a8..6649568a00b20 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -118,6 +118,15 @@ class AstBuilder extends DataTypeAstBuilder
}
}
+ /**
+ * Override to provide actual multi-part identifier parsing using CatalystSqlParser. This allows
+ * the base class to handle IDENTIFIER('qualified.identifier') without needing special case
+ * logic in getIdentifierParts.
+ */
+ override protected def parseMultipartIdentifier(identifier: String): Seq[String] = {
+ CatalystSqlParser.parseMultipartIdentifier(identifier)
+ }
+
/**
* Retrieves the original input text for a given parser context, preserving all whitespace and
* formatting.
@@ -255,12 +264,6 @@ class AstBuilder extends DataTypeAstBuilder
private def visitDeclareConditionStatementImpl(
ctx: DeclareConditionStatementContext): ErrorCondition = {
- // Qualified user defined condition name is not allowed.
- if (ctx.multipartIdentifier().parts.size() > 1) {
- throw SqlScriptingErrors
- .conditionCannotBeQualified(CurrentOrigin.get, ctx.multipartIdentifier().getText)
- }
-
// If SQLSTATE is not provided, default to 45000.
val sqlState = Option(ctx.sqlStateValue())
.map(sqlStateValueContext => string(visitStringLit(sqlStateValueContext.stringLit())))
@@ -269,7 +272,7 @@ class AstBuilder extends DataTypeAstBuilder
assertSqlState(sqlState)
// Get condition name.
- val conditionName = visitMultipartIdentifier(ctx.multipartIdentifier()).head
+ val conditionName = getIdentifierText(ctx.strictIdentifier())
assertConditionName(conditionName)
@@ -297,12 +300,21 @@ class AstBuilder extends DataTypeAstBuilder
parsingCtx)
} else {
// If there is no compound body, then there must be a statement or set statement.
+ // Single-statement handler bodies need a label for the CompoundBody, just like
+ // BEGIN-END blocks do (see visitBeginEndCompoundBlockImpl). Generate a random UUID
+ // label since no explicit label is defined.
+ val labelText = parsingCtx.labelContext.enterLabeledScope(
+ beginLabelCtx = None,
+ endLabelCtx = None
+ )
val statement = Option(ctx.statement().asInstanceOf[ParserRuleContext])
.orElse(Option(ctx.setStatementInsideSqlScript().asInstanceOf[ParserRuleContext]))
.map { s =>
SingleStatement(parsedPlan = visit(s).asInstanceOf[LogicalPlan])
}
- CompoundBody(Seq(statement.get), None, isScope = false)
+ val compoundBody = CompoundBody(Seq(statement.get), Some(labelText), isScope = false)
+ parsingCtx.labelContext.exitLabeledScope(None)
+ compoundBody
}
ExceptionHandler(exceptionHandlerTriggers, body, handlerType)
@@ -561,15 +573,15 @@ class AstBuilder extends DataTypeAstBuilder
val query = withOrigin(queryCtx) {
SingleStatement(visitQuery(queryCtx))
}
- parsingCtx.labelContext.enterForScope(Option(ctx.multipartIdentifier()))
- val varName = Option(ctx.multipartIdentifier()).map(_.getText)
+ parsingCtx.labelContext.enterForScope(Option(ctx.strictIdentifier()))
+ val varName = Option(ctx.strictIdentifier()).map(getIdentifierText)
val body = visitCompoundBodyImpl(
ctx.compoundBody(),
None,
parsingCtx,
isScope = false
)
- parsingCtx.labelContext.exitForScope(Option(ctx.multipartIdentifier()))
+ parsingCtx.labelContext.exitForScope(Option(ctx.strictIdentifier()))
parsingCtx.labelContext.exitLabeledScope(Option(ctx.beginLabel()))
ForStatement(query, varName, body, Some(labelText))
@@ -580,26 +592,26 @@ class AstBuilder extends DataTypeAstBuilder
ctx match {
case c: BeginEndCompoundBlockContext
if Option(c.beginLabel()).exists { b =>
- b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
+ getIdentifierText(b.strictIdentifier()).toLowerCase(Locale.ROOT).equals(label)
} => if (isIterate) {
throw SqlScriptingErrors.invalidIterateLabelUsageForCompound(CurrentOrigin.get, label)
}
true
case c: WhileStatementContext
if Option(c.beginLabel()).exists { b =>
- b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
+ getIdentifierText(b.strictIdentifier()).toLowerCase(Locale.ROOT).equals(label)
} => true
case c: RepeatStatementContext
if Option(c.beginLabel()).exists { b =>
- b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
+ getIdentifierText(b.strictIdentifier()).toLowerCase(Locale.ROOT).equals(label)
} => true
case c: LoopStatementContext
if Option(c.beginLabel()).exists { b =>
- b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
+ getIdentifierText(b.strictIdentifier()).toLowerCase(Locale.ROOT).equals(label)
} => true
case c: ForStatementContext
if Option(c.beginLabel()).exists { b =>
- b.multipartIdentifier().getText.toLowerCase(Locale.ROOT).equals(label)
+ getIdentifierText(b.strictIdentifier()).toLowerCase(Locale.ROOT).equals(label)
} => true
case _ => false
}
@@ -607,7 +619,7 @@ class AstBuilder extends DataTypeAstBuilder
override def visitLeaveStatement(ctx: LeaveStatementContext): LeaveStatement =
withOrigin(ctx) {
- val labelText = ctx.multipartIdentifier().getText.toLowerCase(Locale.ROOT)
+ val labelText = getIdentifierText(ctx.strictIdentifier()).toLowerCase(Locale.ROOT)
var parentCtx = ctx.parent
while (Option(parentCtx).isDefined) {
@@ -623,7 +635,7 @@ class AstBuilder extends DataTypeAstBuilder
override def visitIterateStatement(ctx: IterateStatementContext): IterateStatement =
withOrigin(ctx) {
- val labelText = ctx.multipartIdentifier().getText.toLowerCase(Locale.ROOT)
+ val labelText = getIdentifierText(ctx.strictIdentifier()).toLowerCase(Locale.ROOT)
var parentCtx = ctx.parent
while (Option(parentCtx).isDefined) {
@@ -797,7 +809,8 @@ class AstBuilder extends DataTypeAstBuilder
(columnAliases, plan) =>
UnresolvedSubqueryColumnAliases(visitIdentifierList(columnAliases), plan)
)
- SubqueryAlias(ctx.name.getText, subQuery)
+ // Use getIdentifierText to handle both regular identifiers and IDENTIFIER('literal')
+ SubqueryAlias(getIdentifierText(ctx.name), subQuery)
}
/**
@@ -1222,7 +1235,7 @@ class AstBuilder extends DataTypeAstBuilder
if (pVal.DEFAULT != null) {
throw QueryParsingErrors.defaultColumnReferencesNotAllowedInPartitionSpec(ctx)
}
- val name = pVal.identifier.getText
+ val name = getIdentifierText(pVal.identifier)
val value = Option(pVal.constant).map(v => {
visitStringConstant(v, legacyNullAsString, keepPartitionSpecAsString)
})
@@ -1786,7 +1799,8 @@ class AstBuilder extends DataTypeAstBuilder
// Collect all window specifications defined in the WINDOW clause.
val baseWindowTuples = ctx.namedWindow.asScala.map {
wCtx =>
- (wCtx.name.getText, typedVisit[WindowSpec](wCtx.windowSpec))
+ // Use getIdentifierText to handle both regular identifiers and IDENTIFIER('literal')
+ (getIdentifierText(wCtx.name), typedVisit[WindowSpec](wCtx.windowSpec))
}
baseWindowTuples.groupBy(_._1).foreach { kv =>
if (kv._2.size > 1) {
@@ -1927,6 +1941,7 @@ class AstBuilder extends DataTypeAstBuilder
query: LogicalPlan): LogicalPlan = withOrigin(ctx) {
var plan = query
ctx.hintStatements.asScala.reverse.foreach { stmt =>
+ // Hint names use simpleIdentifier, so .getText() is correct.
plan = UnresolvedHint(stmt.hintName.getText,
stmt.parameters.asScala.map(expression).toSeq, plan)
}
@@ -1943,11 +1958,11 @@ class AstBuilder extends DataTypeAstBuilder
.flatMap(_.namedExpression.asScala)
.map(typedVisit[Expression])
val pivotColumn = if (ctx.pivotColumn.identifiers.size == 1) {
- UnresolvedAttribute.quoted(ctx.pivotColumn.errorCapturingIdentifier.getText)
+ UnresolvedAttribute.quoted(getIdentifierText(ctx.pivotColumn.errorCapturingIdentifier))
} else {
CreateStruct(
ctx.pivotColumn.identifiers.asScala.map(
- identifier => UnresolvedAttribute.quoted(identifier.getText)).toSeq)
+ identifier => UnresolvedAttribute.quoted(getIdentifierText(identifier))).toSeq)
}
val pivotValues = ctx.pivotValues.asScala.map(visitPivotValue)
Pivot(None, pivotColumn, pivotValues.toSeq, aggregates, query)
@@ -1959,7 +1974,7 @@ class AstBuilder extends DataTypeAstBuilder
override def visitPivotValue(ctx: PivotValueContext): Expression = withOrigin(ctx) {
val e = expression(ctx.expression)
if (ctx.errorCapturingIdentifier != null) {
- Alias(e, ctx.errorCapturingIdentifier.getText)()
+ Alias(e, getIdentifierText(ctx.errorCapturingIdentifier))()
} else {
e
}
@@ -1974,17 +1989,18 @@ class AstBuilder extends DataTypeAstBuilder
// this is needed to create unpivot and to filter unpivot for nulls further down
val valueColumnNames =
Option(ctx.unpivotOperator().unpivotSingleValueColumnClause())
- .map(_.unpivotValueColumn().identifier().getText)
+ .map(vc => getIdentifierText(vc.unpivotValueColumn().identifier()))
.map(Seq(_))
.getOrElse(
Option(ctx.unpivotOperator().unpivotMultiValueColumnClause())
- .map(_.unpivotValueColumns.asScala.map(_.identifier().getText).toSeq)
+ .map(_.unpivotValueColumns.asScala.map(vc =>
+ getIdentifierText(vc.identifier())).toSeq)
.get
)
val unpivot = if (ctx.unpivotOperator().unpivotSingleValueColumnClause() != null) {
val unpivotClause = ctx.unpivotOperator().unpivotSingleValueColumnClause()
- val variableColumnName = unpivotClause.unpivotNameColumn().identifier().getText
+ val variableColumnName = getIdentifierText(unpivotClause.unpivotNameColumn().identifier())
val (unpivotColumns, unpivotAliases) =
unpivotClause.unpivotColumns.asScala.map(visitUnpivotColumnAndAlias).toSeq.unzip
@@ -1999,7 +2015,7 @@ class AstBuilder extends DataTypeAstBuilder
)
} else {
val unpivotClause = ctx.unpivotOperator().unpivotMultiValueColumnClause()
- val variableColumnName = unpivotClause.unpivotNameColumn().identifier().getText
+ val variableColumnName = getIdentifierText(unpivotClause.unpivotNameColumn().identifier())
val (unpivotColumns, unpivotAliases) =
unpivotClause.unpivotColumnSets.asScala.map(visitUnpivotColumnSet).toSeq.unzip
@@ -2023,7 +2039,7 @@ class AstBuilder extends DataTypeAstBuilder
// alias unpivot result
if (ctx.errorCapturingIdentifier() != null) {
- val alias = ctx.errorCapturingIdentifier().getText
+ val alias = getIdentifierText(ctx.errorCapturingIdentifier())
SubqueryAlias(alias, filtered)
} else {
filtered
@@ -2043,7 +2059,7 @@ class AstBuilder extends DataTypeAstBuilder
override def visitUnpivotColumnAndAlias(ctx: UnpivotColumnAndAliasContext):
(NamedExpression, Option[String]) = withOrigin(ctx) {
val attr = visitUnpivotColumn(ctx.unpivotColumn())
- val alias = Option(ctx.unpivotAlias()).map(_.errorCapturingIdentifier().getText)
+ val alias = Option(ctx.unpivotAlias()).map(a => getIdentifierText(a.errorCapturingIdentifier()))
(attr, alias)
}
@@ -2055,7 +2071,8 @@ class AstBuilder extends DataTypeAstBuilder
(Seq[NamedExpression], Option[String]) =
withOrigin(ctx) {
val exprs = ctx.unpivotColumns.asScala.map(visitUnpivotColumn).toSeq
- val alias = Option(ctx.unpivotAlias()).map(_.errorCapturingIdentifier().getText)
+ val alias =
+ Option(ctx.unpivotAlias()).map(a => getIdentifierText(a.errorCapturingIdentifier()))
(exprs, alias)
}
@@ -2071,9 +2088,9 @@ class AstBuilder extends DataTypeAstBuilder
unrequiredChildIndex = Nil,
outer = ctx.OUTER != null,
// scalastyle:off caselocale
- Some(ctx.tblName.getText.toLowerCase),
+ Some(getIdentifierText(ctx.tblName).toLowerCase),
// scalastyle:on caselocale
- ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.quoted).toSeq,
+ ctx.colName.asScala.map(getIdentifierText).map(UnresolvedAttribute.quoted).toSeq,
query)
}
@@ -2514,7 +2531,8 @@ class AstBuilder extends DataTypeAstBuilder
* Create an alias ([[SubqueryAlias]]) for a [[LogicalPlan]].
*/
private def aliasPlan(alias: ParserRuleContext, plan: LogicalPlan): LogicalPlan = {
- SubqueryAlias(alias.getText, plan)
+ // Use getIdentifierText to handle both regular identifiers and IDENTIFIER('literal')
+ SubqueryAlias(getIdentifierText(alias), plan)
}
/**
@@ -2523,7 +2541,7 @@ class AstBuilder extends DataTypeAstBuilder
*/
private def mayApplyAliasPlan(tableAlias: TableAliasContext, plan: LogicalPlan): LogicalPlan = {
if (tableAlias.strictIdentifier != null) {
- val alias = tableAlias.strictIdentifier.getText
+ val alias = getIdentifierText(tableAlias.strictIdentifier)
if (tableAlias.identifierList != null) {
val columnNames = visitIdentifierList(tableAlias.identifierList)
SubqueryAlias(alias, UnresolvedSubqueryColumnAliases(columnNames, plan))
@@ -2544,9 +2562,11 @@ class AstBuilder extends DataTypeAstBuilder
/**
* Create a Sequence of Strings for an identifier list.
+ * Each identifier must be unqualified.
+ * Handles both regular identifiers and IDENTIFIER('literal').
*/
override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) {
- ctx.ident.asScala.map(_.getText).toSeq
+ ctx.ident.asScala.map(id => getIdentifierText(id)).toSeq
}
/* ********************************************************************************************
@@ -2554,18 +2574,20 @@ class AstBuilder extends DataTypeAstBuilder
* ******************************************************************************************** */
/**
* Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern.
+ * Handles identifier-lite with qualified identifiers.
*/
override def visitTableIdentifier(
ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) {
- TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText))
+ TableIdentifier(getIdentifierText(ctx.table), Option(ctx.db).map(getIdentifierText))
}
/**
* Create a [[FunctionIdentifier]] from a 'functionName' or 'databaseName'.'functionName' pattern.
+ * Handles identifier-lite with qualified identifiers.
*/
override def visitFunctionIdentifier(
ctx: FunctionIdentifierContext): FunctionIdentifier = withOrigin(ctx) {
- FunctionIdentifier(ctx.function.getText, Option(ctx.db).map(_.getText))
+ FunctionIdentifier(getIdentifierText(ctx.function), Option(ctx.db).map(getIdentifierText))
}
/* ********************************************************************************************
@@ -2639,7 +2661,8 @@ class AstBuilder extends DataTypeAstBuilder
override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) {
val e = expression(ctx.expression)
if (ctx.name != null) {
- Alias(e, ctx.name.getText)()
+ // Use getIdentifierText to handle both regular identifiers and IDENTIFIER('literal')
+ Alias(e, getIdentifierText(ctx.name))()
} else if (ctx.identifierList != null) {
MultiAlias(e, visitIdentifierList(ctx.identifierList))
} else {
@@ -2972,7 +2995,8 @@ class AstBuilder extends DataTypeAstBuilder
}
} else {
// If the parser is not in ansi mode, we should return `UnresolvedAttribute`, in case there
- // are columns named `CURRENT_DATE` or `CURRENT_TIMESTAMP` or `CURRENT_TIME`
+ // are columns named `CURRENT_DATE` or `CURRENT_TIMESTAMP` or `CURRENT_TIME`.
+ // ctx.name is a token, not an identifier context.
UnresolvedAttribute.quoted(ctx.name.getText)
}
}
@@ -3205,7 +3229,7 @@ class AstBuilder extends DataTypeAstBuilder
*/
override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) {
val arguments = ctx.identifier().asScala.map { name =>
- UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts)
+ UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(getIdentifierText(name)).nameParts)
}
val function = expression(ctx.expression).transformUp {
case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts)
@@ -3217,7 +3241,8 @@ class AstBuilder extends DataTypeAstBuilder
* Create a reference to a window frame, i.e. [[WindowSpecReference]].
*/
override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) {
- WindowSpecReference(ctx.name.getText)
+ // Use getIdentifierText to handle both regular identifiers and IDENTIFIER('literal')
+ WindowSpecReference(getIdentifierText(ctx.name))
}
/**
@@ -3353,9 +3378,11 @@ class AstBuilder extends DataTypeAstBuilder
* it can be [[UnresolvedExtractValue]].
*/
override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) {
- val attr = ctx.fieldName.getText
+ // Use getIdentifierText to handle both regular identifiers and IDENTIFIER('literal')
+ val attr = getIdentifierText(ctx.fieldName)
expression(ctx.base) match {
case unresolved_attr @ UnresolvedAttribute(nameParts) =>
+ // For regex check, we need the original text before identifier-lite resolution
ctx.fieldName.getStart.getText match {
case escapedIdentifier(columnNameRegex)
if conf.supportQuotedRegexColumnName &&
@@ -3393,13 +3420,17 @@ class AstBuilder extends DataTypeAstBuilder
* quoted in ``
*/
override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) {
+ // For regex check, we need the original text before identifier-lite resolution
ctx.getStart.getText match {
case escapedIdentifier(columnNameRegex)
if conf.supportQuotedRegexColumnName &&
isRegex(columnNameRegex) && canApplyRegex(ctx) =>
UnresolvedRegex(columnNameRegex, None, conf.caseSensitiveAnalysis)
case _ =>
- UnresolvedAttribute.quoted(ctx.getText)
+ // Use getIdentifierParts to handle IDENTIFIER('literal') correctly
+ // This allows IDENTIFIER('t').c1 to work like t.c1
+ val parts = getIdentifierParts(ctx.identifier())
+ UnresolvedAttribute(parts)
}
}
@@ -4046,7 +4077,7 @@ class AstBuilder extends DataTypeAstBuilder
ctx: ColDefinitionContext): ColumnAndConstraint = withOrigin(ctx) {
import ctx._
- val name: String = colName.getText
+ val name: String = getIdentifierText(colName)
// Check that no duplicates exist among any CREATE TABLE column options specified.
var nullable = true
var defaultExpression: Option[DefaultExpressionContext] = None
@@ -4118,7 +4149,7 @@ class AstBuilder extends DataTypeAstBuilder
ctx: ColumnConstraintDefinitionContext): TableConstraint = {
withOrigin(ctx) {
val name = if (ctx.name != null) {
- ctx.name.getText
+ getIdentifierText(ctx.name)
} else {
null
}
@@ -4230,7 +4261,7 @@ class AstBuilder extends DataTypeAstBuilder
if (!SQLConf.get.objectLevelCollationsEnabled) {
throw QueryCompilationErrors.objectLevelCollationsNotEnabledError()
}
- val collationName = ctx.identifier.getText
+ val collationName = getIdentifierText(ctx.identifier)
CollationFactory.fetchCollation(collationName).collationName
}
@@ -4469,7 +4500,7 @@ class AstBuilder extends DataTypeAstBuilder
def getFieldReference(
ctx: ApplyTransformContext,
arg: V2Expression): FieldReference = {
- lazy val name: String = ctx.identifier.getText
+ lazy val name: String = getIdentifierText(ctx.identifier)
arg match {
case ref: FieldReference =>
ref
@@ -4481,7 +4512,7 @@ class AstBuilder extends DataTypeAstBuilder
def getSingleFieldReference(
ctx: ApplyTransformContext,
arguments: Seq[V2Expression]): FieldReference = {
- lazy val name: String = ctx.identifier.getText
+ lazy val name: String = getIdentifierText(ctx.identifier)
if (arguments.size > 1) {
throw QueryParsingErrors.wrongNumberArgumentsForTransformError(name, arguments.size, ctx)
} else if (arguments.isEmpty) {
@@ -4766,7 +4797,7 @@ class AstBuilder extends DataTypeAstBuilder
string(visitStringLit(c.outFmt)))))
// Expected format: SEQUENCEFILE | TEXTFILE | RCFILE | ORC | PARQUET | AVRO
case (c: GenericFileFormatContext, null) =>
- SerdeInfo(storedAs = Some(c.identifier.getText))
+ SerdeInfo(storedAs = Some(c.simpleIdentifier.getText))
case (null, storageHandler) =>
invalidStatement("STORED BY", ctx)
case _ =>
@@ -4856,7 +4887,7 @@ class AstBuilder extends DataTypeAstBuilder
(rowFormatCtx, createFileFormatCtx.fileFormat) match {
case (_, ffTable: TableFileFormatContext) => // OK
case (rfSerde: RowFormatSerdeContext, ffGeneric: GenericFileFormatContext) =>
- ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match {
+ ffGeneric.simpleIdentifier.getText.toLowerCase(Locale.ROOT) match {
case ("sequencefile" | "textfile" | "rcfile") => // OK
case fmt =>
operationNotAllowed(
@@ -4864,7 +4895,7 @@ class AstBuilder extends DataTypeAstBuilder
parentCtx)
}
case (rfDelimited: RowFormatDelimitedContext, ffGeneric: GenericFileFormatContext) =>
- ffGeneric.identifier.getText.toLowerCase(Locale.ROOT) match {
+ ffGeneric.simpleIdentifier.getText.toLowerCase(Locale.ROOT) match {
case "textfile" => // OK
case fmt => operationNotAllowed(
s"ROW FORMAT DELIMITED is only compatible with 'textfile', not '$fmt'", parentCtx)
@@ -5094,6 +5125,11 @@ class AstBuilder extends DataTypeAstBuilder
"Partition column types may not be specified in Create Table As Select (CTAS)",
ctx)
+ case Some(_) if constraints.nonEmpty =>
+ operationNotAllowed(
+ "Constraints may not be specified in a Create Table As Select (CTAS) statement",
+ ctx)
+
case Some(query) =>
CreateTableAsSelect(identifier, partitioning, query, tableSpec, Map.empty, ifNotExists)
@@ -5173,6 +5209,11 @@ class AstBuilder extends DataTypeAstBuilder
"Partition column types may not be specified in Replace Table As Select (RTAS)",
ctx)
+ case Some(_) if constraints.nonEmpty =>
+ operationNotAllowed(
+ "Constraints may not be specified in a Replace Table As Select (RTAS) statement",
+ ctx)
+
case Some(query) =>
ReplaceTableAsSelect(identifier, partitioning, query, tableSpec,
writeOptions = Map.empty, orCreate = orCreate)
@@ -5460,7 +5501,8 @@ class AstBuilder extends DataTypeAstBuilder
invalidStatement("ALTER TABLE ... PARTITION ... CHANGE COLUMN", ctx)
}
val columnNameParts = typedVisit[Seq[String]](ctx.colName)
- if (!conf.resolver(columnNameParts.last, ctx.colType().colName.getText)) {
+ if (!conf.resolver(columnNameParts.last,
+ getIdentifierText(ctx.colType().colName))) {
throw QueryParsingErrors.operationInHiveStyleCommandUnsupportedError("Renaming column",
"ALTER COLUMN", ctx, Some("please run RENAME COLUMN instead"))
}
@@ -5587,7 +5629,7 @@ class AstBuilder extends DataTypeAstBuilder
ctx: TableConstraintDefinitionContext): TableConstraint =
withOrigin(ctx) {
val name = if (ctx.name != null) {
- ctx.name.getText
+ getIdentifierText(ctx.name)
} else {
null
}
@@ -5691,7 +5733,7 @@ class AstBuilder extends DataTypeAstBuilder
ctx.identifierReference, "ALTER TABLE ... DROP CONSTRAINT")
DropConstraint(
table,
- ctx.name.getText,
+ getIdentifierText(ctx.name),
ifExists = ctx.EXISTS() != null,
cascade = ctx.CASCADE() != null)
}
@@ -5805,9 +5847,9 @@ class AstBuilder extends DataTypeAstBuilder
log"${MDC(PARTITION_SPECIFICATION, ctx.partitionSpec.getText)}")
}
}
- if (ctx.identifier != null &&
- ctx.identifier.getText.toLowerCase(Locale.ROOT) != "noscan") {
- throw QueryParsingErrors.computeStatisticsNotExpectedError(ctx.identifier())
+ if (ctx.simpleIdentifier != null &&
+ ctx.simpleIdentifier.getText.toLowerCase(Locale.ROOT) != "noscan") {
+ throw QueryParsingErrors.computeStatisticsNotExpectedError(ctx.simpleIdentifier)
}
if (ctx.ALL() != null) {
@@ -5828,7 +5870,7 @@ class AstBuilder extends DataTypeAstBuilder
"ANALYZE TABLE",
allowTempView = false),
partitionSpec,
- noScan = ctx.identifier != null)
+ noScan = ctx.simpleIdentifier != null)
} else {
checkPartitionSpec()
AnalyzeColumn(
@@ -5846,16 +5888,16 @@ class AstBuilder extends DataTypeAstBuilder
* }}}
*/
override def visitAnalyzeTables(ctx: AnalyzeTablesContext): LogicalPlan = withOrigin(ctx) {
- if (ctx.identifier != null &&
- ctx.identifier.getText.toLowerCase(Locale.ROOT) != "noscan") {
- throw QueryParsingErrors.computeStatisticsNotExpectedError(ctx.identifier())
+ if (ctx.simpleIdentifier != null &&
+ ctx.simpleIdentifier.getText.toLowerCase(Locale.ROOT) != "noscan") {
+ throw QueryParsingErrors.computeStatisticsNotExpectedError(ctx.simpleIdentifier())
}
val ns = if (ctx.identifierReference() != null) {
withIdentClause(ctx.identifierReference, UnresolvedNamespace(_))
} else {
CurrentNamespace
}
- AnalyzeTables(ns, noScan = ctx.identifier != null)
+ AnalyzeTables(ns, noScan = ctx.simpleIdentifier != null)
}
/**
@@ -5924,8 +5966,6 @@ class AstBuilder extends DataTypeAstBuilder
* }}}
*/
override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) {
- import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
-
val query = Option(ctx.query).map(plan)
withIdentClause(ctx.identifierReference, query.toSeq, (ident, children) => {
if (query.isDefined && ident.length > 1) {
@@ -6303,12 +6343,14 @@ class AstBuilder extends DataTypeAstBuilder
* Create a plan for a SHOW FUNCTIONS command.
*/
override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) {
- val (userScope, systemScope) = Option(ctx.identifier)
- .map(_.getText.toLowerCase(Locale.ROOT)) match {
+ // Function scope uses simpleIdentifier, so .getText() is correct.
+ val scope = Option(ctx.functionScope)
+ val (userScope, systemScope) = scope.map(_.getText.toLowerCase(Locale.ROOT)) match {
case None | Some("all") => (true, true)
case Some("system") => (false, true)
case Some("user") => (true, false)
- case Some(x) => throw QueryParsingErrors.showFunctionsUnsupportedError(x, ctx.identifier())
+ case Some(x) =>
+ throw QueryParsingErrors.showFunctionsUnsupportedError(x, ctx.functionScope)
}
val legacy = Option(ctx.legacy).map(visitMultipartIdentifier)
@@ -6401,7 +6443,7 @@ class AstBuilder extends DataTypeAstBuilder
* }}}
*/
override def visitDropIndex(ctx: DropIndexContext): LogicalPlan = withOrigin(ctx) {
- val indexName = ctx.identifier.getText
+ val indexName = getIdentifierText(ctx.identifier)
DropIndex(
createUnresolvedTable(ctx.identifierReference, "DROP INDEX"),
indexName,
@@ -6435,11 +6477,13 @@ class AstBuilder extends DataTypeAstBuilder
*/
override def visitTimestampadd(ctx: TimestampaddContext): Expression = withOrigin(ctx) {
if (ctx.invalidUnit != null) {
+ // ctx.name and ctx.invalidUnit are tokens, not identifier contexts.
throw QueryParsingErrors.invalidDatetimeUnitError(
ctx,
ctx.name.getText,
ctx.invalidUnit.getText)
} else {
+ // ctx.unit is a token, not an identifier context.
TimestampAdd(ctx.unit.getText, expression(ctx.unitsAmount), expression(ctx.timestamp))
}
}
@@ -6449,11 +6493,13 @@ class AstBuilder extends DataTypeAstBuilder
*/
override def visitTimestampdiff(ctx: TimestampdiffContext): Expression = withOrigin(ctx) {
if (ctx.invalidUnit != null) {
+ // ctx.name and ctx.invalidUnit are tokens, not identifier contexts.
throw QueryParsingErrors.invalidDatetimeUnitError(
ctx,
ctx.name.getText,
ctx.invalidUnit.getText)
} else {
+ // ctx.unit is a token, not an identifier context.
TimestampDiff(ctx.unit.getText, expression(ctx.startTimestamp), expression(ctx.endTimestamp))
}
}
@@ -6463,7 +6509,8 @@ class AstBuilder extends DataTypeAstBuilder
* */
override def visitNamedParameterLiteral(
ctx: NamedParameterLiteralContext): Expression = withOrigin(ctx) {
- NamedParameter(ctx.namedParameterMarker().identifier().getText)
+ // Named parameters use simpleIdentifier, so .getText() is correct.
+ NamedParameter(ctx.namedParameterMarker().simpleIdentifier().getText)
}
/**
@@ -6618,7 +6665,7 @@ class AstBuilder extends DataTypeAstBuilder
target = None, excepts = ids.map(s => Seq(s)), replacements = None))
Project(projectList, left)
}.getOrElse(Option(ctx.AS).map { _ =>
- SubqueryAlias(ctx.errorCapturingIdentifier().getText, left)
+ SubqueryAlias(getIdentifierText(ctx.errorCapturingIdentifier()), left)
}.getOrElse(Option(ctx.whereClause).map { c =>
if (ctx.windowClause() != null) {
throw QueryParsingErrors.windowClauseInPipeOperatorWhereClauseNotAllowedError(ctx)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
index a19b4cca28173..336db1382f898 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
@@ -28,7 +28,7 @@ import org.antlr.v4.runtime.tree.{ParseTree, TerminalNodeImpl}
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.analysis.UnresolvedIdentifier
-import org.apache.spark.sql.catalyst.parser.SqlBaseParser.{BeginLabelContext, EndLabelContext, MultipartIdentifierContext}
+import org.apache.spark.sql.catalyst.parser.SqlBaseParser.{BeginLabelContext, EndLabelContext, StrictIdentifierContext}
import org.apache.spark.sql.catalyst.plans.logical.{CreateVariable, ErrorCondition}
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.util.SparkParserUtils
@@ -279,34 +279,49 @@ class SqlScriptingLabelContext {
* @param beginLabelCtx Begin label context.
* @param endLabelCtx The end label context.
*/
+ /**
+ * Get label text from label context, handling IDENTIFIER() syntax.
+ */
+ private def getLabelText(ctx: ParserRuleContext): String = {
+ val astBuilder = new DataTypeAstBuilder {
+ override protected def parseMultipartIdentifier(identifier: String): Seq[String] = {
+ CatalystSqlParser.parseMultipartIdentifier(identifier)
+ }
+ }
+ val parts = astBuilder.extractIdentifierParts(ctx)
+ if (parts.size > 1) {
+ throw new ParseException(
+ errorClass = "IDENTIFIER_TOO_MANY_NAME_PARTS",
+ messageParameters = Map("identifier" -> parts.map(part => s"`$part`").mkString("."),
+ "limit" -> "1"),
+ ctx)
+ }
+ parts.head
+ }
+
private def checkLabels(
beginLabelCtx: Option[BeginLabelContext],
- endLabelCtx: Option[EndLabelContext]) : Unit = {
+ endLabelCtx: Option[EndLabelContext]): Unit = {
+ // Check label matching and other constraints.
(beginLabelCtx, endLabelCtx) match {
// Throw an error if labels do not match.
- case (Some(bl: BeginLabelContext), Some(el: EndLabelContext))
- if bl.multipartIdentifier().getText.toLowerCase(Locale.ROOT) !=
- el.multipartIdentifier().getText.toLowerCase(Locale.ROOT) =>
- withOrigin(bl) {
- throw SqlScriptingErrors.labelsMismatch(
- CurrentOrigin.get,
- bl.multipartIdentifier().getText,
- el.multipartIdentifier().getText)
- }
- // Throw an error if label is qualified.
- case (Some(bl: BeginLabelContext), _)
- if bl.multipartIdentifier().parts.size() > 1 =>
- withOrigin(bl) {
- throw SqlScriptingErrors.labelCannotBeQualified(
- CurrentOrigin.get,
- bl.multipartIdentifier().getText.toLowerCase(Locale.ROOT)
- )
+ case (Some(bl: BeginLabelContext), Some(el: EndLabelContext)) =>
+ val beginLabel = getLabelText(bl.strictIdentifier()).toLowerCase(Locale.ROOT)
+ val endLabel = getLabelText(el.strictIdentifier()).toLowerCase(Locale.ROOT)
+ if (beginLabel != endLabel) {
+ withOrigin(bl) {
+ throw SqlScriptingErrors.labelsMismatch(
+ CurrentOrigin.get,
+ getLabelText(bl.strictIdentifier()),
+ getLabelText(el.strictIdentifier()))
+ }
}
// Throw an error if end label exists without begin label.
case (None, Some(el: EndLabelContext)) =>
withOrigin(el) {
throw SqlScriptingErrors.endLabelWithoutBeginLabel(
- CurrentOrigin.get, el.multipartIdentifier().getText)
+ CurrentOrigin.get,
+ getLabelText(el.strictIdentifier()))
}
case _ =>
}
@@ -314,7 +329,7 @@ class SqlScriptingLabelContext {
/** Check if the label is defined. */
private def isLabelDefined(beginLabelCtx: Option[BeginLabelContext]): Boolean = {
- beginLabelCtx.map(_.multipartIdentifier().getText).isDefined
+ beginLabelCtx.isDefined
}
/**
@@ -322,13 +337,13 @@ class SqlScriptingLabelContext {
* If the identifier is contained within seenLabels, raise an exception.
*/
private def assertIdentifierNotInSeenLabels(
- identifierCtx: Option[MultipartIdentifierContext]): Unit = {
+ identifierCtx: Option[StrictIdentifierContext]): Unit = {
identifierCtx.foreach { ctx =>
- val identifierName = ctx.getText
- if (seenLabels.contains(identifierName.toLowerCase(Locale.ROOT))) {
+ val identifierName = getLabelText(ctx).toLowerCase(Locale.ROOT)
+ if (seenLabels.contains(identifierName)) {
withOrigin(ctx) {
throw SqlScriptingErrors
- .duplicateLabels(CurrentOrigin.get, identifierName.toLowerCase(Locale.ROOT))
+ .duplicateLabels(CurrentOrigin.get, identifierName)
}
}
}
@@ -348,7 +363,7 @@ class SqlScriptingLabelContext {
// Get label text and add it to seenLabels.
val labelText = if (isLabelDefined(beginLabelCtx)) {
- val txt = beginLabelCtx.get.multipartIdentifier().getText.toLowerCase(Locale.ROOT)
+ val txt = getLabelText(beginLabelCtx.get.strictIdentifier()).toLowerCase(Locale.ROOT)
if (seenLabels.contains(txt)) {
withOrigin(beginLabelCtx.get) {
throw SqlScriptingErrors.duplicateLabels(CurrentOrigin.get, txt)
@@ -374,18 +389,18 @@ class SqlScriptingLabelContext {
*/
def exitLabeledScope(beginLabelCtx: Option[BeginLabelContext]): Unit = {
if (isLabelDefined(beginLabelCtx)) {
- seenLabels.remove(beginLabelCtx.get.multipartIdentifier().getText.toLowerCase(Locale.ROOT))
+ seenLabels.remove(getLabelText(beginLabelCtx.get.strictIdentifier()).toLowerCase(Locale.ROOT))
}
}
/**
* Enter a for loop scope.
- * If the for loop variable is defined, it will be asserted to not be inside seenLabels;
+ * If the for loop variable is defined, it will be asserted to not be inside seenLabels.
* Then, if the for loop variable is defined, it will be added to seenLabels.
*/
- def enterForScope(identifierCtx: Option[MultipartIdentifierContext]): Unit = {
+ def enterForScope(identifierCtx: Option[StrictIdentifierContext]): Unit = {
identifierCtx.foreach { ctx =>
- val identifierName = ctx.getText
+ val identifierName = getLabelText(ctx)
assertIdentifierNotInSeenLabels(identifierCtx)
seenLabels.add(identifierName.toLowerCase(Locale.ROOT))
@@ -403,9 +418,9 @@ class SqlScriptingLabelContext {
* Exit a for loop scope.
* If the for loop variable is defined, it will be removed from seenLabels.
*/
- def exitForScope(identifierCtx: Option[MultipartIdentifierContext]): Unit = {
+ def exitForScope(identifierCtx: Option[StrictIdentifierContext]): Unit = {
identifierCtx.foreach { ctx =>
- val identifierName = ctx.getText
+ val identifierName = getLabelText(ctx)
seenLabels.remove(identifierName.toLowerCase(Locale.ROOT))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SubstituteParamsParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SubstituteParamsParser.scala
index 54c8c2ec089f9..9beead0e64875 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SubstituteParamsParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/SubstituteParamsParser.scala
@@ -186,6 +186,8 @@ class SubstituteParamsParser extends Logging {
/**
* Apply a list of substitutions to the SQL text.
+ * Inserts a space separator when a parameter is immediately preceded by a quote
+ * to avoid back-to-back quotes after substitution.
*/
private def applySubstitutions(sqlText: String, substitutions: List[Substitution]): String = {
// Sort substitutions by start position in reverse order to avoid offset issues
@@ -193,9 +195,18 @@ class SubstituteParamsParser extends Logging {
var result = sqlText
sortedSubstitutions.foreach { substitution =>
- result = result.substring(0, substitution.start) +
- substitution.replacement +
- result.substring(substitution.end)
+ val prefix = result.substring(0, substitution.start)
+ val replacement = substitution.replacement
+ val suffix = result.substring(substitution.end)
+
+ // Check if replacement is immediately preceded by a quote and doesn't already
+ // start with whitespace
+ val needsSpace = substitution.start > 0 &&
+ (result(substitution.start - 1) == '\'' || result(substitution.start - 1) == '"') &&
+ replacement.nonEmpty && !replacement(0).isWhitespace
+
+ val space = if (needsSpace) " " else ""
+ result = s"$prefix$space$replacement$suffix"
}
result
}
@@ -211,4 +222,3 @@ object SubstituteParamsParser {
positionalParams: List[String] = List.empty): (String, Int, PositionMapper) =
instance.substitute(sqlText, namedParams, positionalParams)
}
-
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 142420ee258ae..b87d018f2ab1d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -580,19 +580,18 @@ case class Union(
allowMissingCol: Boolean = false) extends UnionBase {
assert(!allowMissingCol || byName, "`allowMissingCol` can be true only if `byName` is true.")
- override def maxRows: Option[Long] = {
- var sum = BigInt(0)
- children.foreach { child =>
- if (child.maxRows.isDefined) {
- sum += child.maxRows.get
- if (!sum.isValidLong) {
- return None
+ override lazy val maxRows: Option[Long] = {
+ val sum = children.foldLeft(Option(BigInt(0))) {
+ case (Some(acc), child) =>
+ child.maxRows match {
+ case Some(n) =>
+ val newSum = acc + n
+ if (newSum.isValidLong) Some(newSum) else None
+ case None => None
}
- } else {
- return None
- }
+ case (None, _) => None
}
- Some(sum.toLong)
+ sum.map(_.toLong)
}
final override val nodePatterns: Seq[TreePattern] = Seq(UNION)
@@ -600,19 +599,18 @@ case class Union(
/**
* Note the definition has assumption about how union is implemented physically.
*/
- override def maxRowsPerPartition: Option[Long] = {
- var sum = BigInt(0)
- children.foreach { child =>
- if (child.maxRowsPerPartition.isDefined) {
- sum += child.maxRowsPerPartition.get
- if (!sum.isValidLong) {
- return None
+ override lazy val maxRowsPerPartition: Option[Long] = {
+ val sum = children.foldLeft(Option(BigInt(0))) {
+ case (Some(acc), child) =>
+ child.maxRowsPerPartition match {
+ case Some(n) =>
+ val newSum = acc + n
+ if (newSum.isValidLong) Some(newSum) else None
+ case None => None
}
- } else {
- return None
- }
+ case (None, _) => None
}
- Some(sum.toLong)
+ sum.map(_.toLong)
}
private def duplicatesResolvedPerBranch: Boolean =
@@ -666,7 +664,7 @@ case class Join(
hint: JoinHint)
extends BinaryNode with PredicateHelper {
- override def maxRows: Option[Long] = {
+ override lazy val maxRows: Option[Long] = {
joinType match {
case Inner | Cross | FullOuter | LeftOuter | RightOuter | LeftSingle
if left.maxRows.isDefined && right.maxRows.isDefined =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index 802678bc2c6ce..bcfcae2ee16c9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -19,11 +19,13 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.resource.ResourceProfile
import org.apache.spark.sql.catalyst.SQLConfHelper
-import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedAttribute, UnresolvedStar}
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, JsonToStructs, PythonUDF, PythonUDTF}
+import org.apache.spark.sql.catalyst.analysis.{FunctionRegistryBase, MultiInstanceRelation, UnresolvedAttribute, UnresolvedStar}
+import org.apache.spark.sql.catalyst.analysis.TableFunctionRegistry.TableFunctionBuilder
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeSet, Expression, ExpressionDescription, ExpressionInfo, JsonToStructs, PythonUDF, PythonUDTF}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.truncatedString
+import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, TimeMode}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.util.LogUtils
@@ -386,9 +388,26 @@ case class AttachDistributedSequence(
}
}
+// scalastyle:off line.contains.tab line.size.limit
+@ExpressionDescription(
+ usage = """
+ _FUNC_() - Returns a table of logs collected from Python workers.
+ """,
+ examples = """
+ Examples:
+ > SET spark.sql.pyspark.worker.logging.enabled=true;
+ spark.sql.pyspark.worker.logging.enabled true
+ > SELECT * FROM _FUNC_();
+
+ """,
+ since = "4.1.0",
+ group = "table_funcs")
+// scalastyle:on line.contains.tab line.size.limit
case class PythonWorkerLogs(jsonAttr: Attribute)
extends LeafNode with MultiInstanceRelation with SQLConfHelper {
+ def this() = this(DataTypeUtils.toAttribute(StructField("message", StringType)))
+
override def output: Seq[Attribute] = Seq(jsonAttr)
override def newInstance(): PythonWorkerLogs =
@@ -403,23 +422,28 @@ case class PythonWorkerLogs(jsonAttr: Attribute)
)
}
-object PythonWorkerLogs {
- val ViewName = "python_worker_logs"
-
- def apply(): LogicalPlan = {
- PythonWorkerLogs(DataTypeUtils.toAttribute(StructField("message", StringType)))
- }
-
- def viewDefinition(): LogicalPlan = {
- Project(
- Seq(UnresolvedStar(Some(Seq("from_json")))),
- Project(
- Seq(Alias(
- JsonToStructs(
- schema = StructType.fromDDL(LogUtils.SPARK_LOG_SCHEMA),
- options = Map.empty,
- child = UnresolvedAttribute("message")),
- "from_json")()),
- PythonWorkerLogs()))
+object PythonWorkerLogs extends SQLConfHelper {
+ val TableFunctionName = "python_worker_logs"
+
+ val functionBuilder: (String, (ExpressionInfo, TableFunctionBuilder)) = {
+ val (info, builder) = FunctionRegistryBase.build[PythonWorkerLogs](
+ TableFunctionName, None)
+ val funcBuilder = (expressions: Seq[Expression]) => {
+ if (conf.pythonWorkerLoggingEnabled) {
+ Project(
+ Seq(UnresolvedStar(Some(Seq("from_json")))),
+ Project(
+ Seq(Alias(
+ JsonToStructs(
+ schema = StructType.fromDDL(LogUtils.SPARK_LOG_SCHEMA),
+ options = Map.empty,
+ child = UnresolvedAttribute("message")),
+ "from_json")()),
+ builder(expressions)))
+ } else {
+ throw QueryCompilationErrors.pythonWorkerLoggingNotEnabledError()
+ }
+ }
+ TableFunctionName -> (info, funcBuilder)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala
index d1eb561f3add1..843ce22061d8c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2AlterTableCommands.scala
@@ -48,7 +48,11 @@ trait AlterTableCommand extends UnaryCommand {
*/
case class CommentOnTable(table: LogicalPlan, comment: String) extends AlterTableCommand {
override def changes: Seq[TableChange] = {
- Seq(TableChange.setProperty(TableCatalog.PROP_COMMENT, comment))
+ if (comment == null) {
+ Seq(TableChange.removeProperty(TableCatalog.PROP_COMMENT))
+ } else {
+ Seq(TableChange.setProperty(TableCatalog.PROP_COMMENT, comment))
+ }
}
override protected def withNewChildInternal(newChild: LogicalPlan): LogicalPlan =
copy(table = newChild)
@@ -108,6 +112,7 @@ case class AddColumns(
columnsToAdd: Seq[QualifiedColType]) extends AlterTableCommand {
columnsToAdd.foreach { c =>
TypeUtils.failWithIntervalType(c.dataType)
+ TypeUtils.failUnsupportedDataType(c.dataType, conf)
}
override lazy val resolved: Boolean = table.resolved && columnsToAdd.forall(_.resolved)
@@ -140,6 +145,7 @@ case class ReplaceColumns(
columnsToAdd: Seq[QualifiedColType]) extends AlterTableCommand {
columnsToAdd.foreach { c =>
TypeUtils.failWithIntervalType(c.dataType)
+ TypeUtils.failUnsupportedDataType(c.dataType, conf)
}
override lazy val resolved: Boolean = table.resolved && columnsToAdd.forall(_.resolved)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index cd0c2742df3d5..3f5a006c505ec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.{SparkIllegalArgumentException, SparkUnsupportedOperationException}
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, ResolvedProcedure, TypeCheckResult, UnresolvedException, UnresolvedProcedure, ViewSchemaMode}
+import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, ResolvedProcedure, TypeCheckResult, UnresolvedAttribute, UnresolvedException, UnresolvedProcedure, ViewSchemaMode}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.catalog.{FunctionResource, RoutineLanguage}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
@@ -61,8 +61,7 @@ trait KeepAnalyzedQuery extends Command {
trait V2WriteCommand
extends UnaryCommand
with KeepAnalyzedQuery
- with CTEInChildren
- with IgnoreCachedData {
+ with CTEInChildren {
def table: NamedRelation
def query: LogicalPlan
def isByName: Boolean
@@ -257,7 +256,7 @@ case class ReplaceData(
write: Option[Write] = None) extends RowLevelWrite {
override val isByName: Boolean = false
- override val stringArgs: Iterator[Any] = Iterator(table, query, write)
+ override def stringArgs: Iterator[Any] = Iterator(table, query, write)
override lazy val references: AttributeSet = query.outputSet
@@ -339,7 +338,7 @@ case class WriteDelta(
write: Option[DeltaWrite] = None) extends RowLevelWrite {
override val isByName: Boolean = false
- override val stringArgs: Iterator[Any] = Iterator(table, query, write)
+ override def stringArgs: Iterator[Any] = Iterator(table, query, write)
override lazy val references: AttributeSet = query.outputSet
@@ -545,7 +544,6 @@ case class CreateTableAsSelect(
* The base command representation for a statement that can be part of a Declarative Pipeline to
* define a pipeline dataset (MV or ST).
*/
-
trait CreatePipelineDataset extends Command {
// The name of the dataset.
val name: LogicalPlan
@@ -566,7 +564,8 @@ trait CreatePipelineDataset extends Command {
/**
* An extension of the base command representation that represents a CTAS style CREATE statement.
*/
-trait CreatePipelineDatasetAsSelect extends BinaryCommand
+trait CreatePipelineDatasetAsSelect
+ extends BinaryCommand
with CreatePipelineDataset
with CTEInChildren {
@@ -687,7 +686,9 @@ case class ReplaceTableAsSelect(
isAnalyzed: Boolean = false)
extends V2CreateTableAsSelectPlan {
- override def markAsAnalyzed(ac: AnalysisContext): LogicalPlan = copy(isAnalyzed = true)
+ override def markAsAnalyzed(ac: AnalysisContext): LogicalPlan = {
+ copy(isAnalyzed = true)
+ }
override def withPartitioning(rewritten: Seq[Transform]): V2CreateTablePlan = {
this.copy(partitioning = rewritten)
@@ -865,7 +866,7 @@ case class MergeIntoTable(
lazy val aligned: Boolean = {
val actions = matchedActions ++ notMatchedActions ++ notMatchedBySourceActions
actions.forall {
- case UpdateAction(_, assignments) =>
+ case UpdateAction(_, assignments, _) =>
AssignmentUtils.aligned(targetTable.output, assignments)
case _: DeleteAction =>
true
@@ -893,16 +894,57 @@ case class MergeIntoTable(
}
lazy val needSchemaEvolution: Boolean =
+ evaluateSchemaEvolution && changesForSchemaEvolution.nonEmpty
+
+ lazy val evaluateSchemaEvolution: Boolean =
schemaEvolutionEnabled &&
- MergeIntoTable.schemaChanges(targetTable.schema, sourceTable.schema).nonEmpty
+ canEvaluateSchemaEvolution
- private def schemaEvolutionEnabled: Boolean = withSchemaEvolution && {
+ lazy val schemaEvolutionEnabled: Boolean = withSchemaEvolution && {
EliminateSubqueryAliases(targetTable) match {
case r: DataSourceV2Relation if r.autoSchemaEvolution() => true
case _ => false
}
}
+ // Guard that assignments are either resolved or candidates for evolution before
+ // evaluating schema evolution. We need to use resolved assignment values to check
+ // candidates, see MergeIntoTable.sourceSchemaForSchemaEvolution for details.
+ lazy val canEvaluateSchemaEvolution: Boolean = {
+ if ((!targetTable.resolved) || (!sourceTable.resolved)) {
+ false
+ } else {
+ val actions = matchedActions ++ notMatchedActions
+ val hasStarActions = actions.exists {
+ case _: UpdateStarAction => true
+ case _: InsertStarAction => true
+ case _ => false
+ }
+ if (hasStarActions) {
+ // need to resolve star actions first
+ false
+ } else {
+ val assignments = actions.collect {
+ case a: UpdateAction => a.assignments
+ case a: InsertAction => a.assignments
+ }.flatten
+ val sourcePaths = DataTypeUtils.extractAllFieldPaths(sourceTable.schema)
+ assignments.forall { assignment =>
+ assignment.resolved ||
+ (assignment.value.resolved && sourcePaths.exists {
+ path => MergeIntoTable.isEqual(assignment, path)
+ })
+ }
+ }
+ }
+ }
+
+ private lazy val sourceSchemaForEvolution: StructType =
+ MergeIntoTable.sourceSchemaForSchemaEvolution(this)
+
+ lazy val changesForSchemaEvolution: Array[TableChange] =
+ MergeIntoTable.schemaChanges(targetTable.schema, sourceSchemaForEvolution)
+
override def left: LogicalPlan = targetTable
override def right: LogicalPlan = sourceTable
override protected def withNewChildrenInternal(
@@ -911,6 +953,7 @@ case class MergeIntoTable(
}
object MergeIntoTable {
+
def getWritePrivileges(
matchedActions: Iterable[MergeAction],
notMatchedActions: Iterable[MergeAction],
@@ -948,11 +991,12 @@ object MergeIntoTable {
case currentField: StructField if newFieldMap.contains(currentField.name) =>
schemaChanges(currentField.dataType, newFieldMap(currentField.name).dataType,
originalTarget, originalSource, fieldPath ++ Seq(currentField.name))
- }}.flatten
+ }
+ }.flatten
// Identify the newly added fields and append to the end
val currentFieldMap = toFieldMap(currentFields)
- val adds = newFields.filterNot (f => currentFieldMap.contains (f.name))
+ val adds = newFields.filterNot(f => currentFieldMap.contains(f.name))
.map(f => TableChange.addColumn(fieldPath ++ Set(f.name), f.dataType))
updates ++ adds
@@ -990,6 +1034,84 @@ object MergeIntoTable {
CaseInsensitiveMap(fieldMap)
}
}
+
+ // A pruned version of source schema that only contains columns/nested fields
+ // explicitly and directly assigned to a target counterpart in MERGE INTO actions,
+ // which are relevant for schema evolution.
+ // Examples:
+ // * UPDATE SET target.a = source.a
+ // * UPDATE SET nested.a = source.nested.a
+ // * INSERT (a, nested.b) VALUES (source.a, source.nested.b)
+ // New columns/nested fields in this schema that are not existing in target schema
+ // will be added for schema evolution.
+ def sourceSchemaForSchemaEvolution(merge: MergeIntoTable): StructType = {
+ val actions = merge.matchedActions ++ merge.notMatchedActions
+ val assignments = actions.collect {
+ case a: UpdateAction => a.assignments
+ case a: InsertAction => a.assignments
+ }.flatten
+
+ val containsStarAction = actions.exists {
+ case _: UpdateStarAction => true
+ case _: InsertStarAction => true
+ case _ => false
+ }
+
+ def filterSchema(sourceSchema: StructType, basePath: Seq[String]): StructType =
+ StructType(sourceSchema.flatMap { field =>
+ val fieldPath = basePath :+ field.name
+
+ field.dataType match {
+ // Specifically assigned to in one clause:
+ // always keep, including all nested attributes
+ case _ if assignments.exists(isEqual(_, fieldPath)) => Some(field)
+ // If this is a struct and one of the children is being assigned to in a merge clause,
+ // keep it and continue filtering children.
+ case struct: StructType if assignments.exists(assign =>
+ isPrefix(fieldPath, extractFieldPath(assign.key, allowUnresolved = true))) =>
+ Some(field.copy(dataType = filterSchema(struct, fieldPath)))
+ // The field isn't assigned to directly or indirectly (i.e. its children) in any non-*
+ // clause. Check if it should be kept with any * action.
+ case struct: StructType if containsStarAction =>
+ Some(field.copy(dataType = filterSchema(struct, fieldPath)))
+ case _ if containsStarAction => Some(field)
+ // The field and its children are not assigned to in any * or non-* action, drop it.
+ case _ => None
+ }
+ })
+
+ filterSchema(merge.sourceTable.schema, Seq.empty)
+ }
+
+ // Helper method to extract field path from an Expression.
+ private def extractFieldPath(expr: Expression, allowUnresolved: Boolean): Seq[String] = {
+ expr match {
+ case UnresolvedAttribute(nameParts) if allowUnresolved => nameParts
+ case a: AttributeReference => Seq(a.name)
+ case GetStructField(child, ordinal, nameOpt) =>
+ extractFieldPath(child, allowUnresolved) :+ nameOpt.getOrElse(s"col$ordinal")
+ case _ => Seq.empty
+ }
+ }
+
+ // Helper method to check if a given field path is a prefix of another path.
+ private def isPrefix(prefix: Seq[String], path: Seq[String]): Boolean =
+ prefix.length <= path.length && prefix.zip(path).forall {
+ case (prefixNamePart, pathNamePart) =>
+ SQLConf.get.resolver(prefixNamePart, pathNamePart)
+ }
+
+ // Helper method to check if an assignment key is equal to a source column
+ // and if the assignment value is that same source column.
+ // Example: UPDATE SET target.a = source.a
+ private def isEqual(assignment: Assignment, sourceFieldPath: Seq[String]): Boolean = {
+ // key must be a non-qualified field path that may be added to target schema via evolution
+ val assignmenKeyExpr = extractFieldPath(assignment.key, allowUnresolved = true)
+ // value should always be resolved (from source)
+ val assignmentValueExpr = extractFieldPath(assignment.value, allowUnresolved = false)
+ assignmenKeyExpr == assignmentValueExpr &&
+ assignmenKeyExpr == sourceFieldPath
+ }
}
sealed abstract class MergeAction extends Expression with Unevaluable {
@@ -1007,7 +1129,8 @@ case class DeleteAction(condition: Option[Expression]) extends MergeAction {
case class UpdateAction(
condition: Option[Expression],
- assignments: Seq[Assignment]) extends MergeAction {
+ assignments: Seq[Assignment],
+ fromStar: Boolean = false) extends MergeAction {
override def children: Seq[Expression] = condition.toSeq ++ assignments
override protected def withNewChildrenInternal(
@@ -1618,7 +1741,8 @@ case class CacheTableAsSelect(
isLazy: Boolean,
options: Map[String, String],
isAnalyzed: Boolean = false,
- referredTempFunctions: Seq[String] = Seq.empty) extends AnalysisOnlyCommand {
+ referredTempFunctions: Seq[String] = Seq.empty)
+ extends AnalysisOnlyCommand with CTEInChildren {
override protected def withNewChildrenInternal(
newChildren: IndexedSeq[LogicalPlan]): CacheTableAsSelect = {
assert(!isAnalyzed)
@@ -1633,6 +1757,10 @@ case class CacheTableAsSelect(
// Collect the referred temporary functions from AnalysisContext
referredTempFunctions = ac.referredTempFunctionNames.toSeq)
}
+
+ override def withCTEDefs(cteDefs: Seq[CTERelationDef]): LogicalPlan = {
+ copy(plan = WithCTE(plan, cteDefs))
+ }
}
/**
@@ -1868,7 +1996,7 @@ case class Call(
* representation of the matching SQL syntax and cannot be executed. Instead, it is interpreted by
* the pipelines submodule during a pipeline execution
*
- * @param name the name of this flow
+ * @param name the name of this flow
* @param flowOperation the logical plan of the actual transformation this flow should execute
* @param comment an optional comment describing this flow
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 1cbb49c7a1f73..fd71e22c555cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -379,36 +379,36 @@ case class KeyGroupedPartitioning(
expressions: Seq[Expression],
numPartitions: Int,
partitionValues: Seq[InternalRow] = Seq.empty,
- originalPartitionValues: Seq[InternalRow] = Seq.empty) extends HashPartitioningLike {
+ originalPartitionValues: Seq[InternalRow] = Seq.empty,
+ isPartiallyClustered: Boolean = false) extends HashPartitioningLike {
+ // See SPARK-55848. We must check ClusteredDistribution BEFORE delegating to
+ // super.satisfies0(), because HashPartitioningLike.satisfies0() also matches
+ // ClusteredDistribution and returns true, which would short-circuit the
+ // isPartiallyClustered guard.
override def satisfies0(required: Distribution): Boolean = {
- super.satisfies0(required) || {
- required match {
- case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
- if (requireAllClusterKeys) {
- // Checks whether this partitioning is partitioned on exactly same clustering keys of
- // `ClusteredDistribution`.
- c.areAllClusterKeysMatched(expressions)
+ required match {
+ case c @ ClusteredDistribution(requiredClustering, requireAllClusterKeys, _) =>
+ if (isPartiallyClustered) {
+ false
+ } else if (requireAllClusterKeys) {
+ c.areAllClusterKeysMatched(expressions)
+ } else {
+ val attributes = expressions.flatMap(_.collectLeaves())
+
+ if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
+ requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) &&
+ expressions.forall(_.collectLeaves().size == 1)
} else {
- // We'll need to find leaf attributes from the partition expressions first.
- val attributes = expressions.flatMap(_.collectLeaves())
-
- if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) {
- // check that join keys (required clustering keys)
- // overlap with partition keys (KeyGroupedPartitioning attributes)
- requiredClustering.exists(x => attributes.exists(_.semanticEquals(x))) &&
- expressions.forall(_.collectLeaves().size == 1)
- } else {
- attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
- }
+ attributes.forall(x => requiredClustering.exists(_.semanticEquals(x)))
}
+ }
- case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting =>
- o.areAllClusterKeysMatched(expressions)
+ case o @ OrderedDistribution(_) if SQLConf.get.v2BucketingAllowSorting =>
+ o.areAllClusterKeysMatched(expressions)
- case _ =>
- false
- }
+ case _ =>
+ super.satisfies0(required)
}
}
@@ -420,7 +420,7 @@ case class KeyGroupedPartitioning(
// the returned shuffle spec.
val joinKeyPositions = result.keyPositions.map(_.nonEmpty).zipWithIndex.filter(_._1).map(_._2)
val projectedPartitioning = KeyGroupedPartitioning(expressions, joinKeyPositions,
- partitionValues, originalPartitionValues)
+ partitionValues, originalPartitionValues, isPartiallyClustered)
result.copy(partitioning = projectedPartitioning, joinKeyPositions = Some(joinKeyPositions))
} else {
result
@@ -435,7 +435,7 @@ case class KeyGroupedPartitioning(
}
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
- copy(expressions = newChildren)
+ copy(expressions = newChildren, isPartiallyClustered = isPartiallyClustered)
}
object KeyGroupedPartitioning {
@@ -443,7 +443,8 @@ object KeyGroupedPartitioning {
expressions: Seq[Expression],
projectionPositions: Seq[Int],
partitionValues: Seq[InternalRow],
- originalPartitionValues: Seq[InternalRow]): KeyGroupedPartitioning = {
+ originalPartitionValues: Seq[InternalRow],
+ isPartiallyClustered: Boolean): KeyGroupedPartitioning = {
val projectedExpressions = projectionPositions.map(expressions(_))
val projectedPartitionValues = partitionValues.map(project(expressions, projectionPositions, _))
val projectedOriginalPartitionValues =
@@ -455,7 +456,7 @@ object KeyGroupedPartitioning {
.map(_.row)
KeyGroupedPartitioning(projectedExpressions, finalPartitionValues.length,
- finalPartitionValues, projectedOriginalPartitionValues)
+ finalPartitionValues, projectedOriginalPartitionValues, isPartiallyClustered)
}
def project(
@@ -867,7 +868,10 @@ case class KeyGroupedShuffleSpec(
// transform functions.
// 4. the partition values from both sides are following the same order.
case otherSpec @ KeyGroupedShuffleSpec(otherPartitioning, otherDistribution, _) =>
- distribution.clustering.length == otherDistribution.clustering.length &&
+ // SPARK-55848: partially-clustered partitioning is not compatible for SPJ
+ !partitioning.isPartiallyClustered &&
+ !otherPartitioning.isPartiallyClustered &&
+ distribution.clustering.length == otherDistribution.clustering.length &&
numPartitions == other.numPartitions && areKeysCompatible(otherSpec) &&
partitioning.partitionValues.zip(otherPartitioning.partitionValues).forall {
case (left, right) =>
@@ -943,10 +947,13 @@ case class KeyGroupedShuffleSpec(
}
override def createPartitioning(clustering: Seq[Expression]): Partitioning = {
- val newExpressions: Seq[Expression] = clustering.zip(partitioning.expressions).map {
- case (c, e: TransformExpression) => TransformExpression(
- e.function, Seq(c), e.numBucketsOpt)
- case (c, _) => c
+ assert(clustering.size == distribution.clustering.size,
+ "Required distributions of join legs should be the same size.")
+
+ val newExpressions = partitioning.expressions.zip(keyPositions).map {
+ case (te: TransformExpression, positionSet) =>
+ te.copy(children = te.children.map(_ => clustering(positionSet.head)))
+ case (_, positionSet) => clustering(positionSet.head)
}
KeyGroupedPartitioning(newExpressions,
partitioning.numPartitions,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala
index a6204b317d249..7015d0dd3b2cc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/WriteToStreamStatement.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog}
-import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.streaming.{OutputMode, Trigger}
/**
* A statement for Stream writing. It contains all neccessary param and will be resolved in the
@@ -39,7 +39,9 @@ import org.apache.spark.sql.streaming.OutputMode
* @param sink Sink to write the streaming outputs.
* @param outputMode Output mode for the sink.
* @param hadoopConf The Hadoop Configuration to get a FileSystem instance
- * @param isContinuousTrigger Whether the statement is triggered by a continuous query or not.
+ * @param trigger The trigger being used for this streaming query. It is not used to create the
+ * resolved [[WriteToStream]] node; rather, it is only used while checking the plan
+ * for unsupported operations, which happens during resolution.
* @param inputQuery The analyzed query plan from the streaming DataFrame.
* @param catalogAndIdent Catalog and identifier for the sink, set when it is a V2 catalog table
*/
@@ -51,7 +53,7 @@ case class WriteToStreamStatement(
sink: Table,
outputMode: OutputMode,
hadoopConf: Configuration,
- isContinuousTrigger: Boolean,
+ trigger: Trigger,
inputQuery: LogicalPlan,
catalogAndIdent: Option[(TableCatalog, Identifier)] = None,
catalogTable: Option[CatalogTable] = None) extends UnaryNode {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
index c6e51aab4584b..e7bd5bd1aa2d9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
@@ -249,5 +249,38 @@ object DataTypeUtils {
case v: Long => fromDecimal(Decimal(BigDecimal(v)))
case _ => forType(literal.dataType)
}
+
+ /**
+ * Extracts all struct field paths from a nested StructType.
+ */
+ def extractAllFieldPaths(schema: StructType, basePath: Seq[String] = Seq.empty):
+ Seq[Seq[String]] = {
+ schema.flatMap { field =>
+ val fieldPath = basePath :+ field.name
+ field.dataType match {
+ case struct: StructType =>
+ fieldPath +: extractAllFieldPaths(struct, fieldPath)
+ case _ =>
+ Seq(fieldPath)
+ }
+ }
+ }
+
+ /**
+ * Extracts only leaf-level field paths from a nested StructType.
+ * Unlike extractAllFieldPaths, this method does not include intermediate struct paths.
+ */
+ def extractLeafFieldPaths(schema: StructType, basePath: Seq[String] = Seq.empty):
+ Seq[Seq[String]] = {
+ schema.flatMap { field =>
+ val fieldPath = basePath :+ field.name
+ field.dataType match {
+ case struct: StructType =>
+ extractLeafFieldPaths(struct, fieldPath)
+ case _ =>
+ Seq(fieldPath)
+ }
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
index 8793c0407a9b5..ef10a308cff9b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/IntervalUtils.scala
@@ -594,14 +594,19 @@ object IntervalUtils extends SparkIntervalUtils {
interval: CalendarInterval,
targetUnit: TimeUnit,
daysPerMonth: Int = 31): Long = {
- val monthsDuration = Math.multiplyExact(
- daysPerMonth * MICROS_PER_DAY,
- interval.months)
- val daysDuration = Math.multiplyExact(
- MICROS_PER_DAY,
- interval.days)
- val result = Math.addExact(interval.microseconds, Math.addExact(daysDuration, monthsDuration))
- targetUnit.convert(result, TimeUnit.MICROSECONDS)
+ try {
+ val monthsDuration = Math.multiplyExact(
+ daysPerMonth * MICROS_PER_DAY,
+ interval.months)
+ val daysDuration = Math.multiplyExact(
+ MICROS_PER_DAY,
+ interval.days)
+ val result = Math.addExact(interval.microseconds, Math.addExact(daysDuration, monthsDuration))
+ targetUnit.convert(result, TimeUnit.MICROSECONDS)
+ } catch {
+ case _: ArithmeticException =>
+ throw QueryExecutionErrors.withoutSuggestionIntervalArithmeticOverflowError(context = null)
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
index 4bef21d0a091e..488d1acf43ac4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala
@@ -480,7 +480,7 @@ object ResolveDefaultColumns extends QueryErrorsBase
val ret = analyzed match {
case equivalent if equivalent.dataType == supplanted =>
equivalent
- case canUpCast if Cast.canUpCast(canUpCast.dataType, supplanted) =>
+ case _ if Cast.canAssignDefaultValue(analyzed.dataType, supplanted) =>
Cast(analyzed, supplanted, Some(conf.sessionLocalTimeZone))
case other =>
defaultValueFromWiderTypeLiteral(other, supplanted, colName).getOrElse(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala
index ea2f48fafc0dd..ffcf8ba2cb93d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ToNumberParser.scala
@@ -495,6 +495,9 @@ class ToNumberParser(numberFormat: String, errorOnFail: Boolean) extends Seriali
// If we have consumed all the tokens in the format string, but characters remain unconsumed
// in the input string, then the input string does not match the format string.
formatMatchFailure(input, numberFormat)
+ } else if (parsedBeforeDecimalPoint.isEmpty && parsedAfterDecimalPoint.isEmpty) {
+ // If no digits were collected (e.g. input was all whitespace), treat as format match failure.
+ formatMatchFailure(input, numberFormat)
} else {
parseResultToDecimalValue(negateResult)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
index 9f89f068b7568..9c5df04f9569a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala
@@ -20,8 +20,10 @@ package org.apache.spark.sql.catalyst.util
import org.apache.spark.sql.catalyst.analysis.{AnalysisErrorAt, TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.{Expression, RowOrdering}
+import org.apache.spark.sql.catalyst.expressions.st.STExpressionUtils.isGeoSpatialType
import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalNumericType}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
/**
@@ -136,4 +138,15 @@ object TypeUtils extends QueryErrorsBase {
}
if (dataType.existsRecursively(isInterval)) f
}
+
+ def failUnsupportedDataType(dataType: DataType, conf: SQLConf): Unit = {
+ if (!conf.isTimeTypeEnabled && dataType.existsRecursively(_.isInstanceOf[TimeType])) {
+ throw QueryCompilationErrors.unsupportedTimeTypeError()
+ }
+ if (!conf.geospatialEnabled && dataType.existsRecursively(isGeoSpatialType)) {
+ throw new org.apache.spark.sql.AnalysisException(
+ errorClass = "UNSUPPORTED_FEATURE.GEOSPATIAL_DISABLED",
+ messageParameters = scala.collection.immutable.Map.empty)
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index 4e391208d984b..72b466f5a0f9a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
import org.apache.spark.sql.catalyst.optimizer.ConstantFolding
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
-import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, NullOrdering, SortDirection, SortValue, UserDefinedScalarFunc}
+import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, GetArrayItem => V2GetArrayItem, LiteralValue, NullOrdering, SortDirection, SortValue, UserDefinedScalarFunc}
import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
import org.apache.spark.sql.internal.SQLConf
@@ -326,6 +326,13 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L
case _: Sha2 => generateExpressionWithName("SHA2", expr, isPredicate)
case _: StringLPad => generateExpressionWithName("LPAD", expr, isPredicate)
case _: StringRPad => generateExpressionWithName("RPAD", expr, isPredicate)
+ case GetArrayItem(child, ordinal, failOnError) =>
+ (generateExpression(child), generateExpression(ordinal)) match {
+ case (Some(v2ArrayChild), Some(v2Ordinal)) =>
+ Some(new V2GetArrayItem(v2ArrayChild, v2Ordinal, failOnError))
+ case _ =>
+ None
+ }
// TODO supports other expressions
case ApplyFunctionExpression(function, children) =>
val childrenExpressions = children.flatMap(generateExpression(_))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
index 1a47fa7bd43f0..3d6c57c9f7465 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/xml/StaxXmlParser.scala
@@ -1278,18 +1278,19 @@ object StaxXmlParser {
// Try parsing the value as decimal
val decimalParser = ExprUtils.getDecimalParser(options.locale)
- allCatch opt decimalParser(value) match {
- case Some(decimalValue) =>
- var d = decimalValue
- if (d.scale() < 0) {
- d = d.setScale(0)
- }
- if (d.scale <= VariantUtil.MAX_DECIMAL16_PRECISION &&
- d.precision <= VariantUtil.MAX_DECIMAL16_PRECISION) {
- builder.appendDecimal(d)
- return
- }
- case _ =>
+ try {
+ var d = decimalParser(value)
+ if (d.scale() < 0) {
+ d = d.setScale(0)
+ }
+ if (d.scale <= VariantUtil.MAX_DECIMAL16_PRECISION &&
+ d.precision <= VariantUtil.MAX_DECIMAL16_PRECISION) {
+ builder.appendDecimal(d)
+ return
+ }
+ } catch {
+ case NonFatal(_) =>
+ // Ignore the exception and parse it as a string below
}
// If the character is of other primitive types, parse it as a string
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
index 282350dda67d3..726527394deb6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ClusterBySpec}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, QuotingUtils}
+import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, quoteNameParts, QuotingUtils}
import org.apache.spark.sql.connector.expressions.{BucketTransform, ClusterByTransform, FieldReference, IdentityTransform, LogicalExpressions, Transform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.StructType
@@ -105,14 +105,14 @@ private[sql] object CatalogV2Implicits {
case tableCatalog: TableCatalog =>
tableCatalog
case _ =>
- throw QueryCompilationErrors.missingCatalogAbilityError(plugin, "tables")
+ throw QueryCompilationErrors.missingCatalogTablesAbilityError(plugin)
}
def asNamespaceCatalog: SupportsNamespaces = plugin match {
case namespaceCatalog: SupportsNamespaces =>
namespaceCatalog
case _ =>
- throw QueryCompilationErrors.missingCatalogAbilityError(plugin, "namespaces")
+ throw QueryCompilationErrors.missingCatalogNamespacesAbilityError(plugin)
}
def isFunctionCatalog: Boolean = plugin match {
@@ -124,14 +124,14 @@ private[sql] object CatalogV2Implicits {
case functionCatalog: FunctionCatalog =>
functionCatalog
case _ =>
- throw QueryCompilationErrors.missingCatalogAbilityError(plugin, "functions")
+ throw QueryCompilationErrors.missingCatalogFunctionsAbilityError(plugin)
}
def asProcedureCatalog: ProcedureCatalog = plugin match {
case procedureCatalog: ProcedureCatalog =>
procedureCatalog
case _ =>
- throw QueryCompilationErrors.missingCatalogAbilityError(plugin, "procedures")
+ throw QueryCompilationErrors.missingCatalogProceduresAbilityError(plugin)
}
}
@@ -223,6 +223,8 @@ private[sql] object CatalogV2Implicits {
def quoted: String = parts.map(quoteIfNeeded).mkString(".")
+ def fullyQuoted: String = quoteNameParts(parts)
+
def original: String = parts.mkString(".")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala
index 28bca400f5b8b..2d4ef5cd9e07a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala
@@ -25,7 +25,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.{SparkException, SparkIllegalArgumentException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.CurrentUserContext
-import org.apache.spark.sql.catalyst.analysis.{AsOfTimestamp, AsOfVersion, NamedRelation, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException, TimeTravelSpec}
+import org.apache.spark.sql.catalyst.analysis.{AsOfTimestamp, AsOfVersion, NamedRelation, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException, RelationCache, TimeTravelSpec}
import org.apache.spark.sql.catalyst.catalog.ClusterBySpec
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.plans.logical.{SerdeInfo, TableSpec}
@@ -36,6 +36,7 @@ import org.apache.spark.sql.connector.catalog.constraints.Constraint
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
import org.apache.spark.sql.connector.expressions.{ClusterByTransform, LiteralValue, Transform}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, MapType, Metadata, MetadataBuilder, StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.ArrayImplicits._
@@ -497,6 +498,27 @@ private[sql] object CatalogV2Util {
loadTable(catalog, ident).map(DataSourceV2Relation.create(_, Some(catalog), Some(ident)))
}
+ def isSameTable(
+ rel: DataSourceV2Relation,
+ catalog: CatalogPlugin,
+ ident: Identifier,
+ table: Table): Boolean = {
+ rel.catalog.contains(catalog) && rel.identifier.contains(ident) && rel.table.id == table.id
+ }
+
+ def lookupCachedRelation(
+ cache: RelationCache,
+ catalog: CatalogPlugin,
+ ident: Identifier,
+ table: Table,
+ conf: SQLConf): Option[DataSourceV2Relation] = {
+ val nameParts = ident.toQualifiedNameParts(catalog)
+ val cached = cache.lookup(nameParts, conf.resolver)
+ cached.collect {
+ case r: DataSourceV2Relation if isSameTable(r, catalog, ident, table) => r
+ }
+ }
+
def isSessionCatalog(catalog: CatalogPlugin): Boolean = {
catalog.name().equalsIgnoreCase(CatalogManager.SESSION_CATALOG_NAME)
}
@@ -565,12 +587,29 @@ private[sql] object CatalogV2Util {
.asTableCatalog
}
+ def toStructType(cols: Seq[MetadataColumn]): StructType = {
+ StructType(cols.map(toStructField))
+ }
+
+ private def toStructField(col: MetadataColumn): StructField = {
+ val metadata = Option(col.metadataInJSON).map(Metadata.fromJson).getOrElse(Metadata.empty)
+ var f = StructField(col.name, col.dataType, col.isNullable, metadata)
+ if (col.comment != null) {
+ f = f.withComment(col.comment)
+ }
+ f
+ }
+
+ def v2ColumnsToStructType(columns: Array[Column]): StructType = {
+ v2ColumnsToStructType(columns.toImmutableArraySeq)
+ }
+
/**
* Converts DS v2 columns to StructType, which encodes column comment and default value to
* StructField metadata. This is mainly used to define the schema of v2 scan, w.r.t. the columns
* of the v2 table.
*/
- def v2ColumnsToStructType(columns: Array[Column]): StructType = {
+ def v2ColumnsToStructType(columns: Seq[Column]): StructType = {
StructType(columns.map(v2ColumnToStructField))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/Catalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/Catalogs.scala
index 419191f8f9c00..e6c70fdabb159 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/Catalogs.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/Catalogs.scala
@@ -22,6 +22,7 @@ import java.util
import java.util.regex.Pattern
import org.apache.spark.SparkException
+import org.apache.spark.internal.config.ConfigReader
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -93,10 +94,15 @@ private[sql] object Catalogs {
private def catalogOptions(name: String, conf: SQLConf) = {
val prefix = Pattern.compile("^spark\\.sql\\.catalog\\." + name + "\\.(.+)")
val options = new util.HashMap[String, String]
+ val reader = new ConfigReader(options)
conf.getAllConfs.foreach {
case (key, value) =>
val matcher = prefix.matcher(key)
- if (matcher.matches && matcher.groupCount > 0) options.put(matcher.group(1), value)
+ if (matcher.matches && matcher.groupCount > 0) {
+ // pass config entries through default ConfigReader mechanics,
+ // substituting prefixes from bindings: ${env:XYZ} -> sys.env.get("XYZ")
+ options.put(matcher.group(1), reader.substitute(value))
+ }
}
new CaseInsensitiveStringMap(options)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V2TableUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V2TableUtil.scala
new file mode 100644
index 0000000000000..42181c6c8389d
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V2TableUtil.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import java.util.Locale
+
+import org.apache.spark.sql.catalyst.SQLConfHelper
+import org.apache.spark.sql.catalyst.analysis.Resolver
+import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, MetadataColumnHelper}
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.util.SchemaUtils
+import org.apache.spark.sql.util.SchemaValidationMode
+import org.apache.spark.sql.util.SchemaValidationMode.PROHIBIT_CHANGES
+import org.apache.spark.util.ArrayImplicits._
+
+private[sql] object V2TableUtil extends SQLConfHelper {
+
+ def toQualifiedName(catalog: CatalogPlugin, ident: Identifier): String = {
+ s"${quoteIfNeeded(catalog.name)}.${ident.quoted}"
+ }
+
+ /**
+ * Validates that captured data columns match the current table schema.
+ *
+ * @param table the current table metadata
+ * @param relation the relation with captured columns
+ * @param mode validation mode that defines what changes are acceptable
+ * @return validation errors, or empty sequence if valid
+ */
+ def validateCapturedColumns(
+ table: Table,
+ relation: DataSourceV2Relation,
+ mode: SchemaValidationMode): Seq[String] = {
+ validateCapturedColumns(table, relation.table.columns.toImmutableArraySeq, mode)
+ }
+
+ /**
+ * Validates that captured data columns match the current table schema.
+ *
+ * Checks for:
+ * - Column type or nullability changes
+ * - Removed columns (missing from the current table schema)
+ * - Added columns (new in the current table schema)
+ *
+ * @param table the current table metadata
+ * @param originCols the originally captured columns
+ * @param mode validation mode that defines what changes are acceptable
+ * @return validation errors, or empty sequence if valid
+ */
+ def validateCapturedColumns(
+ table: Table,
+ originCols: Seq[Column],
+ mode: SchemaValidationMode = PROHIBIT_CHANGES): Seq[String] = {
+ val originSchema = CatalogV2Util.v2ColumnsToStructType(originCols)
+ val schema = CatalogV2Util.v2ColumnsToStructType(table.columns)
+ SchemaUtils.validateSchemaCompatibility(originSchema, schema, resolver, mode)
+ }
+
+ /**
+ * Validates that captured metadata columns are consistent with the current table metadata.
+ *
+ * @param table the current table metadata
+ * @param relation the relation with captured metadata columns
+ * @param mode validation mode that defines what changes are acceptable
+ * @return validation errors, or empty sequence if valid
+ */
+ def validateCapturedMetadataColumns(
+ table: Table,
+ relation: DataSourceV2Relation,
+ mode: SchemaValidationMode): Seq[String] = {
+ validateCapturedMetadataColumns(table, extractMetadataColumns(relation), mode)
+ }
+
+ /**
+ * Extracts original column info for all metadata attributes in the relation.
+ *
+ * @param relation the relation with captured metadata columns
+ * @return metadata columns captured by the relation
+ */
+ def extractMetadataColumns(relation: DataSourceV2Relation): Seq[MetadataColumn] = {
+ val metaAttrNames = relation.output.filter(_.isMetadataCol).map(_.name)
+ filter(metaAttrNames, metadataColumns(relation.table))
+ }
+
+ /**
+ * Validates that captured metadata columns are consistent with the current table metadata.
+ *
+ * Checks for:
+ * - Metadata column type or nullability changes
+ * - Removed metadata columns (missing from current table)
+ *
+ * @param table the current table metadata
+ * @param originMetaCols the originally captured metadata columns
+ * @param mode validation mode that defines what changes are acceptable
+ * @return validation errors, or empty sequence if valid
+ */
+ def validateCapturedMetadataColumns(
+ table: Table,
+ originMetaCols: Seq[MetadataColumn],
+ mode: SchemaValidationMode = PROHIBIT_CHANGES): Seq[String] = {
+ val originMetaColNames = originMetaCols.map(_.name)
+ val originMetaSchema = CatalogV2Util.toStructType(originMetaCols)
+ val metaCols = filter(originMetaColNames, metadataColumns(table))
+ val metaSchema = CatalogV2Util.toStructType(metaCols)
+ SchemaUtils.validateSchemaCompatibility(originMetaSchema, metaSchema, resolver, mode)
+ }
+
+ private def filter(colNames: Seq[String], cols: Seq[MetadataColumn]): Seq[MetadataColumn] = {
+ val normalizedColNames = colNames.map(normalize).toSet
+ cols.filter(col => normalizedColNames.contains(normalize(col.name)))
+ }
+
+ private def metadataColumns(table: Table): Seq[MetadataColumn] = table match {
+ case hasMeta: SupportsMetadataColumns => hasMeta.metadataColumns.toImmutableArraySeq
+ case _ => Seq.empty
+ }
+
+ private def normalize(name: String): String = {
+ if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT)
+ }
+
+ private def resolver: Resolver = conf.resolver
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 7d79c5d5d642d..1cee321846a46 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -1770,6 +1770,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
messageParameters = Map("provider" -> provider))
}
+ def failedToCreatePlanForDirectQueryError(
+ dataSourceType: String, cause: Throwable): Throwable = {
+ new AnalysisException(
+ errorClass = "FAILED_TO_CREATE_PLAN_FOR_DIRECT_QUERY",
+ messageParameters = Map("dataSourceType" -> dataSourceType),
+ cause = Some(cause))
+ }
+
def findMultipleDataSourceError(provider: String, sourceNames: Seq[String]): Throwable = {
new AnalysisException(
errorClass = "_LEGACY_ERROR_TEMP_1141",
@@ -2113,6 +2121,38 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
}
}
+ def tableIdChangedAfterAnalysis(
+ tableName: String,
+ capturedTableId: String,
+ currentTableId: String): Throwable = {
+ new AnalysisException(
+ errorClass = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.TABLE_ID_MISMATCH",
+ messageParameters = Map(
+ "tableName" -> toSQLId(tableName),
+ "capturedTableId" -> capturedTableId,
+ "currentTableId" -> currentTableId))
+ }
+
+ def columnsChangedAfterAnalysis(
+ tableName: String,
+ errors: Seq[String]): Throwable = {
+ new AnalysisException(
+ errorClass = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH",
+ messageParameters = Map(
+ "tableName" -> toSQLId(tableName),
+ "errors" -> errors.mkString("- ", "\n- ", "")))
+ }
+
+ def metadataColumnsChangedAfterAnalysis(
+ tableName: String,
+ errors: Seq[String]): Throwable = {
+ new AnalysisException(
+ errorClass = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.METADATA_COLUMNS_MISMATCH",
+ messageParameters = Map(
+ "tableName" -> toSQLId(tableName),
+ "errors" -> errors.mkString("- ", "\n- ", "")))
+ }
+
def numberOfPartitionsNotAllowedWithUnspecifiedDistributionError(): Throwable = {
new AnalysisException(
errorClass = "INVALID_WRITE_DISTRIBUTION.PARTITION_NUM_WITH_UNSPECIFIED_DISTRIBUTION",
@@ -2198,12 +2238,58 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
messageParameters = Map.empty)
}
- def missingCatalogAbilityError(plugin: CatalogPlugin, ability: String): Throwable = {
+ def missingCatalogFunctionsAbilityError(plugin: CatalogPlugin): Throwable = {
new AnalysisException(
- errorClass = "_LEGACY_ERROR_TEMP_1184",
- messageParameters = Map(
- "plugin" -> plugin.name,
- "ability" -> ability))
+ errorClass = "MISSING_CATALOG_ABILITY.FUNCTIONS",
+ messageParameters = Map("plugin" -> plugin.name))
+ }
+
+ def missingCatalogTableValuedFunctionsAbilityError(plugin: CatalogPlugin): Throwable = {
+ new AnalysisException(
+ errorClass = "MISSING_CATALOG_ABILITY.TABLE_VALUED_FUNCTIONS",
+ messageParameters = Map("plugin" -> plugin.name))
+ }
+
+ def missingCatalogCreateFunctionAbilityError(plugin: CatalogPlugin): Throwable = {
+ new AnalysisException(
+ errorClass = "MISSING_CATALOG_ABILITY.CREATE_FUNCTION",
+ messageParameters = Map("plugin" -> plugin.name))
+ }
+
+ def missingCatalogDropFunctionAbilityError(plugin: CatalogPlugin): Throwable = {
+ new AnalysisException(
+ errorClass = "MISSING_CATALOG_ABILITY.DROP_FUNCTION",
+ messageParameters = Map("plugin" -> plugin.name))
+ }
+
+ def missingCatalogRefreshFunctionAbilityError(plugin: CatalogPlugin): Throwable = {
+ new AnalysisException(
+ errorClass = "MISSING_CATALOG_ABILITY.REFRESH_FUNCTION",
+ messageParameters = Map("plugin" -> plugin.name))
+ }
+
+ def missingCatalogNamespacesAbilityError(plugin: CatalogPlugin): Throwable = {
+ new AnalysisException(
+ errorClass = "MISSING_CATALOG_ABILITY.NAMESPACES",
+ messageParameters = Map("plugin" -> plugin.name))
+ }
+
+ def missingCatalogProceduresAbilityError(plugin: CatalogPlugin): Throwable = {
+ new AnalysisException(
+ errorClass = "MISSING_CATALOG_ABILITY.PROCEDURES",
+ messageParameters = Map("plugin" -> plugin.name))
+ }
+
+ def missingCatalogTablesAbilityError(plugin: CatalogPlugin): Throwable = {
+ new AnalysisException(
+ errorClass = "MISSING_CATALOG_ABILITY.TABLES",
+ messageParameters = Map("plugin" -> plugin.name))
+ }
+
+ def missingCatalogViewsAbilityError(plugin: CatalogPlugin): Throwable = {
+ new AnalysisException(
+ errorClass = "MISSING_CATALOG_ABILITY.VIEWS",
+ messageParameters = Map("plugin" -> plugin.name))
}
def tableValuedArgumentsNotYetImplementedForSqlFunctions(
@@ -2239,13 +2325,17 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
def identifierTooManyNamePartsError(originalIdentifier: String): Throwable = {
new AnalysisException(
errorClass = "IDENTIFIER_TOO_MANY_NAME_PARTS",
- messageParameters = Map("identifier" -> toSQLId(originalIdentifier)))
+ messageParameters = Map(
+ "identifier" -> toSQLId(originalIdentifier),
+ "limit" -> "2"))
}
def identifierTooManyNamePartsError(names: Seq[String]): Throwable = {
new AnalysisException(
errorClass = "IDENTIFIER_TOO_MANY_NAME_PARTS",
- messageParameters = Map("identifier" -> toSQLId(names)))
+ messageParameters = Map(
+ "identifier" -> toSQLId(names),
+ "limit" -> "2"))
}
def emptyMultipartIdentifierError(): Throwable = {
@@ -4365,4 +4455,47 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat
origin = origin
)
}
+
+ def pythonWorkerLoggingNotEnabledError(): Throwable = {
+ new AnalysisException(
+ errorClass = "FEATURE_NOT_ENABLED",
+ messageParameters = Map(
+ "featureName" -> "Python Worker Logging",
+ "configKey" -> "spark.sql.pyspark.worker.logging.enabled",
+ "configValue" -> "true"
+ )
+ )
+ }
+
+ def columnsChangedAfterViewWithPlanCreation(
+ viewName: Seq[String],
+ tableName: Seq[String],
+ errors: Seq[String]): Throwable = {
+ new AnalysisException(
+ errorClass = "INCOMPATIBLE_COLUMN_CHANGES_AFTER_VIEW_WITH_PLAN_CREATION",
+ messageParameters = Map(
+ "viewName" -> toSQLId(viewName),
+ "tableName" -> toSQLId(tableName),
+ "colType" -> "data",
+ "errors" -> errors.mkString("- ", "\n- ", "")))
+ }
+
+ def metadataColumnsChangedAfterViewWithPlanCreation(
+ viewName: Seq[String],
+ tableName: Seq[String],
+ errors: Seq[String]): Throwable = {
+ new AnalysisException(
+ errorClass = "INCOMPATIBLE_COLUMN_CHANGES_AFTER_VIEW_WITH_PLAN_CREATION",
+ messageParameters = Map(
+ "viewName" -> toSQLId(viewName),
+ "tableName" -> toSQLId(tableName),
+ "colType" -> "metadata",
+ "errors" -> errors.mkString("- ", "\n- ", "")))
+ }
+
+ def unsupportedTimeTypeError(): Throwable = {
+ new AnalysisException(
+ errorClass = "UNSUPPORTED_TIME_TYPE",
+ messageParameters = Map.empty)
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 5f5e1da47184c..351868fcc2e29 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -665,6 +665,17 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
summary = "")
}
+ def stInvalidSridValueError(srid: String): SparkIllegalArgumentException = {
+ new SparkIllegalArgumentException(
+ errorClass = "ST_INVALID_SRID_VALUE",
+ messageParameters = Map("srid" -> srid)
+ )
+ }
+
+ def stInvalidSridValueError(srid: Int): SparkIllegalArgumentException = {
+ stInvalidSridValueError(srid.toString)
+ }
+
def withSuggestionIntervalArithmeticOverflowError(
suggestedFunc: String,
context: QueryContext): ArithmeticException = {
@@ -1109,10 +1120,12 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
cause = e)
}
- def ddlUnsupportedTemporarilyError(ddl: String): SparkUnsupportedOperationException = {
+ def ddlUnsupportedTemporarilyError(
+ ddl: String,
+ tableName: String): SparkUnsupportedOperationException = {
new SparkUnsupportedOperationException(
- errorClass = "_LEGACY_ERROR_TEMP_2096",
- messageParameters = Map("ddl" -> ddl))
+ errorClass = "UNSUPPORTED_FEATURE.TABLE_OPERATION",
+ messageParameters = Map("tableName" -> toSQLId(tableName), "operation" -> ddl))
}
def executeBroadcastTimeoutError(timeout: Long, ex: Option[TimeoutException]): Throwable = {
@@ -2804,6 +2817,12 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
"value" -> toSQLValue(value, IntegerType)))
}
+ def hllKMustBeConstantError(function: String): Throwable = {
+ new SparkRuntimeException(
+ errorClass = "HLL_K_MUST_BE_CONSTANT",
+ messageParameters = Map("function" -> toSQLId(function)))
+ }
+
def hllInvalidInputSketchBuffer(function: String): Throwable = {
new SparkRuntimeException(
errorClass = "HLL_INVALID_INPUT_SKETCH_BUFFER",
@@ -2811,6 +2830,13 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
"function" -> toSQLId(function)))
}
+ def kllInvalidInputSketchBuffer(function: String, reason: String = ""): Throwable = {
+ new SparkRuntimeException(
+ errorClass = "KLL_INVALID_INPUT_SKETCH_BUFFER",
+ messageParameters = Map(
+ "function" -> toSQLId(function)))
+ }
+
def hllUnionDifferentLgK(left: Int, right: Int, function: String): Throwable = {
new SparkRuntimeException(
errorClass = "HLL_UNION_DIFFERENT_LG_K",
@@ -3169,4 +3195,31 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE
"max" -> toSQLValue(max, IntegerType),
"value" -> toSQLValue(value, IntegerType)))
}
+
+ def thetaLgNomEntriesMustBeConstantError(function: String): Throwable = {
+ new SparkRuntimeException(
+ errorClass = "THETA_LG_NOM_ENTRIES_MUST_BE_CONSTANT",
+ messageParameters = Map("function" -> toSQLId(function)))
+ }
+
+ def kllSketchInvalidQuantileRangeError(function: String): Throwable = {
+ new SparkRuntimeException(
+ errorClass = "KLL_SKETCH_INVALID_QUANTILE_RANGE",
+ messageParameters = Map(
+ "functionName" -> toSQLId(function)))
+ }
+
+ def kllSketchKMustBeConstantError(function: String): Throwable = {
+ new SparkRuntimeException(
+ errorClass = "KLL_SKETCH_K_MUST_BE_CONSTANT",
+ messageParameters = Map("functionName" -> toSQLId(function)))
+ }
+
+ def kllSketchKOutOfRangeError(function: String, k: Int): Throwable = {
+ new SparkRuntimeException(
+ errorClass = "KLL_SKETCH_K_OUT_OF_RANGE",
+ messageParameters = Map(
+ "functionName" -> toSQLId(function),
+ "k" -> toSQLValue(k, IntegerType)))
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
index 275fecebdafb8..b5269da035f3b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala
@@ -24,6 +24,7 @@ import org.apache.arrow.vector.complex._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.catalyst.util.STUtils
import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ArrowUtils
@@ -92,6 +93,16 @@ object ArrowWriter {
createFieldWriter(vector.getChildByOrdinal(ordinal))
}
new StructWriter(vector, children.toArray)
+ case (dt: GeometryType, vector: StructVector) =>
+ val children = (0 until vector.size()).map { ordinal =>
+ createFieldWriter(vector.getChildByOrdinal(ordinal))
+ }
+ new GeometryWriter(dt, vector, children.toArray)
+ case (dt: GeographyType, vector: StructVector) =>
+ val children = (0 until vector.size()).map { ordinal =>
+ createFieldWriter(vector.getChildByOrdinal(ordinal))
+ }
+ new GeographyWriter(dt, vector, children.toArray)
case (dt, _) =>
throw ExecutionErrors.unsupportedDataTypeError(dt)
}
@@ -113,9 +124,9 @@ class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) {
count += 1
}
- def sizeInBytes(): Int = {
+ def sizeInBytes(): Long = {
var i = 0
- var bytes = 0
+ var bytes = 0L
while (i < fields.size) {
bytes += fields(i).getSizeInBytes()
i += 1
@@ -446,6 +457,42 @@ private[arrow] class StructWriter(
}
}
+private[arrow] class GeographyWriter(
+ dt: GeographyType,
+ valueVector: StructVector,
+ children: Array[ArrowFieldWriter]) extends StructWriter(valueVector, children) {
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueVector.setIndexDefined(count)
+
+ val geom = STUtils.deserializeGeog(input.getGeography(ordinal), dt)
+ val bytes = geom.getBytes
+ val srid = geom.getSrid
+
+ val row = InternalRow(srid, bytes)
+ children(0).write(row, 0)
+ children(1).write(row, 1)
+ }
+}
+
+private[arrow] class GeometryWriter(
+ dt: GeometryType,
+ valueVector: StructVector,
+ children: Array[ArrowFieldWriter]) extends StructWriter(valueVector, children) {
+
+ override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+ valueVector.setIndexDefined(count)
+
+ val geom = STUtils.deserializeGeom(input.getGeometry(ordinal), dt)
+ val bytes = geom.getBytes
+ val srid = geom.getSrid
+
+ val row = InternalRow(srid, bytes)
+ children(0).write(row, 0)
+ children(1).write(row, 1)
+ }
+}
+
private[arrow] class MapWriter(
val valueVector: MapVector,
val structVector: StructVector,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriterWrapper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriterWrapper.scala
index 6c5799bd241b9..c04bae07f67dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriterWrapper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriterWrapper.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.arrow
import java.io.DataOutputStream
import org.apache.arrow.memory.BufferAllocator
-import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.{VectorSchemaRoot, VectorUnloader}
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.spark.TaskContext
@@ -34,6 +34,7 @@ case class ArrowWriterWrapper(
var arrowWriter: SparkArrowWriter,
var root: VectorSchemaRoot,
var allocator: BufferAllocator,
+ var unloader: VectorUnloader,
context: TaskContext) {
@volatile var isClosed = false
@@ -58,6 +59,7 @@ case class ArrowWriterWrapper(
arrowWriter = null
root = null
allocator = null
+ unloader = null
}
}
}
@@ -77,8 +79,10 @@ object ArrowWriterWrapper {
s"stdout writer for $allocatorOwner", 0, Long.MaxValue)
val root = VectorSchemaRoot.create(arrowSchema, allocator)
val arrowWriter = SparkArrowWriter.create(root)
+
val streamWriter = new ArrowStreamWriter(root, null, dataOut)
streamWriter.start()
- ArrowWriterWrapper(streamWriter, arrowWriter, root, allocator, context)
+ // Unloader will be set by the caller after creation
+ ArrowWriterWrapper(streamWriter, arrowWriter, root, allocator, null, context)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
index 9b8d48c3f3a85..fda415d1ab29d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala
@@ -23,8 +23,9 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Attri
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, ExposesMetadataColumns, Histogram, HistogramBin, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
-import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, truncatedString, CharVarcharUtils}
-import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, SupportsMetadataColumns, Table, TableCapability}
+import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils}
+import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, SupportsMetadataColumns, Table, TableCapability, TableCatalog, V2TableUtil}
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper
import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics, SupportsReportStatistics}
import org.apache.spark.sql.connector.read.streaming.{Offset, SparkDataStream}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -57,9 +58,8 @@ abstract class DataSourceV2RelationBase(
}
override def name: String = {
- import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
(catalog, identifier) match {
- case (Some(cat), Some(ident)) => s"${quoteIfNeeded(cat.name())}.${ident.quoted}"
+ case (Some(cat), Some(ident)) => V2TableUtil.toQualifiedName(cat, ident)
case _ => table.name()
}
}
@@ -133,6 +133,8 @@ case class DataSourceV2Relation(
def autoSchemaEvolution(): Boolean =
table.capabilities().contains(TableCapability.AUTOMATIC_SCHEMA_EVOLUTION)
+
+ def isVersioned: Boolean = table.version != null
}
/**
@@ -182,7 +184,13 @@ case class DataSourceV2ScanRelation(
relation = this.relation.copy(
output = this.relation.output.map(QueryPlan.normalizeExpressions(_, this.relation.output))
),
- output = this.output.map(QueryPlan.normalizeExpressions(_, this.output))
+ output = this.output.map(QueryPlan.normalizeExpressions(_, this.output)),
+ keyGroupedPartitioning = keyGroupedPartitioning.map(
+ _.map(QueryPlan.normalizeExpressions(_, output))
+ ),
+ ordering = ordering.map(
+ _.map(o => o.copy(child = QueryPlan.normalizeExpressions(o.child, output)))
+ )
)
}
}
@@ -259,10 +267,10 @@ object ExtractV2Table {
}
object ExtractV2CatalogAndIdentifier {
- def unapply(relation: DataSourceV2Relation): Option[(CatalogPlugin, Identifier)] = {
+ def unapply(relation: DataSourceV2Relation): Option[(TableCatalog, Identifier)] = {
relation match {
case DataSourceV2Relation(_, _, Some(catalog), Some(identifier), _, _) =>
- Some((catalog, identifier))
+ Some((catalog.asTableCatalog, identifier))
case _ =>
None
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index d88cbe326cfbe..100149a39211f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -243,40 +243,43 @@ object SQLConf {
val PREFER_COLUMN_OVER_LCA_IN_ARRAY_INDEX =
buildConf("spark.sql.analyzer.preferColumnOverLcaInArrayIndex")
- .internal()
- .doc(
- "When true, prefer the column from the underlying relation over the lateral column alias " +
- "reference with the same name (see SPARK-53734)."
- )
- .booleanConf
- .createWithDefault(true)
+ .internal()
+ .version("4.1.0")
+ .doc(
+ "When true, prefer the column from the underlying relation over the lateral column alias " +
+ "reference with the same name (see SPARK-53734).")
+ .booleanConf
+ .createWithDefault(true)
val DONT_DEDUPLICATE_EXPRESSION_IF_EXPR_ID_IN_OUTPUT =
buildConf("spark.sql.analyzer.dontDeduplicateExpressionIfExprIdInOutput")
- .internal()
- .doc(
- "DeduplicateRelations shouldn't remap expressions to new ExprIds if old ExprId still " +
- "exists in output.")
- .booleanConf
- .createWithDefault(true)
+ .internal()
+ .version("4.1.0")
+ .doc(
+ "DeduplicateRelations shouldn't remap expressions to new ExprIds if old ExprId still " +
+ "exists in output.")
+ .booleanConf
+ .createWithDefault(true)
val UNION_IS_RESOLVED_WHEN_DUPLICATES_PER_CHILD_RESOLVED =
buildConf("spark.sql.analyzer.unionIsResolvedWhenDuplicatesPerChildResolved")
- .internal()
- .doc(
- "When true, union should only be resolved once there are no duplicate attributes in " +
- "each branch.")
- .booleanConf
- .createWithDefault(true)
+ .internal()
+ .version("4.1.0")
+ .doc(
+ "When true, union should only be resolved once there are no duplicate attributes in " +
+ "each branch.")
+ .booleanConf
+ .createWithDefault(true)
val ONLY_NECESSARY_AND_UNIQUE_METADATA_COLUMNS =
buildConf("spark.sql.analyzer.uniqueNecessaryMetadataColumns")
- .internal()
- .doc(
- "When this conf is enabled, AddMetadataColumns rule should only add necessary metadata " +
- "columns and only if those columns are not already present in the project list.")
- .booleanConf
- .createWithDefault(true)
+ .internal()
+ .version("4.1.0")
+ .doc(
+ "When this conf is enabled, AddMetadataColumns rule should only add necessary metadata " +
+ "columns and only if those columns are not already present in the project list.")
+ .booleanConf
+ .createWithDefault(true)
val BLOCK_CREATE_TEMP_TABLE_USING_PROVIDER =
buildConf("spark.sql.legacy.blockCreateTempTableUsingProvider")
@@ -324,7 +327,7 @@ object SQLConf {
"(AliasResolution.resolve, FunctionResolution.resolveFunction, etc)." +
"This feature is currently under development."
)
- .version("4.0.0")
+ .version("4.1.0")
.booleanConf
.createWithDefault(false)
@@ -570,6 +573,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val GEOSPATIAL_ENABLED =
+ buildConf("spark.sql.geospatial.enabled")
+ .internal()
+ .doc("When true, enables geospatial types (GEOGRAPHY/GEOMETRY) and ST functions.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefaultFunction(() => Utils.isTesting)
+
val EXTENDED_EXPLAIN_PROVIDERS = buildConf("spark.sql.extendedExplainProviders")
.doc("A comma-separated list of classes that implement the" +
" org.apache.spark.sql.ExtendedExplainGenerator trait. If provided, Spark will print" +
@@ -924,6 +935,7 @@ object SQLConf {
val ADAPTIVE_EXECUTION_ENABLED_IN_STATELESS_STREAMING =
buildConf("spark.sql.adaptive.streaming.stateless.enabled")
+ .internal()
.doc("When true, enable adaptive query execution for stateless streaming query. To " +
"enable this config, `spark.sql.adaptive.enabled` needs to be also enabled.")
.version("4.1.0")
@@ -1050,7 +1062,7 @@ object SQLConf {
"An object with an explicitly set collation will not inherit the collation from the " +
"schema."
)
- .version("4.0.0")
+ .version("4.1.0")
.booleanConf
.createWithDefault(false)
@@ -1211,6 +1223,7 @@ object SQLConf {
val MAP_ZIP_WITH_USES_JAVA_COLLECTIONS =
buildConf("spark.sql.mapZipWithUsesJavaCollections")
+ .internal()
.doc("When true, the `map_zip_with` function uses Java collections instead of Scala " +
"collections. This is useful for avoiding NaN equality issues.")
.version("4.1.0")
@@ -1552,6 +1565,14 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val PARQUET_VECTORIZED_READER_NULL_TYPE_ENABLED =
+ buildConf("spark.sql.parquet.enableNullTypeVectorizedReader")
+ .internal()
+ .doc("Enables vectorized Parquet reader support for NullType columns.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(true)
+
val PARQUET_RECORD_FILTER_ENABLED = buildConf("spark.sql.parquet.recordLevelFilter.enabled")
.doc("If true, enables Parquet's native record-level filtering using the pushed down " +
"filters. " +
@@ -1578,6 +1599,24 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE =
+ buildConf("spark.sql.parquet.variant.annotateLogicalType.enabled")
+ .internal()
+ .doc("When enabled, Spark annotates the variant groups written to Parquet as the parquet " +
+ "variant logical type.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(true)
+
+ val PARQUET_IGNORE_VARIANT_ANNOTATION =
+ buildConf("spark.sql.parquet.ignoreVariantAnnotation")
+ .internal()
+ .doc("When true, ignore the variant logical type annotation and treat the Parquet " +
+ "column in the same way as the underlying struct type")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(false)
+
val PARQUET_FIELD_ID_READ_ENABLED =
buildConf("spark.sql.parquet.fieldId.read.enabled")
.doc("Field ID is a native field of the Parquet schema spec. When enabled, Parquet readers " +
@@ -1880,6 +1919,7 @@ object SQLConf {
val DATA_SOURCE_V2_JOIN_PUSHDOWN =
buildConf("spark.sql.optimizer.datasourceV2JoinPushdown")
.internal()
+ .version("4.1.0")
.doc("When this config is set to true, join is tried to be pushed down" +
"for DSv2 data sources in V2ScanRelationPushdown optimization rule.")
.booleanConf
@@ -1888,6 +1928,7 @@ object SQLConf {
val DATA_SOURCE_V2_EXPR_FOLDING =
buildConf("spark.sql.optimizer.datasourceV2ExprFolding")
.internal()
+ .version("4.1.0")
.doc("When this config is set to true, do safe constant folding for the " +
"expressions before translation and pushdown.")
.booleanConf
@@ -2518,6 +2559,7 @@ object SQLConf {
val STATE_STORE_MAINTENANCE_SHUTDOWN_TIMEOUT =
buildConf("spark.sql.streaming.stateStore.maintenanceShutdownTimeout")
.internal()
+ .version("4.1.0")
.doc("Timeout in seconds for maintenance pool operations to complete on shutdown")
.timeConf(TimeUnit.SECONDS)
.createWithDefault(300L)
@@ -2525,6 +2567,7 @@ object SQLConf {
val STATE_STORE_MAINTENANCE_PROCESSING_TIMEOUT =
buildConf("spark.sql.streaming.stateStore.maintenanceProcessingTimeout")
.internal()
+ .version("4.1.0")
.doc("Timeout in seconds to wait for maintenance to process this partition.")
.timeConf(TimeUnit.SECONDS)
.createWithDefault(30L)
@@ -2643,6 +2686,7 @@ object SQLConf {
"Note: For structured streaming, this configuration cannot be changed between query " +
"restarts from the same checkpoint location.")
.internal()
+ .version("4.1.0")
.intConf
.checkValue(_ > 0,
"The value of spark.sql.streaming.internal.stateStore.partitions must be a positive " +
@@ -3032,6 +3076,13 @@ object SQLConf {
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefault(5000)
+ val STREAMING_REAL_TIME_MODE_ALLOWLIST_CHECK = buildConf(
+ "spark.sql.streaming.realTimeMode.allowlistCheck")
+ .doc("Whether to check all operators, sinks used in real-time mode are in the allowlist.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(true)
+
val VARIABLE_SUBSTITUTE_ENABLED =
buildConf("spark.sql.variable.substitute")
.doc("This enables substitution using syntax like `${var}`, `${system:var}`, " +
@@ -3439,6 +3490,16 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val STREAMING_CHECKPOINT_FILE_CHECKSUM_SKIP_CREATION_IF_FILE_MISSING_CHECKSUM =
+ buildConf("spark.sql.streaming.checkpoint.fileChecksum.skipCreationIfFileMissingChecksum")
+ .internal()
+ .doc("When true, if a microbatch is retried, if a file already exists but its checksum " +
+ "file does not exist, the file checksum will not be created. This is useful for " +
+ "compatibility with files created before file checksums were enabled.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(true)
+
val PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION =
buildConf("spark.sql.statistics.parallelFileListingInStatsComputation.enabled")
.internal()
@@ -3881,6 +3942,14 @@ object SQLConf {
.version("4.1.0")
.fallbackConf(Python.PYTHON_WORKER_TRACEBACK_DUMP_INTERVAL_SECONDS)
+ val PYTHON_UDF_DAEMON_KILL_WORKER_ON_FLUSH_FAILURE =
+ buildConf("spark.sql.execution.pyspark.udf.daemonKillWorkerOnFlushFailure")
+ .doc(
+ s"Same as ${Python.PYTHON_DAEMON_KILL_WORKER_ON_FLUSH_FAILURE.key} " +
+ "for Python execution with DataFrame and SQL. It can change during runtime.")
+ .version("4.1.0")
+ .fallbackConf(Python.PYTHON_DAEMON_KILL_WORKER_ON_FLUSH_FAILURE)
+
val PYTHON_WORKER_LOGGING_ENABLED =
buildConf("spark.sql.pyspark.worker.logging.enabled")
.doc("When set to true, this configuration enables comprehensive logging within " +
@@ -3987,6 +4056,32 @@ object SQLConf {
"than zero and less than INT_MAX.")
.createWithDefaultString("64MB")
+ val ARROW_EXECUTION_COMPRESSION_CODEC =
+ buildConf("spark.sql.execution.arrow.compression.codec")
+ .doc("Compression codec used to compress Arrow IPC data when transferring data " +
+ "between JVM and Python processes (e.g., toPandas, toArrow). This can significantly " +
+ "reduce memory usage and network bandwidth when transferring large datasets. " +
+ "Supported codecs: 'none' (no compression), 'zstd' (Zstandard), 'lz4' (LZ4). " +
+ "Note that compression may add CPU overhead but can provide substantial memory savings " +
+ "especially for datasets with high compression ratios.")
+ .version("4.1.0")
+ .stringConf
+ .transform(_.toLowerCase(java.util.Locale.ROOT))
+ .checkValues(Set("none", "zstd", "lz4"))
+ .createWithDefault("none")
+
+ val ARROW_EXECUTION_ZSTD_COMPRESSION_LEVEL =
+ buildConf("spark.sql.execution.arrow.compression.zstd.level")
+ .doc("Compression level for Zstandard (zstd) codec when compressing Arrow IPC data. " +
+ "This config is only used when spark.sql.execution.arrow.compression.codec is set to " +
+ "'zstd'. Negative values provide ultra-fast compression with lower " +
+ "compression ratios. Positive values provide normal to maximum compression, " +
+ "with higher values giving better compression but slower speed. The default value 3 " +
+ "provides a good balance between compression speed and compression ratio.")
+ .version("4.1.0")
+ .intConf
+ .createWithDefault(3)
+
val ARROW_TRANSFORM_WITH_STATE_IN_PYSPARK_MAX_STATE_RECORDS_PER_BATCH =
buildConf("spark.sql.execution.arrow.transformWithStateInPySpark.maxStateRecordsPerBatch")
.doc("When using TransformWithState in PySpark (both Python Row and Pandas), limit " +
@@ -4946,6 +5041,17 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val LEGACY_IDENTIFIER_CLAUSE_ONLY =
+ buildConf("spark.sql.legacy.identifierClause")
+ .internal()
+ .doc("When set to false, IDENTIFIER('literal') is resolved to an identifier at parse time " +
+ "anywhere identifiers can occur. When set to true, only the legacy " +
+ " IDENTIFIER(constantExpr) clause is allowed, which evaluates the expression at analysis " +
+ " and is limited to a narrow subset of scenarios.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(false)
+
val LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED =
buildConf("spark.sql.legacy.allowNegativeScaleOfDecimal")
.internal()
@@ -5179,6 +5285,7 @@ object SQLConf {
.createWithDefault(LegacyBehaviorPolicy.CORRECTED)
val CTE_RECURSION_LEVEL_LIMIT = buildConf("spark.sql.cteRecursionLevelLimit")
+ .internal()
.doc("Maximum level of recursion that is allowed while executing a recursive CTE definition." +
"If a query does not get exhausted before reaching this limit it fails. Use -1 for " +
"unlimited.")
@@ -5187,6 +5294,7 @@ object SQLConf {
.createWithDefault(100)
val CTE_RECURSION_ROW_LIMIT = buildConf("spark.sql.cteRecursionRowLimit")
+ .internal()
.doc("Maximum number of rows that can be returned when executing a recursive CTE definition." +
"If a query does not get exhausted before reaching this limit it fails. Use -1 for " +
"unlimited.")
@@ -5196,6 +5304,7 @@ object SQLConf {
val CTE_RECURSION_ANCHOR_ROWS_LIMIT_TO_CONVERT_TO_LOCAL_RELATION =
buildConf("spark.sql.cteRecursionAnchorRowsLimitToConvertToLocalRelation")
+ .internal()
.doc("Maximum number of rows that the anchor in a recursive CTE can return for it to be" +
"converted to a localRelation.")
.version("4.1.0")
@@ -5357,6 +5466,7 @@ object SQLConf {
.createWithDefault(false)
val PYTHON_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.python.filterPushdown.enabled")
+ .internal()
.doc("When true, enable filter pushdown to Python datasource, at the cost of running " +
"Python worker one additional time during planning.")
.version("4.1.0")
@@ -5424,7 +5534,7 @@ object SQLConf {
"When false, it only reads unshredded variant.")
.version("4.0.0")
.booleanConf
- .createWithDefault(false)
+ .createWithDefault(true)
val PUSH_VARIANT_INTO_SCAN =
buildConf("spark.sql.variant.pushVariantIntoScan")
@@ -5433,7 +5543,7 @@ object SQLConf {
"requested fields.")
.version("4.0.0")
.booleanConf
- .createWithDefault(false)
+ .createWithDefault(true)
val VARIANT_WRITE_SHREDDING_ENABLED =
buildConf("spark.sql.variant.writeShredding.enabled")
@@ -5441,7 +5551,7 @@ object SQLConf {
.doc("When true, the Parquet writer is allowed to write shredded variant. ")
.version("4.0.0")
.booleanConf
- .createWithDefault(false)
+ .createWithDefault(true)
val VARIANT_FORCE_SHREDDING_SCHEMA_FOR_TEST =
buildConf("spark.sql.variant.forceShreddingSchemaForTest")
@@ -5474,7 +5584,7 @@ object SQLConf {
.doc("Infer shredding schema when writing Variant columns in Parquet tables.")
.version("4.1.0")
.booleanConf
- .createWithDefault(false)
+ .createWithDefault(true)
val LEGACY_CSV_ENABLE_DATE_TIME_PARSING_FALLBACK =
buildConf("spark.sql.legacy.csv.enableDateTimeParsingFallback")
@@ -6005,7 +6115,9 @@ object SQLConf {
.doc("The chunk size in bytes when splitting ChunkedCachedLocalRelation.data " +
"into batches. A new chunk is created when either " +
"spark.sql.session.localRelationChunkSizeBytes " +
- "or spark.sql.session.localRelationChunkSizeRows is reached.")
+ "or spark.sql.session.localRelationChunkSizeRows is reached. " +
+ "Limited by the spark.sql.session.localRelationBatchOfChunksSizeBytes, " +
+ "a minimum of the two confs is used to determine the chunk size.")
.version("4.1.0")
.longConf
.checkValue(_ > 0, "The chunk size in bytes must be positive")
@@ -6029,6 +6141,21 @@ object SQLConf {
.bytesConf(ByteUnit.BYTE)
.createWithDefaultString("3GB")
+ val LOCAL_RELATION_BATCH_OF_CHUNKS_SIZE_BYTES =
+ buildConf(SqlApiConfHelper.LOCAL_RELATION_BATCH_OF_CHUNKS_SIZE_BYTES_KEY)
+ .internal()
+ .doc("Limit on how much memory the client can use when uploading a local relation to the " +
+ "server. The client collects multiple local relation chunks into a single batch in " +
+ "memory until the limit is reached, then uploads the batch to the server. " +
+ "This helps reduce memory pressure on the client when dealing with very large local " +
+ "relations because the client does not have to materialize all chunks in memory. " +
+ "Limits the spark.sql.session.localRelationChunkSizeBytes, " +
+ "a minimum of the two confs is used to determine the chunk size.")
+ .version("4.1.0")
+ .longConf
+ .checkValue(_ > 0, "The batch size in bytes must be positive")
+ .createWithDefault(1 * 1024 * 1024 * 1024L)
+
val DECORRELATE_JOIN_PREDICATE_ENABLED =
buildConf("spark.sql.optimizer.decorrelateJoinPredicate.enabled")
.internal()
@@ -6353,6 +6480,7 @@ object SQLConf {
val PIPELINES_STREAM_STATE_POLLING_INTERVAL = {
buildConf("spark.sql.pipelines.execution.streamstate.pollingInterval")
+ .internal()
.doc(
"Interval in seconds at which the stream state is polled for changes. This is used to " +
"check if the stream has failed and needs to be restarted."
@@ -6364,6 +6492,7 @@ object SQLConf {
val PIPELINES_WATCHDOG_MIN_RETRY_TIME_IN_SECONDS = {
buildConf("spark.sql.pipelines.execution.watchdog.minRetryTime")
+ .internal()
.doc(
"Initial duration in seconds between the time when we notice a flow has failed and " +
"when we try to restart the flow. The interval between flow restarts doubles with " +
@@ -6378,6 +6507,7 @@ object SQLConf {
val PIPELINES_WATCHDOG_MAX_RETRY_TIME_IN_SECONDS = {
buildConf("spark.sql.pipelines.execution.watchdog.maxRetryTime")
+ .internal()
.doc(
"Maximum time interval in seconds at which flows will be restarted."
)
@@ -6388,6 +6518,7 @@ object SQLConf {
val PIPELINES_MAX_CONCURRENT_FLOWS = {
buildConf("spark.sql.pipelines.execution.maxConcurrentFlows")
+ .internal()
.doc(
"Max number of flows to execute at once. Used to tune performance for triggered " +
"pipelines. Has no effect on continuous pipelines."
@@ -6400,6 +6531,7 @@ object SQLConf {
val PIPELINES_TIMEOUT_MS_FOR_TERMINATION_JOIN_AND_LOCK = {
buildConf("spark.sql.pipelines.timeoutMsForTerminationJoinAndLock")
+ .internal()
.doc("Timeout in milliseconds to grab a lock for stopping update - default is 1hr.")
.version("4.1.0")
.timeConf(TimeUnit.MILLISECONDS)
@@ -6417,6 +6549,7 @@ object SQLConf {
val PIPELINES_EVENT_QUEUE_CAPACITY = {
buildConf("spark.sql.pipelines.event.queue.capacity")
+ .internal()
.doc("Capacity of the event queue used in pipelined execution. When the queue is full, " +
"non-terminal FlowProgressEvents will be dropped.")
.version("4.1.0")
@@ -6474,7 +6607,7 @@ object SQLConf {
.createWithDefault(false)
val LEGACY_XML_PARSER_ENABLED = {
- buildConf("spark.sql.xml.legacyXMLParser.enabled")
+ buildConf("spark.sql.legacy.useLegacyXMLParser")
.internal()
.doc(
"When set to true, use the legacy XML parser for parsing XML files. " +
@@ -6495,14 +6628,23 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
- val MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED =
- buildConf("spark.sql.merge.source.nested.type.coercion.enabled")
+ val MERGE_INTO_NESTED_TYPE_COERCION_ENABLED =
+ buildConf("spark.sql.mergeNestedTypeCoercion.enabled")
.internal()
.doc("If enabled, allow MERGE INTO to coerce source nested types if they have less" +
- "nested fields than the target table's nested types.")
+ "nested fields than the target table's nested types. This is experimental and" +
+ "the semantics may change.")
.version("4.1.0")
.booleanConf
- .createWithDefault(true)
+ .createWithDefault(false)
+
+ val TIME_TYPE_ENABLED =
+ buildConf("spark.sql.timeType.enabled")
+ .internal()
+ .doc("When true, the TIME data type is supported.")
+ .version("4.1.0")
+ .booleanConf
+ .createWithDefault(Utils.isTesting)
/**
* Holds information about keys that have been deprecated.
@@ -6678,6 +6820,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def nameResolutionLogLevel: Level = getConf(NAME_RESOLUTION_LOG_LEVEL)
+ def geospatialEnabled: Boolean = getConf(GEOSPATIAL_ENABLED)
+
def dataSourceV2JoinPushdown: Boolean = getConf(DATA_SOURCE_V2_JOIN_PUSHDOWN)
def dynamicPartitionPruningEnabled: Boolean = getConf(DYNAMIC_PARTITION_PRUNING_ENABLED)
@@ -6739,6 +6883,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def checkpointFileChecksumEnabled: Boolean = getConf(STREAMING_CHECKPOINT_FILE_CHECKSUM_ENABLED)
+ def checkpointFileChecksumSkipCreationIfFileMissingChecksum: Boolean =
+ getConf(STREAMING_CHECKPOINT_FILE_CHECKSUM_SKIP_CREATION_IF_FILE_MISSING_CHECKSUM)
+
def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED)
def useDeprecatedKafkaOffsetFetching: Boolean = getConf(USE_DEPRECATED_KAFKA_OFFSET_FETCHING)
@@ -6812,6 +6959,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def parquetVectorizedReaderNestedColumnEnabled: Boolean =
getConf(PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED)
+ def parquetVectorizedReaderNullTypeEnabled: Boolean =
+ getConf(PARQUET_VECTORIZED_READER_NULL_TYPE_ENABLED)
+
def parquetVectorizedReaderBatchSize: Int = getConf(PARQUET_VECTORIZED_READER_BATCH_SIZE)
def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE)
@@ -7312,6 +7462,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def pythonUDFWorkerTracebackDumpIntervalSeconds: Long =
getConf(PYTHON_UDF_WORKER_TRACEBACK_DUMP_INTERVAL_SECONDS)
+ def pythonUDFDaemonKillWorkerOnFlushFailure: Boolean =
+ getConf(PYTHON_UDF_DAEMON_KILL_WORKER_ON_FLUSH_FAILURE)
+
def pythonWorkerLoggingEnabled: Boolean = getConf(PYTHON_WORKER_LOGGING_ENABLED)
def pythonUDFArrowConcurrencyLevel: Option[Int] = getConf(PYTHON_UDF_ARROW_CONCURRENCY_LEVEL)
@@ -7332,6 +7485,10 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def arrowMaxBytesPerBatch: Long = getConf(ARROW_EXECUTION_MAX_BYTES_PER_BATCH)
+ def arrowCompressionCodec: String = getConf(ARROW_EXECUTION_COMPRESSION_CODEC)
+
+ def arrowZstdCompressionLevel: Int = getConf(ARROW_EXECUTION_ZSTD_COMPRESSION_LEVEL)
+
def arrowTransformWithStateInPySparkMaxStateRecordsPerBatch: Int =
getConf(ARROW_TRANSFORM_WITH_STATE_IN_PYSPARK_MAX_STATE_RECORDS_PER_BATCH)
@@ -7561,6 +7718,10 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
def parquetFieldIdWriteEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED)
+ def parquetAnnotateVariantLogicalType: Boolean = getConf(PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE)
+
+ def parquetIgnoreVariantAnnotation: Boolean = getConf(SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION)
+
def ignoreMissingParquetFieldId: Boolean = getConf(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID)
def legacyParquetNanosAsLong: Boolean = getConf(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG)
@@ -7609,6 +7770,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
override def legacyParameterSubstitutionConstantsOnly: Boolean =
getConf(SQLConf.LEGACY_PARAMETER_SUBSTITUTION_CONSTANTS_ONLY)
+ override def legacyIdentifierClauseOnly: Boolean =
+ getConf(SQLConf.LEGACY_IDENTIFIER_CLAUSE_ONLY)
+
def streamStatePollingInterval: Long = getConf(SQLConf.PIPELINES_STREAM_STATE_POLLING_INTERVAL)
def watchdogMinRetryTimeInSeconds: Long = {
@@ -7640,7 +7804,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
getConf(SQLConf.LEGACY_XML_PARSER_ENABLED)
def coerceMergeNestedTypes: Boolean =
- getConf(SQLConf.MERGE_INTO_SOURCE_NESTED_TYPE_COERCION_ENABLED)
+ getConf(SQLConf.MERGE_INTO_NESTED_TYPE_COERCION_ENABLED)
+
+ def isTimeTypeEnabled: Boolean = getConf(SQLConf.TIME_TYPE_ENABLED)
/** ********************** SQLConf functionality methods ************ */
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ToStringSQLBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ToStringSQLBuilder.scala
index 118c1af977454..5c54f28976458 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ToStringSQLBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ToStringSQLBuilder.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.internal.connector
+import org.apache.spark.sql.connector.expressions.GetArrayItem
import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder
/**
@@ -35,4 +36,8 @@ class ToStringSQLBuilder extends V2ExpressionSQLBuilder with Serializable {
val distinct = if (isDistinct) "DISTINCT " else ""
s"""$funcName($distinct${inputs.mkString(", ")})"""
}
+
+ override protected def visitGetArrayItem(getArrayItem: GetArrayItem): String = {
+ s"${getArrayItem.childArray.toString}[${getArrayItem.ordinal.toString}]"
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/VariantExtractionImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/VariantExtractionImpl.scala
new file mode 100644
index 0000000000000..87db41a3217e6
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/VariantExtractionImpl.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal.connector
+
+import org.apache.spark.sql.connector.read.VariantExtraction
+import org.apache.spark.sql.types.{DataType, Metadata}
+
+/**
+ * Implementation of [[VariantExtraction]].
+ *
+ * @param columnName Path to the variant column (e.g., Array("v") for top-level,
+ * Array("struct1", "v") for nested)
+ * @param metadata The metadata for extraction including JSON path, failOnError, and timeZoneId
+ * @param expectedDataType The expected data type for the extracted value
+ */
+case class VariantExtractionImpl(
+ columnName: Array[String],
+ metadata: Metadata,
+ expectedDataType: DataType) extends VariantExtraction {
+
+ require(columnName != null, "columnName cannot be null")
+ require(metadata != null, "metadata cannot be null")
+ require(expectedDataType != null, "expectedDataType cannot be null")
+ require(columnName.nonEmpty, "columnName cannot be empty")
+
+ override def toString: String = {
+ s"VariantExtraction{columnName=${columnName.mkString("[", ", ", "]")}, " +
+ s"metadata='$metadata', expectedDataType=$expectedDataType}"
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala
index cfc17e2683ac8..58ababa04739f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala
@@ -20,13 +20,16 @@ package org.apache.spark.sql.util
import java.util.Locale
import scala.collection.immutable.Queue
+import scala.collection.mutable
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression}
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, NamedTransform, Transform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.SchemaValidationMode.{ALLOW_NEW_TOP_LEVEL_FIELDS, PROHIBIT_CHANGES}
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.SparkSchemaUtils
@@ -389,4 +392,142 @@ private[spark] object SchemaUtils {
case st: StringType => StringHelper.removeCollation(st)
case _ => dt
}
+
+ /**
+ * Validates schema compatibility by recursively checking type and nullability changes.
+ *
+ * @param schema the schema to validate against
+ * @param otherSchema the other schema to check for compatibility
+ * @param resolver the resolver that controls whether the validation is case sensitive
+ * @param mode the validation mode that controls what changes are allowed
+ * @return sequence of error messages describing incompatibilities, empty if fully compatible
+ */
+ def validateSchemaCompatibility(
+ schema: StructType,
+ otherSchema: StructType,
+ resolver: Resolver,
+ mode: SchemaValidationMode): Seq[String] = {
+ checkSchemaColumnNameDuplication(schema, resolver)
+ checkSchemaColumnNameDuplication(otherSchema, resolver)
+ val errors = mutable.ArrayBuffer[String]()
+ validateTypeCompatibility(
+ schema,
+ otherSchema,
+ nullable = false,
+ otherNullable = false,
+ colPath = Seq.empty,
+ resolver,
+ mode,
+ errors)
+ errors.toSeq
+ }
+
+ private def validateTypeCompatibility(
+ dataType: DataType,
+ otherDataType: DataType,
+ nullable: Boolean,
+ otherNullable: Boolean,
+ colPath: Seq[String],
+ resolver: Resolver,
+ mode: SchemaValidationMode,
+ errors: mutable.ArrayBuffer[String]): Unit = {
+ if (nullable && !otherNullable) {
+ errors += s"${colPath.fullyQuoted} is no longer nullable"
+ } else if (!nullable && otherNullable) {
+ errors += s"${colPath.fullyQuoted} is nullable now"
+ }
+
+ (dataType, otherDataType) match {
+ case (StructType(fields), StructType(otherFields)) =>
+ val fieldsByName = index(fields, resolver)
+ val otherFieldsByName = index(otherFields, resolver)
+
+ fieldsByName.foreach { case (normalizedName, field) =>
+ otherFieldsByName.get(normalizedName) match {
+ case Some(otherField) =>
+ validateTypeCompatibility(
+ field.dataType,
+ otherField.dataType,
+ field.nullable,
+ otherField.nullable,
+ colPath :+ field.name,
+ resolver,
+ mode,
+ errors)
+ case None =>
+ errors += s"${formatField(colPath, field)} has been removed"
+ }
+ }
+
+ if (mode == PROHIBIT_CHANGES || (mode == ALLOW_NEW_TOP_LEVEL_FIELDS && colPath.nonEmpty)) {
+ otherFieldsByName.foreach { case (normalizedName, otherField) =>
+ if (!fieldsByName.contains(normalizedName)) {
+ errors += s"${formatField(colPath, otherField)} has been added"
+ }
+ }
+ }
+
+ case (ArrayType(elem, containsNull), ArrayType(otherElem, otherContainsNull)) =>
+ validateTypeCompatibility(
+ elem,
+ otherElem,
+ containsNull,
+ otherContainsNull,
+ colPath :+ "element",
+ resolver,
+ mode,
+ errors)
+
+ case (MapType(keyType, valueType, valueContainsNull),
+ MapType(otherKeyType, otherValueType, otherValueContainsNull)) =>
+ validateTypeCompatibility(
+ keyType,
+ otherKeyType,
+ nullable = false,
+ otherNullable = false,
+ colPath :+ "key",
+ resolver,
+ mode,
+ errors)
+ validateTypeCompatibility(
+ valueType,
+ otherValueType,
+ valueContainsNull,
+ otherValueContainsNull,
+ colPath :+ "value",
+ resolver,
+ mode,
+ errors)
+
+ case _ if dataType != otherDataType =>
+ errors += s"${colPath.fullyQuoted} type has changed " +
+ s"from ${dataType.sql} to ${otherDataType.sql}"
+
+ case _ =>
+ // OK
+ }
+ }
+
+ private def formatField(colPath: Seq[String], field: StructField): String = {
+ val nameParts = colPath :+ field.name
+ val name = nameParts.fullyQuoted
+ val dataType = field.dataType.sql
+ if (field.nullable) s"$name $dataType" else s"$name $dataType NOT NULL"
+ }
+
+ private def index(fields: Array[StructField], resolver: Resolver): Map[String, StructField] = {
+ if (isCaseSensitiveAnalysis(resolver)) {
+ fields.map(field => field.name -> field).toMap
+ } else {
+ fields.map(field => field.name.toLowerCase(Locale.ROOT) -> field).toMap
+ }
+ }
+}
+
+private[spark] sealed trait SchemaValidationMode
+
+private[spark] object SchemaValidationMode {
+ case object PROHIBIT_CHANGES extends SchemaValidationMode
+ case object ALLOW_NEW_FIELDS extends SchemaValidationMode
+ case object ALLOW_NEW_TOP_LEVEL_FIELDS extends SchemaValidationMode
}
diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/connector/catalog/CatalogLoadingSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/connector/catalog/CatalogLoadingSuite.java
index 339f16407ae60..c7e8d7b0f7f30 100644
--- a/sql/catalyst/src/test/java/org/apache/spark/sql/connector/catalog/CatalogLoadingSuite.java
+++ b/sql/catalyst/src/test/java/org/apache/spark/sql/connector/catalog/CatalogLoadingSuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.sql.connector.catalog;
+import org.apache.spark.network.util.JavaUtils;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
@@ -24,6 +25,7 @@
import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import org.apache.spark.util.Utils;
+
public class CatalogLoadingSuite {
@Test
public void testLoad() throws SparkException {
@@ -58,6 +60,7 @@ public void testInitializationOptions() throws SparkException {
conf.setConfString("spark.sql.catalog.test-name", TestCatalogPlugin.class.getCanonicalName());
conf.setConfString("spark.sql.catalog.test-name.name", "not-catalog-name");
conf.setConfString("spark.sql.catalog.test-name.kEy", "valUE");
+ conf.setConfString("spark.sql.catalog.test-name.osName", "${system:os.name}");
CatalogPlugin plugin = Catalogs.load("test-name", conf);
Assertions.assertNotNull(plugin,"Should instantiate a non-null plugin");
@@ -66,11 +69,13 @@ public void testInitializationOptions() throws SparkException {
TestCatalogPlugin testPlugin = (TestCatalogPlugin) plugin;
- Assertions.assertEquals(2, testPlugin.options.size(), "Options should contain only two keys");
+ Assertions.assertEquals(3, testPlugin.options.size(), "Options should contain only three keys");
Assertions.assertEquals("not-catalog-name", testPlugin.options.get("name"),
"Options should contain correct value for name (not overwritten)");
Assertions.assertEquals("valUE", testPlugin.options.get("key"),
"Options should contain correct value for key");
+ Assertions.assertEquals(JavaUtils.osName, testPlugin.options.get("osName"),
+ "Options should contain correct substitution for value");
}
@Test
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala
index 42acc38eee2d1..fa5027ce259d5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala
@@ -174,6 +174,25 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase {
widenTest(FloatType, FloatType, Some(FloatType))
widenTest(DoubleType, DoubleType, Some(DoubleType))
+ // Geography with same fixed SRIDs.
+ widenTest(GeographyType(4326), GeographyType(4326), Some(GeographyType(4326)))
+ // Geography with mixed SRIDs.
+ widenTest(GeographyType("ANY"), GeographyType("ANY"), Some(GeographyType("ANY")))
+ widenTest(GeographyType("ANY"), GeographyType(4326), Some(GeographyType("ANY")))
+ widenTest(GeographyType(4326), GeographyType("ANY"), Some(GeographyType("ANY")))
+ // Geometry with same fixed SRIDs.
+ widenTest(GeometryType(0), GeometryType(0), Some(GeometryType(0)))
+ widenTest(GeometryType(3857), GeometryType(3857), Some(GeometryType(3857)))
+ widenTest(GeometryType(4326), GeometryType(4326), Some(GeometryType(4326)))
+ // Geometry with different fixed SRIDs.
+ widenTest(GeometryType(0), GeometryType(3857), Some(GeometryType("ANY")))
+ widenTest(GeometryType(3857), GeometryType(4326), Some(GeometryType("ANY")))
+ widenTest(GeometryType(4326), GeometryType(0), Some(GeometryType("ANY")))
+ // Geometry with mixed SRIDs.
+ widenTest(GeometryType("ANY"), GeometryType("ANY"), Some(GeometryType("ANY")))
+ widenTest(GeometryType("ANY"), GeometryType(4326), Some(GeometryType("ANY")))
+ widenTest(GeometryType(4326), GeometryType("ANY"), Some(GeometryType("ANY")))
+
// Integral mixed with floating point.
widenTest(IntegerType, FloatType, Some(DoubleType))
widenTest(IntegerType, DoubleType, Some(DoubleType))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala
index 189509e317364..9685ed5c6d256 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TableLookupCacheSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.analysis
import java.io.File
+import java.util
import scala.jdk.CollectionConverters._
@@ -30,6 +31,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogStorageFormat, CatalogTable, CatalogTableType, ExternalCatalog, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.connector.catalog.{CatalogManager, Identifier, InMemoryTable, InMemoryTableCatalog, Table}
+import org.apache.spark.sql.connector.catalog.TableWritePrivilege
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
@@ -55,6 +57,11 @@ class TableLookupCacheSuite extends AnalysisTest with Matchers {
Array.empty,
Map.empty[String, String].asJava)
}
+ override def loadTable(
+ ident: Identifier,
+ writePrivileges: util.Set[TableWritePrivilege]): Table = {
+ loadTable(ident)
+ }
override def name: String = CatalogManager.SESSION_CATALOG_NAME
}
val catalogManager = mock(classOf[CatalogManager])
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 250f20fd09571..e6a9690ad7570 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -597,6 +597,25 @@ class TypeCoercionSuite extends TypeCoercionSuiteBase {
widenTest(FloatType, FloatType, Some(FloatType))
widenTest(DoubleType, DoubleType, Some(DoubleType))
+ // Geography with same fixed SRIDs.
+ widenTest(GeographyType(4326), GeographyType(4326), Some(GeographyType(4326)))
+ // Geography with mixed SRIDs.
+ widenTest(GeographyType("ANY"), GeographyType("ANY"), Some(GeographyType("ANY")))
+ widenTest(GeographyType("ANY"), GeographyType(4326), Some(GeographyType("ANY")))
+ widenTest(GeographyType(4326), GeographyType("ANY"), Some(GeographyType("ANY")))
+ // Geometry with same fixed SRIDs.
+ widenTest(GeometryType(0), GeometryType(0), Some(GeometryType(0)))
+ widenTest(GeometryType(3857), GeometryType(3857), Some(GeometryType(3857)))
+ widenTest(GeometryType(4326), GeometryType(4326), Some(GeometryType(4326)))
+ // Geometry with different fixed SRIDs.
+ widenTest(GeometryType(0), GeometryType(3857), Some(GeometryType("ANY")))
+ widenTest(GeometryType(3857), GeometryType(4326), Some(GeometryType("ANY")))
+ widenTest(GeometryType(4326), GeometryType(0), Some(GeometryType("ANY")))
+ // Geometry with mixed SRIDs.
+ widenTest(GeometryType("ANY"), GeometryType("ANY"), Some(GeometryType("ANY")))
+ widenTest(GeometryType("ANY"), GeometryType(4326), Some(GeometryType("ANY")))
+ widenTest(GeometryType(4326), GeometryType("ANY"), Some(GeometryType("ANY")))
+
// Integral mixed with floating point.
widenTest(IntegerType, FloatType, Some(FloatType))
widenTest(IntegerType, DoubleType, Some(DoubleType))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
index 6ee19bab5180a..425df0856a58a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
@@ -853,6 +853,27 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
Deduplicate(Seq(attribute), streamRelation)), outputMode = Append)
}
+ /*
+ =======================================================================================
+ REAL-TIME STREAMING
+ =======================================================================================
+ */
+
+ {
+ assertNotSupportedForRealTime(
+ "real-time without operators - append mode",
+ streamRelation,
+ Append,
+ "STREAMING_REAL_TIME_MODE.OUTPUT_MODE_NOT_SUPPORTED"
+ )
+
+ assertSupportedForRealTime(
+ "real-time with stream-batch join - update mode",
+ streamRelation.join(batchRelation, joinType = Inner),
+ Update
+ )
+ }
+
/*
=======================================================================================
TESTING FUNCTIONS
@@ -1017,6 +1038,31 @@ class UnsupportedOperationsSuite extends SparkFunSuite with SQLHelper {
}
}
+ /** Assert that the logical plan is supported for real-time mode */
+ def assertSupportedForRealTime(name: String, plan: LogicalPlan, outputMode: OutputMode): Unit = {
+ test(s"real-time trigger - $name: supported") {
+ UnsupportedOperationChecker.checkAdditionalRealTimeModeConstraints(plan, outputMode)
+ }
+ }
+
+ /**
+ * Assert that the logical plan is not supported inside a streaming plan with the
+ * real-time trigger.
+ */
+ def assertNotSupportedForRealTime(
+ name: String,
+ plan: LogicalPlan,
+ outputMode: OutputMode,
+ condition: String): Unit = {
+ testError(
+ s"real-time trigger - $name: not supported",
+ Seq("Streaming real-time mode"),
+ condition
+ ) {
+ UnsupportedOperationChecker.checkAdditionalRealTimeModeConstraints(plan, outputMode)
+ }
+ }
+
/**
* Assert that the logical plan is not supported inside a streaming plan.
*
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
index 309d518f5ef4f..7a0a37f380992 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala
@@ -34,6 +34,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM
import org.apache.spark.sql.connector.catalog.SupportsNamespaces.PROP_OWNER
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
class InMemorySessionCatalogSuite extends SessionCatalogSuite {
protected val utils = new CatalogTestUtils {
@@ -2029,66 +2030,244 @@ abstract class SessionCatalogSuite extends AnalysisTest with Eventually {
test("corrupted view metadata: mismatch between viewQueryColumnNames and schema") {
withSQLConf("spark.sql.viewSchemaBinding.enabled" -> "true") {
- val catalog = new SessionCatalog(newBasicCatalog())
+ withBasicCatalog { catalog =>
+ val db = "test_db"
+ catalog.createDatabase(newDb(db), ignoreIfExists = false)
+
+ // First create a base table for the view to reference
+ val baseTable = CatalogTable(
+ identifier = TableIdentifier("base_table", Some(db)),
+ tableType = CatalogTableType.MANAGED,
+ storage = CatalogStorageFormat.empty,
+ schema = new StructType()
+ .add("id", IntegerType)
+ .add("name", StringType)
+ .add("value", DoubleType)
+ )
+ catalog.createTable(baseTable, ignoreIfExists = false)
+
+ // Create a view with corrupted metadata where viewQueryColumnNames length
+ // doesn't match schema length
+ // We need to set the properties to define viewQueryColumnNames
+ val properties = Map(
+ "view.query.out.numCols" -> "2",
+ "view.query.out.col.0" -> "id",
+ "view.query.out.col.1" -> "name",
+ "view.schema.mode" -> "binding" // Ensure it's not SchemaEvolution
+ )
+ val corruptedView = CatalogTable(
+ identifier = TableIdentifier("corrupted_view", Some(db)),
+ tableType = CatalogTableType.VIEW,
+ storage = CatalogStorageFormat.empty,
+ schema = new StructType()
+ .add("id", IntegerType)
+ .add("name", StringType)
+ .add("value", DoubleType),
+ viewText = Some("SELECT * FROM test_db.base_table"),
+ provider = Some("spark"), // Ensure it's not Hive-created
+ properties = properties // Only 2 query column names but schema has 3 columns
+ )
+
+ catalog.createTable(corruptedView, ignoreIfExists = false)
+
+ // Verify the view was created with corrupted metadata
+ val retrievedView = catalog.getTableMetadata(TableIdentifier("corrupted_view", Some(db)))
+ assert(retrievedView.viewQueryColumnNames.length == 2)
+ assert(retrievedView.schema.length == 3)
+
+ // Attempting to look up the view should throw an assertion error with detailed message
+ val exception = intercept[AssertionError] {
+ catalog.lookupRelation(TableIdentifier("corrupted_view", Some(db)))
+ }
+
+ // The expected message pattern allows for optional catalog prefix
+ val expectedPattern =
+ "assertion failed: Corrupted view metadata detected for view " +
+ "(\\`\\w+\\`\\.)?\\`test_db\\`\\.\\`corrupted_view\\`\\. " +
+ "The number of view query column names 2 " +
+ "does not match the number of columns in the view schema 3\\. " +
+ "View query column names: \\[id, name\\], " +
+ "View schema columns: \\[id, name, value\\]\\. " +
+ "This indicates corrupted view metadata that needs to be repaired\\."
+ assert(exception.getMessage.matches(expectedPattern))
+ }
+ }
+ }
+
+ test("UnresolvedCatalogRelation requires database in identifier") {
+ withEmptyCatalog { catalog =>
val db = "test_db"
- catalog.createDatabase(newDb(db), ignoreIfExists = false)
+ catalog.createDatabase(newDb(db), ignoreIfExists = true)
- // First create a base table for the view to reference
- val baseTable = CatalogTable(
- identifier = TableIdentifier("base_table", Some(db)),
+ // Create a table with database
+ val validTable = CatalogTable(
+ identifier = TableIdentifier("test_table", Some(db)),
tableType = CatalogTableType.MANAGED,
storage = CatalogStorageFormat.empty,
- schema = new StructType()
- .add("id", IntegerType)
- .add("name", StringType)
- .add("value", DoubleType)
+ schema = new StructType().add("id", IntegerType)
)
- catalog.createTable(baseTable, ignoreIfExists = false)
-
- // Create a view with corrupted metadata where viewQueryColumnNames length
- // doesn't match schema length
- // We need to set the properties to define viewQueryColumnNames
- val properties = Map(
- "view.query.out.numCols" -> "2",
- "view.query.out.col.0" -> "id",
- "view.query.out.col.1" -> "name",
- "view.schema.mode" -> "binding" // Ensure it's not SchemaEvolution
+ catalog.createTable(validTable, ignoreIfExists = false)
+
+ // Try to create UnresolvedCatalogRelation without database - should fail
+ val tableMetaWithoutDb = validTable.copy(
+ identifier = TableIdentifier("test_table", None)
)
- val corruptedView = CatalogTable(
- identifier = TableIdentifier("corrupted_view", Some(db)),
- tableType = CatalogTableType.VIEW,
+
+ val exception = intercept[AssertionError] {
+ UnresolvedCatalogRelation(tableMetaWithoutDb)
+ }
+
+ val expectedMessage =
+ "assertion failed: Table identifier `test_table` is missing database name. " +
+ "UnresolvedCatalogRelation requires a fully qualified table identifier with database."
+ assert(exception.getMessage === expectedMessage)
+ }
+ }
+
+ test("HiveTableRelation requires database in identifier") {
+ withEmptyCatalog { catalog =>
+ val db = "test_db"
+ catalog.createDatabase(newDb(db), ignoreIfExists = true)
+
+ // Create a table with database
+ val validTable = CatalogTable(
+ identifier = TableIdentifier("test_table", Some(db)),
+ tableType = CatalogTableType.MANAGED,
storage = CatalogStorageFormat.empty,
schema = new StructType()
.add("id", IntegerType)
.add("name", StringType)
- .add("value", DoubleType),
- viewText = Some("SELECT * FROM test_db.base_table"),
- provider = Some("spark"), // Ensure it's not Hive-created
- properties = properties // Only 2 query column names but schema has 3 columns
)
- catalog.createTable(corruptedView, ignoreIfExists = false)
-
- // Verify the view was created with corrupted metadata
- val retrievedView = catalog.getTableMetadata(TableIdentifier("corrupted_view", Some(db)))
- assert(retrievedView.viewQueryColumnNames.length == 2)
- assert(retrievedView.schema.length == 3)
+ // Try to create HiveTableRelation without database - should fail
+ val tableMetaWithoutDb = validTable.copy(
+ identifier = TableIdentifier("test_table", None)
+ )
- // Attempting to look up the view should throw an assertion error with detailed message
val exception = intercept[AssertionError] {
- catalog.lookupRelation(TableIdentifier("corrupted_view", Some(db)))
+ HiveTableRelation(
+ tableMetaWithoutDb,
+ Seq(AttributeReference("id", IntegerType)()),
+ Seq.empty
+ )
}
- // The expected message pattern allows for optional catalog prefix
- val expectedPattern =
- "assertion failed: Corrupted view metadata detected for view " +
- "(\\`\\w+\\`\\.)?\\`test_db\\`\\.\\`corrupted_view\\`\\. " +
- "The number of view query column names 2 " +
- "does not match the number of columns in the view schema 3\\. " +
- "View query column names: \\[id, name\\], " +
- "View schema columns: \\[id, name, value\\]\\. " +
- "This indicates corrupted view metadata that needs to be repaired\\."
- assert(exception.getMessage.matches(expectedPattern))
+ val expectedMessage =
+ "assertion failed: Table identifier `test_table` is missing database name. " +
+ "HiveTableRelation requires a fully qualified table identifier with database."
+ assert(exception.getMessage === expectedMessage)
}
}
+
+ test("SQLFunction requires either exprText or queryText") {
+ // Test case 1: Neither exprText nor queryText provided
+ val exception1 = intercept[AssertionError] {
+ SQLFunction(
+ name = FunctionIdentifier("test_func"),
+ inputParam = None,
+ returnType = scala.util.Left(IntegerType),
+ exprText = None,
+ queryText = None,
+ comment = None,
+ deterministic = Some(true),
+ containsSQL = Some(false),
+ isTableFunc = false,
+ properties = Map.empty
+ )
+ }
+
+ val expectedMessage = "assertion failed: SQL function 'test_func' is missing function body. " +
+ "Either exprText or queryText must be defined. " +
+ "Found: exprText=None, queryText=None."
+ assert(exception1.getMessage === expectedMessage)
+ }
+
+ test("SQLFunction return type must match function type") {
+ // Test case: isTableFunc=true but returnType is Left (scalar type)
+ val exception = intercept[AssertionError] {
+ SQLFunction(
+ name = FunctionIdentifier("test_func"),
+ inputParam = None,
+ returnType = scala.util.Left(IntegerType), // Scalar return type
+ exprText = Some("SELECT 1"),
+ queryText = None,
+ comment = None,
+ deterministic = Some(true),
+ containsSQL = Some(true),
+ isTableFunc = true, // But marked as table function
+ properties = Map.empty
+ )
+ }
+
+ val expectedMessage =
+ "assertion failed: SQL function 'test_func' has mismatched function type " +
+ "and return type. " +
+ "isTableFunc=true, returnType.isRight=false, returnType.isLeft=true. " +
+ "Table functions require Right[StructType] and scalar functions require Left[DataType]."
+ assert(exception.getMessage === expectedMessage)
+ }
+
+ test("InMemoryCatalog.createTable requires database in identifier") {
+ val catalog = new InMemoryCatalog()
+ val db = "test_db"
+ val dbDefinition = CatalogDatabase(
+ name = db,
+ description = "test database",
+ locationUri = Utils.createTempDir().toURI,
+ properties = Map.empty
+ )
+ catalog.createDatabase(dbDefinition, ignoreIfExists = false)
+
+ // Try to create table without database - should fail
+ val tableWithoutDb = CatalogTable(
+ identifier = TableIdentifier("test_table", None),
+ tableType = CatalogTableType.MANAGED,
+ storage = CatalogStorageFormat.empty,
+ schema = new StructType().add("id", IntegerType)
+ )
+
+ val exception = intercept[AssertionError] {
+ catalog.createTable(tableWithoutDb, ignoreIfExists = false)
+ }
+
+ val expectedMessage =
+ "assertion failed: Table identifier `test_table` is missing database name. " +
+ "Cannot create table without a database specified."
+ assert(exception.getMessage === expectedMessage)
+ }
+
+ test("InMemoryCatalog.alterTable requires database in identifier") {
+ val catalog = new InMemoryCatalog()
+ val db = "test_db"
+ val dbDefinition = CatalogDatabase(
+ name = db,
+ description = "test database",
+ locationUri = Utils.createTempDir().toURI,
+ properties = Map.empty
+ )
+ catalog.createDatabase(dbDefinition, ignoreIfExists = false)
+
+ // First create a valid table
+ val validTable = CatalogTable(
+ identifier = TableIdentifier("test_table", Some(db)),
+ tableType = CatalogTableType.MANAGED,
+ storage = CatalogStorageFormat.empty,
+ schema = new StructType().add("id", IntegerType)
+ )
+ catalog.createTable(validTable, ignoreIfExists = false)
+
+ // Try to alter table with identifier without database - should fail
+ val tableWithoutDb = validTable.copy(
+ identifier = TableIdentifier("test_table", None)
+ )
+
+ val exception = intercept[AssertionError] {
+ catalog.alterTable(tableWithoutDb)
+ }
+
+ val expectedMessage =
+ "assertion failed: Table identifier `test_table` is missing database name. " +
+ "Cannot alter table without a database specified."
+ assert(exception.getMessage === expectedMessage)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index a5c11558b3db3..287b99d10d659 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -612,6 +612,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
provider,
nullable = true))
.resolveAndBind()
+ assert(encoder.isInstanceOf[Serializable])
assert(encoder.schema == new StructType().add("value", BinaryType))
val toRow = encoder.createSerializer()
val fromRow = encoder.createDeserializer()
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeMapSuite.scala
new file mode 100644
index 0000000000000..fbb37d452437b
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeMapSuite.scala
@@ -0,0 +1,278 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StringType}
+
+class AttributeMapSuite extends SparkFunSuite {
+
+ val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
+ val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
+ val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
+
+ val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
+ val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
+
+ val cAttr = AttributeReference("c", StringType)(exprId = ExprId(4))
+
+ test("basic map operations - get") {
+ val map = AttributeMap(Seq((aUpper, "value1"), (bUpper, "value2")))
+
+ // Should find by exprId, not by attribute equality
+ assert(map.get(aLower) === Some("value1"))
+ assert(map.get(aUpper) === Some("value1"))
+ assert(map.get(bLower) === Some("value2"))
+ assert(map.get(bUpper) === Some("value2"))
+
+ // Different exprId should not be found
+ assert(map.get(fakeA) === None)
+ }
+
+ test("basic map operations - contains") {
+ val map = AttributeMap(Seq((aUpper, "value1"), (bUpper, "value2")))
+
+ // Should find by exprId, not by attribute equality
+ assert(map.contains(aLower))
+ assert(map.contains(aUpper))
+ assert(map.contains(bUpper))
+ assert(!map.contains(fakeA))
+ }
+
+ test("basic map operations - getOrElse") {
+ val map = AttributeMap(Seq((aUpper, "value1")))
+
+ assert(map.getOrElse(aLower, "default") === "value1")
+ assert(map.getOrElse(fakeA, "default") === "default")
+ }
+
+ test("+ operator preserves ExprId-based hashing") {
+ val map1 = AttributeMap(Seq((aUpper, "value1")))
+ val map2 = map1 + (bUpper -> "value2")
+
+ // The resulting map should still be an AttributeMap
+ assert(map2.isInstanceOf[AttributeMap[_]])
+
+ // Should look up by exprId, not by attribute equality
+ assert(map2.get(aLower) === Some("value1"))
+ assert(map2.get(bLower) === Some("value2"))
+ }
+
+ test("+ operator with attribute having different metadata") {
+ val metadata1 = new MetadataBuilder().putString("key", "value1").build()
+ val metadata2 = new MetadataBuilder().putString("key", "value2").build()
+
+ // Two attributes with same exprId but different metadata
+ val attrWithMetadata1 = AttributeReference("col", IntegerType, metadata = metadata1)(
+ exprId = ExprId(100))
+ val attrWithMetadata2 = AttributeReference("col", IntegerType, metadata = metadata2)(
+ exprId = ExprId(100))
+
+ // These should have different hashCodes but same exprId
+ assert(attrWithMetadata1.hashCode() != attrWithMetadata2.hashCode(),
+ "Attributes with different metadata should have different hashCodes")
+ assert(attrWithMetadata1.exprId == attrWithMetadata2.exprId,
+ "Attributes should have the same exprId")
+
+ // Create a map with the first attribute
+ val map1 = AttributeMap(Seq((attrWithMetadata1, "original")))
+
+ // Add an entry using the + operator
+ val map2 = map1 + (cAttr -> "new")
+
+ // CRITICAL: The map should still find the original entry using an attribute
+ // with the same exprId but different metadata
+ assert(map2.get(attrWithMetadata2) === Some("original"),
+ "AttributeMap should look up by exprId, not by attribute hashCode")
+
+ // And the new entry should also be present
+ assert(map2.get(cAttr) === Some("new"))
+ }
+
+ test("+ operator updates existing key") {
+ val map1 = AttributeMap(Seq((aUpper, "value1")))
+ val map2 = map1 + (aLower -> "updated")
+
+ // Since aLower has the same exprId as aUpper, it should update the value
+ assert(map2.get(aUpper) === Some("updated"))
+ assert(map2.get(aLower) === Some("updated"))
+ assert(map2.size === 1)
+ }
+
+ test("+ operator with type widening") {
+ val map1: AttributeMap[String] = AttributeMap(Seq((aUpper, "value1")))
+ val map2: AttributeMap[Any] = map1 + (bUpper -> 42)
+
+ assert(map2.get(aUpper) === Some("value1"))
+ assert(map2.get(bUpper) === Some(42))
+ }
+
+ test("++ operator preserves AttributeMap semantics") {
+ val map1 = AttributeMap(Seq((aUpper, "value1")))
+ val map2 = AttributeMap(Seq((bUpper, "value2")))
+ val combined = map1 ++ map2
+
+ assert(combined.isInstanceOf[AttributeMap[_]])
+ assert(combined.get(aLower) === Some("value1"))
+ assert(combined.get(bLower) === Some("value2"))
+ }
+
+ test("updated method") {
+ val map1 = AttributeMap(Seq((aUpper, "value1")))
+ val map2 = map1.updated(bUpper, "value2")
+
+ // Note: updated returns a Map[Attribute, B1], not AttributeMap
+ assert(map2.get(aUpper) === Some("value1"))
+ assert(map2.get(bUpper) === Some("value2"))
+ }
+
+ test("- operator (removal)") {
+ val map1 = AttributeMap(Seq((aUpper, "value1"), (bUpper, "value2")))
+ val map2 = map1 - aLower
+
+ // Note: - returns a Map[Attribute, A], not AttributeMap
+ assert(map2.get(aUpper) === None)
+ assert(map2.get(bUpper) === Some("value2"))
+ }
+
+ test("iterator") {
+ val map = AttributeMap(Seq((aUpper, "value1"), (bUpper, "value2")))
+ val entries = map.iterator.toSeq
+
+ assert(entries.size === 2)
+ assert(entries.contains((aUpper, "value1")))
+ assert(entries.contains((bUpper, "value2")))
+ }
+
+ test("size") {
+ val emptyMap = AttributeMap.empty[String]
+ assert(emptyMap.size === 0)
+
+ val map1 = AttributeMap(Seq((aUpper, "value1")))
+ assert(map1.size === 1)
+
+ val map2 = AttributeMap(Seq((aUpper, "value1"), (bUpper, "value2")))
+ assert(map2.size === 2)
+ }
+
+ test("empty map") {
+ val emptyMap = AttributeMap.empty[String]
+ assert(emptyMap.get(aUpper) === None)
+ assert(emptyMap.size === 0)
+ assert(!emptyMap.contains(aUpper))
+ }
+
+ test("duplicate keys in construction") {
+ // When constructing with duplicate exprIds, the last one should win
+ val map = AttributeMap(Seq((aUpper, "value1"), (aLower, "value2")))
+ assert(map.size === 1)
+ assert(map.get(aUpper) === Some("value2"))
+ }
+
+ test("map construction from Map") {
+ val regularMap = Map(aUpper -> "value1", bUpper -> "value2")
+ val attrMap = AttributeMap(regularMap)
+
+ assert(attrMap.get(aLower) === Some("value1"))
+ assert(attrMap.get(bLower) === Some("value2"))
+ }
+
+ test("chained + operations") {
+ val map = AttributeMap.empty[String] + (aUpper -> "value1") + (bUpper -> "value2") +
+ (cAttr -> "value3")
+
+ assert(map.size === 3)
+ assert(map.get(aLower) === Some("value1"))
+ assert(map.get(bLower) === Some("value2"))
+ assert(map.get(cAttr) === Some("value3"))
+ }
+
+ test("+ should be deterministic with Attributes with diff hashcodes and same exprId") {
+ // The HashMap needs to be of a certain size before the hashing starts to collide, set up
+ // these AttributeMaps to start with size 5.
+ var map1 = AttributeMap(
+ Seq(
+ AttributeReference("a", IntegerType)(exprId = ExprId(1)) -> 1,
+ AttributeReference("b", IntegerType)(exprId = ExprId(2)) -> 2,
+ AttributeReference("c", IntegerType)(exprId = ExprId(3)) -> 3,
+ AttributeReference("d", IntegerType)(exprId = ExprId(4)) -> 4,
+ AttributeReference("e", IntegerType)(exprId = ExprId(5)) -> 5
+ )
+ )
+ var map2 = AttributeMap(
+ Seq(
+ AttributeReference("a", IntegerType)(exprId = ExprId(1)) -> 1,
+ AttributeReference("b", IntegerType)(exprId = ExprId(2)) -> 2,
+ AttributeReference("c", IntegerType)(exprId = ExprId(3)) -> 3,
+ AttributeReference("d", IntegerType)(exprId = ExprId(4)) -> 4,
+ AttributeReference("e", IntegerType)(exprId = ExprId(5)) -> 5
+ )
+ )
+ val qualifier1 = Seq("d")
+ val qualifier2 = Seq()
+ val exprId = ExprId(42)
+ val attr1 = AttributeReference("x", IntegerType)(exprId = exprId, qualifier = qualifier1)
+ val attr2 = AttributeReference("x", IntegerType)(exprId = exprId, qualifier = qualifier2)
+ assert(attr1.hashCode != attr2.hashCode)
+
+ map1 = map1 + (attr1 -> 100)
+ map1 = map1 + (attr2 -> 200)
+ assert(map1.get(attr2) === Some(200))
+
+ map2 = map2 + (attr2 -> 200)
+ map2 = map2 + (attr1 -> 100)
+ assert(map2.get(attr2) === Some(100))
+ }
+
+ test("updated should be deterministic with Attributes with diff hashcodes and same exprId") {
+ // The HashMap needs to be of a certain size before the hashing starts to collide, set up
+ // these AttributeMaps to start with size 5.
+ var map1: Map[Attribute, Int] = AttributeMap(
+ Seq(
+ AttributeReference("a", IntegerType)(exprId = ExprId(1)) -> 1,
+ AttributeReference("b", IntegerType)(exprId = ExprId(2)) -> 2,
+ AttributeReference("c", IntegerType)(exprId = ExprId(3)) -> 3,
+ AttributeReference("d", IntegerType)(exprId = ExprId(4)) -> 4,
+ AttributeReference("e", IntegerType)(exprId = ExprId(5)) -> 5
+ )
+ )
+ var map2: Map[Attribute, Int] = AttributeMap(
+ Seq(
+ AttributeReference("a", IntegerType)(exprId = ExprId(1)) -> 1,
+ AttributeReference("b", IntegerType)(exprId = ExprId(2)) -> 2,
+ AttributeReference("c", IntegerType)(exprId = ExprId(3)) -> 3,
+ AttributeReference("d", IntegerType)(exprId = ExprId(4)) -> 4,
+ AttributeReference("e", IntegerType)(exprId = ExprId(5)) -> 5
+ )
+ )
+ val qualifier1 = Seq("d")
+ val qualifier2 = Seq()
+ val exprId = ExprId(42)
+ val attr1 = AttributeReference("x", IntegerType)(exprId = exprId, qualifier = qualifier1)
+ val attr2 = AttributeReference("x", IntegerType)(exprId = exprId, qualifier = qualifier2)
+ assert(attr1.hashCode != attr2.hashCode)
+
+ map1 = map1.updated(attr1, 100)
+ map1 = map1.updated(attr2, 200)
+ assert(map1.get(attr2) === Some(200))
+
+ map2 = map2.updated(attr2, 200)
+ map2 = map2.updated(attr1, 100)
+ assert(map2.get(attr2) === Some(100))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
index 7a87c86b63c04..e18a489d36f3b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
@@ -1489,6 +1489,84 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
}
}
+ // The following tests are confirming the behavior of casting between geospatial types.
+
+ test("Casting GeographyType to GeographyType") {
+ // Casting from fixed SRID GEOGRAPHY() to mixed SRID GEOGRAPHY(ANY) is always allowed.
+ // Type casting is always safe in this direction, so no additional constraints are imposed.
+ // Casting from mixed SRID GEOGRAPHY(ANY) to fixed SRID GEOGRAPHY() is not allowed.
+ // Type casting can be unsafe in this direction, because per-row SRID values may be different.
+
+ // Valid cast test cases.
+ val canCastTestCases: Seq[(DataType, DataType)] = Seq(
+ (GeographyType(4326), GeographyType("ANY"))
+ )
+ // Iterate over the test cases and verify casting.
+ canCastTestCases.foreach { case (fromType, toType) =>
+ // Cast can be performed from `fromType` to `toType`.
+ assert(Cast.canCast(fromType, toType))
+ assert(Cast.canAnsiCast(fromType, toType))
+ // Cast cannot be performed from `toType` to `fromType`.
+ assert(!Cast.canCast(toType, fromType))
+ assert(!Cast.canAnsiCast(toType, fromType))
+ }
+ }
+
+ test("Casting GeographyType to GeometryType") {
+ // Casting from GEOGRAPHY to GEOMETRY is only allowed if the SRIDs are the same.
+
+ // Valid cast test cases.
+ val canAnsiCastTestCases: Seq[(DataType, DataType)] = Seq(
+ (GeographyType(4326), GeometryType(4326)),
+ (GeographyType("ANY"), GeometryType("ANY"))
+ )
+ // Iterate over the test cases and verify casting.
+ canAnsiCastTestCases.foreach { case (fromType, toType) =>
+ // Cast can be performed from `fromType` to `toType`.
+ assert(Cast.canCast(fromType, toType))
+ assert(Cast.canAnsiCast(fromType, toType))
+ }
+
+ // Invalid cast test cases.
+ val cannotAnsiCastTestCases: Seq[(DataType, DataType)] = Seq(
+ (GeographyType(4326), GeometryType(0)),
+ (GeographyType(4326), GeometryType(3857)),
+ (GeographyType(4326), GeometryType("ANY")),
+ (GeographyType("ANY"), GeometryType(0)),
+ (GeographyType("ANY"), GeometryType(3857)),
+ (GeographyType("ANY"), GeometryType(4326))
+ )
+ // Iterate over the test cases and verify casting.
+ cannotAnsiCastTestCases.foreach { case (fromType, toType) =>
+ // Cast cannot be performed from `fromType` to `toType`.
+ assert(!Cast.canCast(fromType, toType))
+ assert(!Cast.canAnsiCast(fromType, toType))
+ }
+ }
+
+ test("Casting GeometryType to GeometryType") {
+ // Casting from fixed SRID GEOMETRY() to mixed SRID GEOMETRY(ANY) is always allowed.
+ // Type casting is always safe in this direction, so no additional constraints are imposed.
+ // Casting from mixed SRID GEOMETRY(ANY) to fixed SRID GEOMETRY() is not allowed.
+ // Type casting can be unsafe in this direction, because per-row SRID values may be different.
+
+ // Valid cast test cases.
+ val canCastTestCases: Seq[(DataType, DataType)] = Seq(
+ (GeometryType(0), GeometryType("ANY")),
+ (GeometryType(3857), GeometryType("ANY")),
+ (GeometryType(4326), GeometryType("ANY"))
+ )
+ // Iterate over the test cases and verify casting.
+ canCastTestCases.foreach { case (fromType, toType) =>
+ // Cast can be performed from `fromType` to `toType`.
+ assert(Cast.canCast(fromType, toType))
+ assert(Cast.canAnsiCast(fromType, toType))
+ // Cast cannot be performed from `toType` to `fromType`.
+ assert(!Cast.canCast(toType, fromType))
+ assert(!Cast.canAnsiCast(toType, fromType))
+ }
+ }
+
test("cast string to time") {
checkEvaluation(cast(Literal.create("0:0:0"), TimeType()), 0L)
checkEvaluation(cast(Literal.create(" 01:2:3.01 "), TimeType(2)), localTime(1, 2, 3, 10000))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DatasketchesHllSketchSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DatasketchesHllSketchSuite.scala
index 0841702cc5180..0f7f5ca54be01 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DatasketchesHllSketchSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/DatasketchesHllSketchSuite.scala
@@ -108,4 +108,49 @@ class DatasketchesHllSketchSuite extends SparkFunSuite {
assert(HllSketch.heapify(Memory.wrap(binary3.asInstanceOf[Array[Byte]])).getLgConfigK == 12)
}
+
+ test("HllUnionAgg throws proper error for invalid binary input causing ArrayIndexOutOfBounds") {
+ val aggFunc = new HllUnionAgg(BoundReference(0, BinaryType, nullable = true), true)
+ val union = aggFunc.createAggregationBuffer()
+
+ // Craft a byte array that passes initial size checks but has an invalid CurMode ordinal.
+ // HLL preamble layout:
+ // Byte 0: preInts (preamble size in ints)
+ // Byte 1: serVer (must be 1)
+ // Byte 2: famId (must be 7 for HLL)
+ // Byte 3: lgK (4-21)
+ // Byte 5: flags
+ // Byte 7: modeByte - bits 0-1 contain curMode ordinal (0=LIST, 1=SET, 2=HLL)
+ //
+ // Setting bits 0-1 of byte 7 to 0b11 (=3) causes CurMode.fromOrdinal(3) to throw
+ // ArrayIndexOutOfBoundsException since CurMode only has ordinals 0, 1, 2.
+ // This happens in PreambleUtil.extractCurMode() before other validations run.
+ val invalidBinary = Array[Byte](
+ 2, // byte 0: preInts = 2 (LIST_PREINTS, passes check)
+ 1, // byte 1: serVer = 1 (valid)
+ 7, // byte 2: famId = 7 (HLL family)
+ 12, // byte 3: lgK = 12 (valid range 4-21)
+ 0, // byte 4: unused
+ 0, // byte 5: flags = 0
+ 0, // byte 6: unused
+ 3 // byte 7: modeByte with bits 0-1 = 0b11 = 3 (INVALID curMode ordinal!)
+ )
+
+ val exception = intercept[Exception] {
+ aggFunc.update(union, InternalRow(invalidBinary))
+ }
+
+ // Verify that ArrayIndexOutOfBoundsException is properly caught and converted
+ // to the user-friendly HLL_INVALID_INPUT_SKETCH_BUFFER error
+ assert(
+ !exception.isInstanceOf[ArrayIndexOutOfBoundsException],
+ s"ArrayIndexOutOfBoundsException should be caught and converted to " +
+ s"HLL_INVALID_INPUT_SKETCH_BUFFER error, but got: ${exception.getClass.getName}"
+ )
+ assert(
+ exception.getMessage.contains("HLL_INVALID_INPUT_SKETCH_BUFFER"),
+ s"Expected HLL_INVALID_INPUT_SKETCH_BUFFER error, " +
+ s"but got: ${exception.getClass.getName}: ${exception.getMessage}"
+ )
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
index 298329db1ee30..30ecc902ff89f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/SqlScriptingParserSuite.scala
@@ -30,10 +30,12 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
protected override def beforeAll(): Unit = {
super.beforeAll()
+ conf.setConf(SQLConf.SQL_SCRIPTING_ENABLED, true)
conf.setConf(SQLConf.SQL_SCRIPTING_CONTINUE_HANDLER_ENABLED, true)
}
protected override def afterAll(): Unit = {
+ conf.unsetConf(SQLConf.SQL_SCRIPTING_ENABLED.key)
conf.unsetConf(SQLConf.SQL_SCRIPTING_CONTINUE_HANDLER_ENABLED.key)
super.afterAll()
}
@@ -2273,11 +2275,11 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
| END;
|END""".stripMargin
checkError(
- exception = intercept[SqlScriptingException] {
+ exception = intercept[ParseException] {
parsePlan(sqlScriptText)
},
- condition = "INVALID_LABEL_USAGE.QUALIFIED_LABEL_NAME",
- parameters = Map("labelName" -> "PART1.PART2"))
+ condition = "PARSE_SYNTAX_ERROR",
+ parameters = Map("error" -> "'.'", "hint" -> ""))
}
test("qualified label name: label cannot be qualified + end label") {
@@ -2288,11 +2290,11 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
| END part1.part2;
|END""".stripMargin
checkError(
- exception = intercept[SqlScriptingException] {
+ exception = intercept[ParseException] {
parsePlan(sqlScriptText)
},
- condition = "INVALID_LABEL_USAGE.QUALIFIED_LABEL_NAME",
- parameters = Map("labelName" -> "PART1.PART2"))
+ condition = "PARSE_SYNTAX_ERROR",
+ parameters = Map("error" -> "'.'", "hint" -> ""))
}
test("unique label names: nested labeled scope statements") {
@@ -2785,13 +2787,13 @@ class SqlScriptingParserSuite extends SparkFunSuite with SQLHelper {
|BEGIN
| DECLARE TEST.CONDITION CONDITION FOR SQLSTATE '12345';
|END""".stripMargin
- val exception = intercept[SqlScriptingException] {
+ val exception = intercept[ParseException] {
parsePlan(sqlScriptText)
}
checkError(
exception = exception,
- condition = "INVALID_ERROR_CONDITION_DECLARATION.QUALIFIED_CONDITION_NAME",
- parameters = Map("conditionName" -> "TEST.CONDITION"))
+ condition = "PARSE_SYNTAX_ERROR",
+ parameters = Map("error" -> "'FOR'", "hint" -> ": missing ';'"))
assert(exception.origin.line.contains(3))
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
index cbaebfa12238a..a87d599711cfc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/IntervalUtilsSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.util
import java.time.{Duration, Period}
import java.util.concurrent.TimeUnit
-import org.apache.spark.{SparkFunSuite, SparkIllegalArgumentException}
+import org.apache.spark.{SparkArithmeticException, SparkFunSuite, SparkIllegalArgumentException}
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToMicros
@@ -364,10 +364,29 @@ class IntervalUtilsSuite extends SparkFunSuite with SQLHelper {
assert(duration("1 microsecond", TimeUnit.MICROSECONDS, 30) === 1)
assert(duration("1 month -30 days", TimeUnit.DAYS, 31) === 1)
- val e = intercept[ArithmeticException] {
- duration(s"${Integer.MAX_VALUE} month", TimeUnit.SECONDS, 31)
+ checkError(
+ exception = intercept[SparkArithmeticException] {
+ duration(s"${Integer.MAX_VALUE} month", TimeUnit.SECONDS, 31)
+ },
+ condition = "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION",
+ parameters = Map.empty
+ )
+ }
+
+ test("interval overflow with large day values") {
+ // Test case for SPARK-50072: handling ArithmeticException during interval parsing
+ // The value 106751991 days causes overflow when converted to microseconds
+ def duration(s: String, unit: TimeUnit): Long = {
+ IntervalUtils.getDuration(stringToInterval(UTF8String.fromString(s)), unit)
}
- assert(e.getMessage.contains("overflow"))
+
+ checkError(
+ exception = intercept[SparkArithmeticException] {
+ duration("106751991 days 4 hours 0 minutes 54.776 seconds", TimeUnit.MICROSECONDS)
+ },
+ condition = "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION",
+ parameters = Map.empty
+ )
}
test("negative interval") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StUtilsSuite.java b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StUtilsSuite.java
index 8ad4d4c36e45c..f19a92b61641c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StUtilsSuite.java
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StUtilsSuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.util;
+import org.apache.spark.SparkIllegalArgumentException;
import org.apache.spark.unsafe.types.GeographyVal;
import org.apache.spark.unsafe.types.GeometryVal;
import org.junit.jupiter.api.Test;
@@ -62,6 +63,16 @@ class STUtilsSuite {
System.arraycopy(testWkb, 0, testGeometryBytes, sridLen, wkbLen);
}
+ /** Geospatial type casting utility methods. */
+
+ @Test
+ void testGeographyToGeometry() {
+ GeographyVal geographyVal = GeographyVal.fromBytes(testGeographyBytes);
+ GeometryVal geometryVal = STUtils.geographyToGeometry(geographyVal);
+ assertNotNull(geometryVal);
+ assertArrayEquals(geographyVal.getBytes(), geometryVal.getBytes());
+ }
+
/** Tests for ST expression utility methods. */
// ST_AsBinary
@@ -110,4 +121,49 @@ void testStSridGeometry() {
assertEquals(testGeometrySrid, STUtils.stSrid(geometryVal));
}
+ // ST_SetSrid
+ @Test
+ void testStSetSridGeography() {
+ for (int validGeographySrid : new int[]{4326}) {
+ GeographyVal geographyVal = GeographyVal.fromBytes(testGeographyBytes);
+ GeographyVal updatedGeographyVal = STUtils.stSetSrid(geographyVal, validGeographySrid);
+ assertNotNull(updatedGeographyVal);
+ Geography updatedGeography = Geography.fromBytes(updatedGeographyVal.getBytes());
+ assertEquals(validGeographySrid, updatedGeography.srid());
+ }
+ }
+
+ @Test
+ void testStSetSridGeographyInvalidSrid() {
+ for (int invalidGeographySrid : new int[]{-9999, -2, -1, 0, 1, 2, 3857, 9999}) {
+ GeographyVal geographyVal = GeographyVal.fromBytes(testGeographyBytes);
+ SparkIllegalArgumentException exception = assertThrows(SparkIllegalArgumentException.class,
+ () -> STUtils.stSetSrid(geographyVal, invalidGeographySrid));
+ assertEquals("ST_INVALID_SRID_VALUE", exception.getCondition());
+ assertTrue(exception.getMessage().contains("value: " + invalidGeographySrid + "."));
+ }
+ }
+
+ @Test
+ void testStSetSridGeometry() {
+ for (int validGeographySrid : new int[]{0, 3857, 4326}) {
+ GeometryVal geometryVal = GeometryVal.fromBytes(testGeometryBytes);
+ GeometryVal updatedGeometryVal = STUtils.stSetSrid(geometryVal, validGeographySrid);
+ assertNotNull(updatedGeometryVal);
+ Geometry updatedGeometry = Geometry.fromBytes(updatedGeometryVal.getBytes());
+ assertEquals(validGeographySrid, updatedGeometry.srid());
+ }
+ }
+
+ @Test
+ void testStSetSridGeometryInvalidSrid() {
+ for (int invalidGeometrySrid : new int[]{-9999, -2, -1, 1, 2, 9999}) {
+ GeometryVal geometryVal = GeometryVal.fromBytes(testGeometryBytes);
+ SparkIllegalArgumentException exception = assertThrows(SparkIllegalArgumentException.class,
+ () -> STUtils.stSetSrid(geometryVal, invalidGeometrySrid));
+ assertEquals("ST_INVALID_SRID_VALUE", exception.getCondition());
+ assertTrue(exception.getMessage().contains("value: " + invalidGeometrySrid + "."));
+ }
+ }
+
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala
index 4798623417b1f..b4a4c6f46cda4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala
@@ -1357,24 +1357,24 @@ class CatalogSuite extends SparkFunSuite {
intercept[NoSuchFunctionException](catalog.loadFunction(Identifier.of(Array("ns1"), "func")))
}
- test("currentVersion") {
+ test("version") {
val catalog = newCatalog()
val table = catalog.createTable(testIdent, columns, emptyTrans, emptyProps)
.asInstanceOf[InMemoryTable]
- assert(table.currentVersion() == "0")
+ assert(table.version() == "0")
table.withData(Array(
BufferedRows("3", table.columns()).withRow(InternalRow(0, "abc", "3")),
BufferedRows("4", table.columns()).withRow(InternalRow(1, "def", "4"))))
- assert(table.currentVersion() == "1")
+ assert(table.version() == "1")
table.truncateTable()
- assert(catalog.loadTable(testIdent).currentVersion() == "2")
+ assert(catalog.loadTable(testIdent).version() == "2")
catalog.alterTable(testIdent, TableChange.setProperty("prop-1", "1"))
- assert(catalog.loadTable(testIdent).currentVersion() == "3")
+ assert(catalog.loadTable(testIdent).version() == "3")
catalog.alterTable(testIdent, TableChange.addConstraint(constraints.apply(0), "3"))
- assert(catalog.loadTable(testIdent).currentVersion() == "4")
+ assert(catalog.loadTable(testIdent).version() == "4")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index d66ba5a23cc84..407d592f82199 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -66,7 +66,7 @@ abstract class InMemoryBaseTable(
extends Table with SupportsRead with SupportsWrite with SupportsMetadataColumns {
// Tracks the current version number of the table.
- protected var currentTableVersion: Int = 0
+ protected var tableVersion: Int = 0
// Stores the table version validated during the last `ALTER TABLE ... ADD CONSTRAINT` operation.
private var validatedTableVersion: String = null
@@ -75,14 +75,14 @@ abstract class InMemoryBaseTable(
override def columns(): Array[Column] = tableColumns
- override def currentVersion(): String = currentTableVersion.toString
+ override def version(): String = tableVersion.toString
- def setCurrentVersion(version: String): Unit = {
- currentTableVersion = version.toInt
+ def setVersion(version: String): Unit = {
+ tableVersion = version.toInt
}
- def increaseCurrentVersion(): Unit = {
- currentTableVersion += 1
+ def increaseVersion(): Unit = {
+ tableVersion += 1
}
def validatedVersion(): String = {
@@ -236,15 +236,26 @@ abstract class InMemoryBaseTable(
case (v, t) =>
throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
}
+ // the result should be consistent with BucketFunctions defined at transformFunctions.scala
case BucketTransform(numBuckets, cols, _) =>
- val valueTypePairs = cols.map(col => extractor(col.fieldNames, cleanedSchema, row))
- var valueHashCode = 0
- valueTypePairs.foreach( pair =>
- if ( pair._1 != null) valueHashCode += pair._1.hashCode()
- )
- var dataTypeHashCode = 0
- valueTypePairs.foreach(dataTypeHashCode += _._2.hashCode())
- ((valueHashCode + 31 * dataTypeHashCode) & Integer.MAX_VALUE) % numBuckets
+ val hash: Long = cols.foldLeft(0L) { (acc, col) =>
+ val valueHash = extractor(col.fieldNames, cleanedSchema, row) match {
+ case (value: Byte, _: ByteType) => value.toLong
+ case (value: Short, _: ShortType) => value.toLong
+ case (value: Int, _: IntegerType) => value.toLong
+ case (value: Long, _: LongType) => value
+ case (value: Long, _: TimestampType) => value
+ case (value: Long, _: TimestampNTZType) => value
+ case (value: UTF8String, _: StringType) =>
+ value.hashCode.toLong
+ case (value: Array[Byte], BinaryType) =>
+ util.Arrays.hashCode(value).toLong
+ case (v, t) =>
+ throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
+ }
+ acc + valueHash
+ }
+ Math.floorMod(hash, numBuckets)
case NamedTransform("truncate", Seq(ref: NamedReference, length: V2Literal[_])) =>
extractor(ref.fieldNames, cleanedSchema, row) match {
case (str: UTF8String, StringType) =>
@@ -543,6 +554,10 @@ abstract class InMemoryBaseTable(
}
new BufferedRowsReaderFactory(metadataColumns.toSeq, nonMetadataColumns, tableSchema)
}
+
+ override def supportedCustomMetrics(): Array[CustomMetric] = {
+ Array(new RowsReadCustomMetric)
+ }
}
case class InMemoryBatchScan(
@@ -830,10 +845,13 @@ private class BufferedRowsReader(
}
private var index: Int = -1
+ private var rowsRead: Long = 0
override def next(): Boolean = {
index += 1
- index < partition.rows.length
+ val hasNext = index < partition.rows.length
+ if (hasNext) rowsRead += 1
+ hasNext
}
override def get(): InternalRow = {
@@ -976,6 +994,22 @@ private class BufferedRowsReader(
private def castElement(elem: Any, toType: DataType, fromType: DataType): Any =
Cast(Literal(elem, fromType), toType, None, EvalMode.TRY).eval(null)
+
+ override def initMetricsValues(metrics: Array[CustomTaskMetric]): Unit = {
+ metrics.foreach { m =>
+ m.name match {
+ case "rows_read" => rowsRead = m.value()
+ }
+ }
+ }
+
+ override def currentMetricsValues(): Array[CustomTaskMetric] = {
+ val metric = new CustomTaskMetric {
+ override def name(): String = "rows_read"
+ override def value(): Long = rowsRead
+ }
+ Array(metric)
+ }
}
private class BufferedRowsWriterFactory(schema: StructType)
@@ -1044,6 +1078,11 @@ class InMemoryCustomDriverTaskMetric(value: Long) extends CustomTaskMetric {
override def value(): Long = value
}
+class RowsReadCustomMetric extends CustomSumMetric {
+ override def name(): String = "rows_read"
+ override def description(): String = "number of rows read"
+}
+
case class Commit(id: Long, writeSummary: Option[WriteSummary] = None)
sealed trait Operation
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
index 46169a9db4914..3bea136b34d46 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.connector.catalog
import java.util
+import java.util.{Objects, UUID}
import org.apache.spark.sql.connector.catalog.constraints.Constraint
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
@@ -42,7 +43,8 @@ class InMemoryTable(
numPartitions: Option[Int] = None,
advisoryPartitionSize: Option[Long] = None,
isDistributionStrictlyRequired: Boolean = true,
- override val numRowsPerSplit: Int = Int.MaxValue)
+ override val numRowsPerSplit: Int = Int.MaxValue,
+ override val id: String = UUID.randomUUID().toString)
extends InMemoryBaseTable(name, columns, partitioning, properties, constraints, distribution,
ordering, numPartitions, advisoryPartitionSize, isDistributionStrictlyRequired,
numRowsPerSplit) with SupportsDelete {
@@ -67,7 +69,7 @@ class InMemoryTable(
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
dataMap --= InMemoryTable
.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted).toImmutableArraySeq, filters)
- increaseCurrentVersion()
+ increaseVersion()
}
override def withData(data: Array[BufferedRows]): InMemoryTable = {
@@ -107,7 +109,7 @@ class InMemoryTable(
row.getInt(0) == InMemoryTable.uncommittableValue()))) {
throw new IllegalArgumentException(s"Test only mock write failure")
}
- increaseCurrentVersion()
+ increaseVersion()
this
}
}
@@ -137,7 +139,8 @@ class InMemoryTable(
numPartitions,
advisoryPartitionSize,
isDistributionStrictlyRequired,
- numRowsPerSplit)
+ numRowsPerSplit,
+ id)
dataMap.synchronized {
dataMap.foreach { case (key, splits) =>
@@ -152,7 +155,7 @@ class InMemoryTable(
copiedTable.commits ++= commits.map(_.copy())
- copiedTable.setCurrentVersion(currentVersion())
+ copiedTable.setVersion(version())
if (validatedVersion() != null) {
copiedTable.setValidatedVersion(validatedVersion())
}
@@ -160,6 +163,16 @@ class InMemoryTable(
copiedTable
}
+ override def equals(other: Any): Boolean = other match {
+ case that: InMemoryTable =>
+ this.id == that.id && this.version() == that.version()
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ Objects.hash(id, version())
+ }
+
class InMemoryWriterBuilderWithOverWrite(override val info: LogicalWriteInfo)
extends InMemoryWriterBuilder(info) with SupportsOverwrite {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
index 1da0882ec211a..5ea33a1764aae 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
@@ -46,9 +46,11 @@ class BasicInMemoryTableCatalog extends TableCatalog {
private val invalidatedTables: util.Set[Identifier] = ConcurrentHashMap.newKeySet()
private var _name: Option[String] = None
+ private var copyOnLoad: Boolean = false
override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = {
_name = Some(name)
+ copyOnLoad = options.getBoolean("copyOnLoad", false)
}
override def name: String = _name.get
@@ -57,7 +59,22 @@ class BasicInMemoryTableCatalog extends TableCatalog {
tables.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray
}
+ // load table for scans
override def loadTable(ident: Identifier): Table = {
+ Option(tables.get(ident)) match {
+ case Some(table: InMemoryTable) if copyOnLoad =>
+ table.copy() // copy to validate logical table equality
+ case Some(table) =>
+ table
+ case _ =>
+ throw new NoSuchTableException(ident.asMultipartIdentifier)
+ }
+ }
+
+ // load table for writes
+ override def loadTable(
+ ident: Identifier,
+ writePrivileges: util.Set[TableWritePrivilege]): Table = {
Option(tables.get(ident)) match {
case Some(table) =>
table
@@ -162,16 +179,17 @@ class BasicInMemoryTableCatalog extends TableCatalog {
throw new IllegalArgumentException(s"Cannot drop all fields")
}
- table.increaseCurrentVersion()
- val currentVersion = table.currentVersion()
+ table.increaseVersion()
+ val currentVersion = table.version()
val newTable = new InMemoryTable(
name = table.name,
columns = CatalogV2Util.structTypeToV2Columns(schema),
partitioning = finalPartitioning,
properties = properties,
- constraints = constraints)
+ constraints = constraints,
+ id = table.id)
.alterTableWithData(table.data, schema)
- newTable.setCurrentVersion(currentVersion)
+ newTable.setVersion(currentVersion)
changes.foreach {
case a: TableChange.AddConstraint =>
newTable.setValidatedVersion(a.validatedTableVersion())
@@ -191,7 +209,7 @@ class BasicInMemoryTableCatalog extends TableCatalog {
Option(tables.remove(oldIdent)) match {
case Some(table: InMemoryBaseTable) =>
- table.increaseCurrentVersion()
+ table.increaseVersion()
tables.put(newIdent, table)
case _ =>
throw new NoSuchTableException(oldIdent.asMultipartIdentifier)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala
index ee2400cab35c8..7ded99c709a39 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/StagingInMemoryTableCatalog.scala
@@ -36,7 +36,8 @@ class StagingInMemoryTableCatalog extends InMemoryTableCatalog with StagingTable
new TestStagedCreateTable(
ident,
new InMemoryTable(s"$name.${ident.quoted}",
- tableInfo.schema(), tableInfo.partitions(), tableInfo.properties()))
+ tableInfo.columns(), tableInfo.partitions(), tableInfo.properties(),
+ tableInfo.constraints()))
}
override def stageReplace(ident: Identifier, tableInfo: TableInfo): StagedTable = {
@@ -44,7 +45,8 @@ class StagingInMemoryTableCatalog extends InMemoryTableCatalog with StagingTable
new TestStagedReplaceTable(
ident,
new InMemoryTable(s"$name.${ident.quoted}",
- tableInfo.schema(), tableInfo.partitions(), tableInfo.properties()))
+ tableInfo.columns(), tableInfo.partitions(), tableInfo.properties(),
+ tableInfo.constraints()))
}
override def stageCreateOrReplace(ident: Identifier, tableInfo: TableInfo) : StagedTable = {
@@ -52,7 +54,8 @@ class StagingInMemoryTableCatalog extends InMemoryTableCatalog with StagingTable
new TestStagedCreateOrReplaceTable(
ident,
new InMemoryTable(s"$name.${ident.quoted}",
- tableInfo.schema(), tableInfo.partitions(), tableInfo.properties))
+ tableInfo.columns(), tableInfo.partitions(), tableInfo.properties(),
+ tableInfo.constraints()))
}
private def validateStagedTable(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2TableUtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2TableUtilSuite.scala
new file mode 100644
index 0000000000000..a9e5668d7fefb
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2TableUtilSuite.scala
@@ -0,0 +1,656 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog
+
+import java.util
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, MetadataAttribute}
+import org.apache.spark.sql.connector.catalog.TableCapability.BATCH_READ
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.sql.util.SchemaValidationMode.{ALLOW_NEW_TOP_LEVEL_FIELDS, PROHIBIT_CHANGES}
+import org.apache.spark.sql.util.SchemaValidationMode
+import org.apache.spark.util.ArrayImplicits.SparkArrayOps
+
+class V2TableUtilSuite extends SparkFunSuite {
+
+ test("validateCapturedColumns - no changes") {
+ val cols = Array(
+ col("id", LongType, nullable = false),
+ col("name", StringType, nullable = true))
+ val table = TestTableWithMetadataSupport("test", cols)
+
+ val errors = validateCapturedColumns(table, cols)
+ assert(errors.isEmpty, "No changes should produce no errors")
+ }
+
+ test("validateCapturedColumns - column type changed") {
+ val originCols = Array(
+ col("id", LongType, nullable = true), // original type
+ col("name", StringType, nullable = true))
+ val currentCols = Array(
+ col("id", StringType, nullable = true), // changed from LongType
+ col("name", StringType, nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols)
+ assert(errors.size == 1)
+ assert(errors.head.contains("`id` type has changed from BIGINT to STRING"))
+ }
+
+ test("validateCapturedColumns - column nullability changed to not null") {
+ val originCols = Array(
+ col("id", LongType, nullable = true), // originally nullable
+ col("name", StringType, nullable = true))
+ val currentCols = Array(
+ col("id", LongType, nullable = false), // now NOT NULL
+ col("name", StringType, nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols)
+ assert(errors.size == 1)
+ assert(errors.head == "`id` is no longer nullable")
+ }
+
+ test("validateCapturedColumns - column nullability changed to nullable") {
+ val originCols = Array(
+ col("id", LongType, nullable = false), // originally NOT NULL
+ col("name", StringType, nullable = true))
+ val currentCols = Array(
+ col("id", LongType, nullable = true), // now nullable
+ col("name", StringType, nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols)
+ assert(errors.size == 1)
+ assert(errors.head == "`id` is nullable now")
+ }
+
+ test("validateCapturedColumns - column removed") {
+ val originCols = Array(
+ col("id", LongType, nullable = true), // originally present
+ col("name", StringType, nullable = true))
+ val currentCols = Array(
+ col("name", StringType, nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols)
+ assert(errors.size == 1)
+ assert(errors.head == "`id` BIGINT has been removed")
+ }
+
+ test("validateCapturedColumns - column added") {
+ val originCols = Array(
+ col("id", LongType, nullable = true),
+ col("name", StringType, nullable = true))
+ val currentCols = Array(
+ col("id", LongType, nullable = true),
+ col("name", StringType, nullable = true),
+ col("age", IntegerType, nullable = true)) // new column
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols)
+ assert(errors.size == 1)
+ assert(errors.head == "`age` INT has been added")
+ }
+
+ test("validateCapturedColumns - multiple columns removed and added") {
+ val originCols = Array(
+ col("id", LongType, nullable = true),
+ col("name", StringType, nullable = true), // originally present
+ col("address", StringType, nullable = true)) // originally present
+ val currentCols = Array(
+ col("id", LongType, nullable = true),
+ col("email", StringType, nullable = true), // new column
+ col("age", IntegerType, nullable = true)) // new column
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols)
+ assert(errors.size == 4) // 2 removed + 2 added
+ assert(errors.count(_.contains("removed")) == 2)
+ assert(errors.count(_.contains("has been added")) == 2)
+ }
+
+ test("validateCapturedColumns - case insensitive column names") {
+ val originCols = Array(
+ col("id", LongType, nullable = true), // lowercase
+ col("name", StringType, nullable = true))
+ val currentCols = Array(
+ col("ID", LongType, nullable = true), // uppercase
+ col("NAME", StringType, nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols)
+ assert(errors.isEmpty, "Case insensitive comparison should match")
+ }
+
+ test("validateCapturedColumns - duplicate columns with different case") {
+ val originCols = Array(
+ col("id", LongType, nullable = true),
+ col("name", StringType, nullable = true))
+ val currentCols = Array(
+ col("id", LongType, nullable = true),
+ col("ID", StringType, nullable = true), // duplicate with different case
+ col("name", StringType, nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val e = intercept[AnalysisException] { validateCapturedColumns(table, originCols) }
+ assert(e.message.contains("Choose another name or rename the existing column"))
+ }
+
+ test("validateCapturedColumns - complex types") {
+ val structType = StructType(Seq(
+ StructField("street", StringType),
+ StructField("city", StringType)))
+ val originCols = Array(
+ col("id", LongType, nullable = true),
+ col("address", structType, nullable = true))
+ val currentCols = Array(
+ col("id", LongType, nullable = true),
+ col("address", structType, nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols)
+ assert(errors.isEmpty)
+ }
+
+ test("validateCapturedColumns - complex type changed") {
+ val originStructType = StructType(Seq(
+ StructField("street", StringType),
+ StructField("city", StringType))) // originally StringType
+ val originCols = Array(
+ col("id", LongType, nullable = true),
+ col("address", originStructType, nullable = true))
+ val currentStructType = StructType(Seq(
+ StructField("street", StringType),
+ StructField("city", IntegerType))) // changed type
+ val currentCols = Array(
+ col("id", LongType, nullable = true),
+ col("address", currentStructType, nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = V2TableUtil.validateCapturedColumns(table, originCols.toSeq)
+ assert(errors.size == 1)
+ assert(errors.head.contains("`address`.`city` type has changed from STRING to INT"))
+ }
+
+ test("validateCapturedMetadataColumns - no changes") {
+ val originMetaCols = Seq(
+ metaCol("_partition", StringType, nullable = false),
+ metaCol("index", IntegerType, nullable = false))
+ val currentMetaCols = Array(
+ metaCol("_partition", StringType, nullable = false),
+ metaCol("index", IntegerType, nullable = false))
+ val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols)
+
+ val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols)
+ assert(errors.isEmpty, "No changes should produce no errors")
+ }
+
+ test("validateCapturedMetadataColumns - type changed") {
+ val originMetaCols = Seq(
+ metaCol("index", IntegerType, nullable = false)) // originally IntegerType
+ val currentMetaCols = Array(
+ metaCol("index", StringType, nullable = false)) // changed to StringType
+ val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols)
+
+ val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols)
+ assert(errors.size == 1)
+ assert(errors.head == "`index` type has changed from INT to STRING")
+ }
+
+ test("validateCapturedMetadataColumns - nullability changed to nullable") {
+ val originMetaCols = Seq(
+ metaCol("index", IntegerType, nullable = false)) // originally NOT NULL
+ val currentMetaCols = Array(
+ metaCol("index", IntegerType, nullable = true)) // now nullable
+ val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols)
+
+ val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols)
+ assert(errors.size == 1)
+ assert(errors.head == "`index` is nullable now")
+ }
+
+ test("validateCapturedMetadataColumns - nullability changed to not null") {
+ val originMetaCols = Seq(metaCol("index", IntegerType, nullable = true)) // originally nullable
+ val currentMetaCols = Array(metaCol("index", IntegerType, nullable = false)) // now NOT NULL
+ val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols)
+
+ val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols)
+ assert(errors.size == 1)
+ assert(errors.head == "`index` is no longer nullable")
+ }
+
+ test("validateCapturedMetadataColumns - column removed") {
+ val originMetaCols = Seq(metaCol("index", IntegerType, nullable = true)) // originally present
+ val currentMetaCols = Array.empty[MetadataColumn] // no metadata columns
+ val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols)
+
+ val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols)
+ assert(errors.size == 1)
+ assert(errors.head == "`index` INT has been removed")
+ }
+
+ test("validateCapturedMetadataColumns - table doesn't support metadata") {
+ // table that doesn't implement SupportsMetadataColumns
+ val table = TestTable("test", Array(col("id", LongType, nullable = true)))
+ val originMetaCols = Seq(metaCol("index", IntegerType, nullable = false))
+
+ val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols)
+ assert(errors.size == 1)
+ assert(errors.head == "`index` INT NOT NULL has been removed")
+ }
+
+ test("validateCapturedMetadataColumns - multiple errors") {
+ val originMetaCols = Seq(
+ metaCol("_partition", StringType, nullable = false),
+ metaCol("index", IntegerType, nullable = false)) // originally present
+ val currentMetaCols = Array(
+ metaCol("_partition", IntegerType, nullable = false)) // type changed from StringType
+ val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols)
+
+ val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols)
+ assert(errors.size == 2)
+ assert(errors.exists(e => e.contains("_partition") && e.contains("type has changed")))
+ assert(errors.exists(e => e.contains("index") && e.contains("removed")))
+ }
+
+ test("validateCapturedMetadataColumns - case insensitive names") {
+ val originMetaCols = Seq(metaCol("index", IntegerType, nullable = true)) // lowercase
+ val currentMetaCols = Array(metaCol("INDEX", IntegerType, nullable = true)) // uppercase
+ val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols)
+
+ val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols)
+ assert(errors.isEmpty, "Case insensitive comparison should match")
+ }
+
+ test("validateCapturedMetadataColumns - duplicate metadata columns with different case") {
+ val originMetaCols = Seq(
+ metaCol("_partition", StringType, nullable = false),
+ metaCol("index", IntegerType, nullable = false))
+ val currentMetaCols = Array(
+ metaCol("_partition", StringType, nullable = false),
+ metaCol("index", IntegerType, nullable = false),
+ metaCol("INDEX", StringType, nullable = false)) // duplicate with different case
+ val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols)
+
+ val e = intercept[AnalysisException] {
+ V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols)
+ }
+ assert(e.message.contains("Choose another name or rename the existing column"))
+ }
+
+ test("validateCapturedMetadataColumns - empty metadata columns") {
+ val originMetaCols = Seq.empty[MetadataColumn]
+ val currentMetaCols = Array.empty[MetadataColumn]
+ val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols)
+
+ val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols)
+ assert(errors.isEmpty, "No metadata columns should produce no errors")
+ }
+
+ test("validateCapturedMetadataColumns - complex metadata type") {
+ val structType = StructType(Seq(
+ StructField("bucket", IntegerType),
+ StructField("partition", IntegerType)))
+ val originMetaCols = Seq(metaCol("_partition", structType, nullable = false))
+ val currentMetaCols = Array(metaCol("_partition", structType, nullable = false))
+ val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols)
+
+ val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols)
+ assert(errors.isEmpty)
+ }
+
+ test("validateCapturedMetadataColumns - complex metadata type changed") {
+ val originStructType = StructType(Seq(
+ StructField("bucket", IntegerType), // originally IntegerType
+ StructField("partition", IntegerType)))
+ val originMetaCols = Seq(metaCol("_partition", originStructType, nullable = false))
+ val currentStructType = StructType(Seq(
+ StructField("bucket", StringType), // changed type
+ StructField("partition", IntegerType)))
+ val currentMetaCols = Array(metaCol("_partition", currentStructType, nullable = false))
+ val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols)
+
+ val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols)
+ assert(errors.size == 1)
+ assert(errors.head.contains("`_partition`.`bucket` type has changed from INT to STRING"))
+ }
+
+ test("validateCapturedMetadataColumns - with DataSourceV2Relation") {
+ val dataCols = Array(
+ col("id", LongType, nullable = true),
+ col("name", StringType, nullable = true))
+ val originMetaCols = Array(
+ metaCol("_partition", StringType, nullable = false),
+ metaCol("index", IntegerType, nullable = false))
+ val originTable = TestTableWithMetadataSupport("test", dataCols, originMetaCols)
+
+ val dataAttrs = dataCols.map(c => AttributeReference(c.name, c.dataType, c.nullable)())
+ val metadataAttrs = originMetaCols.map(c => MetadataAttribute(c.name, c.dataType, c.isNullable))
+ val attrs = dataAttrs ++ metadataAttrs
+
+ val relation = DataSourceV2Relation(
+ originTable,
+ attrs.toImmutableArraySeq,
+ None,
+ None,
+ CaseInsensitiveStringMap.empty())
+
+ val currentMetaCols = Array(
+ metaCol("_partition", IntegerType, nullable = false), // type changed
+ metaCol("index", IntegerType, nullable = false))
+ val currentTable = TestTableWithMetadataSupport("test", dataCols, currentMetaCols)
+
+ val errors = V2TableUtil.validateCapturedMetadataColumns(
+ currentTable,
+ relation,
+ mode = PROHIBIT_CHANGES)
+ assert(errors.size == 1)
+ assert(errors.head.contains("`_partition` type has changed"))
+ }
+
+ test("validateCapturedMetadataColumns - with DataSourceV2Relation no metadata attrs") {
+ val dataCols = Array(
+ col("id", LongType, nullable = true),
+ col("name", StringType, nullable = true))
+ val originTable = TestTable("test", dataCols)
+
+ val dataAttrs = dataCols.map(c => AttributeReference(c.name, c.dataType, c.nullable)())
+
+ val relation = DataSourceV2Relation(
+ originTable,
+ dataAttrs.toImmutableArraySeq,
+ None,
+ None,
+ CaseInsensitiveStringMap.empty())
+
+ val currentTable = TestTable("test", dataCols)
+
+ val errors = V2TableUtil.validateCapturedMetadataColumns(
+ currentTable,
+ relation,
+ mode = PROHIBIT_CHANGES)
+ assert(errors.isEmpty)
+ }
+
+ test("validateCapturedColumns - array element type changed") {
+ val originCols = Array(
+ col("id", LongType, nullable = true),
+ col("data", ArrayType(IntegerType, containsNull = true), nullable = true))
+ val currentCols = Array(
+ col("id", LongType, nullable = true),
+ col("data", ArrayType(LongType, containsNull = true), nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols)
+ assert(errors.size == 1)
+ assert(errors.head == "`data`.`element` type has changed from INT to BIGINT")
+ }
+
+ test("validateCapturedColumns - array containsNull changed to false") {
+ val originCols = Array(
+ col("id", LongType, nullable = true),
+ col("data", ArrayType(IntegerType, containsNull = true), nullable = true))
+ val currentCols = Array(
+ col("id", LongType, nullable = true),
+ col("data", ArrayType(IntegerType, containsNull = false), nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols)
+ assert(errors.size == 1)
+ assert(errors.head == "`data`.`element` is no longer nullable")
+ }
+
+ test("validateCapturedColumns - array containsNull changed to true") {
+ val originCols = Array(
+ col("id", LongType, nullable = true),
+ col("data", ArrayType(IntegerType, containsNull = false), nullable = true))
+ val currentCols = Array(
+ col("id", LongType, nullable = true),
+ col("data", ArrayType(IntegerType, containsNull = true), nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols)
+ assert(errors.size == 1)
+ assert(errors.head == "`data`.`element` is nullable now")
+ }
+
+ test("validateCapturedColumns - map key type changed") {
+ val originCols = Array(
+ col("id", LongType, nullable = true),
+ col("data", MapType(IntegerType, StringType, valueContainsNull = true), nullable = true))
+ val currentCols = Array(
+ col("id", LongType, nullable = true),
+ col("data", MapType(LongType, StringType, valueContainsNull = true), nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols)
+ assert(errors.size == 1)
+ assert(errors.head == "`data`.`key` type has changed from INT to BIGINT")
+ }
+
+ test("validateCapturedColumns - map value type changed") {
+ val originCols = Array(
+ col("id", LongType, nullable = true),
+ col("data", MapType(StringType, IntegerType, valueContainsNull = true), nullable = true))
+ val currentCols = Array(
+ col("id", LongType, nullable = true),
+ col("data", MapType(StringType, LongType, valueContainsNull = true), nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols)
+ assert(errors.size == 1)
+ assert(errors.head == "`data`.`value` type has changed from INT to BIGINT")
+ }
+
+ test("validateCapturedColumns - map valueContainsNull changed to false") {
+ val originCols = Array(
+ col("id", LongType, nullable = true),
+ col("data", MapType(StringType, IntegerType, valueContainsNull = true), nullable = true))
+ val currentCols = Array(
+ col("id", LongType, nullable = true),
+ col("data", MapType(StringType, IntegerType, valueContainsNull = false), nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols)
+ assert(errors.size == 1)
+ assert(errors.head == "`data`.`value` is no longer nullable")
+ }
+
+ test("validateCapturedColumns - map valueContainsNull changed to true") {
+ val originCols = Array(
+ col("id", LongType, nullable = true),
+ col("data", MapType(StringType, IntegerType, valueContainsNull = false), nullable = true))
+ val currentCols = Array(
+ col("id", LongType, nullable = true),
+ col("data", MapType(StringType, IntegerType, valueContainsNull = true), nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols)
+ assert(errors.size == 1)
+ assert(errors.head == "`data`.`value` is nullable now")
+ }
+
+ test("validateCapturedColumns - nested array in struct element type changed") {
+ val originStructType = StructType(Seq(
+ StructField("name", StringType),
+ StructField("scores", ArrayType(IntegerType, containsNull = true))))
+ val originCols = Array(
+ col("id", LongType, nullable = true),
+ col("person", originStructType, nullable = true))
+ val currentStructType = StructType(Seq(
+ StructField("name", StringType),
+ StructField("scores", ArrayType(LongType, containsNull = true))))
+ val currentCols = Array(
+ col("id", LongType, nullable = true),
+ col("person", currentStructType, nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols)
+ assert(errors.size == 1)
+ assert(errors.head == "`person`.`scores`.`element` type has changed from INT to BIGINT")
+ }
+
+ test("validateCapturedColumns - nested map in struct value type changed") {
+ val originStructType = StructType(Seq(
+ StructField("name", StringType),
+ StructField("attrs", MapType(StringType, IntegerType, valueContainsNull = true))))
+ val originCols = Array(
+ col("id", LongType, nullable = true),
+ col("person", originStructType, nullable = true))
+ val currentStructType = StructType(Seq(
+ StructField("name", StringType),
+ StructField("attrs", MapType(StringType, LongType, valueContainsNull = true))))
+ val currentCols = Array(
+ col("id", LongType, nullable = true),
+ col("person", currentStructType, nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols)
+ assert(errors.size == 1)
+ assert(errors.head == "`person`.`attrs`.`value` type has changed from INT to BIGINT")
+ }
+
+ test("validateCapturedColumns - ALLOW_NEW_TOP_LEVEL_FIELDS allows top-level additions") {
+ val originCols = Array(
+ col("id", LongType, nullable = false),
+ col("name", StringType, nullable = true))
+ val currentCols = Array(
+ col("id", LongType, nullable = false),
+ col("name", StringType, nullable = true),
+ col("age", IntegerType, nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols, ALLOW_NEW_TOP_LEVEL_FIELDS)
+ assert(errors.isEmpty)
+ }
+
+ test("validateCapturedColumns - ALLOW_NEW_TOP_LEVEL_FIELDS prohibits nested additions") {
+ val originAddress = StructType(Seq(
+ StructField("street", StringType),
+ StructField("city", StringType)))
+ val originCols = Array(
+ col("id", LongType, nullable = false),
+ col("address", originAddress, nullable = true))
+ val currentAddress = StructType(Seq(
+ StructField("street", StringType),
+ StructField("city", StringType),
+ StructField("zipCode", StringType)))
+ val currentCols = Array(
+ col("id", LongType, nullable = false),
+ col("address", currentAddress, nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols, ALLOW_NEW_TOP_LEVEL_FIELDS)
+ assert(errors.size == 1)
+ assert(errors.head.contains("`address`.`zipCode` STRING has been added"))
+ }
+
+ test("validateCapturedColumns - ALLOW_NEW_TOP_LEVEL_FIELDS fails new nested fields in array") {
+ val originItem = StructType(Seq(
+ StructField("itemId", LongType),
+ StructField("itemName", StringType)))
+ val originCols = Array(
+ col("id", LongType, nullable = false),
+ col("items", ArrayType(originItem), nullable = true))
+ val currentItem = StructType(Seq(
+ StructField("itemId", LongType),
+ StructField("itemName", StringType),
+ StructField("price", IntegerType)))
+ val currentCols = Array(
+ col("id", LongType, nullable = false),
+ col("items", ArrayType(currentItem), nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = V2TableUtil.validateCapturedColumns(
+ table,
+ originCols.toImmutableArraySeq,
+ mode = ALLOW_NEW_TOP_LEVEL_FIELDS)
+ assert(errors.size == 1)
+ assert(errors.head.contains("`items`.`element`.`price` INT has been added"))
+ }
+
+ test("validateCapturedColumns - ALLOW_NEW_TOP_LEVEL_FIELDS prohibits nested map additions") {
+ val originValue = StructType(Seq(
+ StructField("count", IntegerType),
+ StructField("status", StringType)))
+ val originCols = Array(
+ col("id", LongType, nullable = false),
+ col("metadata", MapType(StringType, originValue), nullable = true))
+ val currentValue = StructType(Seq(
+ StructField("count", IntegerType),
+ StructField("status", StringType),
+ StructField("timestamp", LongType)))
+ val currentCols = Array(
+ col("id", LongType, nullable = false),
+ col("metadata", MapType(StringType, currentValue), nullable = true))
+ val table = TestTableWithMetadataSupport("test", currentCols)
+
+ val errors = validateCapturedColumns(table, originCols, ALLOW_NEW_TOP_LEVEL_FIELDS)
+ assert(errors.size == 1)
+ assert(errors.head.contains("`metadata`.`value`.`timestamp` BIGINT has been added"))
+ }
+
+ // simple table without metadata column support
+ private case class TestTable(
+ override val name: String,
+ override val columns: Array[Column])
+ extends Table {
+ override def capabilities: util.Set[TableCapability] = util.Set.of(BATCH_READ)
+ }
+
+ // simple table implementation with metadata column support
+ private case class TestTableWithMetadataSupport(
+ override val name: String,
+ override val columns: Array[Column],
+ override val metadataColumns: Array[MetadataColumn] = Array.empty)
+ extends Table with SupportsMetadataColumns {
+ override def capabilities: util.Set[TableCapability] = util.Set.of(BATCH_READ)
+ }
+
+ private case class TestMetadataColumn(
+ override val name: String,
+ override val dataType: DataType,
+ override val isNullable: Boolean)
+ extends MetadataColumn {
+ override def comment: String = s"Test metadata column $name"
+ override def metadataInJSON: String = "{}"
+ }
+
+ private def validateCapturedColumns(
+ table: Table,
+ originCols: Array[Column],
+ mode: SchemaValidationMode = PROHIBIT_CHANGES): Seq[String] = {
+ V2TableUtil.validateCapturedColumns(table, originCols.toImmutableArraySeq, mode)
+ }
+
+ private def col(name: String, dataType: DataType, nullable: Boolean): Column = {
+ Column.create(name, dataType, nullable)
+ }
+
+ private def metaCol(
+ name: String,
+ dataType: DataType,
+ nullable: Boolean): MetadataColumn = {
+ TestMetadataColumn(name, dataType, nullable)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
index 7124c94b390d0..8011e69e724c4 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala
@@ -49,6 +49,10 @@ class ArrowUtilsSuite extends SparkFunSuite {
roundtrip(BinaryType)
roundtrip(DecimalType.SYSTEM_DEFAULT)
roundtrip(DateType)
+ roundtrip(GeometryType("ANY"))
+ roundtrip(GeometryType(4326))
+ roundtrip(GeographyType("ANY"))
+ roundtrip(GeographyType(4326))
roundtrip(YearMonthIntervalType())
roundtrip(DayTimeIntervalType())
checkError(
diff --git a/sql/connect/client/jdbc/pom.xml b/sql/connect/client/jdbc/pom.xml
index c2dda12b1e639..c84ae04d3d735 100644
--- a/sql/connect/client/jdbc/pom.xml
+++ b/sql/connect/client/jdbc/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../../../pom.xml
@@ -37,79 +37,8 @@
org.apache.spark
- spark-connect-common_${scala.binary.version}
- ${project.version}
-
-
- org.apache.spark
- spark-sql-api_${scala.binary.version}
- ${project.version}
-
-
- org.apache.spark
- spark-connect-shims_${scala.binary.version}
- ${project.version}
-
-
- org.apache.spark
- spark-sketch_${scala.binary.version}
- ${project.version}
-
-
- org.scala-lang
- scala-compiler
- compile
-
-
- org.scala-lang.modules
- scala-xml_${scala.binary.version}
- compile
-
-
-
- com.google.protobuf
- protobuf-java
- compile
-
-
- com.google.guava
- guava
- ${connect.guava.version}
- compile
-
-
- com.google.guava
- failureaccess
- ${guava.failureaccess.version}
- compile
-
-
- org.apache.spark
- spark-tags_${scala.binary.version}
- test-jar
- test
-
-
- org.scalacheck
- scalacheck_${scala.binary.version}
- test
-
-
- org.apache.spark
- spark-sql-api_${scala.binary.version}
- ${project.version}
- tests
- test
-
-
- org.apache.spark
- spark-common-utils_${scala.binary.version}
+ spark-connect-client-jvm_${scala.binary.version}
${project.version}
- tests
- test
org.apache.spark
@@ -118,13 +47,6 @@
tests
test
-
-
- com.typesafe
- mima-core_${scala.binary.version}
- ${mima.version}
- test
-
target/scala-${scala.binary.version}/classes
@@ -144,10 +66,10 @@
maven-shade-plugin
false
- true
+ false
- org.apache.spark:spark-connect-client-jdbc_${scala.binary.version}
+ org.spark-project.spark:unused
@@ -177,14 +99,6 @@
io.netty
${spark.shade.packageName}.io.netty
-
- org.checkerframework
- ${spark.shade.packageName}.org.checkerframework
-
-
- javax.annotation
- ${spark.shade.packageName}.javax.annotation
-
io.perfmark
${spark.shade.packageName}.io.perfmark
diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectConnection.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectConnection.scala
index 95ec956771dbb..21b9471bb6069 100644
--- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectConnection.scala
+++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectConnection.scala
@@ -185,7 +185,7 @@ class SparkConnectConnection(val url: String, val info: Properties) extends Conn
if (level != Connection.TRANSACTION_NONE) {
throw new SQLFeatureNotSupportedException(
"Requested transaction isolation level " +
- s"${stringfiyTransactionIsolationLevel(level)} is not supported")
+ s"${stringifyTransactionIsolationLevel(level)} is not supported")
}
}
@@ -207,7 +207,7 @@ class SparkConnectConnection(val url: String, val info: Properties) extends Conn
override def setHoldability(holdability: Int): Unit = {
if (holdability != ResultSet.HOLD_CURSORS_OVER_COMMIT) {
throw new SQLFeatureNotSupportedException(
- s"Holdability ${stringfiyHoldability(holdability)} is not supported")
+ s"Holdability ${stringifyHoldability(holdability)} is not supported")
}
}
diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaData.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaData.scala
index a16cba5e3da4e..7a37c272daf2b 100644
--- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaData.scala
+++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaData.scala
@@ -18,13 +18,22 @@
package org.apache.spark.sql.connect.client.jdbc
import java.sql.{Array => _, _}
+import java.sql.DatabaseMetaData._
import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION}
+import org.apache.spark.SparkThrowable
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.util.QuotingUtils._
+import org.apache.spark.sql.connect
import org.apache.spark.sql.connect.client.jdbc.SparkConnectDatabaseMetaData._
+import org.apache.spark.sql.connect.client.jdbc.util.JdbcTypeUtils
+import org.apache.spark.sql.functions._
import org.apache.spark.util.VersionUtils
class SparkConnectDatabaseMetaData(conn: SparkConnectConnection) extends DatabaseMetaData {
+ import conn.spark.implicits._
+
override def allProceduresAreCallable: Boolean = false
override def allTablesAreSelectable: Boolean = false
@@ -95,8 +104,7 @@ class SparkConnectDatabaseMetaData(conn: SparkConnectConnection) extends Databas
override def getTimeDateFunctions: String =
throw new SQLFeatureNotSupportedException
- override def getSearchStringEscape: String =
- throw new SQLFeatureNotSupportedException
+ override def getSearchStringEscape: String = "\\"
override def getExtraNameCharacters: String = ""
@@ -275,6 +283,9 @@ class SparkConnectDatabaseMetaData(conn: SparkConnectConnection) extends Databas
override def dataDefinitionIgnoredInTransactions: Boolean = false
+ private def isNullOrWildcard(pattern: String): Boolean =
+ pattern == null || pattern == "%"
+
override def getProcedures(
catalog: String,
schemaPattern: String,
@@ -288,31 +299,279 @@ class SparkConnectDatabaseMetaData(conn: SparkConnectConnection) extends Databas
columnNamePattern: String): ResultSet =
throw new SQLFeatureNotSupportedException
- override def getCatalogs: ResultSet =
- throw new SQLFeatureNotSupportedException
+ override def getCatalogs: ResultSet = {
+ conn.checkOpen()
- override def getSchemas: ResultSet =
- throw new SQLFeatureNotSupportedException
+ val df = conn.spark.sql("SHOW CATALOGS")
+ .select($"catalog".as("TABLE_CAT"))
+ .orderBy("TABLE_CAT")
+ new SparkConnectResultSet(df.collectResult())
+ }
- override def getSchemas(catalog: String, schemaPattern: String): ResultSet =
- throw new SQLFeatureNotSupportedException
+ override def getSchemas: ResultSet = {
+ conn.checkOpen()
- override def getTableTypes: ResultSet =
- throw new SQLFeatureNotSupportedException
+ getSchemas(null, null)
+ }
+
+ // Schema of the returned DataFrame is:
+ // |-- TABLE_SCHEM: string (nullable = false)
+ // |-- TABLE_CATALOG: string (nullable = false)
+ private def getSchemasDataFrame(
+ catalog: String, schemaPatternOpt: Option[String]): connect.DataFrame = {
+
+ val schemaFilterExpr = schemaPatternOpt match {
+ case None => $"TABLE_SCHEM".equalTo(conn.spark.catalog.currentDatabase)
+ case Some(schemaPattern) if isNullOrWildcard(schemaPattern) => lit(true)
+ case Some(schemaPattern) => $"TABLE_SCHEM".like(schemaPattern)
+ }
+
+ lazy val emptyDf = conn.spark.emptyDataFrame
+ .withColumn("TABLE_SCHEM", lit(""))
+ .withColumn("TABLE_CATALOG", lit(""))
+
+ def internalGetSchemas(
+ catalogOpt: Option[String],
+ schemaFilterExpr: Column): connect.DataFrame = {
+ val catalog = catalogOpt.getOrElse(conn.getCatalog)
+ try {
+ // Spark SQL supports LIKE clause in SHOW SCHEMAS command, but we can't use that
+ // because the LIKE pattern does not follow SQL standard.
+ conn.spark.sql(s"SHOW SCHEMAS IN ${quoteIdentifier(catalog)}")
+ .select($"namespace".as("TABLE_SCHEM"))
+ .filter(schemaFilterExpr)
+ .withColumn("TABLE_CATALOG", lit(catalog))
+ } catch {
+ case st: SparkThrowable if st.getCondition == "MISSING_CATALOG_ABILITY.NAMESPACES" =>
+ emptyDf
+ }
+ }
+
+ if (catalog == null) {
+ // search in all catalogs
+ conn.spark.catalog.listCatalogs().collect().map(_.name).map { c =>
+ internalGetSchemas(Some(c), schemaFilterExpr)
+ }.fold(emptyDf) { (l, r) => l.unionAll(r) }
+ } else if (catalog == "") {
+ // search only in current catalog
+ internalGetSchemas(None, schemaFilterExpr)
+ } else {
+ // search in the specific catalog
+ internalGetSchemas(Some(catalog), schemaFilterExpr)
+ }
+ }
+
+ override def getSchemas(catalog: String, schemaPattern: String): ResultSet = {
+ conn.checkOpen()
+
+ val df = getSchemasDataFrame(catalog, Some(schemaPattern))
+ .orderBy("TABLE_CATALOG", "TABLE_SCHEM")
+ new SparkConnectResultSet(df.collectResult())
+ }
+
+ override def getTableTypes: ResultSet = {
+ conn.checkOpen()
+
+ val df = TABLE_TYPES.toDF("TABLE_TYPE")
+ .orderBy("TABLE_TYPE")
+ new SparkConnectResultSet(df.collectResult())
+ }
+
+ // Schema of the returned DataFrame is:
+ // |-- TABLE_CAT: string (nullable = false)
+ // |-- TABLE_SCHEM: string (nullable = false)
+ // |-- TABLE_NAME: string (nullable = false)
+ // |-- TABLE_TYPE: string (nullable = false)
+ // |-- REMARKS: string (nullable = false)
+ // |-- TYPE_CAT: string (nullable = false)
+ // |-- TYPE_SCHEM: string (nullable = false)
+ // |-- TYPE_NAME: string (nullable = false)
+ // |-- SELF_REFERENCING_COL_NAME: string (nullable = false)
+ // |-- REF_GENERATION: string (nullable = false)
+ private def getTablesDataFrame(
+ catalog: String,
+ schemaPattern: String,
+ tableNamePattern: String): connect.DataFrame = {
+
+ val catalogSchemasDf = if (schemaPattern == "") {
+ getSchemasDataFrame(catalog, None)
+ } else {
+ getSchemasDataFrame(catalog, Some(schemaPattern))
+ }
+
+ val catalogSchemas = catalogSchemasDf.collect()
+ .map { row => (row.getString(1), row.getString(0)) }
+
+ val tableNameFilterExpr = if (isNullOrWildcard(tableNamePattern)) {
+ lit(true)
+ } else {
+ $"TABLE_NAME".like(tableNamePattern)
+ }
+
+ lazy val emptyDf = conn.spark.emptyDataFrame
+ .withColumn("TABLE_CAT", lit(""))
+ .withColumn("TABLE_SCHEM", lit(""))
+ .withColumn("TABLE_NAME", lit(""))
+ .withColumn("TABLE_TYPE", lit(""))
+ .withColumn("REMARKS", lit(""))
+ .withColumn("TYPE_CAT", lit(""))
+ .withColumn("TYPE_SCHEM", lit(""))
+ .withColumn("TYPE_NAME", lit(""))
+ .withColumn("SELF_REFERENCING_COL_NAME", lit(""))
+ .withColumn("REF_GENERATION", lit(""))
+
+ catalogSchemas.map { case (catalog, schema) =>
+ val viewDf = try {
+ conn.spark
+ .sql(s"SHOW VIEWS IN ${quoteNameParts(Seq(catalog, schema))}")
+ .select($"namespace".as("TABLE_SCHEM"), $"viewName".as("TABLE_NAME"))
+ .filter(tableNameFilterExpr)
+ } catch {
+ case st: SparkThrowable if st.getCondition == "MISSING_CATALOG_ABILITY.VIEWS" =>
+ emptyDf.select("TABLE_SCHEM", "TABLE_NAME")
+ }
+
+ val tableDf = try {
+ conn.spark
+ .sql(s"SHOW TABLES IN ${quoteNameParts(Seq(catalog, schema))}")
+ .select($"namespace".as("TABLE_SCHEM"), $"tableName".as("TABLE_NAME"))
+ .filter(tableNameFilterExpr)
+ .exceptAll(viewDf)
+ } catch {
+ case st: SparkThrowable if st.getCondition == "MISSING_CATALOG_ABILITY.TABLES" =>
+ emptyDf.select("TABLE_SCHEM", "TABLE_NAME")
+ }
+
+ tableDf.withColumn("TABLE_TYPE", lit("TABLE"))
+ .unionAll(viewDf.withColumn("TABLE_TYPE", lit("VIEW")))
+ .withColumn("TABLE_CAT", lit(catalog))
+ .withColumn("REMARKS", lit(""))
+ .withColumn("TYPE_CAT", lit(""))
+ .withColumn("TYPE_SCHEM", lit(""))
+ .withColumn("TYPE_NAME", lit(""))
+ .withColumn("SELF_REFERENCING_COL_NAME", lit(""))
+ .withColumn("REF_GENERATION", lit(""))
+ .select("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "TABLE_TYPE", "REMARKS",
+ "TYPE_CAT", "TYPE_SCHEM", "TYPE_NAME", "SELF_REFERENCING_COL_NAME",
+ "REF_GENERATION")
+ }.fold(emptyDf) { (l, r) => l.unionAll(r) }
+ }
override def getTables(
catalog: String,
schemaPattern: String,
tableNamePattern: String,
- types: Array[String]): ResultSet =
- throw new SQLFeatureNotSupportedException
+ types: Array[String]): ResultSet = {
+ conn.checkOpen()
+
+ if (types != null) {
+ val unsupported = types.diff(TABLE_TYPES)
+ if (unsupported.nonEmpty) {
+ throw new SQLException(
+ "The requested table types contains unsupported items: " +
+ s"${unsupported.mkString(", ")}. Available table types are: " +
+ s"${TABLE_TYPES.mkString(", ")}.")
+ }
+ }
+
+ var df = getTablesDataFrame(catalog, schemaPattern, tableNamePattern)
+ .orderBy("TABLE_TYPE", "TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME")
+
+ if (types != null) {
+ df = df.filter($"TABLE_TYPE".isInCollection(types))
+ }
+ new SparkConnectResultSet(df.collectResult())
+ }
override def getColumns(
catalog: String,
schemaPattern: String,
tableNamePattern: String,
- columnNamePattern: String): ResultSet =
- throw new SQLFeatureNotSupportedException
+ columnNamePattern: String): ResultSet = {
+ conn.checkOpen()
+
+ val columnNameFilterExpr = if (isNullOrWildcard(columnNamePattern)) {
+ lit(true)
+ } else {
+ $"COLUMN_NAME".like(columnNamePattern)
+ }
+
+ lazy val emptyDf = conn.spark.emptyDataFrame
+ .withColumn("TABLE_CAT", lit(""))
+ .withColumn("TABLE_SCHEM", lit(""))
+ .withColumn("TABLE_NAME", lit(""))
+ .withColumn("COLUMN_NAME", lit(""))
+ .withColumn("DATA_TYPE", lit(0))
+ .withColumn("TYPE_NAME", lit(""))
+ .withColumn("COLUMN_SIZE", lit(0))
+ .withColumn("BUFFER_LENGTH", lit(0))
+ .withColumn("DECIMAL_DIGITS", lit(0))
+ .withColumn("NUM_PREC_RADIX", lit(0))
+ .withColumn("NULLABLE", lit(0))
+ .withColumn("REMARKS", lit(""))
+ .withColumn("COLUMN_DEF", lit(""))
+ .withColumn("SQL_DATA_TYPE", lit(0))
+ .withColumn("SQL_DATETIME_SUB", lit(0))
+ .withColumn("CHAR_OCTET_LENGTH", lit(0))
+ .withColumn("ORDINAL_POSITION", lit(0))
+ .withColumn("IS_NULLABLE", lit(""))
+ .withColumn("SCOPE_CATALOG", lit(""))
+ .withColumn("SCOPE_SCHEMA", lit(""))
+ .withColumn("SCOPE_TABLE", lit(""))
+ .withColumn("SOURCE_DATA_TYPE", lit(0.toShort))
+ .withColumn("IS_AUTOINCREMENT", lit(""))
+ .withColumn("IS_GENERATEDCOLUMN", lit(""))
+
+ val catalogSchemaTables =
+ getTablesDataFrame(catalog, schemaPattern, tableNamePattern)
+ .select("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME")
+ .collect().map { row => (row.getString(0), row.getString(1), row.getString(2)) }
+
+ val df = catalogSchemaTables.map { case (catalog, schema, table) =>
+ val columns = conn.spark.table(quoteNameParts(Seq(catalog, schema, table)))
+ .schema.zipWithIndex.map { case (field, i) =>
+ (
+ field.name, // COLUMN_NAME
+ JdbcTypeUtils.getColumnType(field), // DATA_TYPE
+ field.dataType.sql, // TYPE_NAME
+ JdbcTypeUtils.getDisplaySize(field), // COLUMN_SIZE
+ JdbcTypeUtils.getDecimalDigits(field), // DECIMAL_DIGITS
+ JdbcTypeUtils.getNumPrecRadix(field), // NUM_PREC_RADIX
+ if (field.nullable) columnNullable else columnNoNulls, // NULLABLE
+ field.getComment().orNull, // REMARKS
+ field.getCurrentDefaultValue().orNull, // COLUMN_DEF
+ 0, // CHAR_OCTET_LENGTH
+ i + 1, // ORDINAL_POSITION
+ if (field.nullable) "YES" else "NO", // IS_NULLABLE
+ "", // IS_AUTOINCREMENT
+ "" // IS_GENERATEDCOLUMN
+ )
+ }
+ columns.toDF("COLUMN_NAME", "DATA_TYPE", "TYPE_NAME", "COLUMN_SIZE", "DECIMAL_DIGITS",
+ "NUM_PREC_RADIX", "NULLABLE", "REMARKS", "COLUMN_DEF", "CHAR_OCTET_LENGTH",
+ "ORDINAL_POSITION", "IS_NULLABLE", "IS_AUTOINCREMENT", "IS_GENERATEDCOLUMN")
+ .filter(columnNameFilterExpr)
+ .withColumn("TABLE_CAT", lit(catalog))
+ .withColumn("TABLE_SCHEM", lit(schema))
+ .withColumn("TABLE_NAME", lit(table))
+ .withColumn("BUFFER_LENGTH", lit(0))
+ .withColumn("SQL_DATA_TYPE", lit(0))
+ .withColumn("SQL_DATETIME_SUB", lit(0))
+ .withColumn("SCOPE_CATALOG", lit(""))
+ .withColumn("SCOPE_SCHEMA", lit(""))
+ .withColumn("SCOPE_TABLE", lit(""))
+ .withColumn("SOURCE_DATA_TYPE", lit(0.toShort))
+ .select("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "DATA_TYPE",
+ "TYPE_NAME", "COLUMN_SIZE", "BUFFER_LENGTH", "DECIMAL_DIGITS", "NUM_PREC_RADIX",
+ "NULLABLE", "REMARKS", "COLUMN_DEF", "SQL_DATA_TYPE", "SQL_DATETIME_SUB",
+ "CHAR_OCTET_LENGTH", "ORDINAL_POSITION", "IS_NULLABLE", "SCOPE_CATALOG",
+ "SCOPE_SCHEMA", "SCOPE_TABLE", "SOURCE_DATA_TYPE", "IS_AUTOINCREMENT",
+ "IS_GENERATEDCOLUMN")
+ }.fold(emptyDf) { (l, r) => l.unionAll(r) }
+ .orderBy("TABLE_CAT", "TABLE_SCHEM", "TABLE_NAME", "ORDINAL_POSITION")
+
+ new SparkConnectResultSet(df.collectResult())
+ }
override def getColumnPrivileges(
catalog: String,
@@ -555,4 +814,6 @@ object SparkConnectDatabaseMetaData {
"XMLFOREST", "XMLNAMESPACES", "XMLPARSE", "XMLPI", "XMLROOT", "XMLSERIALIZE",
"YEAR"
)
+
+ private[jdbc] val TABLE_TYPES = Seq("TABLE", "VIEW")
}
diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala
index 0745ddc099111..e90f80f783dc9 100644
--- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala
+++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSet.scala
@@ -20,12 +20,15 @@ package org.apache.spark.sql.connect.client.jdbc
import java.io.{InputStream, Reader}
import java.net.URL
import java.sql.{Array => JdbcArray, _}
+import java.time.{LocalDateTime, LocalTime}
+import java.time.temporal.ChronoUnit
import java.util
import java.util.Calendar
import org.apache.spark.sql.Row
import org.apache.spark.sql.connect.client.SparkResult
import org.apache.spark.sql.connect.client.jdbc.util.JdbcErrorUtils
+import org.apache.spark.sql.types.{TimestampNTZType, TimestampType}
class SparkConnectResultSet(
sparkResult: SparkResult[Row],
@@ -42,9 +45,15 @@ class SparkConnectResultSet(
private var cursor: Int = 0
private var _wasNull: Boolean = false
- override def wasNull: Boolean = _wasNull
+
+ override def wasNull: Boolean = {
+ checkOpen()
+ _wasNull
+ }
override def next(): Boolean = {
+ checkOpen()
+
val hasNext = iterator.hasNext
if (hasNext) {
currentRow = iterator.next()
@@ -58,7 +67,7 @@ class SparkConnectResultSet(
hasNext
}
- @volatile protected var closed: Boolean = false
+ @volatile private var closed: Boolean = false
override def isClosed: Boolean = closed
@@ -76,7 +85,27 @@ class SparkConnectResultSet(
}
}
+ private def getColumnValue[T](columnIndex: Int, defaultVal: T)(getter: Int => T): T = {
+ checkOpen()
+ // the passed index value is 1-indexed, but the underlying array is 0-indexed
+ val index = columnIndex - 1
+ if (index < 0 || index >= currentRow.length) {
+ throw new SQLException(s"The column index is out of range: $columnIndex, " +
+ s"number of columns: ${currentRow.length}.")
+ }
+
+ if (currentRow.isNullAt(index)) {
+ _wasNull = true
+ defaultVal
+ } else {
+ _wasNull = false
+ getter(index)
+ }
+ }
+
override def findColumn(columnLabel: String): Int = {
+ checkOpen()
+
sparkResult.schema.getFieldIndex(columnLabel) match {
case Some(i) => i + 1
case None =>
@@ -85,91 +114,90 @@ class SparkConnectResultSet(
}
override def getString(columnIndex: Int): String = {
- if (currentRow.isNullAt(columnIndex - 1)) {
- _wasNull = true
- return null
+ getColumnValue(columnIndex, null: String) { idx =>
+ currentRow.get(idx) match {
+ case bytes: Array[Byte] =>
+ new String(bytes, java.nio.charset.StandardCharsets.UTF_8)
+ case other => String.valueOf(other)
+ }
}
- _wasNull = false
- String.valueOf(currentRow.get(columnIndex - 1))
}
override def getBoolean(columnIndex: Int): Boolean = {
- if (currentRow.isNullAt(columnIndex - 1)) {
- _wasNull = true
- return false
- }
- _wasNull = false
- currentRow.getBoolean(columnIndex - 1)
+ getColumnValue(columnIndex, false) { idx => currentRow.getBoolean(idx) }
}
override def getByte(columnIndex: Int): Byte = {
- if (currentRow.isNullAt(columnIndex - 1)) {
- _wasNull = true
- return 0.toByte
- }
- _wasNull = false
- currentRow.getByte(columnIndex - 1)
+ getColumnValue(columnIndex, 0.toByte) { idx => currentRow.getByte(idx) }
}
override def getShort(columnIndex: Int): Short = {
- if (currentRow.isNullAt(columnIndex - 1)) {
- _wasNull = true
- return 0.toShort
- }
- _wasNull = false
- currentRow.getShort(columnIndex - 1)
+ getColumnValue(columnIndex, 0.toShort) { idx => currentRow.getShort(idx) }
}
override def getInt(columnIndex: Int): Int = {
- if (currentRow.isNullAt(columnIndex - 1)) {
- _wasNull = true
- return 0
- }
- _wasNull = false
- currentRow.getInt(columnIndex - 1)
+ getColumnValue(columnIndex, 0) { idx => currentRow.getInt(idx) }
}
override def getLong(columnIndex: Int): Long = {
- if (currentRow.isNullAt(columnIndex - 1)) {
- _wasNull = true
- return 0L
- }
- _wasNull = false
- currentRow.getLong(columnIndex - 1)
+ getColumnValue(columnIndex, 0.toLong) { idx => currentRow.getLong(idx) }
}
override def getFloat(columnIndex: Int): Float = {
- if (currentRow.isNullAt(columnIndex - 1)) {
- _wasNull = true
- return 0.toFloat
- }
- _wasNull = false
- currentRow.getFloat(columnIndex - 1)
+ getColumnValue(columnIndex, 0.toFloat) { idx => currentRow.getFloat(idx) }
}
override def getDouble(columnIndex: Int): Double = {
- if (currentRow.isNullAt(columnIndex - 1)) {
- _wasNull = true
- return 0.toDouble
- }
- _wasNull = false
- currentRow.getDouble(columnIndex - 1)
+ getColumnValue(columnIndex, 0.toDouble) { idx => currentRow.getDouble(idx) }
}
override def getBigDecimal(columnIndex: Int, scale: Int): java.math.BigDecimal =
throw new SQLFeatureNotSupportedException
- override def getBytes(columnIndex: Int): Array[Byte] =
- throw new SQLFeatureNotSupportedException
+ override def getBytes(columnIndex: Int): Array[Byte] = {
+ getColumnValue(columnIndex, null: Array[Byte]) { idx =>
+ currentRow.get(idx).asInstanceOf[Array[Byte]]
+ }
+ }
- override def getDate(columnIndex: Int): Date =
- throw new SQLFeatureNotSupportedException
+ override def getDate(columnIndex: Int): Date = {
+ getColumnValue(columnIndex, null: Date) { idx => currentRow.getDate(idx) }
+ }
- override def getTime(columnIndex: Int): Time =
- throw new SQLFeatureNotSupportedException
+ override def getTime(columnIndex: Int): Time = {
+ getColumnValue(columnIndex, null: Time) { idx =>
+ val localTime = currentRow.get(idx).asInstanceOf[LocalTime]
+ // Convert LocalTime to milliseconds since midnight to preserve fractional seconds.
+ // Note: java.sql.Time can only store up to millisecond precision (3 digits).
+ // For TIME types with higher precision (TIME(4-9)), microseconds/nanoseconds are truncated.
+ // If user needs full precision,
+ // should use: getObject(columnIndex, classOf[LocalTime])
+ val millisSinceMidnight = ChronoUnit.MILLIS.between(LocalTime.MIDNIGHT, localTime)
+ new Time(millisSinceMidnight)
+ }
+ }
- override def getTimestamp(columnIndex: Int): Timestamp =
- throw new SQLFeatureNotSupportedException
+ override def getTimestamp(columnIndex: Int): Timestamp = {
+ getColumnValue(columnIndex, null: Timestamp) { idx =>
+ val value = currentRow.get(idx)
+ if (value == null) {
+ null
+ } else {
+ sparkResult.schema.fields(idx).dataType match {
+ case TimestampNTZType =>
+ // TIMESTAMP_NTZ is represented as LocalDateTime
+ Timestamp.valueOf(value.asInstanceOf[LocalDateTime])
+ case TimestampType =>
+ // TIMESTAMP is represented as Timestamp
+ value.asInstanceOf[Timestamp]
+ case other =>
+ throw new SQLException(
+ s"Cannot call getTimestamp() on column of type $other. " +
+ s"Expected TIMESTAMP or TIMESTAMP_NTZ.")
+ }
+ }
+ }
+ }
override def getAsciiStream(columnIndex: Int): InputStream =
throw new SQLFeatureNotSupportedException
@@ -208,16 +236,16 @@ class SparkConnectResultSet(
throw new SQLFeatureNotSupportedException
override def getBytes(columnLabel: String): Array[Byte] =
- throw new SQLFeatureNotSupportedException
+ getBytes(findColumn(columnLabel))
override def getDate(columnLabel: String): Date =
- throw new SQLFeatureNotSupportedException
+ getDate(findColumn(columnLabel))
override def getTime(columnLabel: String): Time =
- throw new SQLFeatureNotSupportedException
+ getTime(findColumn(columnLabel))
override def getTimestamp(columnLabel: String): Timestamp =
- throw new SQLFeatureNotSupportedException
+ getTimestamp(findColumn(columnLabel))
override def getAsciiStream(columnLabel: String): InputStream =
throw new SQLFeatureNotSupportedException
@@ -240,12 +268,9 @@ class SparkConnectResultSet(
}
override def getObject(columnIndex: Int): AnyRef = {
- if (currentRow.isNullAt(columnIndex - 1)) {
- _wasNull = true
- return null
+ getColumnValue(columnIndex, null: AnyRef) { idx =>
+ currentRow.get(idx).asInstanceOf[AnyRef]
}
- _wasNull = false
- currentRow.get(columnIndex - 1).asInstanceOf[AnyRef]
}
override def getObject(columnLabel: String): AnyRef =
@@ -257,11 +282,14 @@ class SparkConnectResultSet(
override def getCharacterStream(columnLabel: String): Reader =
throw new SQLFeatureNotSupportedException
- override def getBigDecimal(columnIndex: Int): java.math.BigDecimal =
- throw new SQLFeatureNotSupportedException
+ override def getBigDecimal(columnIndex: Int): java.math.BigDecimal = {
+ getColumnValue(columnIndex, null: java.math.BigDecimal) { idx =>
+ currentRow.getDecimal(idx)
+ }
+ }
override def getBigDecimal(columnLabel: String): java.math.BigDecimal =
- throw new SQLFeatureNotSupportedException
+ getBigDecimal(findColumn(columnLabel))
override def isBeforeFirst: Boolean = {
checkOpen()
@@ -518,23 +546,62 @@ class SparkConnectResultSet(
override def getArray(columnLabel: String): JdbcArray =
throw new SQLFeatureNotSupportedException
- override def getDate(columnIndex: Int, cal: Calendar): Date =
- throw new SQLFeatureNotSupportedException
+ /**
+ * @inheritdoc
+ *
+ * Note: The Calendar parameter is ignored. Spark Connect handles timezone conversions
+ * server-side to avoid client/server timezone inconsistencies.
+ */
+ override def getDate(columnIndex: Int, cal: Calendar): Date = {
+ getDate(columnIndex)
+ }
+ /**
+ * @inheritdoc
+ *
+ * Note: The Calendar parameter is ignored. Spark Connect handles timezone conversions
+ * server-side to avoid client/server timezone inconsistencies.
+ */
override def getDate(columnLabel: String, cal: Calendar): Date =
- throw new SQLFeatureNotSupportedException
-
- override def getTime(columnIndex: Int, cal: Calendar): Time =
- throw new SQLFeatureNotSupportedException
+ getDate(findColumn(columnLabel))
+
+ /**
+ * @inheritdoc
+ *
+ * Note: The Calendar parameter is ignored. Spark Connect handles timezone conversions
+ * server-side to avoid client/server timezone inconsistencies.
+ */
+ override def getTime(columnIndex: Int, cal: Calendar): Time = {
+ getTime(columnIndex)
+ }
+ /**
+ * @inheritdoc
+ *
+ * Note: The Calendar parameter is ignored. Spark Connect handles timezone conversions
+ * server-side to avoid client/server timezone inconsistencies.
+ */
override def getTime(columnLabel: String, cal: Calendar): Time =
- throw new SQLFeatureNotSupportedException
-
- override def getTimestamp(columnIndex: Int, cal: Calendar): Timestamp =
- throw new SQLFeatureNotSupportedException
+ getTime(findColumn(columnLabel))
+
+ /**
+ * @inheritdoc
+ *
+ * Note: The Calendar parameter is ignored. Spark Connect handles timezone conversions
+ * server-side to avoid client/server timezone inconsistencies.
+ */
+ override def getTimestamp(columnIndex: Int, cal: Calendar): Timestamp = {
+ getTimestamp(columnIndex)
+ }
+ /**
+ * @inheritdoc
+ *
+ * Note: The Calendar parameter is ignored. Spark Connect handles timezone conversions
+ * server-side to avoid client/server timezone inconsistencies.
+ */
override def getTimestamp(columnLabel: String, cal: Calendar): Timestamp =
- throw new SQLFeatureNotSupportedException
+ getTimestamp(findColumn(columnLabel), cal)
override def getURL(columnIndex: Int): URL =
throw new SQLFeatureNotSupportedException
diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatement.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatement.scala
index 8de227f9d07c2..d1947ae93a40c 100644
--- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatement.scala
+++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatement.scala
@@ -19,11 +19,15 @@ package org.apache.spark.sql.connect.client.jdbc
import java.sql.{Array => _, _}
+import org.apache.spark.sql.connect.client.SparkResult
+
class SparkConnectStatement(conn: SparkConnectConnection) extends Statement {
private var operationId: String = _
private var resultSet: SparkConnectResultSet = _
+ private var maxRows: Int = 0
+
@volatile private var closed: Boolean = false
override def isClosed: Boolean = closed
@@ -31,14 +35,21 @@ class SparkConnectStatement(conn: SparkConnectConnection) extends Statement {
override def close(): Unit = synchronized {
if (!closed) {
if (operationId != null) {
- conn.spark.interruptOperation(operationId)
+ try {
+ conn.spark.interruptOperation(operationId)
+ } catch {
+ case _: java.net.ConnectException =>
+ // Ignore ConnectExceptions during cleanup as the operation may have already completed
+ // or the server may be unavailable. The important part is marking this statement
+ // as closed to prevent further use.
+ }
operationId = null
}
if (resultSet != null) {
resultSet.close()
resultSet = null
}
- closed = false
+ closed = true
}
}
@@ -49,33 +60,54 @@ class SparkConnectStatement(conn: SparkConnectConnection) extends Statement {
}
override def executeQuery(sql: String): ResultSet = {
- checkOpen()
-
- val df = conn.spark.sql(sql)
- val sparkResult = df.collectResult()
- operationId = sparkResult.operationId
- resultSet = new SparkConnectResultSet(sparkResult, this)
- resultSet
+ val hasResultSet = execute(sql)
+ if (hasResultSet) {
+ assert(resultSet != null)
+ resultSet
+ } else {
+ throw new SQLException("The query does not produce a ResultSet.")
+ }
}
override def executeUpdate(sql: String): Int = {
- checkOpen()
-
- val df = conn.spark.sql(sql)
- val sparkResult = df.collectResult()
- operationId = sparkResult.operationId
- resultSet = null
+ val hasResultSet = execute(sql)
+ if (hasResultSet) {
+ // user are not expected to access the result set in this case,
+ // we must close it to avoid memory leak.
+ resultSet.close()
+ throw new SQLException("The query produces a ResultSet.")
+ } else {
+ assert(resultSet == null)
+ getUpdateCount
+ }
+ }
- // always return 0 because affected rows is not supported yet
- 0
+ private def hasResultSet(sparkResult: SparkResult[_]): Boolean = {
+ // suppose this works in most cases
+ sparkResult.schema.length > 0
}
override def execute(sql: String): Boolean = {
checkOpen()
- // always perform executeQuery and reture a ResultSet
- executeQuery(sql)
- true
+ // stmt can be reused to execute more than one queries,
+ // reset before executing new query
+ operationId = null
+ resultSet = null
+
+ var df = conn.spark.sql(sql)
+ if (maxRows > 0) {
+ df = df.limit(maxRows)
+ }
+ val sparkResult = df.collectResult()
+ operationId = sparkResult.operationId
+ if (hasResultSet(sparkResult)) {
+ resultSet = new SparkConnectResultSet(sparkResult, this)
+ true
+ } else {
+ sparkResult.close()
+ false
+ }
}
override def getResultSet: ResultSet = {
@@ -91,11 +123,17 @@ class SparkConnectStatement(conn: SparkConnectConnection) extends Statement {
override def getMaxRows: Int = {
checkOpen()
- 0
+ this.maxRows
}
- override def setMaxRows(max: Int): Unit =
- throw new SQLFeatureNotSupportedException
+ override def setMaxRows(max: Int): Unit = {
+ checkOpen()
+
+ if (max < 0) {
+ throw new SQLException("The max rows must be zero or a positive integer.")
+ }
+ this.maxRows = max
+ }
override def setEscapeProcessing(enable: Boolean): Unit =
throw new SQLFeatureNotSupportedException
@@ -123,8 +161,15 @@ class SparkConnectStatement(conn: SparkConnectConnection) extends Statement {
override def setCursorName(name: String): Unit =
throw new SQLFeatureNotSupportedException
- override def getUpdateCount: Int =
- throw new SQLFeatureNotSupportedException
+ override def getUpdateCount: Int = {
+ checkOpen()
+
+ if (resultSet != null) {
+ -1
+ } else {
+ 0 // always return 0 because affected rows is not supported yet
+ }
+ }
override def getMoreResults: Boolean =
throw new SQLFeatureNotSupportedException
diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcErrorUtils.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcErrorUtils.scala
index 3d9f72d87d150..6480c5d768f3f 100644
--- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcErrorUtils.scala
+++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcErrorUtils.scala
@@ -21,7 +21,7 @@ import java.sql.{Array => _, _}
private[jdbc] object JdbcErrorUtils {
- def stringfiyTransactionIsolationLevel(level: Int): String = level match {
+ def stringifyTransactionIsolationLevel(level: Int): String = level match {
case Connection.TRANSACTION_NONE => "NONE"
case Connection.TRANSACTION_READ_UNCOMMITTED => "READ_UNCOMMITTED"
case Connection.TRANSACTION_READ_COMMITTED => "READ_COMMITTED"
@@ -31,7 +31,7 @@ private[jdbc] object JdbcErrorUtils {
throw new IllegalArgumentException(s"Invalid transaction isolation level: $level")
}
- def stringfiyHoldability(holdability: Int): String = holdability match {
+ def stringifyHoldability(holdability: Int): String = holdability match {
case ResultSet.HOLD_CURSORS_OVER_COMMIT => "HOLD_CURSORS_OVER_COMMIT"
case ResultSet.CLOSE_CURSORS_AT_COMMIT => "CLOSE_CURSORS_AT_COMMIT"
case _ =>
diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcTypeUtils.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcTypeUtils.scala
index 55e3d29c99a5e..458f94c51f893 100644
--- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcTypeUtils.scala
+++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/util/JdbcTypeUtils.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.connect.client.jdbc.util
import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort}
+import java.math.{BigDecimal => JBigDecimal}
import java.sql.{Array => _, _}
import org.apache.spark.sql.types._
@@ -34,6 +35,12 @@ private[jdbc] object JdbcTypeUtils {
case FloatType => Types.FLOAT
case DoubleType => Types.DOUBLE
case StringType => Types.VARCHAR
+ case _: DecimalType => Types.DECIMAL
+ case DateType => Types.DATE
+ case TimestampType => Types.TIMESTAMP
+ case TimestampNTZType => Types.TIMESTAMP
+ case BinaryType => Types.VARBINARY
+ case _: TimeType => Types.TIME
case other =>
throw new SQLFeatureNotSupportedException(s"DataType $other is not supported yet.")
}
@@ -48,13 +55,21 @@ private[jdbc] object JdbcTypeUtils {
case FloatType => classOf[JFloat].getName
case DoubleType => classOf[JDouble].getName
case StringType => classOf[String].getName
+ case _: DecimalType => classOf[JBigDecimal].getName
+ case DateType => classOf[Date].getName
+ case TimestampType => classOf[Timestamp].getName
+ case TimestampNTZType => classOf[Timestamp].getName
+ case BinaryType => classOf[Array[Byte]].getName
+ case _: TimeType => classOf[Time].getName
case other =>
throw new SQLFeatureNotSupportedException(s"DataType $other is not supported yet.")
}
def isSigned(field: StructField): Boolean = field.dataType match {
- case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true
- case NullType | BooleanType | StringType => false
+ case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType |
+ _: DecimalType => true
+ case NullType | BooleanType | StringType | DateType | BinaryType | _: TimeType |
+ TimestampType | TimestampNTZType => false
case other =>
throw new SQLFeatureNotSupportedException(s"DataType $other is not supported yet.")
}
@@ -68,7 +83,17 @@ private[jdbc] object JdbcTypeUtils {
case LongType => 19
case FloatType => 7
case DoubleType => 15
- case StringType => 255
+ case StringType => Int.MaxValue
+ case DecimalType.Fixed(p, _) => p
+ case DateType => 10
+ case TimestampType => 29
+ case TimestampNTZType => 29
+ case BinaryType => Int.MaxValue
+ // Returns the Spark SQL TIME type precision, even though java.sql.ResultSet.getTime()
+ // can only retrieve up to millisecond precision (3) due to java.sql.Time limitations.
+ // Users can call getObject(index, classOf[LocalTime]) to access full microsecond
+ // precision when the source type is TIME(4) or higher.
+ case TimeType(precision) => precision
case other =>
throw new SQLFeatureNotSupportedException(s"DataType $other is not supported yet.")
}
@@ -76,7 +101,11 @@ private[jdbc] object JdbcTypeUtils {
def getScale(field: StructField): Int = field.dataType match {
case FloatType => 7
case DoubleType => 15
- case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | StringType => 0
+ case TimestampType => 6
+ case TimestampNTZType => 6
+ case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | StringType |
+ DateType | BinaryType | _: TimeType => 0
+ case DecimalType.Fixed(_, s) => s
case other =>
throw new SQLFeatureNotSupportedException(s"DataType $other is not supported yet.")
}
@@ -90,7 +119,34 @@ private[jdbc] object JdbcTypeUtils {
case DoubleType => 24
case StringType =>
getPrecision(field)
+ case DateType => 10 // length of `YYYY-MM-DD`
+ case TimestampType => 29 // length of `YYYY-MM-DD HH:MM:SS.SSSSSS`
+ case TimestampNTZType => 29 // length of `YYYY-MM-DD HH:MM:SS.SSSSSS`
+ case BinaryType => Int.MaxValue
+ case TimeType(precision) if precision > 0 => 8 + 1 + precision // length of `HH:MM:SS.ffffff`
+ case TimeType(_) => 8 // length of `HH:MM:SS`
+ // precision + negative sign + leading zero + decimal point, like DECIMAL(5,5) = -0.12345
+ case DecimalType.Fixed(p, s) if p == s => p + 3
+ // precision + negative sign, like DECIMAL(5,0) = -12345
+ case DecimalType.Fixed(p, s) if s == 0 => p + 1
+ // precision + negative sign + decimal point, like DECIMAL(5,2) = -123.45
+ case DecimalType.Fixed(p, _) => p + 2
case other =>
throw new SQLFeatureNotSupportedException(s"DataType $other is not supported yet.")
}
+
+ def getDecimalDigits(field: StructField): Integer = field.dataType match {
+ case BooleanType | _: IntegralType => 0
+ case FloatType => 7
+ case DoubleType => 15
+ case d: DecimalType => d.scale
+ case TimeType(scale) => scale
+ case TimestampType | TimestampNTZType => 6
+ case _ => null
+ }
+
+ def getNumPrecRadix(field: StructField): Integer = field.dataType match {
+ case _: NumericType => 10
+ case _ => null
+ }
}
diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaDataSuite.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaDataSuite.scala
index b2ecc163b2b8a..4d66392109e70 100644
--- a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaDataSuite.scala
+++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaDataSuite.scala
@@ -19,16 +19,29 @@ package org.apache.spark.sql.connect.client.jdbc
import java.sql.{Array => _, _}
+import scala.util.Using
+
import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION}
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connect.client.jdbc.test.JdbcHelper
-import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession}
+import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession, SQLHelper}
import org.apache.spark.util.VersionUtils
class SparkConnectDatabaseMetaDataSuite extends ConnectFunSuite with RemoteSparkSession
- with JdbcHelper {
+ with JdbcHelper with SQLHelper {
def jdbcUrl: String = s"jdbc:sc://localhost:$serverPort"
+ // catalyst test jar is inaccessible here, but presents at the testing connect server classpath
+ private val TEST_IN_MEMORY_CATALOG = "org.apache.spark.sql.connector.catalog.InMemoryCatalog"
+ private val TEST_BASIC_IN_MEMORY_CATALOG =
+ "org.apache.spark.sql.connector.catalog.BasicInMemoryTableCatalog"
+
+ private def registerCatalog(
+ name: String, className: String)(implicit spark: SparkSession): Unit = {
+ spark.conf.set(s"spark.sql.catalog.$name", className)
+ }
+
test("SparkConnectDatabaseMetaData simple methods") {
withConnection { conn =>
val spark = conn.asInstanceOf[SparkConnectConnection].spark
@@ -58,6 +71,7 @@ class SparkConnectDatabaseMetaDataSuite extends ConnectFunSuite with RemoteSpark
assert(metadata.storesLowerCaseQuotedIdentifiers === false)
assert(metadata.storesMixedCaseQuotedIdentifiers === false)
assert(metadata.getIdentifierQuoteString === "`")
+ assert(metadata.getSearchStringEscape === "\\")
assert(metadata.getExtraNameCharacters === "")
assert(metadata.supportsAlterTableWithAddColumn === true)
assert(metadata.supportsAlterTableWithDropColumn === true)
@@ -199,4 +213,592 @@ class SparkConnectDatabaseMetaDataSuite extends ConnectFunSuite with RemoteSpark
// scalastyle:on line.size.limit
}
}
+
+ test("SparkConnectDatabaseMetaData getCatalogs") {
+ withConnection { conn =>
+ implicit val spark: SparkSession = conn.asInstanceOf[SparkConnectConnection].spark
+
+ registerCatalog("testcat", TEST_IN_MEMORY_CATALOG)
+ registerCatalog("testcat2", TEST_IN_MEMORY_CATALOG)
+
+ // forcibly initialize the registered catalogs because SHOW CATALOGS only
+ // returns the initialized catalogs.
+ spark.sql("USE testcat")
+ spark.sql("USE testcat2")
+ spark.sql("USE spark_catalog")
+
+ val metadata = conn.getMetaData
+ Using.resource(metadata.getCatalogs) { rs =>
+ val catalogs = new Iterator[String] {
+ def hasNext: Boolean = rs.next()
+ def next(): String = rs.getString("TABLE_CAT")
+ }.toSeq
+ // results are ordered by TABLE_CAT
+ assert(catalogs === Seq("spark_catalog", "testcat", "testcat2"))
+ }
+ }
+ }
+
+ test("SparkConnectDatabaseMetaData getSchemas") {
+
+ def verifyGetSchemas(
+ getSchemas: () => ResultSet)(verify: Seq[(String, String)] => Unit): Unit = {
+ Using.resource(getSchemas()) { rs =>
+ val catalogDatabases = new Iterator[(String, String)] {
+ def hasNext: Boolean = rs.next()
+ def next(): (String, String) =
+ (rs.getString("TABLE_CATALOG"), rs.getString("TABLE_SCHEM"))
+ }.toSeq
+ verify(catalogDatabases)
+ }
+ }
+
+ withConnection { conn =>
+ implicit val spark: SparkSession = conn.asInstanceOf[SparkConnectConnection].spark
+
+ // this catalog does not support namespace
+ registerCatalog("test_noop", TEST_BASIC_IN_MEMORY_CATALOG)
+ // Spark loads catalog plugins lazily, we must initialize it first,
+ // otherwise it won't be listed by SHOW CATALOGS
+ conn.setCatalog("test_noop")
+
+ registerCatalog("test`cat", TEST_IN_MEMORY_CATALOG)
+
+ spark.sql("CREATE DATABASE IF NOT EXISTS `test``cat`.t_db1")
+ spark.sql("CREATE DATABASE IF NOT EXISTS `test``cat`.t_db2")
+ spark.sql("CREATE DATABASE IF NOT EXISTS `test``cat`.t_db_")
+
+ spark.sql("CREATE DATABASE IF NOT EXISTS spark_catalog.db1")
+ spark.sql("CREATE DATABASE IF NOT EXISTS spark_catalog.db2")
+ spark.sql("CREATE DATABASE IF NOT EXISTS spark_catalog.test_db3")
+
+ val metadata = conn.getMetaData
+
+ // no need to care about "test`cat" because it is memory based and session isolated,
+ // also is inaccessible from another SparkSession
+ withDatabase("spark_catalog.db1", "spark_catalog.db2", "spark_catalog.test_db3") {
+ // list schemas in all catalogs
+ val getSchemasInAllCatalogs = (() => metadata.getSchemas) ::
+ List(null, "%").map { database => () => metadata.getSchemas(null, database) } ::: Nil
+
+ getSchemasInAllCatalogs.foreach { getSchemas =>
+ verifyGetSchemas(getSchemas) { catalogDatabases =>
+ // results are ordered by TABLE_CATALOG, TABLE_SCHEM
+ assert {
+ catalogDatabases === Seq(
+ ("spark_catalog", "db1"),
+ ("spark_catalog", "db2"),
+ ("spark_catalog", "default"),
+ ("spark_catalog", "test_db3"),
+ ("test`cat", "t_db1"),
+ ("test`cat", "t_db2"),
+ ("test`cat", "t_db_"))
+ }
+ }
+ }
+
+ // list schemas in current catalog
+ conn.setCatalog("spark_catalog")
+ assert(conn.getCatalog === "spark_catalog")
+ val getSchemasInCurrentCatalog =
+ List(null, "%").map { database => () => metadata.getSchemas("", database) }
+ getSchemasInCurrentCatalog.foreach { getSchemas =>
+ verifyGetSchemas(getSchemas) { catalogDatabases =>
+ // results are ordered by TABLE_CATALOG, TABLE_SCHEM
+ assert {
+ catalogDatabases === Seq(
+ ("spark_catalog", "db1"),
+ ("spark_catalog", "db2"),
+ ("spark_catalog", "default"),
+ ("spark_catalog", "test_db3"))
+ }
+ }
+ }
+
+ // list schemas with schema pattern
+ verifyGetSchemas { () => metadata.getSchemas(null, "db%") } { catalogDatabases =>
+ // results are ordered by TABLE_CATALOG, TABLE_SCHEM
+ assert {
+ catalogDatabases === Seq(
+ ("spark_catalog", "db1"),
+ ("spark_catalog", "db2"))
+ }
+ }
+
+ verifyGetSchemas { () => metadata.getSchemas(null, "db_") } { catalogDatabases =>
+ // results are ordered by TABLE_CATALOG, TABLE_SCHEM
+ assert {
+ catalogDatabases === Seq(
+ ("spark_catalog", "db1"),
+ ("spark_catalog", "db2"))
+ }
+ }
+
+ // escape backtick in catalog, and _ in schema pattern
+ verifyGetSchemas {
+ () => metadata.getSchemas("test`cat", "t\\_db\\_")
+ } { catalogDatabases =>
+ assert(catalogDatabases === Seq(("test`cat", "t_db_")))
+ }
+
+ // skip testing escape ', % in schema pattern, because Spark SQL does not
+ // allow using those chars in schema table name.
+ //
+ // CREATE DATABASE IF NOT EXISTS `t_db1'`;
+ //
+ // the above SQL fails with error condition:
+ // [INVALID_SCHEMA_OR_RELATION_NAME] `t_db1'` is not a valid name for tables/schemas.
+ // Valid names only contain alphabet characters, numbers and _. SQLSTATE: 42602
+ }
+ }
+ }
+
+ test("SparkConnectDatabaseMetaData getTableTypes") {
+ withConnection { conn =>
+ val metadata = conn.getMetaData
+ Using.resource(metadata.getTableTypes) { rs =>
+ val types = new Iterator[String] {
+ def hasNext: Boolean = rs.next()
+ def next(): String = rs.getString("TABLE_TYPE")
+ }.toSeq
+ // results are ordered by TABLE_TYPE
+ assert(types === Seq("TABLE", "VIEW"))
+ }
+ }
+ }
+
+ test("SparkConnectDatabaseMetaData getTables") {
+
+ case class GetTableResult(
+ TABLE_CAT: String,
+ TABLE_SCHEM: String,
+ TABLE_NAME: String,
+ TABLE_TYPE: String,
+ REMARKS: String,
+ TYPE_CAT: String,
+ TYPE_SCHEM: String,
+ TYPE_NAME: String,
+ SELF_REFERENCING_COL_NAME: String,
+ REF_GENERATION: String)
+
+ def verifyEmptyStringFields(result: GetTableResult): Unit = {
+ assert(result.REMARKS === "")
+ assert(result.TYPE_CAT === "")
+ assert(result.TYPE_SCHEM === "")
+ assert(result.TYPE_NAME === "")
+ assert(result.SELF_REFERENCING_COL_NAME === "")
+ assert(result.REF_GENERATION === "")
+ }
+
+ def verifyGetTables(
+ getTables: () => ResultSet)(verify: Seq[GetTableResult] => Unit): Unit = {
+ Using.resource(getTables()) { rs =>
+ val getTableResults = new Iterator[GetTableResult] {
+ def hasNext: Boolean = rs.next()
+ def next(): GetTableResult = GetTableResult(
+ TABLE_CAT = rs.getString("TABLE_CAT"),
+ TABLE_SCHEM = rs.getString("TABLE_SCHEM"),
+ TABLE_NAME = rs.getString("TABLE_NAME"),
+ TABLE_TYPE = rs.getString("TABLE_TYPE"),
+ REMARKS = rs.getString("REMARKS"),
+ TYPE_CAT = rs.getString("TYPE_CAT"),
+ TYPE_SCHEM = rs.getString("TYPE_SCHEM"),
+ TYPE_NAME = rs.getString("TYPE_NAME"),
+ SELF_REFERENCING_COL_NAME = rs.getString("SELF_REFERENCING_COL_NAME"),
+ REF_GENERATION = rs.getString("REF_GENERATION"))
+ }.toSeq
+ verify(getTableResults)
+ }
+ }
+
+ withConnection { conn =>
+ implicit val spark: SparkSession = conn.asInstanceOf[SparkConnectConnection].spark
+
+ // this catalog does not support namespace
+ registerCatalog("test_noop", TEST_BASIC_IN_MEMORY_CATALOG)
+ // Spark loads catalog plugins lazily, we must initialize it first,
+ // otherwise it won't be listed by SHOW CATALOGS
+ conn.setCatalog("test_noop")
+
+ // this catalog does not support view
+ registerCatalog("testcat", TEST_IN_MEMORY_CATALOG)
+
+ spark.sql("CREATE DATABASE IF NOT EXISTS testcat.t_db1")
+ spark.sql("CREATE TABLE IF NOT EXISTS testcat.t_db1.t_t1 (id INT)")
+
+ spark.sql("CREATE DATABASE IF NOT EXISTS spark_catalog.db1")
+ spark.sql("CREATE TABLE IF NOT EXISTS spark_catalog.db1.t1 (id INT)")
+ spark.sql("CREATE TABLE IF NOT EXISTS spark_catalog.db1.t_2 (id INT)")
+ spark.sql(
+ """CREATE VIEW IF NOT EXISTS spark_catalog.db1.t1_v AS
+ |SELECT id FROM spark_catalog.db1.t1
+ |""".stripMargin)
+
+ spark.sql("CREATE DATABASE IF NOT EXISTS spark_catalog.db_2")
+ spark.sql("CREATE TABLE IF NOT EXISTS spark_catalog.db_2.t_2 (id INT)")
+ spark.sql(
+ """CREATE VIEW IF NOT EXISTS spark_catalog.db_2.t_2_v AS
+ |SELECT id FROM spark_catalog.db_2.t_2
+ |""".stripMargin)
+
+ spark.sql("CREATE DATABASE IF NOT EXISTS spark_catalog.db_")
+ spark.sql("CREATE TABLE IF NOT EXISTS spark_catalog.db_.t_ (id INT)")
+
+ val metadata = conn.getMetaData
+
+ // no need to care about "testcat" because it is memory based and session isolated,
+ // also is inaccessible from another SparkSession
+ withDatabase("spark_catalog.db1", "spark_catalog.db_2", "spark_catalog.db_") {
+ // list tables in all catalogs and schemas
+ val getTablesInAllCatalogsAndSchemas = List(null, "%").flatMap { database =>
+ List(null, "%").flatMap { table =>
+ List(null, Array("TABLE", "VIEW")).map { tableTypes =>
+ () => metadata.getTables(null, database, table, tableTypes)
+ }
+ }
+ }
+
+ getTablesInAllCatalogsAndSchemas.foreach { getTables =>
+ verifyGetTables(getTables) { getTableResults =>
+ // results are ordered by TABLE_TYPE, TABLE_CAT, TABLE_SCHEM and TABLE_NAME
+ assert {
+ getTableResults.map { result =>
+ (result.TABLE_TYPE, result.TABLE_CAT, result.TABLE_SCHEM, result.TABLE_NAME)
+ } === Seq(
+ ("TABLE", "spark_catalog", "db1", "t1"),
+ ("TABLE", "spark_catalog", "db1", "t_2"),
+ ("TABLE", "spark_catalog", "db_", "t_"),
+ ("TABLE", "spark_catalog", "db_2", "t_2"),
+ ("TABLE", "testcat", "t_db1", "t_t1"),
+ ("VIEW", "spark_catalog", "db1", "t1_v"),
+ ("VIEW", "spark_catalog", "db_2", "t_2_v"))
+ }
+ getTableResults.foreach(verifyEmptyStringFields)
+ }
+ }
+
+ // list tables with table types
+ val se = intercept[SQLException] {
+ metadata.getTables("spark_catalog", "foo", "bar", Array("TABLE", "MATERIALIZED VIEW"))
+ }
+ assert(se.getMessage ===
+ "The requested table types contains unsupported items: MATERIALIZED VIEW. " +
+ "Available table types are: TABLE, VIEW.")
+
+ verifyGetTables {
+ () => metadata.getTables("spark_catalog", "db1", "%", Array("TABLE"))
+ } { getTableResults =>
+ // results are ordered by TABLE_TYPE, TABLE_CAT, TABLE_SCHEM and TABLE_NAME
+ assert {
+ getTableResults.map { result =>
+ (result.TABLE_TYPE, result.TABLE_CAT, result.TABLE_SCHEM, result.TABLE_NAME)
+ } === Seq(
+ ("TABLE", "spark_catalog", "db1", "t1"),
+ ("TABLE", "spark_catalog", "db1", "t_2"))
+ }
+ getTableResults.foreach(verifyEmptyStringFields)
+ }
+
+ verifyGetTables {
+ () => metadata.getTables("spark_catalog", "db1", "%", Array("VIEW"))
+ } { getTableResults =>
+ // results are ordered by TABLE_TYPE, TABLE_CAT, TABLE_SCHEM and TABLE_NAME
+ assert {
+ getTableResults.map { result =>
+ (result.TABLE_TYPE, result.TABLE_CAT, result.TABLE_SCHEM, result.TABLE_NAME)
+ } === Seq(("VIEW", "spark_catalog", "db1", "t1_v"))
+ }
+ getTableResults.foreach(verifyEmptyStringFields)
+ }
+
+ // list tables in the current catalog and schema
+ conn.setCatalog("spark_catalog")
+ conn.setSchema("db1")
+ assert(conn.getCatalog === "spark_catalog")
+ assert(conn.getSchema === "db1")
+
+ verifyGetTables {
+ () => metadata.getTables("", "", "%", null)
+ } { getTableResults =>
+ assert {
+ getTableResults.map { result =>
+ (result.TABLE_TYPE, result.TABLE_CAT, result.TABLE_SCHEM, result.TABLE_NAME)
+ } === Seq(
+ ("TABLE", "spark_catalog", "db1", "t1"),
+ ("TABLE", "spark_catalog", "db1", "t_2"),
+ ("VIEW", "spark_catalog", "db1", "t1_v"))
+ }
+ getTableResults.foreach(verifyEmptyStringFields)
+ }
+
+ // list tables with schema pattern and table mame pattern
+ verifyGetTables {
+ () => metadata.getTables(null, "db%", "t_", null)
+ } { getTableResults =>
+ assert {
+ getTableResults.map { result =>
+ (result.TABLE_TYPE, result.TABLE_CAT, result.TABLE_SCHEM, result.TABLE_NAME)
+ } === Seq(
+ ("TABLE", "spark_catalog", "db1", "t1"),
+ ("TABLE", "spark_catalog", "db_", "t_"))
+ }
+ getTableResults.foreach(verifyEmptyStringFields)
+ }
+
+ // escape _ in schema pattern and table mame pattern
+ verifyGetTables {
+ () => metadata.getTables(null, "db\\_", "t\\_", null)
+ } { getTableResults =>
+ assert {
+ getTableResults.map { result =>
+ (result.TABLE_TYPE, result.TABLE_CAT, result.TABLE_SCHEM, result.TABLE_NAME)
+ } === Seq(("TABLE", "spark_catalog", "db_", "t_"))
+ }
+ getTableResults.foreach(verifyEmptyStringFields)
+ }
+
+ // skip testing escape ', % in schema pattern, because Spark SQL does not
+ // allow using those chars in schema table name.
+ }
+ }
+ }
+
+ test("SparkConnectDatabaseMetaData getColumns") {
+
+ case class GetColumnResult(
+ TABLE_CAT: String,
+ TABLE_SCHEM: String,
+ TABLE_NAME: String,
+ COLUMN_NAME: String,
+ DATA_TYPE: Int,
+ TYPE_NAME: String,
+ COLUMN_SIZE: Int,
+ BUFFER_LENGTH: Int,
+ DECIMAL_DIGITS: Int,
+ NUM_PREC_RADIX: Int,
+ NULLABLE: Int,
+ REMARKS: String,
+ COLUMN_DEF: String,
+ SQL_DATA_TYPE: Int,
+ SQL_DATETIME_SUB: Int,
+ CHAR_OCTET_LENGTH: Int,
+ ORDINAL_POSITION: Int,
+ IS_NULLABLE: String,
+ SCOPE_CATALOG: String,
+ SCOPE_SCHEMA: String,
+ SCOPE_TABLE: String,
+ SOURCE_DATA_TYPE: Short,
+ IS_AUTOINCREMENT: String,
+ IS_GENERATEDCOLUMN: String)
+
+ def verifyEmptyFields(result: GetColumnResult): Unit = {
+ assert(result.BUFFER_LENGTH === 0)
+ assert(result.SQL_DATA_TYPE === 0)
+ assert(result.SQL_DATETIME_SUB === 0)
+ assert(result.SCOPE_CATALOG === "")
+ assert(result.SCOPE_SCHEMA === "")
+ assert(result.SCOPE_TABLE === "")
+ assert(result.SOURCE_DATA_TYPE === 0.toShort)
+ }
+
+ def verifyGetColumns(
+ getColumns: () => ResultSet)(verify: Seq[GetColumnResult] => Unit): Unit = {
+ Using.resource(getColumns()) { rs =>
+ val getTableResults = new Iterator[GetColumnResult] {
+ def hasNext: Boolean = rs.next()
+
+ def next(): GetColumnResult = GetColumnResult(
+ TABLE_CAT = rs.getString("TABLE_CAT"),
+ TABLE_SCHEM = rs.getString("TABLE_SCHEM"),
+ TABLE_NAME = rs.getString("TABLE_NAME"),
+ COLUMN_NAME = rs.getString("COLUMN_NAME"),
+ DATA_TYPE = rs.getInt("DATA_TYPE"),
+ TYPE_NAME = rs.getString("TYPE_NAME"),
+ COLUMN_SIZE = rs.getInt("COLUMN_SIZE"),
+ BUFFER_LENGTH = rs.getInt("BUFFER_LENGTH"),
+ DECIMAL_DIGITS = rs.getInt("DECIMAL_DIGITS"),
+ NUM_PREC_RADIX = rs.getInt("NUM_PREC_RADIX"),
+ NULLABLE = rs.getInt("NULLABLE"),
+ REMARKS = rs.getString("REMARKS"),
+ COLUMN_DEF = rs.getString("COLUMN_DEF"),
+ SQL_DATA_TYPE = rs.getInt("SQL_DATA_TYPE"),
+ SQL_DATETIME_SUB = rs.getInt("SQL_DATETIME_SUB"),
+ CHAR_OCTET_LENGTH = rs.getInt("CHAR_OCTET_LENGTH"),
+ ORDINAL_POSITION = rs.getInt("ORDINAL_POSITION"),
+ IS_NULLABLE = rs.getString("IS_NULLABLE"),
+ SCOPE_CATALOG = rs.getString("SCOPE_CATALOG"),
+ SCOPE_SCHEMA = rs.getString("SCOPE_SCHEMA"),
+ SCOPE_TABLE = rs.getString("SCOPE_TABLE"),
+ SOURCE_DATA_TYPE = rs.getShort("SOURCE_DATA_TYPE"),
+ IS_AUTOINCREMENT = rs.getString("IS_AUTOINCREMENT"),
+ IS_GENERATEDCOLUMN = rs.getString("IS_GENERATEDCOLUMN"))
+ }.toSeq
+ verify(getTableResults)
+ }
+ }
+
+ withConnection { conn =>
+ implicit val spark: SparkSession = conn.asInstanceOf[SparkConnectConnection].spark
+
+ spark.sql("CREATE DATABASE IF NOT EXISTS testcat.t_db1")
+ spark.sql("CREATE TABLE IF NOT EXISTS testcat.t_db1.t_t1 (id INT)")
+
+ spark.sql("CREATE DATABASE IF NOT EXISTS spark_catalog.db1")
+ spark.sql(
+ """CREATE TABLE IF NOT EXISTS spark_catalog.db1.t1 (
+ | id INT NOT NULL,
+ | i_ INT,
+ | location STRING COMMENT 'city name' DEFAULT 'unknown')
+ |""".stripMargin)
+ spark.sql(
+ """CREATE TABLE IF NOT EXISTS spark_catalog.db1.t2 (
+ | col_null VOID,
+ | col_boolean BOOLEAN,
+ | col_byte BYTE,
+ | col_short SHORT,
+ | col_int INT,
+ | col_long LONG,
+ | col_float FLOAT,
+ | col_double DOUBLE,
+ | col_string STRING,
+ | col_decimal DECIMAL(10, 5),
+ | col_date DATE,
+ | col_timestamp TIMESTAMP,
+ | col_timestamp_ntz TIMESTAMP_NTZ,
+ | col_binary BINARY,
+ | col_time TIME)""".stripMargin)
+
+ spark.sql(
+ """CREATE VIEW IF NOT EXISTS spark_catalog.db1.t1_v AS
+ |SELECT id FROM spark_catalog.db1.t1
+ |""".stripMargin)
+
+ spark.sql("CREATE DATABASE IF NOT EXISTS spark_catalog.db_")
+ spark.sql("CREATE TABLE IF NOT EXISTS spark_catalog.db_.t_ (id INT, i_ INT)")
+
+ spark.sql("CREATE DATABASE IF NOT EXISTS spark_catalog.db_2")
+ spark.sql("CREATE TABLE IF NOT EXISTS spark_catalog.db_2.t_2 (id INT)")
+
+ val metadata = conn.getMetaData
+
+ // no need to care about "testcat" because it is memory based and session isolated,
+ // also is inaccessible from another SparkSession
+ withDatabase("spark_catalog.db1", "spark_catalog.db_2", "spark_catalog.db_") {
+ // list columns of all tables in all catalogs and schemas
+ val getColumnsInAllTables = List(null, "%").flatMap { database =>
+ List(null, "%").flatMap { table =>
+ List(null, "%").map { column =>
+ () => metadata.getColumns(null, database, table, column)
+ }
+ }
+ }
+
+ getColumnsInAllTables.foreach { getColumns =>
+ verifyGetColumns(getColumns) { getColumnResults =>
+ // results are ordered by TABLE_CAT, TABLE_SCHEM, TABLE_NAME, ORDINAL_POSITION
+ assert {
+ getColumnResults.map { r =>
+ (r.TABLE_CAT, r.TABLE_SCHEM, r.TABLE_NAME, r.ORDINAL_POSITION, r.COLUMN_NAME)
+ } === Seq(
+ ("spark_catalog", "db1", "t1", 1, "id"),
+ ("spark_catalog", "db1", "t1", 2, "i_"),
+ ("spark_catalog", "db1", "t1", 3, "location"),
+ ("spark_catalog", "db1", "t1_v", 1, "id"),
+ ("spark_catalog", "db1", "t2", 1, "col_null"),
+ ("spark_catalog", "db1", "t2", 2, "col_boolean"),
+ ("spark_catalog", "db1", "t2", 3, "col_byte"),
+ ("spark_catalog", "db1", "t2", 4, "col_short"),
+ ("spark_catalog", "db1", "t2", 5, "col_int"),
+ ("spark_catalog", "db1", "t2", 6, "col_long"),
+ ("spark_catalog", "db1", "t2", 7, "col_float"),
+ ("spark_catalog", "db1", "t2", 8, "col_double"),
+ ("spark_catalog", "db1", "t2", 9, "col_string"),
+ ("spark_catalog", "db1", "t2", 10, "col_decimal"),
+ ("spark_catalog", "db1", "t2", 11, "col_date"),
+ ("spark_catalog", "db1", "t2", 12, "col_timestamp"),
+ ("spark_catalog", "db1", "t2", 13, "col_timestamp_ntz"),
+ ("spark_catalog", "db1", "t2", 14, "col_binary"),
+ ("spark_catalog", "db1", "t2", 15, "col_time"),
+ ("spark_catalog", "db_", "t_", 1, "id"),
+ ("spark_catalog", "db_", "t_", 2, "i_"),
+ ("spark_catalog", "db_2", "t_2", 1, "id"),
+ ("testcat", "t_db1", "t_t1", 1, "id"))
+ }
+
+ // TODO verify the remaining attributes
+ // DATA_TYPE = rs.getInt("DATA_TYPE"),
+ // TYPE_NAME = rs.getString("TYPE_NAME"),
+ // COLUMN_SIZE = rs.getInt("COLUMN_SIZE"),
+ // DECIMAL_DIGITS = rs.getInt("DECIMAL_DIGITS"),
+ // NUM_PREC_RADIX = rs.getInt("NUM_PREC_RADIX"),
+ // NULLABLE = rs.getInt("NULLABLE"),
+ // REMARKS = rs.getString("REMARKS"),
+ // COLUMN_DEF = rs.getString("COLUMN_DEF"),
+ // CHAR_OCTET_LENGTH = rs.getInt("CHAR_OCTET_LENGTH"),
+ // IS_NULLABLE = rs.getString("IS_NULLABLE"),
+ // IS_AUTOINCREMENT = rs.getString("IS_AUTOINCREMENT"),
+ // IS_GENERATEDCOLUMN = rs.getString("IS_GENERATEDCOLUMN")
+
+ getColumnResults.foreach(verifyEmptyFields)
+ }
+ }
+
+ // list columns of all tables in the current catalog and schema
+ conn.setCatalog("spark_catalog")
+ conn.setSchema("db1")
+ assert(conn.getCatalog === "spark_catalog")
+ assert(conn.getSchema === "db1")
+
+ verifyGetColumns(() => metadata.getColumns("", "", "%", "id")) { getColumnResults =>
+ // results are ordered by TABLE_CAT, TABLE_SCHEM, TABLE_NAME, ORDINAL_POSITION
+ assert {
+ getColumnResults.map { r =>
+ (r.TABLE_CAT, r.TABLE_SCHEM, r.TABLE_NAME, r.ORDINAL_POSITION, r.COLUMN_NAME)
+ } === Seq(
+ ("spark_catalog", "db1", "t1", 1, "id"),
+ ("spark_catalog", "db1", "t1_v", 1, "id"))
+ }
+
+ getColumnResults.foreach(verifyEmptyFields)
+ }
+
+ // list columns of tables with schema pattern, table mame pattern, and column name pattern
+ verifyGetColumns {
+ () => metadata.getColumns(null, "%db_", "%t_", "%d%")
+ } { getColumnResults =>
+ // results are ordered by TABLE_CAT, TABLE_SCHEM, TABLE_NAME, ORDINAL_POSITION
+ assert {
+ getColumnResults.map { r =>
+ (r.TABLE_CAT, r.TABLE_SCHEM, r.TABLE_NAME, r.ORDINAL_POSITION, r.COLUMN_NAME)
+ } === Seq(
+ ("spark_catalog", "db1", "t1", 1, "id"),
+ ("spark_catalog", "db1", "t2", 8, "col_double"),
+ ("spark_catalog", "db1", "t2", 10, "col_decimal"),
+ ("spark_catalog", "db1", "t2", 11, "col_date"),
+ ("spark_catalog", "db_", "t_", 1, "id"),
+ ("testcat", "t_db1", "t_t1", 1, "id"))
+ }
+
+ getColumnResults.foreach(verifyEmptyFields)
+ }
+
+ // escape _ in schema pattern and table mame pattern
+ verifyGetColumns {
+ () => metadata.getColumns(null, "db\\_", "t\\_", "i\\_")
+ } { getColumnResults =>
+ // results are ordered by TABLE_CAT, TABLE_SCHEM, TABLE_NAME, ORDINAL_POSITION
+ assert {
+ getColumnResults.map { r =>
+ (r.TABLE_CAT, r.TABLE_SCHEM, r.TABLE_NAME, r.ORDINAL_POSITION, r.COLUMN_NAME)
+ } === Seq(("spark_catalog", "db_", "t_", 2, "i_"))
+ }
+
+ getColumnResults.foreach(verifyEmptyFields)
+ }
+
+ // skip testing escape ', % in schema pattern, because Spark SQL does not
+ // allow using those chars in schema table name.
+ }
+ }
+ }
}
diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala
index 619b279310eb3..75bd056879eff 100644
--- a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala
+++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectJdbcDataTypeSuite.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.connect.client.jdbc
-import java.sql.Types
+import java.sql.{ResultSet, SQLException, Types}
+
+import scala.util.Using
import org.apache.spark.sql.connect.client.jdbc.test.JdbcHelper
import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession}
@@ -27,6 +29,10 @@ class SparkConnectJdbcDataTypeSuite extends ConnectFunSuite with RemoteSparkSess
override def jdbcUrl: String = s"jdbc:sc://localhost:$serverPort"
+ private def timeToMillis(hour: Int, minute: Int, second: Int, millis: Int): Long = {
+ hour * 3600000L + minute * 60000L + second * 1000L + millis
+ }
+
test("get null type") {
withExecuteQuery("SELECT null") { rs =>
assert(rs.next())
@@ -210,9 +216,572 @@ class SparkConnectJdbcDataTypeSuite extends ConnectFunSuite with RemoteSparkSess
assert(metaData.getColumnTypeName(1) === "STRING")
assert(metaData.getColumnClassName(1) === "java.lang.String")
assert(metaData.isSigned(1) === false)
- assert(metaData.getPrecision(1) === 255)
+ assert(metaData.getPrecision(1) === Int.MaxValue)
assert(metaData.getScale(1) === 0)
- assert(metaData.getColumnDisplaySize(1) === 255)
+ assert(metaData.getColumnDisplaySize(1) === Int.MaxValue)
+ }
+ }
+
+ test("get decimal type") {
+ withStatement { stmt =>
+ Seq(
+ ("123.45", 37, 2, 39),
+ ("-0.12345", 5, 5, 8),
+ ("-0.12345", 6, 5, 8),
+ ("-123.45", 5, 2, 7),
+ ("12345", 5, 0, 6),
+ ("-12345", 5, 0, 6)
+ ).foreach {
+ case (value, precision, scale, expectedColumnDisplaySize) =>
+ val decimalType = s"DECIMAL($precision,$scale)"
+ withExecuteQuery(stmt, s"SELECT cast('$value' as $decimalType)") { rs =>
+ assert(rs.next())
+ assert(rs.getBigDecimal(1) === new java.math.BigDecimal(value))
+ assert(!rs.wasNull)
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnName(1) === s"CAST($value AS $decimalType)")
+ assert(metaData.getColumnLabel(1) === s"CAST($value AS $decimalType)")
+ assert(metaData.getColumnType(1) === Types.DECIMAL)
+ assert(metaData.getColumnTypeName(1) === decimalType)
+ assert(metaData.getColumnClassName(1) === "java.math.BigDecimal")
+ assert(metaData.isSigned(1) === true)
+ assert(metaData.getPrecision(1) === precision)
+ assert(metaData.getScale(1) === scale)
+ assert(metaData.getColumnDisplaySize(1) === expectedColumnDisplaySize)
+ assert(metaData.getColumnDisplaySize(1) >= value.size)
+ }
+ }
+ }
+ }
+
+ test("getter functions column index out of bound") {
+ withStatement { stmt =>
+ Seq(
+ ("'foo'", (rs: ResultSet) => rs.getString(999)),
+ ("true", (rs: ResultSet) => rs.getBoolean(999)),
+ ("cast(1 AS BYTE)", (rs: ResultSet) => rs.getByte(999)),
+ ("cast(1 AS SHORT)", (rs: ResultSet) => rs.getShort(999)),
+ ("cast(1 AS INT)", (rs: ResultSet) => rs.getInt(999)),
+ ("cast(1 AS BIGINT)", (rs: ResultSet) => rs.getLong(999)),
+ ("cast(1 AS FLOAT)", (rs: ResultSet) => rs.getFloat(999)),
+ ("cast(1 AS DOUBLE)", (rs: ResultSet) => rs.getDouble(999)),
+ ("cast(1 AS DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(999)),
+ ("CAST(X'0A0B0C' AS BINARY)", (rs: ResultSet) => rs.getBytes(999)),
+ ("date '2025-11-15'", (rs: ResultSet) => rs.getBytes(999)),
+ ("time '12:34:56.123456'", (rs: ResultSet) => rs.getBytes(999)),
+ ("timestamp '2025-11-15 10:30:45.123456'", (rs: ResultSet) => rs.getTimestamp(999)),
+ ("timestamp_ntz '2025-11-15 10:30:45.789012'", (rs: ResultSet) => rs.getTimestamp(999))
+ ).foreach {
+ case (query, getter) =>
+ withExecuteQuery(stmt, s"SELECT $query") { rs =>
+ assert(rs.next())
+ val exception = intercept[SQLException] {
+ getter(rs)
+ }
+ assert(exception.getMessage() ===
+ "The column index is out of range: 999, number of columns: 1.")
+ }
+ }
+ }
+ }
+
+ test("getter functions called after statement closed") {
+ withStatement { stmt =>
+ Seq(
+ ("'foo'", (rs: ResultSet) => rs.getString(1), "foo"),
+ ("true", (rs: ResultSet) => rs.getBoolean(1), true),
+ ("cast(1 AS BYTE)", (rs: ResultSet) => rs.getByte(1), 1.toByte),
+ ("cast(1 AS SHORT)", (rs: ResultSet) => rs.getShort(1), 1.toShort),
+ ("cast(1 AS INT)", (rs: ResultSet) => rs.getInt(1), 1.toInt),
+ ("cast(1 AS BIGINT)", (rs: ResultSet) => rs.getLong(1), 1.toLong),
+ ("cast(1 AS FLOAT)", (rs: ResultSet) => rs.getFloat(1), 1.toFloat),
+ ("cast(1 AS DOUBLE)", (rs: ResultSet) => rs.getDouble(1), 1.toDouble),
+ ("cast(1 AS DECIMAL(10,5))", (rs: ResultSet) => rs.getBigDecimal(1),
+ new java.math.BigDecimal("1.00000")),
+ ("CAST(X'0A0B0C' AS BINARY)", (rs: ResultSet) => rs.getBytes(1),
+ Array[Byte](0x0A, 0x0B, 0x0C)),
+ ("date '2023-11-15'", (rs: ResultSet) => rs.getDate(1),
+ java.sql.Date.valueOf("2023-11-15")),
+ ("time '12:34:56.123456'", (rs: ResultSet) => rs.getTime(1), {
+ val millis = timeToMillis(12, 34, 56, 123)
+ new java.sql.Time(millis)
+ })
+ ).foreach {
+ case (query, getter, expectedValue) =>
+ var resultSet: Option[ResultSet] = None
+ withExecuteQuery(stmt, s"SELECT $query") { rs =>
+ assert(rs.next())
+ expectedValue match {
+ case arr: Array[Byte] =>
+ assert(getter(rs).asInstanceOf[Array[Byte]].sameElements(arr))
+ case other => assert(getter(rs) === other)
+ }
+ assert(!rs.wasNull)
+ resultSet = Some(rs)
+ }
+ assert(resultSet.isDefined)
+ val exception = intercept[SQLException] {
+ getter(resultSet.get)
+ }
+ assert(exception.getMessage() === "JDBC Statement is closed.")
+ }
+ }
+ }
+
+ test("get date type") {
+ withStatement { stmt =>
+ // Test basic date type
+ withExecuteQuery(stmt, "SELECT date '2023-11-15'") { rs =>
+ assert(rs.next())
+ assert(rs.getDate(1) === java.sql.Date.valueOf("2023-11-15"))
+ assert(!rs.wasNull)
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnName(1) === "DATE '2023-11-15'")
+ assert(metaData.getColumnLabel(1) === "DATE '2023-11-15'")
+ assert(metaData.getColumnType(1) === Types.DATE)
+ assert(metaData.getColumnTypeName(1) === "DATE")
+ assert(metaData.getColumnClassName(1) === "java.sql.Date")
+ assert(metaData.isSigned(1) === false)
+ assert(metaData.getPrecision(1) === 10)
+ assert(metaData.getScale(1) === 0)
+ assert(metaData.getColumnDisplaySize(1) === 10)
+ }
+
+ // Test date type with null
+ withExecuteQuery(stmt, "SELECT cast(null as date)") { rs =>
+ assert(rs.next())
+ assert(rs.getDate(1) === null)
+ assert(rs.wasNull)
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnName(1) === "CAST(NULL AS DATE)")
+ assert(metaData.getColumnLabel(1) === "CAST(NULL AS DATE)")
+ assert(metaData.getColumnType(1) === Types.DATE)
+ assert(metaData.getColumnTypeName(1) === "DATE")
+ assert(metaData.getColumnClassName(1) === "java.sql.Date")
+ assert(metaData.isSigned(1) === false)
+ assert(metaData.getPrecision(1) === 10)
+ assert(metaData.getScale(1) === 0)
+ assert(metaData.getColumnDisplaySize(1) === 10)
+ }
+
+ // Test date type by column label
+ withExecuteQuery(stmt, "SELECT date '2025-11-15' as test_date") { rs =>
+ assert(rs.next())
+ assert(rs.getDate("test_date") === java.sql.Date.valueOf("2025-11-15"))
+ assert(!rs.wasNull)
+ assert(!rs.next())
+ }
+ }
+ }
+
+ test("get binary type") {
+ withStatement { stmt =>
+ // Test basic binary type
+ val testBytes = Array[Byte](0x01, 0x02, 0x03, 0x04, 0x05)
+ val hexString = testBytes.map(b => "%02X".format(b)).mkString
+ withExecuteQuery(stmt, s"SELECT CAST(X'$hexString' AS BINARY)") { rs =>
+ assert(rs.next())
+ val bytes = rs.getBytes(1)
+ assert(bytes !== null)
+ assert(bytes.length === testBytes.length)
+ assert(bytes.sameElements(testBytes))
+ assert(!rs.wasNull)
+
+ val stringValue = rs.getString(1)
+ val expectedString = new String(testBytes, java.nio.charset.StandardCharsets.UTF_8)
+ assert(stringValue === expectedString)
+
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnType(1) === Types.VARBINARY)
+ assert(metaData.getColumnTypeName(1) === "BINARY")
+ assert(metaData.getColumnClassName(1) === "[B")
+ assert(metaData.isSigned(1) === false)
+ }
+
+ // Test binary type with UTF-8 text
+ val textBytes = "\\xDeAdBeEf".getBytes(java.nio.charset.StandardCharsets.UTF_8)
+ val hexString2 = textBytes.map(b => "%02X".format(b)).mkString
+ withExecuteQuery(stmt, s"SELECT CAST(X'$hexString2' AS BINARY)") { rs =>
+ assert(rs.next())
+ val bytes = rs.getBytes(1)
+ assert(bytes !== null)
+ assert(bytes.sameElements(textBytes))
+
+ val stringValue = rs.getString(1)
+ assert(stringValue === "\\xDeAdBeEf")
+
+ assert(!rs.next())
+ }
+
+ // Test binary type with null
+ withExecuteQuery(stmt, "SELECT cast(null as binary)") { rs =>
+ assert(rs.next())
+ assert(rs.getBytes(1) === null)
+ assert(rs.wasNull)
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnType(1) === Types.VARBINARY)
+ assert(metaData.getColumnTypeName(1) === "BINARY")
+ assert(metaData.getColumnClassName(1) === "[B")
+ }
+
+ // Test binary type by column label
+ val testBytes2 = Array[Byte](0x0A, 0x0B, 0x0C)
+ val hexString3 = testBytes2.map(b => "%02X".format(b)).mkString
+ withExecuteQuery(stmt, s"SELECT CAST(X'$hexString3' AS BINARY) as test_binary") { rs =>
+ assert(rs.next())
+ val bytes = rs.getBytes("test_binary")
+ assert(bytes !== null)
+ assert(bytes.length === testBytes2.length)
+ assert(bytes.sameElements(testBytes2))
+ assert(!rs.wasNull)
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnName(1) === "test_binary")
+ assert(metaData.getColumnLabel(1) === "test_binary")
+ }
+
+ // Test empty binary
+ withExecuteQuery(stmt, "SELECT CAST(X'' AS BINARY)") { rs =>
+ assert(rs.next())
+ val bytes = rs.getBytes(1)
+ assert(bytes !== null)
+ assert(bytes.length === 0)
+ assert(!rs.wasNull)
+
+ val stringValue = rs.getString(1)
+ assert(stringValue === "")
+ assert(!rs.next())
+ }
+ }
+ }
+
+ test("get time type") {
+ withStatement { stmt =>
+ // Test basic time type
+ withExecuteQuery(stmt, "SELECT time '12:34:56.123456'") { rs =>
+ assert(rs.next())
+ val time = rs.getTime(1)
+ // Verify milliseconds are preserved (123 from 123456 microseconds)
+ val expectedMillis = timeToMillis(12, 34, 56, 123)
+ assert(time.getTime === expectedMillis)
+ assert(!rs.wasNull)
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnName(1) === "TIME '12:34:56.123456'")
+ assert(metaData.getColumnLabel(1) === "TIME '12:34:56.123456'")
+ assert(metaData.getColumnType(1) === Types.TIME)
+ assert(metaData.getColumnTypeName(1) === "TIME(6)")
+ assert(metaData.getColumnClassName(1) === "java.sql.Time")
+ assert(metaData.isSigned(1) === false)
+ assert(metaData.getPrecision(1) === 6)
+ assert(metaData.getScale(1) === 0)
+ assert(metaData.getColumnDisplaySize(1) === 15)
+ }
+
+ // Test time type with null
+ withExecuteQuery(stmt, "SELECT cast(null as time)") { rs =>
+ assert(rs.next())
+ assert(rs.getTime(1) === null)
+ assert(rs.wasNull)
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnName(1) === "CAST(NULL AS TIME(6))")
+ assert(metaData.getColumnLabel(1) === "CAST(NULL AS TIME(6))")
+ assert(metaData.getColumnType(1) === Types.TIME)
+ assert(metaData.getColumnTypeName(1) === "TIME(6)")
+ assert(metaData.getColumnClassName(1) === "java.sql.Time")
+ assert(metaData.isSigned(1) === false)
+ assert(metaData.getPrecision(1) === 6)
+ assert(metaData.getScale(1) === 0)
+ assert(metaData.getColumnDisplaySize(1) === 15)
+ }
+
+ // Test time type by column label
+ withExecuteQuery(stmt, "SELECT time '09:15:30.456789' as test_time") { rs =>
+ assert(rs.next())
+ val time = rs.getTime("test_time")
+ // Verify milliseconds are preserved (456 from 456789 microseconds)
+ val expectedMillis = timeToMillis(9, 15, 30, 456)
+ assert(time.getTime === expectedMillis)
+ assert(!rs.wasNull)
+ assert(!rs.next())
+ }
+ }
+ }
+
+ test("get time type with different precisions") {
+ withStatement { stmt =>
+ Seq(
+ // (timeValue, precision, expectedDisplaySize, expectedMillis)
+ // HH:MM:SS (no fractional)
+ ("15:45:30.123456", 0, 8, timeToMillis(15, 45, 30, 0)),
+ // HH:MM:SS.f (100ms from .1)
+ ("10:20:30.123456", 1, 10, timeToMillis(10, 20, 30, 100)),
+ // HH:MM:SS.fff (123ms)
+ ("08:15:45.123456", 3, 12, timeToMillis(8, 15, 45, 123)),
+ // HH:MM:SS.fff (999ms) . Spark TIME values can have microsecond precision,
+ // but java.sql.Time can only store up to millisecond precision
+ ("23:59:59.999999", 6, 15, timeToMillis(23, 59, 59, 999))
+ ).foreach {
+ case (timeValue, precision, expectedDisplaySize, expectedMillis) =>
+ withExecuteQuery(stmt, s"SELECT cast(time '$timeValue' as time($precision))") { rs =>
+ assert(rs.next(), s"Failed to get next row for precision $precision")
+ val time = rs.getTime(1)
+ assert(time.getTime === expectedMillis,
+ s"Time millis mismatch for precision" +
+ s" $precision: expected $expectedMillis, got ${time.getTime}")
+ assert(!rs.wasNull, s"wasNull should be false for precision $precision")
+ assert(!rs.next(), s"Should have no more rows for precision $precision")
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnType(1) === Types.TIME,
+ s"Column type mismatch for precision $precision")
+ assert(metaData.getColumnTypeName(1) === s"TIME($precision)",
+ s"Column type name mismatch for precision $precision")
+ assert(metaData.getColumnClassName(1) === "java.sql.Time",
+ s"Column class name mismatch for precision $precision")
+ assert(metaData.getPrecision(1) === precision,
+ s"Precision mismatch for precision $precision")
+ assert(metaData.getScale(1) === 0,
+ s"Scale should be 0 for precision $precision")
+ assert(metaData.getColumnDisplaySize(1) === expectedDisplaySize,
+ s"Display size mismatch for precision $precision: " +
+ s"expected $expectedDisplaySize, got ${metaData.getColumnDisplaySize(1)}")
+ }
+ }
+ }
+ }
+
+ test("get date type with spark.sql.datetime.java8API.enabled") {
+ withStatement { stmt =>
+ Seq(true, false).foreach { java8APIEnabled =>
+ stmt.execute(s"set spark.sql.datetime.java8API.enabled=$java8APIEnabled")
+ Using.resource(stmt.executeQuery("SELECT date '2025-11-15'")) { rs =>
+ assert(rs.next())
+ assert(rs.getDate(1) === java.sql.Date.valueOf("2025-11-15"))
+ assert(!rs.wasNull)
+ assert(!rs.next())
+ }
+ }
+ }
+ }
+
+ test("get time type with spark.sql.datetime.java8API.enabled") {
+ withStatement { stmt =>
+ Seq(true, false).foreach { java8APIEnabled =>
+ stmt.execute(s"set spark.sql.datetime.java8API.enabled=$java8APIEnabled")
+ Using.resource(stmt.executeQuery("SELECT time '12:34:56.123456'")) { rs =>
+ assert(rs.next())
+ val time = rs.getTime(1)
+ val expectedMillis = timeToMillis(12, 34, 56, 123)
+ assert(time.getTime === expectedMillis)
+ assert(!rs.wasNull)
+ assert(!rs.next())
+ }
+ }
+ }
+ }
+
+ test("get timestamp type") {
+ withStatement { stmt =>
+ // Test basic timestamp type
+ withExecuteQuery(stmt, "SELECT timestamp '2025-11-15 10:30:45.123456'") { rs =>
+ assert(rs.next())
+ val timestamp = rs.getTimestamp(1)
+ assert(timestamp !== null)
+ assert(timestamp === java.sql.Timestamp.valueOf("2025-11-15 10:30:45.123456"))
+ assert(!rs.wasNull)
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnName(1) === "TIMESTAMP '2025-11-15 10:30:45.123456'")
+ assert(metaData.getColumnLabel(1) === "TIMESTAMP '2025-11-15 10:30:45.123456'")
+ assert(metaData.getColumnType(1) === Types.TIMESTAMP)
+ assert(metaData.getColumnTypeName(1) === "TIMESTAMP")
+ assert(metaData.getColumnClassName(1) === "java.sql.Timestamp")
+ assert(metaData.isSigned(1) === false)
+ assert(metaData.getPrecision(1) === 29)
+ assert(metaData.getScale(1) === 6)
+ assert(metaData.getColumnDisplaySize(1) === 29)
+ }
+
+ // Test timestamp type with null
+ withExecuteQuery(stmt, "SELECT cast(null as timestamp)") { rs =>
+ assert(rs.next())
+ assert(rs.getTimestamp(1) === null)
+ assert(rs.wasNull)
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnName(1) === "CAST(NULL AS TIMESTAMP)")
+ assert(metaData.getColumnLabel(1) === "CAST(NULL AS TIMESTAMP)")
+ assert(metaData.getColumnType(1) === Types.TIMESTAMP)
+ assert(metaData.getColumnTypeName(1) === "TIMESTAMP")
+ assert(metaData.getColumnClassName(1) === "java.sql.Timestamp")
+ assert(metaData.isSigned(1) === false)
+ assert(metaData.getPrecision(1) === 29)
+ assert(metaData.getScale(1) === 6)
+ assert(metaData.getColumnDisplaySize(1) === 29)
+ }
+
+ // Test timestamp type by column label and with calendar
+ val tsString = "2025-11-15 10:30:45.987654"
+ withExecuteQuery(stmt, s"SELECT timestamp '$tsString' as test_timestamp") { rs =>
+ assert(rs.next())
+
+ // Test by column label
+ val timestamp = rs.getTimestamp("test_timestamp")
+ assert(timestamp !== null)
+ assert(timestamp === java.sql.Timestamp.valueOf(tsString))
+ assert(!rs.wasNull)
+
+ // Test with calendar - should return same value (Calendar is ignored)
+ // Note: Spark Connect handles timezone at server, Calendar param is for API compliance
+ val calUTC = java.util.Calendar.getInstance(java.util.TimeZone.getTimeZone("UTC"))
+ val timestampUTC = rs.getTimestamp(1, calUTC)
+ assert(timestampUTC !== null)
+ assert(timestampUTC.getTime === timestamp.getTime)
+
+ val calPST = java.util.Calendar.getInstance(
+ java.util.TimeZone.getTimeZone("America/Los_Angeles"))
+ val timestampPST = rs.getTimestamp(1, calPST)
+ assert(timestampPST !== null)
+ // Same value regardless of calendar
+ assert(timestampPST.getTime === timestamp.getTime)
+ assert(timestampUTC.getTime === timestampPST.getTime)
+
+ // Test with calendar by label
+ val timestampLabel = rs.getTimestamp("test_timestamp", calUTC)
+ assert(timestampLabel !== null)
+ assert(timestampLabel.getTime === timestamp.getTime)
+
+ // Test with null calendar - returns same value
+ val timestampNullCal = rs.getTimestamp(1, null)
+ assert(timestampNullCal !== null)
+ assert(timestampNullCal.getTime === timestamp.getTime)
+
+ assert(!rs.next())
+ }
+
+ // Test timestamp type with calendar for null value
+ withExecuteQuery(stmt, "SELECT cast(null as timestamp)") { rs =>
+ assert(rs.next())
+
+ // Calendar parameter should not affect null handling
+ val cal = java.util.Calendar.getInstance(java.util.TimeZone.getTimeZone("UTC"))
+ val timestamp = rs.getTimestamp(1, cal)
+ assert(timestamp === null)
+ assert(rs.wasNull)
+ assert(!rs.next())
+ }
+ }
+ }
+
+ test("get timestamp_ntz type") {
+ withStatement { stmt =>
+ // Test basic timestamp_ntz type
+ withExecuteQuery(stmt, "SELECT timestamp_ntz '2025-11-15 10:30:45.123456'") { rs =>
+ assert(rs.next())
+ val timestamp = rs.getTimestamp(1)
+ assert(timestamp !== null)
+ assert(timestamp === java.sql.Timestamp.valueOf("2025-11-15 10:30:45.123456"))
+ assert(!rs.wasNull)
+ assert(!rs.next())
+
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 1)
+ assert(metaData.getColumnName(1) === "TIMESTAMP_NTZ '2025-11-15 10:30:45.123456'")
+ assert(metaData.getColumnLabel(1) === "TIMESTAMP_NTZ '2025-11-15 10:30:45.123456'")
+ assert(metaData.getColumnType(1) === Types.TIMESTAMP)
+ assert(metaData.getColumnTypeName(1) === "TIMESTAMP_NTZ")
+ assert(metaData.getColumnClassName(1) === "java.sql.Timestamp")
+ assert(metaData.isSigned(1) === false)
+ assert(metaData.getPrecision(1) === 29)
+ assert(metaData.getScale(1) === 6)
+ assert(metaData.getColumnDisplaySize(1) === 29)
+ }
+
+ // Test timestamp_ntz by label, null, and with calendar - non-null value
+ val tsString = "2025-11-15 14:22:33.789456"
+ withExecuteQuery(stmt, s"SELECT timestamp_ntz '$tsString' as test_ts_ntz") { rs =>
+ assert(rs.next())
+
+ // Test by column label
+ val timestamp = rs.getTimestamp("test_ts_ntz")
+ assert(timestamp !== null)
+ assert(timestamp === java.sql.Timestamp.valueOf(tsString))
+ assert(!rs.wasNull)
+
+ // Test with calendar - should return same value (Calendar is ignored)
+ val calUTC = java.util.Calendar.getInstance(java.util.TimeZone.getTimeZone("UTC"))
+ val timestampCal = rs.getTimestamp(1, calUTC)
+ assert(timestampCal !== null)
+ assert(timestampCal.getTime === timestamp.getTime)
+
+ assert(!rs.next())
+ }
+
+ // Test timestamp_ntz with null value
+ withExecuteQuery(stmt, "SELECT cast(null as timestamp_ntz)") { rs =>
+ assert(rs.next())
+ assert(rs.getTimestamp(1) === null)
+ assert(rs.wasNull)
+ assert(!rs.next())
+ }
+ }
+ }
+
+ test("get timestamp types with spark.sql.datetime.java8API.enabled") {
+ withStatement { stmt =>
+ Seq(true, false).foreach { java8APIEnabled =>
+ stmt.execute(s"set spark.sql.datetime.java8API.enabled=$java8APIEnabled")
+
+ Using.resource(stmt.executeQuery(
+ """SELECT
+ | timestamp '2025-11-15 10:30:45.123456' as ts,
+ | timestamp_ntz '2025-11-15 14:22:33.789012' as ts_ntz
+ |""".stripMargin)) { rs =>
+ assert(rs.next())
+
+ // Test TIMESTAMP type
+ val timestamp = rs.getTimestamp(1)
+ assert(timestamp !== null)
+ assert(timestamp === java.sql.Timestamp.valueOf("2025-11-15 10:30:45.123456"))
+ assert(!rs.wasNull)
+
+ // Test TIMESTAMP_NTZ type
+ val timestampNtz = rs.getTimestamp(2)
+ assert(timestampNtz !== null)
+ assert(timestampNtz === java.sql.Timestamp.valueOf("2025-11-15 14:22:33.789012"))
+ assert(!rs.wasNull)
+
+ assert(!rs.next())
+ }
+ }
}
}
}
diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSetSuite.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSetSuite.scala
index ac2866837e939..21b8e261aef45 100644
--- a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSetSuite.scala
+++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectResultSetSuite.scala
@@ -122,4 +122,68 @@ class SparkConnectResultSetSuite extends ConnectFunSuite with RemoteSparkSession
assert(rs.isAfterLast)
}
}
+
+ test("getTimestamp with multiple columns, rows, and types") {
+ withExecuteQuery(
+ """SELECT ts_tz, ts_ntz, id FROM VALUES
+ | (timestamp '2025-01-15 10:30:45.123456', timestamp_ntz '2025-06-20 14:22:33.789012', 1),
+ | (null, timestamp_ntz '2025-03-01 18:30:45.456789', 2),
+ | (timestamp '2025-10-31 23:59:59.999999', null, 3)
+ | AS t(ts_tz, ts_ntz, id)
+ |""".stripMargin) { rs =>
+
+ // Test findColumn
+ assert(rs.findColumn("ts_tz") === 1)
+ assert(rs.findColumn("ts_ntz") === 2)
+ assert(rs.findColumn("id") === 3)
+
+ // Verify metadata
+ val metaData = rs.getMetaData
+ assert(metaData.getColumnCount === 3)
+ assert(metaData.getColumnTypeName(1) === "TIMESTAMP")
+ assert(metaData.getColumnTypeName(2) === "TIMESTAMP_NTZ")
+
+ // Row 1: Both timestamps have values
+ assert(rs.next())
+ assert(rs.getRow === 1)
+
+ val ts1 = rs.getTimestamp(1)
+ assert(ts1 !== null)
+ assert(ts1 === java.sql.Timestamp.valueOf("2025-01-15 10:30:45.123456"))
+ assert(!rs.wasNull)
+
+ val tsNtz1 = rs.getTimestamp("ts_ntz")
+ assert(tsNtz1 !== null)
+ assert(tsNtz1 === java.sql.Timestamp.valueOf("2025-06-20 14:22:33.789012"))
+ assert(!rs.wasNull)
+
+ // Row 2: TIMESTAMP is null, TIMESTAMP_NTZ has value
+ assert(rs.next())
+ assert(rs.getRow === 2)
+
+ val ts2 = rs.getTimestamp(1)
+ assert(ts2 === null)
+ assert(rs.wasNull)
+
+ val tsNtz2 = rs.getTimestamp(2)
+ assert(tsNtz2 !== null)
+ assert(tsNtz2 === java.sql.Timestamp.valueOf("2025-03-01 18:30:45.456789"))
+ assert(!rs.wasNull)
+
+ // Row 3: TIMESTAMP has value, TIMESTAMP_NTZ is null
+ assert(rs.next())
+ assert(rs.getRow === 3)
+
+ val ts3 = rs.getTimestamp("ts_tz")
+ assert(ts3 !== null)
+ assert(ts3 === java.sql.Timestamp.valueOf("2025-10-31 23:59:59.999999"))
+ assert(!rs.wasNull)
+
+ val tsNtz3 = rs.getTimestamp(2)
+ assert(tsNtz3 === null)
+ assert(rs.wasNull)
+
+ assert(!rs.next())
+ }
+ }
}
diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatementSuite.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatementSuite.scala
new file mode 100644
index 0000000000000..fa9df3f1247f7
--- /dev/null
+++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectStatementSuite.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.client.jdbc
+
+import java.sql.{Array => _, _}
+
+import scala.util.Using
+
+import org.apache.spark.sql.connect.client.jdbc.test.JdbcHelper
+import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession, SQLHelper}
+
+class SparkConnectStatementSuite extends ConnectFunSuite with RemoteSparkSession
+ with JdbcHelper with SQLHelper {
+
+ override def jdbcUrl: String = s"jdbc:sc://localhost:$serverPort"
+
+ test("returned result set and update count of execute* methods") {
+ withTable("t1", "t2", "t3") {
+ withStatement { stmt =>
+ // CREATE TABLE
+ assert(!stmt.execute("CREATE TABLE t1 (id INT) USING Parquet"))
+ assert(stmt.getUpdateCount === 0)
+ assert(stmt.getResultSet === null)
+
+ var se = intercept[SQLException] {
+ stmt.executeQuery("CREATE TABLE t2 (id INT) USING Parquet")
+ }
+ assert(se.getMessage === "The query does not produce a ResultSet.")
+
+ assert(stmt.executeUpdate("CREATE TABLE t3 (id INT) USING Parquet") === 0)
+ assert(stmt.getResultSet === null)
+
+ // INSERT INTO
+ assert(!stmt.execute("INSERT INTO t1 VALUES (1)"))
+ assert(stmt.getUpdateCount === 0)
+ assert(stmt.getResultSet === null)
+
+ se = intercept[SQLException] {
+ stmt.executeQuery("INSERT INTO t1 VALUES (1)")
+ }
+ assert(se.getMessage === "The query does not produce a ResultSet.")
+
+ assert(stmt.executeUpdate("INSERT INTO t1 VALUES (1)") === 0)
+ assert(stmt.getResultSet === null)
+
+ // SELECT
+ assert(stmt.execute("SELECT id FROM t1"))
+ assert(stmt.getUpdateCount === -1)
+ Using.resource(stmt.getResultSet) { rs =>
+ assert(rs !== null)
+ }
+
+ Using.resource(stmt.executeQuery("SELECT id FROM t1")) { rs =>
+ assert(stmt.getUpdateCount === -1)
+ assert(rs !== null)
+ }
+
+ se = intercept[SQLException] {
+ stmt.executeUpdate("SELECT id FROM t1")
+ }
+ assert(se.getMessage === "The query produces a ResultSet.")
+ }
+ }
+ }
+
+ test("max rows from SparkConnectStatement") {
+ def verifyMaxRows(
+ expectedRows: Int, query: String)(stmt: Statement): Unit = {
+ Using(stmt.executeQuery(query)) { rs =>
+ (0 until expectedRows).foreach { _ =>
+ assert(rs.next())
+ }
+ assert(!rs.next())
+ }
+ }
+
+ withStatement { stmt =>
+ // by default, it has no max rows limitation
+ assert(stmt.getMaxRows === 0)
+ verifyMaxRows(10, "SELECT id FROM range(10)")(stmt)
+
+ val se = intercept[SQLException] {
+ stmt.setMaxRows(-1)
+ }
+ assert(se.getMessage === "The max rows must be zero or a positive integer.")
+
+ stmt.setMaxRows(5)
+ assert(stmt.getMaxRows === 5)
+ verifyMaxRows(5, "SELECT id FROM range(10)")(stmt)
+
+ // set max rows for query that has LIMIT
+ stmt.setMaxRows(5)
+ assert(stmt.getMaxRows === 5)
+ verifyMaxRows(3, "SELECT id FROM range(10) LIMIT 3")(stmt)
+ verifyMaxRows(5, "SELECT id FROM range(10) LIMIT 8")(stmt)
+
+ // set max rows for one statement won't affect others
+ withStatement { stmt2 =>
+ assert(stmt2.getMaxRows === 0)
+ verifyMaxRows(10, "SELECT id FROM range(10)")(stmt2)
+ }
+ }
+ }
+}
diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/test/JdbcHelper.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/test/JdbcHelper.scala
index 9b3aa373e93ce..a512a44cac3b3 100644
--- a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/test/JdbcHelper.scala
+++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/test/JdbcHelper.scala
@@ -39,8 +39,10 @@ trait JdbcHelper {
}
def withExecuteQuery(query: String)(f: ResultSet => Unit): Unit = {
- withStatement { stmt =>
- Using.resource { stmt.executeQuery(query) } { rs => f(rs) }
- }
+ withStatement { stmt => withExecuteQuery(stmt, query)(f) }
+ }
+
+ def withExecuteQuery(stmt: Statement, query: String)(f: ResultSet => Unit): Unit = {
+ Using.resource { stmt.executeQuery(query) } { rs => f(rs) }
}
}
diff --git a/sql/connect/client/jvm/pom.xml b/sql/connect/client/jvm/pom.xml
index 1c16b7e9ca8ca..f939c328c3cb2 100644
--- a/sql/connect/client/jvm/pom.xml
+++ b/sql/connect/client/jvm/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../../../pom.xml
@@ -74,16 +74,24 @@
protobuf-java
compile
+
+ com.google.protobuf
+ protobuf-java-util
+ compile
+
com.google.guava
guava
- ${connect.guava.version}
compile
com.google.guava
failureaccess
- ${guava.failureaccess.version}
+ compile
+
+
+ com.github.luben
+ zstd-jni
compile
- org.apache.tomcat
- annotations-api
- ${tomcat.annotations.api.version}
+
+ com.github.luben
+ zstd-jni
false
org.spark-project.spark:unused
- org.apache.tomcat:annotations-api
diff --git a/sql/connect/common/src/main/buf.gen.yaml b/sql/connect/common/src/main/buf.gen.yaml
index d6120bfd36fa1..beaa7f1949e25 100644
--- a/sql/connect/common/src/main/buf.gen.yaml
+++ b/sql/connect/common/src/main/buf.gen.yaml
@@ -16,20 +16,20 @@
#
version: v1
plugins:
- - plugin: buf.build/protocolbuffers/cpp:v29.5
+ - plugin: buf.build/protocolbuffers/cpp:v33.0
out: gen/proto/cpp
- - plugin: buf.build/protocolbuffers/csharp:v29.5
+ - plugin: buf.build/protocolbuffers/csharp:v33.0
out: gen/proto/csharp
- - plugin: buf.build/protocolbuffers/java:v29.5
+ - plugin: buf.build/protocolbuffers/java:v33.0
out: gen/proto/java
- - plugin: buf.build/grpc/ruby:v1.67.0
+ - plugin: buf.build/grpc/ruby:v1.76.0
out: gen/proto/ruby
- - plugin: buf.build/protocolbuffers/ruby:v29.5
+ - plugin: buf.build/protocolbuffers/ruby:v33.0
out: gen/proto/ruby
# Building the Python build and building the mypy interfaces.
- - plugin: buf.build/protocolbuffers/python:v29.5
+ - plugin: buf.build/protocolbuffers/python:v33.0
out: gen/proto/python
- - plugin: buf.build/grpc/python:v1.67.0
+ - plugin: buf.build/grpc/python:v1.76.0
out: gen/proto/python
- name: mypy
out: gen/proto/python
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/base.proto b/sql/connect/common/src/main/protobuf/spark/connect/base.proto
index 6e1029bf0a6a2..a97d2d25f490e 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -33,17 +33,35 @@ option java_package = "org.apache.spark.connect.proto";
option go_package = "internal/generated";
// A [[Plan]] is the structure that carries the runtime information for the execution from the
-// client to the server. A [[Plan]] can either be of the type [[Relation]] which is a reference
-// to the underlying logical plan or it can be of the [[Command]] type that is used to execute
-// commands on the server.
+// client to the server. A [[Plan]] can be one of the following:
+// - [[Relation]]: a reference to the underlying logical plan.
+// - [[Command]]: used to execute commands on the server.
+// - [[CompressedOperation]]: a compressed representation of either a Relation or a Command.
message Plan {
oneof op_type {
Relation root = 1;
Command command = 2;
+ CompressedOperation compressed_operation = 3;
}
-}
+ message CompressedOperation {
+ bytes data = 1;
+ OpType op_type = 2;
+ CompressionCodec compression_codec = 3;
+ enum OpType {
+ OP_TYPE_UNSPECIFIED = 0;
+ OP_TYPE_RELATION = 1;
+ OP_TYPE_COMMAND = 2;
+ }
+ }
+}
+
+// Compression codec for plan compression.
+enum CompressionCodec {
+ COMPRESSION_CODEC_UNSPECIFIED = 0;
+ COMPRESSION_CODEC_ZSTD = 1;
+}
// User Context is used to refer to one particular user session that is executing
// queries in the backend.
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
index c6a5e571f9792..0874c2d10ec5c 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto
@@ -104,6 +104,9 @@ message PipelineCommand {
spark.connect.DataType schema_data_type = 4;
string schema_string = 5;
}
+
+ // Optional cluster columns for the table.
+ repeated string clustering_columns = 6;
}
// Metadata that's only applicable to sinks.
@@ -149,6 +152,13 @@ message PipelineCommand {
optional spark.connect.Relation relation = 1;
}
+ // If true, define the flow as a one-time flow, such as for backfill.
+ // Set to true changes the flow in two ways:
+ // - The flow is run one time by default. If the pipeline is ran with a full refresh,
+ // the flow will run again.
+ // - The flow function must be a batch DataFrame, not a streaming DataFrame.
+ optional bool once = 8;
+
message Response {
// Fully qualified flow name that uniquely identify a flow in the Dataflow graph.
optional string flow_name = 1;
@@ -289,6 +299,8 @@ message PipelineAnalysisContext {
optional string dataflow_graph_id = 1;
// The path of the top-level pipeline file determined at runtime during pipeline initialization.
optional string definition_path = 2;
+ // The name of the Flow involved in this analysis
+ optional string flow_name = 3;
// Reserved field for protocol extensions.
repeated google.protobuf.Any extension = 999;
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/types.proto b/sql/connect/common/src/main/protobuf/spark/connect/types.proto
index 1800e3885774f..caaa2340f95dd 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/types.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/types.proto
@@ -67,15 +67,17 @@ message DataType {
// UserDefinedType
UDT udt = 23;
+ // Geospatial types
+ Geometry geometry = 26;
+
+ Geography geography = 27;
+
// UnparsedDataType
Unparsed unparsed = 24;
Time time = 28;
}
- // Reserved for geometry and geography types
- reserved 26, 27;
-
message Boolean {
uint32 type_variation_reference = 1;
}
@@ -192,6 +194,16 @@ message DataType {
uint32 type_variation_reference = 4;
}
+ message Geometry {
+ int32 srid = 1;
+ uint32 type_variation_reference = 2;
+ }
+
+ message Geography {
+ int32 srid = 1;
+ uint32 type_variation_reference = 2;
+ }
+
message Variant {
uint32 type_variation_reference = 1;
}
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
index 0d9d4e5d60f0a..42dd1a2b99793 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
@@ -51,13 +51,13 @@ import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{agnosticEncoderFor, BoxedLongEncoder, UnboundRowEncoder}
import org.apache.spark.sql.connect.ColumnNodeToProtoConverter.toLiteral
import org.apache.spark.sql.connect.ConnectConversions._
-import org.apache.spark.sql.connect.client.{ClassFinder, CloseableIterator, SparkConnectClient, SparkResult}
+import org.apache.spark.sql.connect.client.{ClassFinder, SparkConnectClient, SparkResult}
import org.apache.spark.sql.connect.client.SparkConnectClient.Configuration
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.internal.{SessionState, SharedState, SqlApiConf, SubqueryExpression}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.ExecutionListenerManager
+import org.apache.spark.sql.util.{CloseableIterator, ExecutionListenerManager}
import org.apache.spark.util.ArrayImplicits._
/**
@@ -118,39 +118,73 @@ class SparkSession private[sql] (
newDataset(encoder) { builder =>
if (data.nonEmpty) {
val threshold = conf.get(SqlApiConf.LOCAL_RELATION_CACHE_THRESHOLD_KEY).toInt
- val maxRecordsPerBatch = conf.get(SqlApiConf.LOCAL_RELATION_CHUNK_SIZE_ROWS_KEY).toInt
- val maxBatchSize = conf.get(SqlApiConf.LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY).toInt
+ val maxChunkSizeRows = conf.get(SqlApiConf.LOCAL_RELATION_CHUNK_SIZE_ROWS_KEY).toInt
+ val maxChunkSizeBytes = conf.get(SqlApiConf.LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY).toInt
+ val maxBatchOfChunksSize =
+ conf.get(SqlApiConf.LOCAL_RELATION_BATCH_OF_CHUNKS_SIZE_BYTES_KEY).toLong
+
// Serialize with chunking support
val it = ArrowSerializer.serialize(
data,
encoder,
allocator,
- maxRecordsPerBatch = maxRecordsPerBatch,
- maxBatchSize = maxBatchSize,
+ maxRecordsPerBatch = maxChunkSizeRows,
+ maxBatchSize = math.min(maxChunkSizeBytes, maxBatchOfChunksSize),
timeZoneId = timeZoneId,
largeVarTypes = largeVarTypes,
- batchSizeCheckInterval = math.min(1024, maxRecordsPerBatch))
+ batchSizeCheckInterval = math.min(1024, maxChunkSizeRows))
+
+ try {
+ val schemaBytes = encoder.schema.json.getBytes
+ // Schema is the first chunk, data chunks follow from the iterator
+ val currentBatch = scala.collection.mutable.ArrayBuffer[Array[Byte]](schemaBytes)
+ var totalChunks = 1
+ var currentBatchSize = schemaBytes.length.toLong
+ var totalSize = currentBatchSize
+
+ // store all hashes of uploaded chunks. The first hash is schema, rest are data hashes
+ val allHashes = scala.collection.mutable.ArrayBuffer[String]()
+ while (it.hasNext) {
+ val chunk = it.next()
+ val chunkSize = chunk.length
+ totalChunks += 1
+ totalSize += chunkSize
+
+ // Check if adding this chunk would exceed batch size
+ if (currentBatchSize + chunkSize > maxBatchOfChunksSize) {
+ // Upload current batch
+ allHashes ++= client.artifactManager.cacheArtifacts(currentBatch.toArray)
+ // Start new batch
+ currentBatch.clear()
+ currentBatchSize = 0
+ }
- val chunks =
- try {
- it.toArray
- } finally {
- it.close()
+ currentBatch += chunk
+ currentBatchSize += chunkSize
}
- // If we got multiple chunks or a single large chunk, use ChunkedCachedLocalRelation
- val totalSize = chunks.map(_.length).sum
- if (chunks.length > 1 || totalSize > threshold) {
- val (dataHashes, schemaHash) = client.cacheLocalRelation(chunks, encoder.schema.json)
- builder.getChunkedCachedLocalRelationBuilder
- .setSchemaHash(schemaHash)
- .addAllDataHashes(dataHashes.asJava)
- } else {
- // Small data, use LocalRelation directly
- val arrowData = ByteString.copyFrom(chunks(0))
- builder.getLocalRelationBuilder
- .setSchema(encoder.schema.json)
- .setData(arrowData)
+ // Decide whether to use LocalRelation or ChunkedCachedLocalRelation
+ if (totalChunks == 2 && totalSize <= threshold) {
+ // Schema + single small data chunk: use LocalRelation with inline data
+ val arrowData = ByteString.copyFrom(currentBatch.last)
+ builder.getLocalRelationBuilder
+ .setSchema(encoder.schema.json)
+ .setData(arrowData)
+ } else {
+ // Multiple data chunks or large data: use ChunkedCachedLocalRelation
+ // Upload remaining batch
+ allHashes ++= client.artifactManager.cacheArtifacts(currentBatch.toArray)
+
+ // First hash is schema, rest are data
+ val schemaHash = allHashes.head
+ val dataHashes = allHashes.tail
+
+ builder.getChunkedCachedLocalRelationBuilder
+ .setSchemaHash(schemaHash)
+ .addAllDataHashes(dataHashes.asJava)
+ }
+ } finally {
+ it.close()
}
} else {
builder.getLocalRelationBuilder
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQueryListenerBus.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQueryListenerBus.scala
index 52b0ea24e9e33..7548238351468 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQueryListenerBus.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/StreamingQueryListenerBus.scala
@@ -23,9 +23,9 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.connect.proto.{Command, ExecutePlanResponse, Plan, StreamingQueryEventType}
import org.apache.spark.internal.{Logging, LogKeys}
-import org.apache.spark.sql.connect.client.CloseableIterator
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.sql.streaming.StreamingQueryListener.{Event, QueryIdleEvent, QueryProgressEvent, QueryStartedEvent, QueryTerminatedEvent}
+import org.apache.spark.sql.util.CloseableIterator
class StreamingQueryListenerBus(sparkSession: SparkSession) extends Logging {
private val listeners = new CopyOnWriteArrayList[StreamingQueryListener]()
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/TableValuedFunction.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/TableValuedFunction.scala
index 05fc4b441f98e..e773f66cd6d05 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/TableValuedFunction.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/TableValuedFunction.scala
@@ -102,4 +102,8 @@ class TableValuedFunction(sparkSession: SparkSession) extends sql.TableValuedFun
/** @inheritdoc */
override def variant_explode_outer(input: Column): Dataset[Row] =
fn("variant_explode_outer", Seq(input))
+
+ /** @inheritdoc */
+ override def python_worker_logs(): Dataset[Row] =
+ fn("python_worker_logs", Seq.empty)
}
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
index 913f068fcf345..715da0df73491 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala
@@ -21,6 +21,7 @@ import scala.jdk.CollectionConverters._
import io.grpc.ManagedChannel
import org.apache.spark.connect.proto._
+import org.apache.spark.sql.util.CloseableIterator
private[connect] class CustomSparkConnectBlockingStub(
channel: ManagedChannel,
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
index f3c13c9c2c4d8..131a2e77cc431 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ExecutePlanResponseReattachableIterator.scala
@@ -28,6 +28,7 @@ import io.grpc.stub.StreamObserver
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.connect.client.GrpcRetryHandler.RetryException
+import org.apache.spark.sql.util.WrappedCloseableIterator
/**
* Retryable iterator of ExecutePlanResponses to an ExecutePlan call.
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
index d3dae47f4c471..7e0b0949fcf1d 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala
@@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException,
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.streaming.StreamingQueryException
+import org.apache.spark.sql.util.{CloseableIterator, WrappedCloseableIterator}
import org.apache.spark.util.ArrayImplicits._
/**
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
index 3f4558ee97dad..d92dc902fedc5 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcRetryHandler.scala
@@ -23,6 +23,7 @@ import io.grpc.stub.StreamObserver
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{ERROR, NUM_RETRY, POLICY, RETRY_WAIT_TIME}
+import org.apache.spark.sql.util.{CloseableIterator, WrappedCloseableIterator}
private[sql] class GrpcRetryHandler(
private val policies: Seq[RetryPolicy],
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala
index 03548120457f3..6cf39b8d18798 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ResponseValidator.scala
@@ -23,6 +23,7 @@ import io.grpc.{Status, StatusRuntimeException}
import io.grpc.stub.StreamObserver
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.util.{CloseableIterator, WrappedCloseableIterator}
// This is common logic to be shared between different stub instances to keep the server-side
// session id and to validate responses as seen by the client.
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index fa32eba91eb2c..5d36fc45f9480 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -24,15 +24,23 @@ import java.util.concurrent.Executor
import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.util.Properties
+import scala.util.control.NonFatal
+import com.google.protobuf
+import com.google.protobuf.ByteString
import io.grpc._
import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION}
+import org.apache.spark.SparkThrowable
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.UserContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.LogKeys.{ERROR, RATIO, SIZE, TIME}
+import org.apache.spark.sql.connect.RuntimeConfig
import org.apache.spark.sql.connect.common.ProtoUtils
import org.apache.spark.sql.connect.common.config.ConnectCommon
+import org.apache.spark.sql.util.CloseableIterator
import org.apache.spark.util.SparkSystemUtils
/**
@@ -40,7 +48,8 @@ import org.apache.spark.util.SparkSystemUtils
*/
private[sql] class SparkConnectClient(
private[sql] val configuration: SparkConnectClient.Configuration,
- private[sql] val channel: ManagedChannel) {
+ private[sql] val channel: ManagedChannel)
+ extends Logging {
private val userContext: UserContext = configuration.userContext
@@ -64,6 +73,70 @@ private[sql] class SparkConnectClient(
// a new client will create a new session ID.
private[sql] val sessionId: String = configuration.sessionId.getOrElse(UUID.randomUUID.toString)
+ private val conf: RuntimeConfig = new RuntimeConfig(this)
+
+ // Cached plan compression options.
+ private var _planCompressionOptions: Option[Option[PlanCompressionOptions]] = None
+
+ // Get the plan compression options. The options are cached after the first call.
+ private[sql] def getPlanCompressionOptions: Option[PlanCompressionOptions] = {
+ _planCompressionOptions match {
+ case Some(options) => options
+ case None =>
+ val options =
+ try {
+ Some(
+ PlanCompressionOptions(
+ thresholdBytes =
+ conf.get("spark.connect.session.planCompression.threshold").toInt,
+ algorithm = conf.get("spark.connect.session.planCompression.defaultAlgorithm")))
+ } catch {
+ // Disable plan compression if the server does not support it. Other exceptions are not
+ // swallowed.
+ case e: NoSuchElementException =>
+ logWarning(
+ log"Plan compression is disabled because the server does not support it",
+ e)
+ None
+ case e: SparkThrowable
+ if e.getCondition == "INVALID_CONF_VALUE"
+ || e.getCondition == "SQL_CONF_NOT_FOUND"
+ || e.getCondition == "CONFIG_NOT_AVAILABLE" =>
+ logWarning(
+ log"Plan compression is disabled because the server does not support it",
+ e)
+ None
+ }
+ _planCompressionOptions = Some(options)
+ options
+ }
+ }
+
+ // For testing and internal use only.
+ private[sql] def setPlanCompressionOptions(
+ planCompressionOptions: Option[PlanCompressionOptions]): Unit = {
+ _planCompressionOptions = Some(planCompressionOptions)
+ }
+
+ /**
+ * Handle plan compression errors.
+ */
+ private def handlePlanCompressionErrors[E](fn: => E): E = {
+ try {
+ fn
+ } catch {
+ // If the server cannot parse the compressed plan, disable plan compression for subsequent
+ // requests on the session.
+ case e: SparkThrowable if e.getCondition == "CONNECT_INVALID_PLAN.CANNOT_PARSE" =>
+ logWarning(
+ log"Disabling plan compression for the session due to " +
+ log"CONNECT_INVALID_PLAN.CANNOT_PARSE error.")
+ setPlanCompressionOptions(None)
+ // Retry the code block without plan compression.
+ fn
+ }
+ }
+
/**
* Hijacks the stored server side session ID with the given suffix. Used for testing to make
* sure that server is validating the session ID.
@@ -120,6 +193,90 @@ private[sql] class SparkConnectClient(
}
}
+ /**
+ * Try to compress the plan if it exceeds the threshold defined in the planCompressionOptions.
+ * Return the original plan if compression is disabled, not needed, or not effective.
+ */
+ private def tryCompressPlan(plan: proto.Plan): proto.Plan = {
+ def tryCompressMessage(
+ message: protobuf.Message,
+ opType: proto.Plan.CompressedOperation.OpType,
+ options: PlanCompressionOptions): Option[proto.Plan.CompressedOperation] = {
+ val serialized = message.toByteArray
+ if (serialized.length > options.thresholdBytes) {
+ try {
+ import com.github.luben.zstd.Zstd
+
+ val startTime = System.nanoTime()
+ val compressed = Zstd.compress(serialized)
+ val duration = (System.nanoTime() - startTime) / 1e9
+ val savingRatio = 1 - compressed.length.toDouble / serialized.length
+ logDebug(
+ log"Plan compression: original_size=${MDC(SIZE, serialized.length)}, " +
+ log"compressed_size=${MDC(SIZE, compressed.length)}, " +
+ log"saving_ratio=${MDC(RATIO, savingRatio)}, " +
+ log"duration_s=${MDC(TIME, duration)}")
+ if (compressed.length < serialized.length) {
+ return Some(
+ proto.Plan.CompressedOperation
+ .newBuilder()
+ .setData(ByteString.copyFrom(compressed))
+ .setOpType(opType)
+ .setCompressionCodec(proto.CompressionCodec.COMPRESSION_CODEC_ZSTD)
+ .build())
+ } else {
+ logDebug(log"Plan compression not effective. Using original plan.")
+ }
+ } catch {
+ case _: NoClassDefFoundError | _: ClassNotFoundException =>
+ logInfo(log"Zstd library not available. Disabling plan compression.")
+ setPlanCompressionOptions(None)
+ case NonFatal(e) =>
+ logWarning(
+ log"Failed to compress plan: ${MDC(ERROR, e.getMessage)}. Using original " +
+ log"plan and disabling plan compression.")
+ setPlanCompressionOptions(None)
+ }
+ }
+ None
+ }
+
+ def maybeCompressPlan(
+ plan: proto.Plan,
+ message: protobuf.Message,
+ opType: proto.Plan.CompressedOperation.OpType,
+ clearFn: proto.Plan.Builder => proto.Plan.Builder,
+ options: PlanCompressionOptions): proto.Plan = {
+ tryCompressMessage(message, opType, options) match {
+ case Some(compressedOperation) =>
+ clearFn(proto.Plan.newBuilder(plan)).setCompressedOperation(compressedOperation).build()
+ case None => plan
+ }
+ }
+
+ getPlanCompressionOptions match {
+ case Some(options) if options.algorithm == "ZSTD" && options.thresholdBytes >= 0 =>
+ plan.getOpTypeCase match {
+ case proto.Plan.OpTypeCase.ROOT =>
+ maybeCompressPlan(
+ plan,
+ plan.getRoot,
+ proto.Plan.CompressedOperation.OpType.OP_TYPE_RELATION,
+ _.clearRoot(),
+ options)
+ case proto.Plan.OpTypeCase.COMMAND =>
+ maybeCompressPlan(
+ plan,
+ plan.getCommand,
+ proto.Plan.CompressedOperation.OpType.OP_TYPE_COMMAND,
+ _.clearCommand(),
+ options)
+ case _ => plan
+ }
+ case _ => plan
+ }
+ }
+
/**
* Execute the plan and return response iterator.
*
@@ -131,25 +288,46 @@ private[sql] class SparkConnectClient(
plan: proto.Plan,
operationId: Option[String] = None): CloseableIterator[proto.ExecutePlanResponse] = {
artifactManager.uploadAllClassFileArtifacts()
- val request = proto.ExecutePlanRequest
- .newBuilder()
- .setPlan(plan)
- .setUserContext(userContext)
- .setSessionId(sessionId)
- .setClientType(userAgent)
- .addAllTags(tags.get.toSeq.asJava)
- serverSideSessionId.foreach(session => request.setClientObservedServerSideSessionId(session))
- operationId.foreach { opId =>
- require(
- isValidUUID(opId),
- s"Invalid operationId: $opId. The id must be an UUID string of " +
- "the format `00112233-4455-6677-8899-aabbccddeeff`")
- request.setOperationId(opId)
- }
- if (configuration.useReattachableExecute) {
- bstub.executePlanReattachable(request.build())
- } else {
- bstub.executePlan(request.build())
+ handlePlanCompressionErrors {
+ // Compress the plan if needed.
+ val maybeCompressedPlan = tryCompressPlan(plan)
+ val request = proto.ExecutePlanRequest
+ .newBuilder()
+ .setPlan(maybeCompressedPlan)
+ .setUserContext(userContext)
+ .setSessionId(sessionId)
+ .setClientType(userAgent)
+ .addAllTags(tags.get.toSeq.asJava)
+
+ // Add request option to allow result chunking.
+ if (configuration.allowArrowBatchChunking) {
+ val chunkingOptionsBuilder = proto.ResultChunkingOptions
+ .newBuilder()
+ .setAllowArrowBatchChunking(true)
+ configuration.preferredArrowChunkSize.foreach { size =>
+ chunkingOptionsBuilder.setPreferredArrowChunkSize(size)
+ }
+ request.addRequestOptions(
+ proto.ExecutePlanRequest.RequestOption
+ .newBuilder()
+ .setResultChunkingOptions(chunkingOptionsBuilder.build())
+ .build())
+ }
+
+ serverSideSessionId.foreach(session =>
+ request.setClientObservedServerSideSessionId(session))
+ operationId.foreach { opId =>
+ require(
+ isValidUUID(opId),
+ s"Invalid operationId: $opId. The id must be an UUID string of " +
+ "the format `00112233-4455-6677-8899-aabbccddeeff`")
+ request.setOperationId(opId)
+ }
+ if (configuration.useReattachableExecute) {
+ bstub.executePlanReattachable(request.build())
+ } else {
+ bstub.executePlan(request.build())
+ }
}
}
@@ -180,71 +358,87 @@ private[sql] class SparkConnectClient(
plan: Option[proto.Plan] = None,
explainMode: Option[proto.AnalyzePlanRequest.Explain.ExplainMode] = None)
: proto.AnalyzePlanResponse = {
- val builder = proto.AnalyzePlanRequest.newBuilder()
- method match {
- case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA =>
- assert(plan.isDefined)
- builder.setSchema(
- proto.AnalyzePlanRequest.Schema
- .newBuilder()
- .setPlan(plan.get)
- .build())
- case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN =>
- if (explainMode.isEmpty) {
- throw new IllegalArgumentException(s"ExplainMode is required in Explain request")
- }
- assert(plan.isDefined)
- builder.setExplain(
- proto.AnalyzePlanRequest.Explain
- .newBuilder()
- .setPlan(plan.get)
- .setExplainMode(explainMode.get)
- .build())
- case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL =>
- assert(plan.isDefined)
- builder.setIsLocal(
- proto.AnalyzePlanRequest.IsLocal
- .newBuilder()
- .setPlan(plan.get)
- .build())
- case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING =>
- assert(plan.isDefined)
- builder.setIsStreaming(
- proto.AnalyzePlanRequest.IsStreaming
- .newBuilder()
- .setPlan(plan.get)
- .build())
- case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES =>
- assert(plan.isDefined)
- builder.setInputFiles(
- proto.AnalyzePlanRequest.InputFiles
- .newBuilder()
- .setPlan(plan.get)
- .build())
- case proto.AnalyzePlanRequest.AnalyzeCase.SPARK_VERSION =>
- builder.setSparkVersion(proto.AnalyzePlanRequest.SparkVersion.newBuilder().build())
- case other => throw new IllegalArgumentException(s"Unknown Analyze request $other")
+ handlePlanCompressionErrors {
+ val builder = proto.AnalyzePlanRequest.newBuilder()
+ // Compress the plan if needed.
+ val maybeCompressedPlan = plan match {
+ case Some(p) => Some(tryCompressPlan(p))
+ case None => None
+ }
+ method match {
+ case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA =>
+ assert(maybeCompressedPlan.isDefined)
+ builder.setSchema(
+ proto.AnalyzePlanRequest.Schema
+ .newBuilder()
+ .setPlan(maybeCompressedPlan.get)
+ .build())
+ case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN =>
+ if (explainMode.isEmpty) {
+ throw new IllegalArgumentException(s"ExplainMode is required in Explain request")
+ }
+ assert(maybeCompressedPlan.isDefined)
+ builder.setExplain(
+ proto.AnalyzePlanRequest.Explain
+ .newBuilder()
+ .setPlan(maybeCompressedPlan.get)
+ .setExplainMode(explainMode.get)
+ .build())
+ case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL =>
+ assert(maybeCompressedPlan.isDefined)
+ builder.setIsLocal(
+ proto.AnalyzePlanRequest.IsLocal
+ .newBuilder()
+ .setPlan(maybeCompressedPlan.get)
+ .build())
+ case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING =>
+ assert(maybeCompressedPlan.isDefined)
+ builder.setIsStreaming(
+ proto.AnalyzePlanRequest.IsStreaming
+ .newBuilder()
+ .setPlan(maybeCompressedPlan.get)
+ .build())
+ case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES =>
+ assert(maybeCompressedPlan.isDefined)
+ builder.setInputFiles(
+ proto.AnalyzePlanRequest.InputFiles
+ .newBuilder()
+ .setPlan(maybeCompressedPlan.get)
+ .build())
+ case proto.AnalyzePlanRequest.AnalyzeCase.SPARK_VERSION =>
+ builder.setSparkVersion(proto.AnalyzePlanRequest.SparkVersion.newBuilder().build())
+ case other => throw new IllegalArgumentException(s"Unknown Analyze request $other")
+ }
+ analyze(builder)
}
- analyze(builder)
}
def sameSemantics(plan: proto.Plan, otherPlan: proto.Plan): proto.AnalyzePlanResponse = {
- val builder = proto.AnalyzePlanRequest.newBuilder()
- builder.setSameSemantics(
- proto.AnalyzePlanRequest.SameSemantics
- .newBuilder()
- .setTargetPlan(plan)
- .setOtherPlan(otherPlan))
- analyze(builder)
+ handlePlanCompressionErrors {
+ val builder = proto.AnalyzePlanRequest.newBuilder()
+ // Compress the plan if needed.
+ val maybeCompressedPlan = tryCompressPlan(plan)
+ val otherMaybeCompressedPlan = tryCompressPlan(otherPlan)
+ builder.setSameSemantics(
+ proto.AnalyzePlanRequest.SameSemantics
+ .newBuilder()
+ .setTargetPlan(maybeCompressedPlan)
+ .setOtherPlan(otherMaybeCompressedPlan))
+ analyze(builder)
+ }
}
def semanticHash(plan: proto.Plan): proto.AnalyzePlanResponse = {
- val builder = proto.AnalyzePlanRequest.newBuilder()
- builder.setSemanticHash(
- proto.AnalyzePlanRequest.SemanticHash
- .newBuilder()
- .setPlan(plan))
- analyze(builder)
+ handlePlanCompressionErrors {
+ val builder = proto.AnalyzePlanRequest.newBuilder()
+ // Compress the plan if needed.
+ val maybeCompressedPlan = tryCompressPlan(plan)
+ builder.setSemanticHash(
+ proto.AnalyzePlanRequest.SemanticHash
+ .newBuilder()
+ .setPlan(maybeCompressedPlan))
+ analyze(builder)
+ }
}
private[sql] def analyze(
@@ -332,6 +526,16 @@ private[sql] class SparkConnectClient(
def copy(): SparkConnectClient = configuration.toSparkConnectClient
+ /**
+ * Returns whether arrow batch chunking is allowed.
+ */
+ def allowArrowBatchChunking: Boolean = configuration.allowArrowBatchChunking
+
+ /**
+ * Returns the preferred arrow chunk size in bytes.
+ */
+ def preferredArrowChunkSize: Option[Int] = configuration.preferredArrowChunkSize
+
/**
* Add a single artifact to the client session.
*
@@ -402,26 +606,6 @@ private[sql] class SparkConnectClient(
channel.shutdownNow()
}
- /**
- * Cache the given local relation Arrow stream from a local file and return its hashes. The file
- * is streamed in chunks and does not need to fit in memory.
- *
- * This method batches artifact status checks and uploads to minimize RPC overhead.
- */
- private[sql] def cacheLocalRelation(
- data: Array[Array[Byte]],
- schema: String): (Seq[String], String) = {
- val schemaBytes = schema.getBytes
- val allBlobs = data :+ schemaBytes
- val allHashes = artifactManager.cacheArtifacts(allBlobs)
-
- // Last hash is the schema hash, rest are data hashes
- val dataHashes = allHashes.dropRight(1)
- val schemaHash = allHashes.last
-
- (dataHashes, schemaHash)
- }
-
/**
* Clone this client session, creating a new session with the same configuration and shared
* state as the current session but with independent runtime state.
@@ -473,6 +657,9 @@ private[sql] class SparkConnectClient(
}
}
+// Options for plan compression
+case class PlanCompressionOptions(thresholdBytes: Int, algorithm: String)
+
object SparkConnectClient {
private[sql] val SPARK_REMOTE: String = "SPARK_REMOTE"
@@ -757,6 +944,21 @@ object SparkConnectClient {
this
}
+ def allowArrowBatchChunking(allow: Boolean): Builder = {
+ _configuration = _configuration.copy(allowArrowBatchChunking = allow)
+ this
+ }
+
+ def allowArrowBatchChunking: Boolean = _configuration.allowArrowBatchChunking
+
+ def preferredArrowChunkSize(size: Option[Int]): Builder = {
+ size.foreach(s => require(s > 0, "preferredArrowChunkSize must be positive"))
+ _configuration = _configuration.copy(preferredArrowChunkSize = size)
+ this
+ }
+
+ def preferredArrowChunkSize: Option[Int] = _configuration.preferredArrowChunkSize
+
def build(): SparkConnectClient = _configuration.toSparkConnectClient
}
@@ -801,7 +1003,9 @@ object SparkConnectClient {
interceptors: List[ClientInterceptor] = List.empty,
sessionId: Option[String] = None,
grpcMaxMessageSize: Int = ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE,
- grpcMaxRecursionLimit: Int = ConnectCommon.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT) {
+ grpcMaxRecursionLimit: Int = ConnectCommon.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT,
+ allowArrowBatchChunking: Boolean = true,
+ preferredArrowChunkSize: Option[Int] = None) {
private def isLocal = host.equals("localhost")
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index ef55edd10c8a3..4199801d8505c 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -16,33 +16,37 @@
*/
package org.apache.spark.sql.connect.client
+import java.io.SequenceInputStream
import java.lang.ref.Cleaner
import java.util.Objects
import scala.collection.mutable
import scala.jdk.CollectionConverters._
+import com.google.protobuf.ByteString
import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector.ipc.message.{ArrowMessage, ArrowRecordBatch}
import org.apache.arrow.vector.types.pojo
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse.ObservedMetrics
+import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, UnboundRowEncoder}
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
-import org.apache.spark.sql.connect.client.arrow.{AbstractMessageIterator, ArrowDeserializingIterator, ConcatenatingArrowStreamReader, MessageIterator}
+import org.apache.spark.sql.connect.client.arrow.ArrowDeserializingIterator
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, LiteralValueProtoConverter}
import org.apache.spark.sql.types.{DataType, StructType}
-import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.util.{AbstractMessageIterator, ArrowUtils, CloseableIterator, ConcatenatingArrowStreamReader, MessageIterator}
private[sql] class SparkResult[T](
responses: CloseableIterator[proto.ExecutePlanResponse],
allocator: BufferAllocator,
encoder: AgnosticEncoder[T],
timeZoneId: String)
- extends AutoCloseable { self =>
+ extends AutoCloseable
+ with Logging { self =>
case class StageInfo(
stageId: Long,
@@ -118,6 +122,7 @@ private[sql] class SparkResult[T](
stopOnFirstNonEmptyResponse: Boolean = false): Boolean = {
var nonEmpty = false
var stop = false
+ val arrowBatchChunksToAssemble = mutable.Buffer.empty[ByteString]
while (!stop && responses.hasNext) {
val response = responses.next()
@@ -151,55 +156,96 @@ private[sql] class SparkResult[T](
stop |= stopOnSchema
}
if (response.hasArrowBatch) {
- val ipcStreamBytes = response.getArrowBatch.getData
- val expectedNumRows = response.getArrowBatch.getRowCount
- val reader = new MessageIterator(ipcStreamBytes.newInput(), allocator)
- if (arrowSchema == null) {
- arrowSchema = reader.schema
- stop |= stopOnArrowSchema
- } else if (arrowSchema != reader.schema) {
- throw new IllegalStateException(
- s"""Schema Mismatch between expected and received schema:
- |=== Expected Schema ===
- |$arrowSchema
- |=== Received Schema ===
- |${reader.schema}
- |""".stripMargin)
- }
- if (structType == null) {
- // If the schema is not available yet, fallback to the arrow schema.
- structType = ArrowUtils.fromArrowSchema(reader.schema)
- }
- if (response.getArrowBatch.hasStartOffset) {
- val expectedStartOffset = response.getArrowBatch.getStartOffset
- if (numRecords != expectedStartOffset) {
+ val arrowBatch = response.getArrowBatch
+ logDebug(
+ s"Received arrow batch rows=${arrowBatch.getRowCount} " +
+ s"Number of chunks in batch=${arrowBatch.getNumChunksInBatch} " +
+ s"Chunk index=${arrowBatch.getChunkIndex} " +
+ s"size=${arrowBatch.getData.size()}")
+
+ if (arrowBatchChunksToAssemble.nonEmpty) {
+ // Expect next chunk of the same batch
+ if (arrowBatch.getChunkIndex != arrowBatchChunksToAssemble.size) {
throw new IllegalStateException(
- s"Expected arrow batch to start at row offset $numRecords in results, " +
- s"but received arrow batch starting at offset $expectedStartOffset.")
+ s"Expected chunk index ${arrowBatchChunksToAssemble.size} of the " +
+ s"arrow batch but got ${arrowBatch.getChunkIndex}.")
}
- }
- var numRecordsInBatch = 0
- val messages = Seq.newBuilder[ArrowMessage]
- while (reader.hasNext) {
- val message = reader.next()
- message match {
- case batch: ArrowRecordBatch =>
- numRecordsInBatch += batch.getLength
- case _ =>
+ } else {
+ // Expect next batch
+ if (arrowBatch.hasStartOffset) {
+ val expectedStartOffset = arrowBatch.getStartOffset
+ if (numRecords != expectedStartOffset) {
+ throw new IllegalStateException(
+ s"Expected arrow batch to start at row offset $numRecords in results, " +
+ s"but received arrow batch starting at offset $expectedStartOffset.")
+ }
+ }
+ if (arrowBatch.getChunkIndex != 0) {
+ throw new IllegalStateException(
+ s"Expected chunk index 0 of the next arrow batch " +
+ s"but got ${arrowBatch.getChunkIndex}.")
}
- messages += message
- }
- if (numRecordsInBatch != expectedNumRows) {
- throw new IllegalStateException(
- s"Expected $expectedNumRows rows in arrow batch but got $numRecordsInBatch.")
}
- // Skip the entire result if it is empty.
- if (numRecordsInBatch > 0) {
- numRecords += numRecordsInBatch
- resultMap.put(nextResultIndex, (reader.bytesRead, messages.result()))
- nextResultIndex += 1
- nonEmpty |= true
- stop |= stopOnFirstNonEmptyResponse
+
+ arrowBatchChunksToAssemble += arrowBatch.getData
+
+ // Assemble the chunks to an arrow batch to process if
+ // (a) chunking is not enabled (numChunksInBatch is not set or is 0,
+ // in this case, it is the single chunk in the batch)
+ // (b) or the client has received all chunks of the batch.
+ if (!arrowBatch.hasNumChunksInBatch ||
+ arrowBatch.getNumChunksInBatch == 0 ||
+ arrowBatchChunksToAssemble.size == arrowBatch.getNumChunksInBatch) {
+
+ val numChunks = arrowBatchChunksToAssemble.size
+ val inputStreams =
+ arrowBatchChunksToAssemble.map(_.newInput()).iterator.asJavaEnumeration
+ val input = new SequenceInputStream(inputStreams)
+ arrowBatchChunksToAssemble.clear()
+ logDebug(s"Assembling arrow batch from $numChunks chunks.")
+
+ val expectedNumRows = arrowBatch.getRowCount
+ val reader = new MessageIterator(input, allocator)
+ if (arrowSchema == null) {
+ arrowSchema = reader.schema
+ stop |= stopOnArrowSchema
+ } else if (arrowSchema != reader.schema) {
+ throw new IllegalStateException(
+ s"""Schema Mismatch between expected and received schema:
+ |=== Expected Schema ===
+ |$arrowSchema
+ |=== Received Schema ===
+ |${reader.schema}
+ |""".stripMargin)
+ }
+ if (structType == null) {
+ // If the schema is not available yet, fallback to the arrow schema.
+ structType = ArrowUtils.fromArrowSchema(reader.schema)
+ }
+
+ var numRecordsInBatch = 0
+ val messages = Seq.newBuilder[ArrowMessage]
+ while (reader.hasNext) {
+ val message = reader.next()
+ message match {
+ case batch: ArrowRecordBatch =>
+ numRecordsInBatch += batch.getLength
+ case _ =>
+ }
+ messages += message
+ }
+ if (numRecordsInBatch != expectedNumRows) {
+ throw new IllegalStateException(
+ s"Expected $expectedNumRows rows in arrow batch but got $numRecordsInBatch.")
+ }
+ // Skip the entire result if it is empty.
+ if (numRecordsInBatch > 0) {
+ numRecords += numRecordsInBatch
+ resultMap.put(nextResultIndex, (reader.bytesRead, messages.result()))
+ nextResultIndex += 1
+ nonEmpty |= true
+ stop |= stopOnFirstNonEmptyResponse
+ }
}
}
}
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
index 7597a0ceeb8cd..82029025a7f0b 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
@@ -37,9 +37,9 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
-import org.apache.spark.sql.connect.client.CloseableIterator
import org.apache.spark.sql.errors.{CompilationErrors, ExecutionErrors}
import org.apache.spark.sql.types.Decimal
+import org.apache.spark.sql.util.{CloseableIterator, ConcatenatingArrowStreamReader, MessageIterator}
import org.apache.spark.unsafe.types.VariantVal
/**
@@ -341,6 +341,14 @@ object ArrowDeserializers {
}
}
+ case (_: GeometryEncoder, StructVectors(struct, vectors)) =>
+ val gdser = new GeometryArrowSerDe
+ gdser.createDeserializer(struct, vectors, timeZoneId)
+
+ case (_: GeographyEncoder, StructVectors(struct, vectors)) =>
+ val gdser = new GeographyArrowSerDe
+ gdser.createDeserializer(struct, vectors, timeZoneId)
+
case (VariantEncoder, StructVectors(struct, vectors)) =>
assert(vectors.exists(_.getName == "value"))
assert(
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
index 5b1539e39f4f4..2430c2bbc86fc 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderUtils.scala
@@ -41,6 +41,22 @@ private[arrow] object ArrowEncoderUtils {
def unsupportedCollectionType(cls: Class[_]): Nothing = {
throw new RuntimeException(s"Unsupported collection type: $cls")
}
+
+ def assertMetadataPresent(
+ vectors: Seq[FieldVector],
+ expectedVectors: Seq[String],
+ expectedMetadata: Seq[(String, String)]): Unit = {
+ expectedVectors.foreach { vectorName =>
+ assert(vectors.exists(_.getName == vectorName))
+ }
+
+ expectedVectors.zip(expectedMetadata).foreach { case (vectorName, (key, value)) =>
+ assert(
+ vectors.exists(field =>
+ field.getName == vectorName && field.getField.getMetadata
+ .containsKey(key) && field.getField.getMetadata.get(key) == value))
+ }
+ }
}
private[arrow] object StructVectors {
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
index 4acb11f014d19..d547c81afe5ad 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
@@ -38,10 +38,9 @@ import org.apache.spark.sql.catalyst.DefinedByConstructorParams
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, Codec}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
import org.apache.spark.sql.catalyst.util.{SparkDateTimeUtils, SparkIntervalUtils}
-import org.apache.spark.sql.connect.client.CloseableIterator
import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.types.Decimal
-import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.util.{ArrowUtils, CloseableIterator}
import org.apache.spark.unsafe.types.VariantVal
/**
@@ -487,6 +486,14 @@ object ArrowSerializer {
extractor = (v: Any) => v.asInstanceOf[VariantVal].getMetadata,
serializerFor(BinaryEncoder, struct.getChild("metadata")))))
+ case (_: GeographyEncoder, StructVectors(struct, vectors)) =>
+ val gser = new GeographyArrowSerDe
+ gser.createSerializer(struct, vectors)
+
+ case (_: GeometryEncoder, StructVectors(struct, vectors)) =>
+ val gser = new GeometryArrowSerDe
+ gser.createSerializer(struct, vectors)
+
case (JavaBeanEncoder(tag, fields), StructVectors(struct, vectors)) =>
structSerializerFor(fields, struct, vectors) { (field, _) =>
val getter = methodLookup.findVirtual(
@@ -585,12 +592,14 @@ object ArrowSerializer {
}
}
- private class StructFieldSerializer(val extractor: Any => Any, val serializer: Serializer) {
+ private[arrow] class StructFieldSerializer(
+ val extractor: Any => Any,
+ val serializer: Serializer) {
def write(index: Int, value: Any): Unit = serializer.write(index, extractor(value))
def writeNull(index: Int): Unit = serializer.write(index, null)
}
- private class StructSerializer(
+ private[arrow] class StructSerializer(
struct: StructVector,
fieldSerializers: Seq[StructFieldSerializer])
extends Serializer {
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/GeospatialArrowSerDe.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/GeospatialArrowSerDe.scala
new file mode 100644
index 0000000000000..443523ef02cdc
--- /dev/null
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/GeospatialArrowSerDe.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.client.arrow
+
+import org.apache.arrow.vector.FieldVector
+import org.apache.arrow.vector.complex.StructVector
+
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, PrimitiveIntEncoder}
+import org.apache.spark.sql.errors.CompilationErrors
+import org.apache.spark.sql.types.{Geography, Geometry}
+
+abstract class GeospatialArrowSerDe[T](typeName: String) {
+
+ def createDeserializer(
+ struct: StructVector,
+ vectors: Seq[FieldVector],
+ timeZoneId: String): ArrowDeserializers.StructFieldSerializer[T] = {
+ assertMetadataPresent(vectors)
+ val wkbDecoder = ArrowDeserializers.deserializerFor(
+ BinaryEncoder,
+ vectors
+ .find(_.getName == "wkb")
+ .getOrElse(throw CompilationErrors.columnNotFoundError("wkb")),
+ timeZoneId)
+ val sridDecoder = ArrowDeserializers.deserializerFor(
+ PrimitiveIntEncoder,
+ vectors
+ .find(_.getName == "srid")
+ .getOrElse(throw CompilationErrors.columnNotFoundError("srid")),
+ timeZoneId)
+ new ArrowDeserializers.StructFieldSerializer[T](struct) {
+ override def value(i: Int): T = createInstance(wkbDecoder.get(i), sridDecoder.get(i))
+ }
+ }
+
+ def createSerializer(
+ struct: StructVector,
+ vectors: Seq[FieldVector]): ArrowSerializer.StructSerializer = {
+ assertMetadataPresent(vectors)
+ new ArrowSerializer.StructSerializer(
+ struct,
+ Seq(
+ new ArrowSerializer.StructFieldSerializer(
+ extractor = (v: Any) => extractSrid(v),
+ ArrowSerializer.serializerFor(PrimitiveIntEncoder, struct.getChild("srid"))),
+ new ArrowSerializer.StructFieldSerializer(
+ extractor = (v: Any) => extractBytes(v),
+ ArrowSerializer.serializerFor(BinaryEncoder, struct.getChild("wkb")))))
+ }
+
+ private def assertMetadataPresent(vectors: Seq[FieldVector]): Unit = {
+ assert(vectors.exists(_.getName == "srid"))
+ assert(
+ vectors.exists(field =>
+ field.getName == "wkb" && field.getField.getMetadata
+ .containsKey(typeName) && field.getField.getMetadata.get(typeName) == "true"))
+ }
+
+ protected def createInstance(wkb: Any, srid: Any): T
+ protected def extractSrid(value: Any): Int
+ protected def extractBytes(value: Any): Array[Byte]
+}
+
+// Geography-specific implementation
+class GeographyArrowSerDe extends GeospatialArrowSerDe[Geography]("geography") {
+ override protected def createInstance(wkb: Any, srid: Any): Geography =
+ Geography.fromWKB(wkb.asInstanceOf[Array[Byte]], srid.asInstanceOf[Int])
+
+ override protected def extractSrid(value: Any): Int =
+ value.asInstanceOf[Geography].getSrid
+
+ override protected def extractBytes(value: Any): Array[Byte] =
+ value.asInstanceOf[Geography].getBytes
+}
+
+// Geometry-specific implementation
+class GeometryArrowSerDe extends GeospatialArrowSerDe[Geometry]("geometry") {
+ override protected def createInstance(wkb: Any, srid: Any): Geometry =
+ Geometry.fromWKB(wkb.asInstanceOf[Array[Byte]], srid.asInstanceOf[Int])
+
+ override protected def extractSrid(value: Any): Int =
+ value.asInstanceOf[Geometry].getSrid
+
+ override protected def extractBytes(value: Any): Array[Byte] =
+ value.asInstanceOf[Geometry].getBytes
+}
diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
index 419cc8e082af2..ac69f084c307b 100644
--- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
+++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/common/DataTypeProtoConverter.scala
@@ -71,6 +71,21 @@ object DataTypeProtoConverter {
case proto.DataType.KindCase.MAP => toCatalystMapType(t.getMap)
case proto.DataType.KindCase.VARIANT => VariantType
+ case proto.DataType.KindCase.GEOMETRY =>
+ val srid = t.getGeometry.getSrid
+ if (srid == GeometryType.MIXED_SRID) {
+ GeometryType("ANY")
+ } else {
+ GeometryType(srid)
+ }
+ case proto.DataType.KindCase.GEOGRAPHY =>
+ val srid = t.getGeography.getSrid
+ if (srid == GeographyType.MIXED_SRID) {
+ GeographyType("ANY")
+ } else {
+ GeographyType(srid)
+ }
+
case proto.DataType.KindCase.UDT => toCatalystUDT(t.getUdt)
case _ =>
@@ -307,6 +322,26 @@ object DataTypeProtoConverter {
.build())
.build()
+ case g: GeographyType =>
+ proto.DataType
+ .newBuilder()
+ .setGeography(
+ proto.DataType.Geography
+ .newBuilder()
+ .setSrid(g.srid)
+ .build())
+ .build()
+
+ case g: GeometryType =>
+ proto.DataType
+ .newBuilder()
+ .setGeometry(
+ proto.DataType.Geometry
+ .newBuilder()
+ .setSrid(g.srid)
+ .build())
+ .build()
+
case VariantType => ProtoDataTypes.VariantType
case pyudt: PythonUserDefinedType =>
diff --git a/sql/connect/server/pom.xml b/sql/connect/server/pom.xml
index 5f99b5af3243f..25394b1c6cb88 100644
--- a/sql/connect/server/pom.xml
+++ b/sql/connect/server/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../../pom.xml
@@ -183,13 +183,11 @@
com.google.guava
guava
- ${connect.guava.version}
compile
com.google.guava
failureaccess
- ${guava.failureaccess.version}
compile
@@ -240,12 +238,6 @@
${netty.version}
provided
-
- org.apache.tomcat
- annotations-api
- ${tomcat.annotations.api.version}
- provided
-
org.scalacheck
scalacheck_${scala.binary.version}
@@ -292,7 +284,6 @@
false
- com.google.guava:*
io.grpc:*:
com.google.protobuf:*
@@ -304,28 +295,24 @@
com.google.api.grpc:proto-google-common-protos
io.perfmark:perfmark-api
org.codehaus.mojo:animal-sniffer-annotations
- com.google.errorprone:error_prone_annotations
- com.google.j2objc:j2objc-annotations
- org.checkerframework:checker-qual
com.google.code.gson:gson
org.apache.spark:spark-connect-common_${scala.binary.version}
+
com.google.common
- ${spark.shade.packageName}.connect.guava
-
- com.google.common.**
-
+ ${spark.shade.packageName}.guava
com.google.thirdparty
- ${spark.shade.packageName}.connect.guava
-
- com.google.thirdparty.**
-
+ ${spark.shade.packageName}.guava.thirdparty
+
com.google.protobuf
${spark.shade.packageName}.connect.protobuf
@@ -350,18 +337,6 @@
org.codehaus.mojo.animal_sniffer
${spark.shade.packageName}.connect.animal_sniffer
-
- com.google.j2objc.annotations
- ${spark.shade.packageName}.connect.j2objc_annotations
-
-
- com.google.errorprone.annotations
- ${spark.shade.packageName}.connect.errorprone_annotations
-
-
- org.checkerframework
- ${spark.shade.packageName}.connect.checkerframework
-
com.google.gson
${spark.shade.packageName}.connect.gson
@@ -376,6 +351,10 @@
com.google.api
${spark.shade.packageName}.connect.google_protos.api
+
+ com.google.apps
+ ${spark.shade.packageName}.connect.google_protos.apps
+
com.google.cloud
${spark.shade.packageName}.connect.google_protos.cloud
@@ -396,6 +375,10 @@
com.google.rpc
${spark.shade.packageName}.connect.google_protos.rpc
+
+ com.google.shopping
+ ${spark.shade.packageName}.connect.google_protos.shopping
+
com.google.type
${spark.shade.packageName}.connect.google_protos.type
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index c6049187f6be8..1df97d855678e 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.connect.config
+import java.util.Locale
import java.util.concurrent.TimeUnit
import org.apache.spark.SparkEnv
@@ -418,4 +419,35 @@ object Connect {
.bytesConf(ByteUnit.BYTE)
// 90% of the max message size by default to allow for some overhead.
.createWithDefault((ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE * 0.9).toInt)
+
+ private[spark] val CONNECT_MAX_PLAN_SIZE =
+ buildStaticConf("spark.connect.maxPlanSize")
+ .doc(
+ "The maximum size of a (decompressed) proto plan that can be executed in Spark " +
+ "Connect. If the size of the plan exceeds this limit, an error will be thrown. " +
+ "The size is in bytes.")
+ .version("4.1.0")
+ .internal()
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefault(512 * 1024 * 1024) // 512 MB
+
+ val CONNECT_SESSION_PLAN_COMPRESSION_THRESHOLD =
+ buildConf("spark.connect.session.planCompression.threshold")
+ .doc("The threshold in bytes for the size of proto plan to be compressed. " +
+ "If the size of proto plan is smaller than this threshold, it will not be compressed. " +
+ "Set to -1 to disable plan compression.")
+ .version("4.1.0")
+ .internal()
+ .intConf
+ .createWithDefault(10 * 1024 * 1024) // 10 MB
+
+ val CONNECT_PLAN_COMPRESSION_DEFAULT_ALGORITHM =
+ buildConf("spark.connect.session.planCompression.defaultAlgorithm")
+ .doc("The default algorithm of proto plan compression.")
+ .version("4.1.0")
+ .internal()
+ .stringConf
+ .transform(_.toUpperCase(Locale.ROOT))
+ .checkValues(ConnectPlanCompressionAlgorithm.values.map(_.toString))
+ .createWithDefault(ConnectPlanCompressionAlgorithm.ZSTD.toString)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/IgnoreCachedData.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/ConnectPlanCompressionAlgorithm.scala
similarity index 80%
rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/IgnoreCachedData.scala
rename to sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/ConnectPlanCompressionAlgorithm.scala
index 85958cb43d4f8..4052627fd8c3b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/IgnoreCachedData.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/config/ConnectPlanCompressionAlgorithm.scala
@@ -14,10 +14,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+package org.apache.spark.sql.connect.config
-package org.apache.spark.sql.catalyst.plans.logical
-
-/**
- * A [[LogicalPlan]] operator that does not use the cached results stored in CacheManager
- */
-trait IgnoreCachedData extends LogicalPlan {}
+object ConnectPlanCompressionAlgorithm extends Enumeration {
+ val ZSTD, NONE = Value
+}
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
index 38ed2528cbde0..f206ee1555a73 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.connect.service.{ExecuteHolder, ExecuteSessionTag, S
import org.apache.spark.sql.connect.utils.ErrorUtils
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils
+import org.apache.spark.util.Utils.CONNECT_EXECUTE_THREAD_PREFIX
/**
* This class launches the actual execution in an execution thread. The execution pushes the
@@ -329,7 +330,7 @@ private[connect] class ExecuteThreadRunner(executeHolder: ExecuteHolder) extends
}
private class ExecutionThread()
- extends Thread(s"SparkConnectExecuteThread_opId=${executeHolder.operationId}") {
+ extends Thread(s"${CONNECT_EXECUTE_THREAD_PREFIX}_opId=${executeHolder.operationId}") {
override def run(): Unit = execute()
}
}
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
index f5cb2696d849b..4b12c96e977e6 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
@@ -29,6 +29,7 @@ import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.classic.{DataFrame, Dataset}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
@@ -126,6 +127,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
val sessionId = executePlan.sessionHolder.sessionId
val spark = dataframe.sparkSession
val schema = dataframe.schema
+ TypeUtils.failUnsupportedDataType(schema, spark.sessionState.conf)
val maxRecordsPerBatch = spark.sessionState.conf.arrowMaxRecordsPerBatch
val timeZoneId = spark.sessionState.conf.sessionLocalTimeZone
val largeVarTypes = spark.sessionState.conf.arrowUseLargeVarTypes
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala
index e0c7beb43001d..e8114f38ec40c 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/DataflowGraphRegistry.scala
@@ -32,13 +32,14 @@ class DataflowGraphRegistry {
private val dataflowGraphs = new ConcurrentHashMap[String, GraphRegistrationContext]()
- /** Registers a DataflowGraph and generates a unique id to associate with the graph */
+ /**
+ * Registers a GraphRegistrationContext and generates a unique id to associate with the graph
+ */
def createDataflowGraph(
defaultCatalog: String,
defaultDatabase: String,
defaultSqlConf: Map[String, String]): String = {
val graphId = java.util.UUID.randomUUID().toString
- // TODO: propagate pipeline catalog and schema from pipeline spec here.
dataflowGraphs.put(
graphId,
new GraphRegistrationContext(defaultCatalog, defaultDatabase, defaultSqlConf))
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
index 7e69e546893e6..62f060014117c 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.connect.pipelines
+import scala.collection.Seq
import scala.jdk.CollectionConverters._
import scala.util.Using
@@ -27,9 +28,10 @@ import org.apache.spark.connect.proto.{ExecutePlanResponse, PipelineCommandResul
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{Command, CreateNamespace, CreateTable, CreateTableAsSelect, CreateView, DescribeRelation, DropView, InsertIntoStatement, LogicalPlan, RenameTable, ShowColumns, ShowCreateTable, ShowFunctions, ShowTableProperties, ShowTables, ShowViews}
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.service.SessionHolder
+import org.apache.spark.sql.execution.command.{ShowCatalogsCommand, ShowNamespacesCommand}
import org.apache.spark.sql.pipelines.Language.Python
import org.apache.spark.sql.pipelines.common.RunState.{CANCELED, FAILED}
import org.apache.spark.sql.pipelines.graph.{AllTables, FlowAnalysis, GraphIdentifierManager, GraphRegistrationContext, IdentifierHelper, NoTables, PipelineUpdateContextImpl, QueryContext, QueryOrigin, QueryOriginType, Sink, SinkImpl, SomeTables, SqlGraphRegistrationContext, Table, TableFilter, TemporaryView, UnresolvedFlow}
@@ -47,8 +49,6 @@ private[connect] object PipelinesHandler extends Logging {
* Command to be handled
* @param responseObserver
* The response observer where the response will be sent
- * @param sparkSession
- * The spark session
* @param transformRelationFunc
* Function used to convert a relation to a LogicalPlan. This is used when determining the
* LogicalPlan that a flow returns.
@@ -85,9 +85,7 @@ private[connect] object PipelinesHandler extends Logging {
defineOutput(cmd.getDefineOutput, sessionHolder)
val identifierBuilder = ResolvedIdentifier.newBuilder()
resolvedDataset.catalog.foreach(identifierBuilder.setCatalogName)
- resolvedDataset.database.foreach { ns =>
- identifierBuilder.addNamespace(ns)
- }
+ resolvedDataset.database.foreach(identifierBuilder.addNamespace)
identifierBuilder.setTableName(resolvedDataset.identifier)
val identifier = identifierBuilder.build()
PipelineCommandResult
@@ -114,7 +112,7 @@ private[connect] object PipelinesHandler extends Logging {
.setDefineFlowResult(
PipelineCommandResult.DefineFlowResult
.newBuilder()
- .setResolvedIdentifier(identifierBuilder)
+ .setResolvedIdentifier(identifier)
.build())
.build()
case proto.PipelineCommand.CommandTypeCase.START_RUN =>
@@ -129,25 +127,78 @@ private[connect] object PipelinesHandler extends Logging {
}
}
+ /**
+ * Block SQL commands that have side effects or modify data.
+ *
+ * Pipeline definitions should be declarative and side-effect free. This prevents users from
+ * inadvertently modifying catalogs, creating tables, or performing other stateful operations
+ * outside the pipeline API boundary during pipeline registration or analysis.
+ *
+ * This is a best-effort approach: we block known problematic commands while allowing a curated
+ * set of read-only operations (e.g., SHOW, DESCRIBE).
+ */
+ def blockUnsupportedSqlCommand(queryPlan: LogicalPlan): Unit = {
+ val allowlistedCommands = Set(
+ classOf[DescribeRelation],
+ classOf[ShowTables],
+ classOf[ShowTableProperties],
+ classOf[ShowNamespacesCommand],
+ classOf[ShowColumns],
+ classOf[ShowFunctions],
+ classOf[ShowViews],
+ classOf[ShowCatalogsCommand],
+ classOf[ShowCreateTable])
+ val isSqlCommandExplicitlyAllowlisted = allowlistedCommands.exists(_.isInstance(queryPlan))
+ val isUnsupportedSqlPlan = if (isSqlCommandExplicitlyAllowlisted) {
+ false
+ } else {
+ // Disable all [[Command]] except the ones that are explicitly allowlisted
+ // in "allowlistedCommands".
+ queryPlan.isInstanceOf[Command] ||
+ // Following commands are not subclasses of [[Command]] but have side effects.
+ queryPlan.isInstanceOf[CreateTableAsSelect] ||
+ queryPlan.isInstanceOf[CreateTable] ||
+ queryPlan.isInstanceOf[CreateView] ||
+ queryPlan.isInstanceOf[InsertIntoStatement] ||
+ queryPlan.isInstanceOf[RenameTable] ||
+ queryPlan.isInstanceOf[CreateNamespace] ||
+ queryPlan.isInstanceOf[DropView]
+ }
+ if (isUnsupportedSqlPlan) {
+ throw new AnalysisException(
+ "UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND",
+ Map("command" -> queryPlan.getClass.getSimpleName))
+ }
+ }
+
private def createDataflowGraph(
cmd: proto.PipelineCommand.CreateDataflowGraph,
sessionHolder: SessionHolder): String = {
val defaultCatalog = Option
.when(cmd.hasDefaultCatalog)(cmd.getDefaultCatalog)
.getOrElse {
- logInfo(s"No default catalog was supplied. Falling back to the current catalog.")
- sessionHolder.session.catalog.currentCatalog()
+ val currentCatalog = sessionHolder.session.catalog.currentCatalog()
+ logInfo(
+ "No default catalog was supplied. " +
+ s"Falling back to the current catalog: $currentCatalog.")
+ currentCatalog
}
val defaultDatabase = Option
.when(cmd.hasDefaultDatabase)(cmd.getDefaultDatabase)
.getOrElse {
- logInfo(s"No default database was supplied. Falling back to the current database.")
- sessionHolder.session.catalog.currentDatabase
+ val currentDatabase = sessionHolder.session.catalog.currentDatabase
+ logInfo(
+ "No default database was supplied. " +
+ s"Falling back to the current database: $currentDatabase.")
+ currentDatabase
}
val defaultSqlConf = cmd.getSqlConfMap.asScala.toMap
+ sessionHolder.session.catalog.setCurrentCatalog(defaultCatalog)
+ sessionHolder.session.catalog.setCurrentDatabase(defaultDatabase)
+
sessionHolder.dataflowGraphRegistry.createDataflowGraph(
defaultCatalog = defaultCatalog,
defaultDatabase = defaultDatabase,
@@ -203,6 +254,8 @@ private[connect] object PipelinesHandler extends Logging {
},
partitionCols = Option(tableDetails.getPartitionColsList.asScala.toSeq)
.filter(_.nonEmpty),
+ clusterCols = Option(tableDetails.getClusteringColumnsList.asScala.toSeq)
+ .filter(_.nonEmpty),
properties = tableDetails.getTablePropertiesMap.asScala.toMap,
origin = QueryOrigin(
filePath = Option.when(output.getSourceCodeLocation.hasFileName)(
@@ -229,18 +282,15 @@ private[connect] object PipelinesHandler extends Logging {
output.getSourceCodeLocation.getFileName),
line = Option.when(output.getSourceCodeLocation.hasLineNumber)(
output.getSourceCodeLocation.getLineNumber),
- objectType = Option(QueryOriginType.View.toString),
+ objectType = Some(QueryOriginType.View.toString),
objectName = Option(viewIdentifier.unquotedString),
- language = Option(Python())),
+ language = Some(Python())),
properties = Map.empty,
sqlText = None))
viewIdentifier
case proto.OutputType.SINK =>
- val dataflowGraphId = output.getDataflowGraphId
- val graphElementRegistry =
- sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
val identifier = GraphIdentifierManager
- .parseTableIdentifier(name = output.getOutputName, spark = sessionHolder.session)
+ .parseTableIdentifier(output.getOutputName, sessionHolder.session)
val sinkDetails = output.getSinkDetails
graphElementRegistry.registerSink(
SinkImpl(
@@ -254,7 +304,7 @@ private[connect] object PipelinesHandler extends Logging {
output.getSourceCodeLocation.getLineNumber),
objectType = Option(QueryOriginType.Sink.toString),
objectName = Option(identifier.unquotedString),
- language = Option(Python()))))
+ language = Some(Python()))))
identifier
case _ =>
throw new IllegalArgumentException(s"Unknown output type: ${output.getOutputType}")
@@ -265,6 +315,11 @@ private[connect] object PipelinesHandler extends Logging {
flow: proto.PipelineCommand.DefineFlow,
transformRelationFunc: Relation => LogicalPlan,
sessionHolder: SessionHolder): TableIdentifier = {
+ if (flow.hasOnce) {
+ throw new AnalysisException(
+ "DEFINE_FLOW_ONCE_OPTION_NOT_SUPPORTED",
+ Map("flowName" -> flow.getFlowName))
+ }
val dataflowGraphId = flow.getDataflowGraphId
val graphElementRegistry =
sessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId)
@@ -286,8 +341,7 @@ private[connect] object PipelinesHandler extends Logging {
val rawDestinationIdentifier = GraphIdentifierManager
.parseTableIdentifier(name = flow.getTargetDatasetName, spark = sessionHolder.session)
val flowWritesToView =
- graphElementRegistry
- .getViews()
+ graphElementRegistry.getViews
.filter(_.isInstanceOf[TemporaryView])
.exists(_.identifier == rawDestinationIdentifier)
val flowWritesToSink =
@@ -297,7 +351,7 @@ private[connect] object PipelinesHandler extends Logging {
// If the flow is created implicitly as part of defining a view or that it writes to a sink,
// then we do not qualify the flow identifier and the flow destination. This is because
// views and sinks are not permitted to have multipart
- val isImplicitFlowForTempView = (isImplicitFlow && flowWritesToView)
+ val isImplicitFlowForTempView = isImplicitFlow && flowWritesToView
val Seq(flowIdentifier, destinationIdentifier) =
Seq(rawFlowIdentifier, rawDestinationIdentifier).map { rawIdentifier =>
if (isImplicitFlowForTempView || flowWritesToSink) {
@@ -314,7 +368,7 @@ private[connect] object PipelinesHandler extends Logging {
val relationFlowDetails = flow.getRelationFlowDetails
graphElementRegistry.registerFlow(
- new UnresolvedFlow(
+ UnresolvedFlow(
identifier = flowIdentifier,
destinationIdentifier = destinationIdentifier,
func = FlowAnalysis.createFlowFunctionFromLogicalPlan(
@@ -327,9 +381,9 @@ private[connect] object PipelinesHandler extends Logging {
flow.getSourceCodeLocation.getFileName),
line = Option.when(flow.getSourceCodeLocation.hasLineNumber)(
flow.getSourceCodeLocation.getLineNumber),
- objectType = Option(QueryOriginType.Flow.toString),
+ objectType = Some(QueryOriginType.Flow.toString),
objectName = Option(flowIdentifier.unquotedString),
- language = Option(Python()))))
+ language = Some(Python()))))
flowIdentifier
}
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/InvalidInputErrors.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/InvalidInputErrors.scala
index fcef696c88afc..eb4df9673e594 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/InvalidInputErrors.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/InvalidInputErrors.scala
@@ -96,9 +96,6 @@ object InvalidInputErrors {
def chunkedCachedLocalRelationWithoutData(): InvalidPlanInput =
InvalidPlanInput("ChunkedCachedLocalRelation should contain data.")
- def chunkedCachedLocalRelationChunksWithDifferentSchema(): InvalidPlanInput =
- InvalidPlanInput("ChunkedCachedLocalRelation data chunks have different schema.")
-
def schemaRequiredForLocalRelation(): InvalidPlanInput =
InvalidPlanInput("Schema for LocalRelation is required when the input data is not provided.")
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 8f8e6261066f4..9cbb760f6cc0e 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -19,22 +19,22 @@ package org.apache.spark.sql.connect.planner
import java.util.{HashMap, Properties, UUID}
-import scala.collection.immutable.ArraySeq
import scala.collection.mutable
import scala.jdk.CollectionConverters._
+import scala.reflect.ClassTag
import scala.util.Try
import scala.util.control.NonFatal
import com.google.common.collect.Lists
-import com.google.protobuf.{Any => ProtoAny, ByteString}
+import com.google.protobuf.{Any => ProtoAny, ByteString, Message}
import io.grpc.{Context, Status, StatusRuntimeException}
import io.grpc.stub.StreamObserver
-import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException, TaskContext}
+import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException}
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
import org.apache.spark.connect.proto
-import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult}
+import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, PipelineAnalysisContext, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult}
import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
@@ -1492,9 +1492,12 @@ class SparkConnectPlanner(
}
if (rel.hasData) {
- val (rows, structType) =
- ArrowConverters.fromIPCStream(rel.getData.toByteArray, TaskContext.get())
- buildLocalRelationFromRows(rows, structType, Option(schema))
+ val (rows, structType) = ArrowConverters.fromIPCStream(rel.getData.toByteArray)
+ try {
+ buildLocalRelationFromRows(rows, structType, Option(schema))
+ } finally {
+ rows.close()
+ }
} else {
if (schema == null) {
throw InvalidInputErrors.schemaRequiredForLocalRelation()
@@ -1565,28 +1568,13 @@ class SparkConnectPlanner(
}
// Load and combine all batches
- var combinedRows: Iterator[InternalRow] = Iterator.empty
- var structType: StructType = null
-
- for ((dataHash, batchIndex) <- dataHashes.zipWithIndex) {
- val dataBytes = readChunkedCachedLocalRelationBlock(dataHash)
- val (batchRows, batchStructType) =
- ArrowConverters.fromIPCStream(dataBytes, TaskContext.get())
-
- // For the first batch, set the schema; for subsequent batches, verify compatibility
- if (batchIndex == 0) {
- structType = batchStructType
- combinedRows = batchRows
-
- } else {
- if (batchStructType != structType) {
- throw InvalidInputErrors.chunkedCachedLocalRelationChunksWithDifferentSchema()
- }
- combinedRows = combinedRows ++ batchRows
- }
+ val (rows, structType) =
+ ArrowConverters.fromIPCStream(dataHashes.iterator.map(readChunkedCachedLocalRelationBlock))
+ try {
+ buildLocalRelationFromRows(rows, structType, Option(schema))
+ } finally {
+ rows.close()
}
-
- buildLocalRelationFromRows(combinedRows, structType, Option(schema))
}
private def toStructTypeOrWrap(dt: DataType): StructType = dt match {
@@ -1608,7 +1596,7 @@ class SparkConnectPlanner(
schemaOpt match {
case None =>
- logical.LocalRelation(attributes, ArraySeq.unsafeWrapArray(data.map(_.copy()).toArray))
+ logical.LocalRelation(attributes, data.map(_.copy()).toArray.toImmutableArraySeq)
case Some(schema) =>
def normalize(dt: DataType): DataType = dt match {
case udt: UserDefinedType[_] => normalize(udt.sqlType)
@@ -2942,10 +2930,28 @@ class SparkConnectPlanner(
.build())
}
+ private def getExtensionList[T <: Message: ClassTag](
+ extensions: mutable.Buffer[ProtoAny]): Seq[T] = {
+ val cls = implicitly[ClassTag[T]].runtimeClass
+ .asInstanceOf[Class[_ <: Message]]
+ extensions.collect {
+ case any if any.is(cls) => any.unpack(cls).asInstanceOf[T]
+ }.toSeq
+ }
+
private def handleSqlCommand(
command: SqlCommand,
responseObserver: StreamObserver[ExecutePlanResponse]): Unit = {
val tracker = executeHolder.eventsManager.createQueryPlanningTracker()
+ val userContextExtensions = executeHolder.request.getUserContext.getExtensionsList.asScala
+ val pipelineAnalysisContextList = {
+ getExtensionList[PipelineAnalysisContext](userContextExtensions)
+ }
+ val hasPipelineAnalysisContext = pipelineAnalysisContextList.nonEmpty
+ val insidePipelineFlowFunction = pipelineAnalysisContextList.exists(_.hasFlowName)
+ // To avoid explicit handling of the result on the client, we build the expected input
+ // of the relation on the server. The client has to simply forward the result.
+ val result = SqlCommandResult.newBuilder()
val relation = if (command.hasInput) {
command.getInput
@@ -2965,6 +2971,26 @@ class SparkConnectPlanner(
.build()
}
+ // Block unsupported SQL commands if the request comes from Spark Declarative Pipelines.
+ if (hasPipelineAnalysisContext) {
+ PipelinesHandler.blockUnsupportedSqlCommand(queryPlan = transformRelation(relation))
+ }
+
+ // If the spark.sql() is called inside a pipeline flow function, we don't need to execute
+ // the SQL command and defer the actual analysis and execution to the flow function.
+ if (insidePipelineFlowFunction) {
+ result.setRelation(relation)
+ executeHolder.eventsManager.postFinished()
+ responseObserver.onNext(
+ ExecutePlanResponse
+ .newBuilder()
+ .setSessionId(sessionHolder.sessionId)
+ .setServerSideSessionId(sessionHolder.serverSessionId)
+ .setSqlCommandResult(result)
+ .build)
+ return
+ }
+
val df = relation.getRelTypeCase match {
case proto.Relation.RelTypeCase.SQL =>
executeSQL(relation.getSql, tracker)
@@ -2983,9 +3009,6 @@ class SparkConnectPlanner(
case _ => Seq.empty
}
- // To avoid explicit handling of the result on the client, we build the expected input
- // of the relation on the server. The client has to simply forward the result.
- val result = SqlCommandResult.newBuilder()
// Only filled when isCommand
val metrics = ExecutePlanResponse.Metrics.newBuilder()
if (isCommand || isSqlScript) {
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectStreamingQueryListenerHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectStreamingQueryListenerHandler.scala
index 04312a35a3b4b..7bee4539d01bd 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectStreamingQueryListenerHandler.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectStreamingQueryListenerHandler.scala
@@ -21,6 +21,7 @@ import scala.util.control.NonFatal
import io.grpc.stub.StreamObserver
+import org.apache.spark.SparkSQLException
import org.apache.spark.connect.proto.ExecutePlanResponse
import org.apache.spark.connect.proto.StreamingQueryListenerBusCommand
import org.apache.spark.connect.proto.StreamingQueryListenerEventsResult
@@ -117,7 +118,10 @@ class SparkConnectStreamingQueryListenerHandler(executeHolder: ExecuteHolder) ex
return
}
case StreamingQueryListenerBusCommand.CommandCase.COMMAND_NOT_SET =>
- throw new IllegalArgumentException("Missing command in StreamingQueryListenerBusCommand")
+ throw new SparkSQLException(
+ errorClass = "INVALID_PARAMETER_VALUE.STREAMING_LISTENER_COMMAND_MISSING",
+ messageParameters =
+ Map("parameter" -> "command", "functionName" -> "StreamingQueryListenerBusCommand"))
}
executeHolder.eventsManager.postFinished()
}
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
index becd7d855133d..477d5b974facc 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
@@ -219,9 +219,9 @@ class SparkConnectAddArtifactsHandler(val responseObserver: StreamObserver[AddAr
ArtifactUtils.concatenatePaths(stagingDir, path)
} catch {
case _: IllegalArgumentException =>
- throw new IllegalArgumentException(
- s"Artifact with name: $name is invalid. The `name` " +
- s"must be a relative path and cannot reference parent/sibling/nephew directories.")
+ throw new SparkRuntimeException(
+ errorClass = "INVALID_ARTIFACT_PATH",
+ messageParameters = Map("name" -> name))
case NonFatal(e) => throw e
}
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
index 8fa003c11681d..cdf7013211f77 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.classic.{DataFrame, Dataset}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, StorageLevelProtoConverter}
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.utils.PlanCompressionUtils
import org.apache.spark.sql.execution.{CodegenMode, CommandExecutionMode, CostMode, ExtendedMode, FormattedMode, SimpleMode}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.ArrayImplicits._
@@ -63,6 +64,9 @@ private[connect] class SparkConnectAnalyzeHandler(
val builder = proto.AnalyzePlanResponse.newBuilder()
def transformRelation(rel: proto.Relation) = planner.transformRelation(rel, cachePlan = true)
+ def transformRelationPlan(plan: proto.Plan) = {
+ transformRelation(PlanCompressionUtils.decompressPlan(plan).getRoot)
+ }
def getDataFrameWithoutExecuting(rel: LogicalPlan): DataFrame = {
val qe = session.sessionState.executePlan(rel, CommandExecutionMode.SKIP)
@@ -71,7 +75,7 @@ private[connect] class SparkConnectAnalyzeHandler(
request.getAnalyzeCase match {
case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA =>
- val rel = transformRelation(request.getSchema.getPlan.getRoot)
+ val rel = transformRelationPlan(request.getSchema.getPlan)
val schema = getDataFrameWithoutExecuting(rel).schema
builder.setSchema(
proto.AnalyzePlanResponse.Schema
@@ -79,7 +83,7 @@ private[connect] class SparkConnectAnalyzeHandler(
.setSchema(DataTypeProtoConverter.toConnectProtoType(schema))
.build())
case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN =>
- val rel = transformRelation(request.getExplain.getPlan.getRoot)
+ val rel = transformRelationPlan(request.getExplain.getPlan)
val queryExecution = getDataFrameWithoutExecuting(rel).queryExecution
val explainString = request.getExplain.getExplainMode match {
case proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE =>
@@ -101,7 +105,7 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())
case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING =>
- val rel = transformRelation(request.getTreeString.getPlan.getRoot)
+ val rel = transformRelationPlan(request.getTreeString.getPlan)
val schema = getDataFrameWithoutExecuting(rel).schema
val treeString = if (request.getTreeString.hasLevel) {
schema.treeString(request.getTreeString.getLevel)
@@ -115,7 +119,7 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())
case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL =>
- val rel = transformRelation(request.getIsLocal.getPlan.getRoot)
+ val rel = transformRelationPlan(request.getIsLocal.getPlan)
val isLocal = getDataFrameWithoutExecuting(rel).isLocal
builder.setIsLocal(
proto.AnalyzePlanResponse.IsLocal
@@ -124,7 +128,7 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())
case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING =>
- val rel = transformRelation(request.getIsStreaming.getPlan.getRoot)
+ val rel = transformRelationPlan(request.getIsStreaming.getPlan)
val isStreaming = getDataFrameWithoutExecuting(rel).isStreaming
builder.setIsStreaming(
proto.AnalyzePlanResponse.IsStreaming
@@ -133,7 +137,7 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())
case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES =>
- val rel = transformRelation(request.getInputFiles.getPlan.getRoot)
+ val rel = transformRelationPlan(request.getInputFiles.getPlan)
val inputFiles = getDataFrameWithoutExecuting(rel).inputFiles
builder.setInputFiles(
proto.AnalyzePlanResponse.InputFiles
@@ -157,8 +161,8 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())
case proto.AnalyzePlanRequest.AnalyzeCase.SAME_SEMANTICS =>
- val targetRel = transformRelation(request.getSameSemantics.getTargetPlan.getRoot)
- val otherRel = transformRelation(request.getSameSemantics.getOtherPlan.getRoot)
+ val targetRel = transformRelationPlan(request.getSameSemantics.getTargetPlan)
+ val otherRel = transformRelationPlan(request.getSameSemantics.getOtherPlan)
val target = getDataFrameWithoutExecuting(targetRel)
val other = getDataFrameWithoutExecuting(otherRel)
builder.setSameSemantics(
@@ -167,7 +171,7 @@ private[connect] class SparkConnectAnalyzeHandler(
.setResult(target.sameSemantics(other)))
case proto.AnalyzePlanRequest.AnalyzeCase.SEMANTIC_HASH =>
- val rel = transformRelation(request.getSemanticHash.getPlan.getRoot)
+ val rel = transformRelationPlan(request.getSemanticHash.getPlan)
val semanticHash = getDataFrameWithoutExecuting(rel)
.semanticHash()
builder.setSemanticHash(
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala
index 027f4517cf3be..6780ca37e96a7 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutePlanHandler.scala
@@ -22,6 +22,7 @@ import io.grpc.stub.StreamObserver
import org.apache.spark.SparkSQLException
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.connect.utils.PlanCompressionUtils
class SparkConnectExecutePlanHandler(responseObserver: StreamObserver[proto.ExecutePlanResponse])
extends Logging {
@@ -35,12 +36,20 @@ class SparkConnectExecutePlanHandler(responseObserver: StreamObserver[proto.Exec
.getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId, previousSessionId)
val executeKey = ExecuteKey(v, sessionHolder)
+ val decompressedRequest =
+ v.toBuilder.setPlan(PlanCompressionUtils.decompressPlan(v.getPlan)).build()
+
SparkConnectService.executionManager.getExecuteHolder(executeKey) match {
case None =>
// Create a new execute holder and attach to it.
SparkConnectService.executionManager
- .createExecuteHolderAndAttach(executeKey, v, sessionHolder, responseObserver)
- case Some(executeHolder) if executeHolder.request.getPlan.equals(v.getPlan) =>
+ .createExecuteHolderAndAttach(
+ executeKey,
+ decompressedRequest,
+ sessionHolder,
+ responseObserver)
+ case Some(executeHolder)
+ if executeHolder.request.getPlan.equals(decompressedRequest.getPlan) =>
// If the execute holder already exists with the same plan, reattach to it.
SparkConnectService.executionManager
.reattachExecuteHolder(executeHolder, responseObserver, None)
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala
index ae38e55d3c672..8f41257ccdfdb 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectInterruptHandler.scala
@@ -21,6 +21,7 @@ import scala.jdk.CollectionConverters._
import io.grpc.stub.StreamObserver
+import org.apache.spark.SparkSQLException
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
@@ -41,18 +42,23 @@ class SparkConnectInterruptHandler(responseObserver: StreamObserver[proto.Interr
sessionHolder.interruptAll()
case proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_TAG =>
if (!v.hasOperationTag) {
- throw new IllegalArgumentException(
- s"INTERRUPT_TYPE_TAG requested, but no operation_tag provided.")
+ throw new SparkSQLException(
+ errorClass = "INVALID_PARAMETER_VALUE.INTERRUPT_TYPE_TAG_REQUIRES_TAG",
+ messageParameters =
+ Map("parameter" -> "operation_tag", "functionName" -> "interrupt"))
}
sessionHolder.interruptTag(v.getOperationTag)
case proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID =>
if (!v.hasOperationId) {
- throw new IllegalArgumentException(
- s"INTERRUPT_TYPE_OPERATION_ID requested, but no operation_id provided.")
+ throw new SparkSQLException(
+ errorClass = "INVALID_PARAMETER_VALUE.INTERRUPT_TYPE_OPERATION_ID_REQUIRES_ID",
+ messageParameters = Map("parameter" -> "operation_id", "functionName" -> "interrupt"))
}
sessionHolder.interruptOperation(v.getOperationId)
case other =>
- throw new UnsupportedOperationException(s"Unknown InterruptType $other!")
+ throw new SparkSQLException(
+ errorClass = "UNSUPPORTED_FEATURE.INTERRUPT_TYPE",
+ messageParameters = Map("interruptType" -> other.toString))
}
val response = proto.InterruptResponse
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectServer.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectServer.scala
index 5b1034a4a27b7..1b2130a0e66b5 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectServer.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectServer.scala
@@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{HOST, PORT}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.Utils
/**
* The Spark Connect server
@@ -38,9 +39,9 @@ object SparkConnectServer extends Logging {
try {
SparkConnectService.start(session.sparkContext)
val isa = SparkConnectService.bindingAddress
+ val host = Utils.normalizeIpIfNeeded(isa.getAddress.getHostAddress)
logInfo(
- log"Spark Connect server started at: " +
- log"${MDC(HOST, isa.getAddress.getHostAddress)}:${MDC(PORT, isa.getPort)}")
+ log"Spark Connect server started at: ${MDC(HOST, host)}:${MDC(PORT, isa.getPort)}")
} catch {
case e: Exception =>
logError("Error starting Spark Connect server", e)
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 13ce2d64256b4..00b93c19b2c73 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -38,6 +38,8 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.HOST
import org.apache.spark.internal.config.UI.UI_ENABLED
import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerEvent}
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.connect.config.Connect.{getAuthenticateToken, CONNECT_GRPC_BINDING_ADDRESS, CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE, CONNECT_GRPC_PORT_MAX_RETRIES}
import org.apache.spark.sql.connect.execution.ConnectProgressExecutionListener
import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab}
@@ -436,6 +438,8 @@ object SparkConnectService extends Logging {
return
}
+ sessionManager.initializeBaseSession(() =>
+ SparkSession.builder().sparkContext(sc).getOrCreate().newSession())
startGRPCService()
createListenerAndUI(sc)
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
index f28af0379a04c..d3ddf592e9e7d 100644
--- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala
@@ -39,6 +39,12 @@ import org.apache.spark.util.ThreadUtils
*/
class SparkConnectSessionManager extends Logging {
+ // Used to lazily initialize the base session
+ @volatile private var baseSessionCreator: Option[() => SparkSession] = None
+
+ // Base SparkSession created from the SparkContext, used to create new isolated sessions
+ @volatile private var _baseSession: Option[SparkSession] = None
+
private val sessionStore: ConcurrentMap[SessionKey, SessionHolder] =
new ConcurrentHashMap[SessionKey, SessionHolder]()
@@ -48,6 +54,23 @@ class SparkConnectSessionManager extends Logging {
.maximumSize(SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE))
.build[SessionKey, SessionHolderInfo]()
+ private def baseSession: Option[SparkSession] = {
+ if (_baseSession.isEmpty && baseSessionCreator.isDefined) {
+ _baseSession = Some(baseSessionCreator.get())
+ }
+ _baseSession
+ }
+
+ /**
+ * Initialize the base SparkSession from the provided SparkContext. This should be called once
+ * during SparkConnectService startup.
+ */
+ def initializeBaseSession(createSession: () => SparkSession): Unit = {
+ if (baseSessionCreator.isEmpty) {
+ baseSessionCreator = Some(createSession)
+ }
+ }
+
/** Executor for the periodic maintenance */
private val scheduledExecutor: AtomicReference[ScheduledExecutorService] =
new AtomicReference[ScheduledExecutorService]()
@@ -333,12 +356,12 @@ class SparkConnectSessionManager extends Logging {
}
private def newIsolatedSession(): SparkSession = {
- val active = SparkSession.active
- if (active.sparkContext.isStopped) {
+ val session = baseSession.get
+ if (session.sparkContext.isStopped) {
assert(SparkSession.getDefaultSession.nonEmpty)
SparkSession.getDefaultSession.get.newSession()
} else {
- active.newSession()
+ session.newSession()
}
}
diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/PlanCompressionUtils.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/PlanCompressionUtils.scala
new file mode 100644
index 0000000000000..708ef1ee6558f
--- /dev/null
+++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/PlanCompressionUtils.scala
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.utils
+
+import java.io.IOException
+
+import scala.util.control.NonFatal
+
+import com.github.luben.zstd.{Zstd, ZstdInputStreamNoFinalizer}
+import com.google.protobuf.{ByteString, CodedInputStream}
+import org.apache.commons.io.input.BoundedInputStream
+
+import org.apache.spark.{SparkEnv, SparkSQLException}
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.config.Connect
+import org.apache.spark.sql.connect.planner.InvalidInputErrors
+
+object PlanCompressionUtils {
+ def decompressPlan(plan: proto.Plan): proto.Plan = {
+ plan.getOpTypeCase match {
+ case proto.Plan.OpTypeCase.COMPRESSED_OPERATION =>
+ val (cis, closeStream) = decompressBytes(
+ plan.getCompressedOperation.getData,
+ plan.getCompressedOperation.getCompressionCodec)
+ try {
+ plan.getCompressedOperation.getOpType match {
+ case proto.Plan.CompressedOperation.OpType.OP_TYPE_RELATION =>
+ proto.Plan.newBuilder().setRoot(proto.Relation.parser().parseFrom(cis)).build()
+ case proto.Plan.CompressedOperation.OpType.OP_TYPE_COMMAND =>
+ proto.Plan.newBuilder().setCommand(proto.Command.parser().parseFrom(cis)).build()
+ case other =>
+ throw InvalidInputErrors.invalidOneOfField(
+ other,
+ plan.getCompressedOperation.getDescriptorForType)
+ }
+ } catch {
+ case e: SparkSQLException =>
+ throw e
+ case NonFatal(e) =>
+ throw new SparkSQLException(
+ errorClass = "CONNECT_INVALID_PLAN.CANNOT_PARSE",
+ messageParameters = Map("errorMsg" -> e.getMessage))
+ } finally {
+ try {
+ closeStream()
+ } catch {
+ case NonFatal(_) =>
+ }
+ }
+ case _ => plan
+ }
+ }
+
+ private def getMaxPlanSize: Long = {
+ SparkEnv.get.conf.get(Connect.CONNECT_MAX_PLAN_SIZE)
+ }
+
+ /**
+ * Decompress the given bytes using the specified codec.
+ * @return
+ * A tuple of decompressed CodedInputStream and a function to close the underlying stream.
+ */
+ private def decompressBytes(
+ data: ByteString,
+ compressionCodec: proto.CompressionCodec): (CodedInputStream, () => Unit) = {
+ compressionCodec match {
+ case proto.CompressionCodec.COMPRESSION_CODEC_ZSTD =>
+ decompressBytesWithZstd(data, getMaxPlanSize)
+ case other =>
+ throw InvalidInputErrors.invalidEnum(other)
+ }
+ }
+
+ private def decompressBytesWithZstd(
+ input: ByteString,
+ maxOutputSize: Long): (CodedInputStream, () => Unit) = {
+ // Check the declared size in the header against the limit.
+ val declaredSize = Zstd.getFrameContentSize(input.asReadOnlyByteBuffer())
+ if (declaredSize > maxOutputSize) {
+ throw new SparkSQLException(
+ errorClass = "CONNECT_INVALID_PLAN.PLAN_SIZE_LARGER_THAN_MAX",
+ messageParameters =
+ Map("planSize" -> declaredSize.toString, "maxPlanSize" -> maxOutputSize.toString))
+ }
+
+ val zstdStream = new ZstdInputStreamNoFinalizer(input.newInput())
+
+ // Create a bounded input stream to limit the decompressed output size to avoid decompression
+ // bomb attacks.
+ val boundedStream = new BoundedInputStream(zstdStream, maxOutputSize) {
+ @throws[IOException]
+ override protected def onMaxLength(maxBytes: Long, count: Long): Unit =
+ throw new SparkSQLException(
+ errorClass = "CONNECT_INVALID_PLAN.PLAN_SIZE_LARGER_THAN_MAX",
+ messageParameters =
+ Map("planSize" -> "unknown", "maxPlanSize" -> maxOutputSize.toString))
+ }
+ val cis = CodedInputStream.newInstance(boundedStream)
+ cis.setSizeLimit(Integer.MAX_VALUE)
+ cis.setRecursionLimit(SparkEnv.get.conf.get(Connect.CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT))
+ (cis, () => boundedStream.close())
+ }
+}
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
index 1b2b7ab420296..77ede8e852e87 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala
@@ -26,15 +26,19 @@ import org.scalatest.time.Span
import org.scalatest.time.SpanSugar._
import org.apache.spark.connect.proto
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.connect.client.{CloseableIterator, CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, RetryPolicy, SparkConnectClient, SparkConnectStubState}
+import org.apache.spark.sql.classic
+import org.apache.spark.sql.connect
+import org.apache.spark.sql.connect.client.{CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, RetryPolicy, SparkConnectClient, SparkConnectStubState}
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
import org.apache.spark.sql.connect.common.config.ConnectCommon
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.dsl.MockRemoteSession
import org.apache.spark.sql.connect.dsl.plans._
-import org.apache.spark.sql.connect.service.{ExecuteHolder, SparkConnectService}
+import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionKey, SparkConnectService}
import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.util.CloseableIterator
/**
* Base class and utilities for a test suite that starts and tests the real SparkConnectService
@@ -320,4 +324,71 @@ trait SparkConnectServerTest extends SharedSparkSession {
val plan = buildPlan(query)
runQuery(plan, queryTimeout, iterSleep)
}
+
+ /**
+ * Helper method to create a connect SparkSession that connects to the localhost server. Similar
+ * to withClient, but provides a full SparkSession API instead of just a client.
+ *
+ * @param sessionId
+ * Optional session ID (defaults to defaultSessionId)
+ * @param userId
+ * Optional user ID (defaults to defaultUserId)
+ * @param f
+ * Function to execute with the session
+ */
+ protected def withSession(sessionId: String = defaultSessionId, userId: String = defaultUserId)(
+ f: SparkSession => Unit): Unit = {
+ withSession(f, sessionId, userId)
+ }
+
+ /**
+ * Helper method to create a connect SparkSession with default session and user IDs.
+ *
+ * @param f
+ * Function to execute with the session
+ */
+ protected def withSession(f: SparkSession => Unit): Unit = {
+ withSession(f, defaultSessionId, defaultUserId)
+ }
+
+ private def withSession(f: SparkSession => Unit, sessionId: String, userId: String): Unit = {
+ val client = SparkConnectClient
+ .builder()
+ .port(serverPort)
+ .sessionId(sessionId)
+ .userId(userId)
+ .build()
+
+ val session = connect.SparkSession
+ .builder()
+ .client(client)
+ .create()
+ try f(session)
+ finally {
+ session.close()
+ }
+ }
+
+ /**
+ * Get the server-side SparkSession corresponding to a client SparkSession.
+ *
+ * This helper takes a sql.SparkSession (which is assumed to be a connect.SparkSession),
+ * extracts the userId and sessionId from it, and looks up the corresponding server-side classic
+ * SparkSession using SparkConnectSessionManager.
+ *
+ * @param clientSession
+ * The client SparkSession (must be a connect.SparkSession)
+ * @return
+ * The server-side classic SparkSession
+ */
+ protected def getServerSession(clientSession: SparkSession): classic.SparkSession = {
+ val connectSession = clientSession.asInstanceOf[connect.SparkSession]
+ val userId = connectSession.client.userId
+ val sessionId = connectSession.sessionId
+ val key = SessionKey(userId, sessionId)
+ SparkConnectService.sessionManager
+ .getIsolatedSessionIfPresent(key)
+ .get
+ .session
+ }
}
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTestSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTestSuite.scala
new file mode 100644
index 0000000000000..c14114ced6634
--- /dev/null
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTestSuite.scala
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect
+
+import org.scalatest.time.SpanSugar._
+
+/**
+ * Test suite showcasing the APIs provided by SparkConnectServerTest trait.
+ *
+ * This suite demonstrates:
+ * - Session and client helper methods (withSession, withClient, getServerSession)
+ * - Low-level stub helpers (withRawBlockingStub, withCustomBlockingStub)
+ * - Plan building helpers (buildPlan, buildExecutePlanRequest, etc.)
+ * - Assertion helpers for execution state
+ */
+class SparkConnectServerTestSuite extends SparkConnectServerTest {
+
+ test("withSession: execute SQL and collect results") {
+ withSession { session =>
+ val df = session.sql("SELECT 1 as value")
+ val result = df.collect()
+ assert(result.length == 1)
+ assert(result(0).getInt(0) == 1)
+ }
+ }
+
+ test("withSession: with custom session and user IDs") {
+ val customSessionId = java.util.UUID.randomUUID().toString
+ val customUserId = "test-user"
+ withSession(sessionId = customSessionId, userId = customUserId) { session =>
+ val df = session.sql("SELECT 'hello' as greeting")
+ val result = df.collect()
+ assert(result.length == 1)
+ assert(result(0).getString(0) == "hello")
+ }
+ }
+
+ test("withSession: DataFrame operations") {
+ withSession { session =>
+ val df = session.range(10)
+ assert(df.count() == 10)
+
+ val sum = df.selectExpr("sum(id)").collect()(0).getLong(0)
+ assert(sum == 45) // 0 + 1 + ... + 9 = 45
+ }
+ }
+
+ test("withClient: execute plan and iterate results") {
+ withClient { client =>
+ val plan = buildPlan("SELECT 1 as x, 2 as y")
+ val iter = client.execute(plan)
+ var hasResults = false
+ while (iter.hasNext) {
+ iter.next()
+ hasResults = true
+ }
+ assert(hasResults)
+ }
+ }
+
+ test("withClient: with custom session and user IDs") {
+ val customSessionId = java.util.UUID.randomUUID().toString
+ val customUserId = "custom-user"
+ withClient(sessionId = customSessionId, userId = customUserId) { client =>
+ val plan = buildPlan("SELECT 42")
+ val iter = client.execute(plan)
+ while (iter.hasNext) iter.next()
+ }
+ }
+
+ test("getServerSession: returns server-side classic session") {
+ withSession { clientSession =>
+ clientSession.sql("SELECT 1").collect()
+
+ val serverSession = getServerSession(clientSession)
+
+ assert(serverSession != null)
+ assert(serverSession.sparkContext != null)
+ }
+ }
+
+ test("getServerSession: client and server share configuration") {
+ withSession { clientSession =>
+ clientSession.sql("SET spark.sql.shuffle.partitions=17").collect()
+
+ val serverSession = getServerSession(clientSession)
+ assert(serverSession.conf.get("spark.sql.shuffle.partitions") == "17")
+ }
+ }
+
+ test("getServerSession: register and use temporary view from server") {
+ withSession { clientSession =>
+ clientSession.sql("SELECT 1 as a, 2 as b").collect()
+
+ val serverSession = getServerSession(clientSession)
+
+ // Create a temp view on the server side
+ import serverSession.implicits._
+ val serverDf = Seq((100, "server"), (200, "side")).toDF("num", "source")
+ serverDf.createOrReplaceTempView("server_view")
+
+ // Access the view from the client
+ val result = clientSession.sql("SELECT * FROM server_view ORDER BY num").collect()
+ assert(result.length == 2)
+ assert(result(0).getInt(0) == 100)
+ assert(result(0).getString(1) == "server")
+ assert(result(1).getInt(0) == 200)
+ assert(result(1).getString(1) == "side")
+ }
+ }
+
+ test("withRawBlockingStub: execute plan via raw gRPC stub") {
+ withRawBlockingStub { stub =>
+ val request = buildExecutePlanRequest(buildPlan("SELECT 'raw' as mode"))
+ val iter = stub.executePlan(request)
+ assert(iter.hasNext)
+ while (iter.hasNext) iter.next()
+ }
+ }
+
+ test("withCustomBlockingStub: execute plan via custom blocking stub") {
+ withCustomBlockingStub() { stub =>
+ val request = buildExecutePlanRequest(buildPlan("SELECT 'custom' as mode"))
+ val iter = stub.executePlan(request)
+ while (iter.hasNext) iter.next()
+ }
+ }
+
+ test("buildPlan: creates plan from SQL query") {
+ val plan = buildPlan("SELECT 1, 2, 3")
+ assert(plan.hasRoot)
+ }
+
+ test("buildSqlCommandPlan: creates command plan") {
+ val plan = buildSqlCommandPlan("SET spark.sql.adaptive.enabled=true")
+ assert(plan.hasCommand)
+ assert(plan.getCommand.hasSqlCommand)
+ }
+
+ test("buildLocalRelation: creates plan from local data") {
+ val data = Seq((1, "a"), (2, "b"), (3, "c"))
+ val plan = buildLocalRelation(data)
+ assert(plan.hasRoot)
+ assert(plan.getRoot.hasLocalRelation)
+ }
+
+ test("buildExecutePlanRequest: creates request with options") {
+ val plan = buildPlan("SELECT 1")
+ val request = buildExecutePlanRequest(plan)
+ assert(request.hasPlan)
+ assert(request.hasUserContext)
+ assert(request.getSessionId == defaultSessionId)
+ }
+
+ test("buildExecutePlanRequest: with custom session and operation IDs") {
+ val plan = buildPlan("SELECT 1")
+ val customSessionId = "my-session"
+ val customOperationId = "my-operation"
+ val request =
+ buildExecutePlanRequest(plan, sessionId = customSessionId, operationId = customOperationId)
+ assert(request.getSessionId == customSessionId)
+ assert(request.getOperationId == customOperationId)
+ }
+
+ test("runQuery: executes query string with timeout") {
+ runQuery("SELECT * FROM range(100)", 30.seconds)
+ }
+
+ test("runQuery: executes plan with timeout and iter sleep") {
+ val plan = buildPlan("SELECT * FROM range(10)")
+ runQuery(plan, 30.seconds, iterSleep = 10)
+ }
+
+ test("assertNoActiveExecutions: verifies clean state") {
+ assertNoActiveExecutions()
+ }
+
+ test("assertNoActiveRpcs: verifies no active RPCs") {
+ assertNoActiveRpcs()
+ }
+
+ test("eventuallyGetExecutionHolder: retrieves active execution") {
+ withRawBlockingStub { stub =>
+ val request = buildExecutePlanRequest(buildPlan("SELECT * FROM range(1000000)"))
+ val iter = stub.executePlan(request)
+ iter.hasNext // trigger execution
+
+ val holder = eventuallyGetExecutionHolder
+ assert(holder != null)
+ assert(holder.operationId == request.getOperationId)
+ }
+ }
+}
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala
index 55b8a315df570..f674b45bb072d 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala
@@ -162,7 +162,7 @@ class EndToEndAPISuite extends PipelineTest with APITest with SparkConnectServer
|name: test-pipeline
|${spec.catalog.map(catalog => s"""catalog: "$catalog"""").getOrElse("")}
|${spec.database.map(database => s"""database: "$database"""").getOrElse("")}
- |storage: "${projectDir.resolve("storage").toAbsolutePath}"
+ |storage: "file://${projectDir.resolve("storage").toAbsolutePath}"
|configuration:
| "spark.remote": "sc://localhost:$serverPort"
|libraries:
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
index 79c34ac46b9fb..98b33c3296fac 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala
@@ -26,6 +26,9 @@ import java.util.concurrent.TimeUnit
import scala.collection.mutable.ArrayBuffer
import scala.util.Try
+import org.scalactic.source.Position
+import org.scalatest.Tag
+
import org.apache.spark.api.python.PythonUtils
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
@@ -49,8 +52,7 @@ class PythonPipelineSuite
with EventVerificationTestHelpers {
def buildGraph(pythonText: String): DataflowGraph = {
- assume(PythonTestDepsChecker.isConnectDepsAvailable)
- val indentedPythonText = pythonText.linesIterator.map(" " + _).mkString("\n")
+ val indentedPythonText = pythonText.linesIterator.map(" " + _).mkString("\n")
// create a unique identifier to allow identifying the session and dataflow graph
val customSessionIdentifier = UUID.randomUUID().toString
val pythonCode =
@@ -64,6 +66,9 @@ class PythonPipelineSuite
|from pyspark.pipelines.graph_element_registry import (
| graph_element_registration_context,
|)
+ |from pyspark.pipelines.add_pipeline_analysis_context import (
+ | add_pipeline_analysis_context
+ |)
|
|spark = SparkSession.builder \\
| .remote("sc://localhost:$serverPort") \\
@@ -79,7 +84,10 @@ class PythonPipelineSuite
|)
|
|registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id)
- |with graph_element_registration_context(registry):
+ |with add_pipeline_analysis_context(
+ | spark=spark, dataflow_graph_id=dataflow_graph_id, flow_name=None
+ |):
+ | with graph_element_registration_context(registry):
|$indentedPythonText
|""".stripMargin
@@ -143,7 +151,7 @@ class PythonPipelineSuite
QueryOrigin(
language = Option(Python()),
filePath = Option(""),
- line = Option(28),
+ line = Option(34),
objectName = Option("spark_catalog.default.table1"),
objectType = Option(QueryOriginType.Flow.toString))),
errorChecker = ex =>
@@ -195,7 +203,7 @@ class PythonPipelineSuite
QueryOrigin(
language = Option(Python()),
filePath = Option(""),
- line = Option(34),
+ line = Option(40),
objectName = Option("spark_catalog.default.mv2"),
objectType = Option(QueryOriginType.Flow.toString))),
expectedEventLevel = EventLevel.INFO)
@@ -209,7 +217,7 @@ class PythonPipelineSuite
QueryOrigin(
language = Option(Python()),
filePath = Option(""),
- line = Option(38),
+ line = Option(44),
objectName = Option("spark_catalog.default.mv"),
objectType = Option(QueryOriginType.Flow.toString))),
expectedEventLevel = EventLevel.INFO)
@@ -227,7 +235,7 @@ class PythonPipelineSuite
QueryOrigin(
language = Option(Python()),
filePath = Option(""),
- line = Option(28),
+ line = Option(34),
objectName = Option("spark_catalog.default.table1"),
objectType = Option(QueryOriginType.Flow.toString))),
expectedEventLevel = EventLevel.INFO)
@@ -241,7 +249,7 @@ class PythonPipelineSuite
QueryOrigin(
language = Option(Python()),
filePath = Option(""),
- line = Option(43),
+ line = Option(49),
objectName = Option("spark_catalog.default.standalone_flow1"),
objectType = Option(QueryOriginType.Flow.toString))),
expectedEventLevel = EventLevel.INFO)
@@ -334,21 +342,37 @@ class PythonPipelineSuite
|@dp.table
|def b():
| return spark.readStream.table("src")
+ |
+ |@dp.materialized_view
+ |def c():
+ | return spark.sql("SELECT * FROM src")
+ |
+ |@dp.table
+ |def d():
+ | return spark.sql("SELECT * FROM STREAM src")
|""".stripMargin).resolve().validate()
assert(
graph.table.keySet == Set(
graphIdentifier("src"),
graphIdentifier("a"),
- graphIdentifier("b")))
- Seq("a", "b").foreach { flowName =>
+ graphIdentifier("b"),
+ graphIdentifier("c"),
+ graphIdentifier("d")))
+ Seq("a", "b", "c").foreach { flowName =>
// dependency is properly tracked
assert(graph.resolvedFlow(graphIdentifier(flowName)).inputs == Set(graphIdentifier("src")))
}
val (streamingFlows, batchFlows) = graph.resolvedFlows.partition(_.df.isStreaming)
- assert(batchFlows.map(_.identifier) == Seq(graphIdentifier("src"), graphIdentifier("a")))
- assert(streamingFlows.map(_.identifier) == Seq(graphIdentifier("b")))
+ assert(
+ batchFlows.map(_.identifier).toSet == Set(
+ graphIdentifier("src"),
+ graphIdentifier("a"),
+ graphIdentifier("c")))
+ assert(
+ streamingFlows.map(_.identifier).toSet ==
+ Set(graphIdentifier("b"), graphIdentifier("d")))
}
test("referencing external datasets") {
@@ -365,18 +389,33 @@ class PythonPipelineSuite
|@dp.table
|def c():
| return spark.readStream.table("spark_catalog.default.src")
+ |
+ |@dp.materialized_view
+ |def d():
+ | return spark.sql("SELECT * FROM spark_catalog.default.src")
+ |
+ |@dp.table
+ |def e():
+ | return spark.sql("SELECT * FROM STREAM spark_catalog.default.src")
|""".stripMargin).resolve().validate()
assert(
graph.tables.map(_.identifier).toSet == Set(
graphIdentifier("a"),
graphIdentifier("b"),
- graphIdentifier("c")))
+ graphIdentifier("c"),
+ graphIdentifier("d"),
+ graphIdentifier("e")))
// dependency is not tracked
assert(graph.resolvedFlows.forall(_.inputs.isEmpty))
val (streamingFlows, batchFlows) = graph.resolvedFlows.partition(_.df.isStreaming)
- assert(batchFlows.map(_.identifier).toSet == Set(graphIdentifier("a"), graphIdentifier("b")))
- assert(streamingFlows.map(_.identifier) == Seq(graphIdentifier("c")))
+ assert(
+ batchFlows.map(_.identifier).toSet == Set(
+ graphIdentifier("a"),
+ graphIdentifier("b"),
+ graphIdentifier("d")))
+ assert(
+ streamingFlows.map(_.identifier).toSet == Set(graphIdentifier("c"), graphIdentifier("e")))
}
test("referencing internal datasets failed") {
@@ -392,9 +431,17 @@ class PythonPipelineSuite
|@dp.table
|def c():
| return spark.readStream.table("src")
+ |
+ |@dp.materialized_view
+ |def d():
+ | return spark.sql("SELECT * FROM src")
+ |
+ |@dp.table
+ |def e():
+ | return spark.sql("SELECT * FROM STREAM src")
|""".stripMargin).resolve()
- assert(graph.resolutionFailedFlows.size == 3)
+ assert(graph.resolutionFailedFlows.size == 5)
graph.resolutionFailedFlows.foreach { flow =>
assert(flow.failure.head.getMessage.contains("[TABLE_OR_VIEW_NOT_FOUND]"))
assert(flow.failure.head.getMessage.contains("`src`"))
@@ -414,14 +461,95 @@ class PythonPipelineSuite
|@dp.materialized_view
|def c():
| return spark.readStream.table("spark_catalog.default.src")
+ |
+ |@dp.materialized_view
+ |def d():
+ | return spark.sql("SELECT * FROM spark_catalog.default.src")
+ |
+ |@dp.table
+ |def e():
+ | return spark.sql("SELECT * FROM STREAM spark_catalog.default.src")
|""".stripMargin).resolve()
+ assert(graph.resolutionFailedFlows.size == 5)
graph.resolutionFailedFlows.foreach { flow =>
- assert(flow.failure.head.getMessage.contains("[TABLE_OR_VIEW_NOT_FOUND] The table or view"))
+ assert(flow.failure.head.getMessage.contains("[TABLE_OR_VIEW_NOT_FOUND]"))
+ assert(flow.failure.head.getMessage.contains("`spark_catalog`.`default`.`src`"))
}
}
+ test("reading external datasets outside query function works") {
+ sql("CREATE TABLE spark_catalog.default.src AS SELECT * FROM RANGE(5)")
+ val graph = buildGraph(s"""
+ |spark_sql_df = spark.sql("SELECT * FROM spark_catalog.default.src")
+ |read_table_df = spark.read.table("spark_catalog.default.src")
+ |
+ |@dp.materialized_view
+ |def mv_from_spark_sql_df():
+ | return spark_sql_df
+ |
+ |@dp.materialized_view
+ |def mv_from_read_table_df():
+ | return read_table_df
+ |""".stripMargin).resolve().validate()
+
+ assert(
+ graph.resolvedFlows.map(_.identifier).toSet == Set(
+ graphIdentifier("mv_from_spark_sql_df"),
+ graphIdentifier("mv_from_read_table_df")))
+ assert(graph.resolvedFlows.forall(_.inputs.isEmpty))
+ assert(graph.resolvedFlows.forall(!_.df.isStreaming))
+ }
+
+ test(
+ "reading internal datasets outside query function that don't trigger " +
+ "eager analysis or execution") {
+ val graph = buildGraph("""
+ |@dp.materialized_view
+ |def src():
+ | return spark.range(5)
+ |
+ |read_table_df = spark.read.table("src")
+ |
+ |@dp.materialized_view
+ |def mv_from_read_table_df():
+ | return read_table_df
+ |
+ |""".stripMargin).resolve().validate()
+ assert(
+ graph.resolvedFlows.map(_.identifier).toSet == Set(
+ graphIdentifier("mv_from_read_table_df"),
+ graphIdentifier("src")))
+ assert(graph.resolvedFlows.forall(!_.df.isStreaming))
+ assert(
+ graph
+ .resolvedFlow(graphIdentifier("mv_from_read_table_df"))
+ .inputs
+ .contains(graphIdentifier("src")))
+ }
+
+ gridTest(
+ "reading internal datasets outside query function that trigger " +
+ "eager analysis or execution will fail")(
+ Seq("""spark.sql("SELECT * FROM src")""", """spark.read.table("src").collect()""")) {
+ command =>
+ val ex = intercept[RuntimeException] {
+ buildGraph(s"""
+ |@dp.materialized_view
+ |def src():
+ | return spark.range(5)
+ |
+ |spark_sql_df = $command
+ |
+ |@dp.materialized_view
+ |def mv_from_spark_sql_df():
+ | return spark_sql_df
+ |""".stripMargin)
+ }
+ assert(ex.getMessage.contains("TABLE_OR_VIEW_NOT_FOUND"))
+ assert(ex.getMessage.contains("`src`"))
+ }
+
test("create dataset with the same name will fail") {
- assume(PythonTestDepsChecker.isConnectDepsAvailable)
val ex = intercept[AnalysisException] {
buildGraph(s"""
|@dp.materialized_view
@@ -495,7 +623,6 @@ class PythonPipelineSuite
}
test("create datasets with three part names") {
- assume(PythonTestDepsChecker.isConnectDepsAvailable)
val graphTry = Try {
buildGraph(s"""
|@dp.table(name = "some_catalog.some_schema.mv")
@@ -548,7 +675,6 @@ class PythonPipelineSuite
}
test("create named flow with multipart name will fail") {
- assume(PythonTestDepsChecker.isConnectDepsAvailable)
val ex = intercept[RuntimeException] {
buildGraph(s"""
|@dp.table
@@ -597,7 +723,8 @@ class PythonPipelineSuite
assert(
graph
.flowsTo(graphIdentifier("a"))
- .map(_.identifier) == Seq(graphIdentifier("a"), graphIdentifier("something")))
+ .map(_.identifier)
+ .toSet == Set(graphIdentifier("a"), graphIdentifier("something")))
}
test("groupby and rollup works with internal datasets, referencing with (col, str)") {
@@ -696,7 +823,6 @@ class PythonPipelineSuite
}
test("create pipeline without table will throw RUN_EMPTY_PIPELINE exception") {
- assume(PythonTestDepsChecker.isConnectDepsAvailable)
checkError(
exception = intercept[AnalysisException] {
buildGraph(s"""
@@ -708,7 +834,6 @@ class PythonPipelineSuite
}
test("create pipeline with only temp view will throw RUN_EMPTY_PIPELINE exception") {
- assume(PythonTestDepsChecker.isConnectDepsAvailable)
checkError(
exception = intercept[AnalysisException] {
buildGraph(s"""
@@ -722,7 +847,6 @@ class PythonPipelineSuite
}
test("create pipeline with only flow will throw RUN_EMPTY_PIPELINE exception") {
- assume(PythonTestDepsChecker.isConnectDepsAvailable)
checkError(
exception = intercept[AnalysisException] {
buildGraph(s"""
@@ -865,4 +989,128 @@ class PythonPipelineSuite
(exitCode, output.toSeq)
}
+
+ test("empty cluster_by list should work and create table with no clustering") {
+ withTable("mv", "st") {
+ val graph = buildGraph("""
+ |from pyspark.sql.functions import col
+ |
+ |@dp.materialized_view(cluster_by = [])
+ |def mv():
+ | return spark.range(5).withColumn("id_mod", col("id") % 2)
+ |
+ |@dp.table(cluster_by = [])
+ |def st():
+ | return spark.readStream.table("mv")
+ |""".stripMargin)
+ val updateContext =
+ new PipelineUpdateContextImpl(graph, eventCallback = _ => (), storageRoot = storageRoot)
+ updateContext.pipelineExecution.runPipeline()
+ updateContext.pipelineExecution.awaitCompletion()
+
+ // Check tables are created with no clustering transforms
+ val catalog = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
+
+ val mvIdentifier = Identifier.of(Array("default"), "mv")
+ val mvTable = catalog.loadTable(mvIdentifier)
+ val mvTransforms = mvTable.partitioning()
+ assert(
+ mvTransforms.isEmpty,
+ s"MaterializedView should have no transforms, but got: ${mvTransforms.mkString(", ")}")
+
+ val stIdentifier = Identifier.of(Array("default"), "st")
+ val stTable = catalog.loadTable(stIdentifier)
+ val stTransforms = stTable.partitioning()
+ assert(
+ stTransforms.isEmpty,
+ s"Table should have no transforms, but got: ${stTransforms.mkString(", ")}")
+ }
+ }
+
+ // List of unsupported SQL commands that should result in a failure.
+ private val unsupportedSqlCommandList: Seq[String] = Seq(
+ "SET CATALOG some_catalog",
+ "USE SCHEMA some_schema",
+ "SET `test_conf` = `true`",
+ "CREATE TABLE some_table (id INT)",
+ "CREATE VIEW some_view AS SELECT * FROM some_table",
+ "INSERT INTO some_table VALUES (1)",
+ "ALTER TABLE some_table RENAME TO some_new_table",
+ "CREATE NAMESPACE some_namespace",
+ "DROP VIEW some_view",
+ "CREATE MATERIALIZED VIEW some_view AS SELECT * FROM some_table",
+ "CREATE STREAMING TABLE some_table AS SELECT * FROM some_table")
+
+ gridTest("Unsupported SQL command outside query function should result in a failure")(
+ unsupportedSqlCommandList) { unsupportedSqlCommand =>
+ val ex = intercept[RuntimeException] {
+ buildGraph(s"""
+ |spark.sql("$unsupportedSqlCommand")
+ |
+ |@dp.materialized_view()
+ |def mv():
+ | return spark.range(5)
+ |""".stripMargin)
+ }
+ assert(ex.getMessage.contains("UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND"))
+ }
+
+ gridTest("Unsupported SQL command inside query function should result in a failure")(
+ unsupportedSqlCommandList) { unsupportedSqlCommand =>
+ val ex = intercept[RuntimeException] {
+ buildGraph(s"""
+ |@dp.materialized_view()
+ |def mv():
+ | spark.sql("$unsupportedSqlCommand")
+ | return spark.range(5)
+ |""".stripMargin)
+ }
+ assert(ex.getMessage.contains("UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND"))
+ }
+
+ // List of supported SQL commands that should work.
+ val supportedSqlCommandList: Seq[String] = Seq(
+ "DESCRIBE TABLE spark_catalog.default.src",
+ "SHOW TABLES",
+ "SHOW TBLPROPERTIES spark_catalog.default.src",
+ "SHOW NAMESPACES",
+ "SHOW COLUMNS FROM spark_catalog.default.src",
+ "SHOW FUNCTIONS",
+ "SHOW VIEWS",
+ "SHOW CATALOGS",
+ "SHOW CREATE TABLE spark_catalog.default.src",
+ "SELECT * FROM RANGE(5)",
+ "SELECT * FROM spark_catalog.default.src")
+
+ gridTest("Supported SQL command outside query function should work")(supportedSqlCommandList) {
+ supportedSqlCommand =>
+ sql("CREATE TABLE spark_catalog.default.src AS SELECT * FROM RANGE(5)")
+ buildGraph(s"""
+ |spark.sql("$supportedSqlCommand")
+ |
+ |@dp.materialized_view()
+ |def mv():
+ | return spark.range(5)
+ |""".stripMargin)
+ }
+
+ gridTest("Supported SQL command inside query function should work")(supportedSqlCommandList) {
+ supportedSqlCommand =>
+ sql("CREATE TABLE spark_catalog.default.src AS SELECT * FROM RANGE(5)")
+ buildGraph(s"""
+ |@dp.materialized_view()
+ |def mv():
+ | spark.sql("$supportedSqlCommand")
+ | return spark.range(5)
+ |""".stripMargin)
+ }
+
+ override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
+ pos: Position): Unit = {
+ if (PythonTestDepsChecker.isConnectDepsAvailable) {
+ super.test(testName, testTags: _*)(testFun)
+ } else {
+ super.ignore(testName, testTags: _*)(testFun)
+ }
+ }
}
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala
index 9dba27c4525c2..3cb45fa6e1720 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala
@@ -51,6 +51,32 @@ class SparkDeclarativePipelinesServerSuite
}
}
+ test(
+ "create dataflow graph set session catalog and database to pipeline " +
+ "default catalog and database") {
+ withRawBlockingStub { implicit stub =>
+ // Use default spark_catalog and create a test database
+ sql("CREATE DATABASE IF NOT EXISTS test_db")
+ try {
+ val graphId = sendPlan(
+ buildCreateDataflowGraphPlan(
+ proto.PipelineCommand.CreateDataflowGraph
+ .newBuilder()
+ .setDefaultCatalog("spark_catalog")
+ .setDefaultDatabase("test_db")
+ .build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId
+ val definition =
+ getDefaultSessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(graphId)
+ assert(definition.defaultCatalog == "spark_catalog")
+ assert(definition.defaultDatabase == "test_db")
+ assert(getDefaultSessionHolder.session.catalog.currentCatalog() == "spark_catalog")
+ assert(getDefaultSessionHolder.session.catalog.currentDatabase == "test_db")
+ } finally {
+ sql("DROP DATABASE IF EXISTS test_db")
+ }
+ }
+ }
+
test("Define a flow for a graph that does not exist") {
val ex = intercept[Exception] {
withRawBlockingStub { implicit stub =>
@@ -71,6 +97,24 @@ class SparkDeclarativePipelinesServerSuite
}
+ gridTest("Define flow 'once' argument not supported")(Seq(true, false)) { onceValue =>
+ val ex = intercept[Exception] {
+ withRawBlockingStub { implicit stub =>
+ val graphId = createDataflowGraph
+ sendPlan(
+ buildPlanFromPipelineCommand(
+ PipelineCommand
+ .newBuilder()
+ .setDefineFlow(DefineFlow
+ .newBuilder()
+ .setDataflowGraphId(graphId)
+ .setOnce(onceValue))
+ .build()))
+ }
+ }
+ assert(ex.getMessage.contains("DEFINE_FLOW_ONCE_OPTION_NOT_SUPPORTED"))
+ }
+
test(
"Cross dependency between SQL dataset and non-SQL dataset is valid and can be registered") {
withRawBlockingStub { implicit stub =>
@@ -497,8 +541,7 @@ class SparkDeclarativePipelinesServerSuite
name: String,
datasetType: OutputType,
datasetName: String,
- defaultCatalog: String = "",
- defaultDatabase: String = "",
+ defaultDatabase: String,
expectedResolvedDatasetName: String,
expectedResolvedCatalog: String,
expectedResolvedNamespace: Seq[String])
@@ -508,6 +551,7 @@ class SparkDeclarativePipelinesServerSuite
name = "TEMPORARY_VIEW",
datasetType = OutputType.TEMPORARY_VIEW,
datasetName = "tv",
+ defaultDatabase = "default",
expectedResolvedDatasetName = "tv",
expectedResolvedCatalog = "",
expectedResolvedNamespace = Seq.empty),
@@ -515,6 +559,7 @@ class SparkDeclarativePipelinesServerSuite
name = "TABLE",
datasetType = OutputType.TABLE,
datasetName = "`tb`",
+ defaultDatabase = "default",
expectedResolvedDatasetName = "tb",
expectedResolvedCatalog = "spark_catalog",
expectedResolvedNamespace = Seq("default")),
@@ -522,6 +567,7 @@ class SparkDeclarativePipelinesServerSuite
name = "MV",
datasetType = OutputType.MATERIALIZED_VIEW,
datasetName = "mv",
+ defaultDatabase = "default",
expectedResolvedDatasetName = "mv",
expectedResolvedCatalog = "spark_catalog",
expectedResolvedNamespace = Seq("default"))).map(tc => tc.name -> tc).toMap
@@ -531,7 +577,6 @@ class SparkDeclarativePipelinesServerSuite
name = "TEMPORARY_VIEW",
datasetType = OutputType.TEMPORARY_VIEW,
datasetName = "tv",
- defaultCatalog = "custom_catalog",
defaultDatabase = "custom_db",
expectedResolvedDatasetName = "tv",
expectedResolvedCatalog = "",
@@ -540,19 +585,17 @@ class SparkDeclarativePipelinesServerSuite
name = "TABLE",
datasetType = OutputType.TABLE,
datasetName = "`tb`",
- defaultCatalog = "`my_catalog`",
defaultDatabase = "`my_db`",
expectedResolvedDatasetName = "tb",
- expectedResolvedCatalog = "`my_catalog`",
+ expectedResolvedCatalog = "spark_catalog",
expectedResolvedNamespace = Seq("`my_db`")),
DefineOutputTestCase(
name = "MV",
datasetType = OutputType.MATERIALIZED_VIEW,
datasetName = "mv",
- defaultCatalog = "another_catalog",
defaultDatabase = "another_db",
expectedResolvedDatasetName = "mv",
- expectedResolvedCatalog = "another_catalog",
+ expectedResolvedCatalog = "spark_catalog",
expectedResolvedNamespace = Seq("another_db")))
.map(tc => tc.name -> tc)
.toMap
@@ -586,40 +629,45 @@ class SparkDeclarativePipelinesServerSuite
}
}
- namedGridTest("DefineOutput returns resolved data name for custom catalog/schema")(
+ namedGridTest("DefineOutput returns resolved data name for custom schema")(
defineDatasetCustomTests) { testCase =>
withRawBlockingStub { implicit stub =>
- // Build and send the CreateDataflowGraph command with custom catalog/db
- val graphId = sendPlan(
- buildCreateDataflowGraphPlan(
- proto.PipelineCommand.CreateDataflowGraph
- .newBuilder()
- .setDefaultCatalog(testCase.defaultCatalog)
- .setDefaultDatabase(testCase.defaultDatabase)
- .build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId
-
- assert(graphId.nonEmpty)
+ sql(s"CREATE DATABASE IF NOT EXISTS spark_catalog.${testCase.defaultDatabase}")
+ try {
+ // Build and send the CreateDataflowGraph command with custom catalog/db
+ val graphId = sendPlan(
+ buildCreateDataflowGraphPlan(
+ proto.PipelineCommand.CreateDataflowGraph
+ .newBuilder()
+ .setDefaultCatalog("spark_catalog")
+ .setDefaultDatabase(testCase.defaultDatabase)
+ .build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId
- // Build DefineOutput with the created graphId and dataset info
- val defineDataset = DefineOutput
- .newBuilder()
- .setDataflowGraphId(graphId)
- .setOutputName(testCase.datasetName)
- .setOutputType(testCase.datasetType)
- val pipelineCmd = PipelineCommand
- .newBuilder()
- .setDefineOutput(defineDataset)
- .build()
+ assert(graphId.nonEmpty)
- val res = sendPlan(buildPlanFromPipelineCommand(pipelineCmd)).getPipelineCommandResult
- assert(res !== PipelineCommandResult.getDefaultInstance)
- assert(res.hasDefineOutputResult)
- val graphResult = res.getDefineOutputResult
- val identifier = graphResult.getResolvedIdentifier
+ // Build DefineOutput with the created graphId and dataset info
+ val defineDataset = DefineOutput
+ .newBuilder()
+ .setDataflowGraphId(graphId)
+ .setOutputName(testCase.datasetName)
+ .setOutputType(testCase.datasetType)
+ val pipelineCmd = PipelineCommand
+ .newBuilder()
+ .setDefineOutput(defineDataset)
+ .build()
- assert(identifier.getCatalogName == testCase.expectedResolvedCatalog)
- assert(identifier.getNamespaceList.asScala == testCase.expectedResolvedNamespace)
- assert(identifier.getTableName == testCase.expectedResolvedDatasetName)
+ val res = sendPlan(buildPlanFromPipelineCommand(pipelineCmd)).getPipelineCommandResult
+ assert(res !== PipelineCommandResult.getDefaultInstance)
+ assert(res.hasDefineOutputResult)
+ val graphResult = res.getDefineOutputResult
+ val identifier = graphResult.getResolvedIdentifier
+
+ assert(identifier.getCatalogName == testCase.expectedResolvedCatalog)
+ assert(identifier.getNamespaceList.asScala == testCase.expectedResolvedNamespace)
+ assert(identifier.getTableName == testCase.expectedResolvedDatasetName)
+ } finally {
+ sql(s"DROP DATABASE IF EXISTS spark_catalog.${testCase.defaultDatabase}")
+ }
}
}
@@ -627,7 +675,6 @@ class SparkDeclarativePipelinesServerSuite
name: String,
datasetType: OutputType,
flowName: String,
- defaultCatalog: String,
defaultDatabase: String,
expectedResolvedFlowName: String,
expectedResolvedCatalog: String,
@@ -638,7 +685,6 @@ class SparkDeclarativePipelinesServerSuite
name = "MV",
datasetType = OutputType.MATERIALIZED_VIEW,
flowName = "`mv`",
- defaultCatalog = "`spark_catalog`",
defaultDatabase = "`default`",
expectedResolvedFlowName = "mv",
expectedResolvedCatalog = "spark_catalog",
@@ -647,7 +693,6 @@ class SparkDeclarativePipelinesServerSuite
name = "TV",
datasetType = OutputType.TEMPORARY_VIEW,
flowName = "tv",
- defaultCatalog = "spark_catalog",
defaultDatabase = "default",
expectedResolvedFlowName = "tv",
expectedResolvedCatalog = "",
@@ -658,16 +703,14 @@ class SparkDeclarativePipelinesServerSuite
name = "MV custom",
datasetType = OutputType.MATERIALIZED_VIEW,
flowName = "mv",
- defaultCatalog = "custom_catalog",
defaultDatabase = "custom_db",
expectedResolvedFlowName = "mv",
- expectedResolvedCatalog = "custom_catalog",
+ expectedResolvedCatalog = "spark_catalog",
expectedResolvedNamespace = Seq("custom_db")),
DefineFlowTestCase(
name = "TV custom",
datasetType = OutputType.TEMPORARY_VIEW,
flowName = "tv",
- defaultCatalog = "custom_catalog",
defaultDatabase = "custom_db",
expectedResolvedFlowName = "tv",
expectedResolvedCatalog = "",
@@ -738,68 +781,167 @@ class SparkDeclarativePipelinesServerSuite
namedGridTest("DefineFlow returns resolved data name for custom catalog/schema")(
defineFlowCustomTests) { testCase =>
withRawBlockingStub { implicit stub =>
- val graphId = sendPlan(
- buildCreateDataflowGraphPlan(
- proto.PipelineCommand.CreateDataflowGraph
+ sql(s"CREATE DATABASE IF NOT EXISTS spark_catalog.${testCase.defaultDatabase}")
+ try {
+ val graphId = sendPlan(
+ buildCreateDataflowGraphPlan(
+ proto.PipelineCommand.CreateDataflowGraph
+ .newBuilder()
+ .setDefaultCatalog("spark_catalog")
+ .setDefaultDatabase(testCase.defaultDatabase)
+ .build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId
+ assert(graphId.nonEmpty)
+
+ // If the dataset type is TEMPORARY_VIEW, define the dataset explicitly first
+ if (testCase.datasetType == OutputType.TEMPORARY_VIEW) {
+ val defineDataset = DefineOutput
.newBuilder()
- .setDefaultCatalog(testCase.defaultCatalog)
- .setDefaultDatabase(testCase.defaultDatabase)
- .build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId
- assert(graphId.nonEmpty)
+ .setDataflowGraphId(graphId)
+ .setOutputName(testCase.flowName)
+ .setOutputType(OutputType.TEMPORARY_VIEW)
- // If the dataset type is TEMPORARY_VIEW, define the dataset explicitly first
- if (testCase.datasetType == OutputType.TEMPORARY_VIEW) {
- val defineDataset = DefineOutput
+ val defineDatasetCmd = PipelineCommand
+ .newBuilder()
+ .setDefineOutput(defineDataset)
+ .build()
+
+ val datasetRes =
+ sendPlan(buildPlanFromPipelineCommand(defineDatasetCmd)).getPipelineCommandResult
+ assert(datasetRes.hasDefineOutputResult)
+ }
+
+ val defineFlow = DefineFlow
.newBuilder()
.setDataflowGraphId(graphId)
- .setOutputName(testCase.flowName)
- .setOutputType(OutputType.TEMPORARY_VIEW)
-
- val defineDatasetCmd = PipelineCommand
+ .setFlowName(testCase.flowName)
+ .setTargetDatasetName(testCase.flowName)
+ .setRelationFlowDetails(
+ DefineFlow.WriteRelationFlowDetails
+ .newBuilder()
+ .setRelation(
+ Relation
+ .newBuilder()
+ .setUnresolvedTableValuedFunction(
+ UnresolvedTableValuedFunction
+ .newBuilder()
+ .setFunctionName("range")
+ .addArguments(Expression
+ .newBuilder()
+ .setLiteral(Expression.Literal.newBuilder().setInteger(5).build())
+ .build())
+ .build())
+ .build())
+ .build())
+ .build()
+ val pipelineCmd = PipelineCommand
.newBuilder()
- .setDefineOutput(defineDataset)
+ .setDefineFlow(defineFlow)
.build()
-
- val datasetRes =
- sendPlan(buildPlanFromPipelineCommand(defineDatasetCmd)).getPipelineCommandResult
- assert(datasetRes.hasDefineOutputResult)
+ val res = sendPlan(buildPlanFromPipelineCommand(pipelineCmd)).getPipelineCommandResult
+ assert(res.hasDefineFlowResult)
+ val graphResult = res.getDefineFlowResult
+ val identifier = graphResult.getResolvedIdentifier
+
+ assert(identifier.getCatalogName == testCase.expectedResolvedCatalog)
+ assert(identifier.getNamespaceList.asScala == testCase.expectedResolvedNamespace)
+ assert(identifier.getTableName == testCase.expectedResolvedFlowName)
+ } finally {
+ sql(s"DROP DATABASE IF EXISTS spark_catalog.${testCase.defaultDatabase}")
}
+ }
+ }
- val defineFlow = DefineFlow
+ test(
+ "SPARK-54452: spark.sql() inside a pipeline flow function should return a sql_command_result") {
+ withRawBlockingStub { implicit stub =>
+ val graphId = createDataflowGraph
+ val pipelineAnalysisContext = proto.PipelineAnalysisContext
.newBuilder()
.setDataflowGraphId(graphId)
- .setFlowName(testCase.flowName)
- .setTargetDatasetName(testCase.flowName)
- .setRelationFlowDetails(
- DefineFlow.WriteRelationFlowDetails
+ .setFlowName("flow1")
+ .build()
+ val userContext = proto.UserContext
+ .newBuilder()
+ .addExtensions(com.google.protobuf.Any.pack(pipelineAnalysisContext))
+ .setUserId("test_user")
+ .build()
+
+ val relation = proto.Plan
+ .newBuilder()
+ .setCommand(
+ proto.Command
.newBuilder()
- .setRelation(
- Relation
+ .setSqlCommand(
+ proto.SqlCommand
.newBuilder()
- .setUnresolvedTableValuedFunction(
- UnresolvedTableValuedFunction
+ .setInput(
+ proto.Relation
.newBuilder()
- .setFunctionName("range")
- .addArguments(Expression
+ .setRead(proto.Read
.newBuilder()
- .setLiteral(Expression.Literal.newBuilder().setInteger(5).build())
+ .setNamedTable(
+ proto.Read.NamedTable.newBuilder().setUnparsedIdentifier("table"))
.build())
- .build())
- .build())
+ .build()))
.build())
.build()
- val pipelineCmd = PipelineCommand
+
+ val sparkSqlRequest = proto.ExecutePlanRequest
.newBuilder()
- .setDefineFlow(defineFlow)
+ .setUserContext(userContext)
+ .setPlan(relation)
+ .setSessionId(UUID.randomUUID().toString)
.build()
- val res = sendPlan(buildPlanFromPipelineCommand(pipelineCmd)).getPipelineCommandResult
- assert(res.hasDefineFlowResult)
- val graphResult = res.getDefineFlowResult
- val identifier = graphResult.getResolvedIdentifier
+ val sparkSqlResponse = stub.executePlan(sparkSqlRequest).next()
+ assert(sparkSqlResponse.hasSqlCommandResult)
+ assert(
+ sparkSqlResponse.getSqlCommandResult.getRelation ==
+ relation.getCommand.getSqlCommand.getInput)
+ }
+ }
- assert(identifier.getCatalogName == testCase.expectedResolvedCatalog)
- assert(identifier.getNamespaceList.asScala == testCase.expectedResolvedNamespace)
- assert(identifier.getTableName == testCase.expectedResolvedFlowName)
+ test(
+ "SPARK-54452: spark.sql() outside a pipeline flow function should return a " +
+ "sql_command_result") {
+ withRawBlockingStub { implicit stub =>
+ val graphId = createDataflowGraph
+ val pipelineAnalysisContext = proto.PipelineAnalysisContext
+ .newBuilder()
+ .setDataflowGraphId(graphId)
+ .build()
+ val userContext = proto.UserContext
+ .newBuilder()
+ .addExtensions(com.google.protobuf.Any.pack(pipelineAnalysisContext))
+ .setUserId("test_user")
+ .build()
+
+ val relation = proto.Plan
+ .newBuilder()
+ .setCommand(
+ proto.Command
+ .newBuilder()
+ .setSqlCommand(
+ proto.SqlCommand
+ .newBuilder()
+ .setInput(proto.Relation
+ .newBuilder()
+ .setSql(proto.SQL.newBuilder().setQuery("SELECT * FROM RANGE(5)"))
+ .build())
+ .build())
+ .build())
+ .build()
+
+ val sparkSqlRequest = proto.ExecutePlanRequest
+ .newBuilder()
+ .setUserContext(userContext)
+ .setPlan(relation)
+ .setSessionId(UUID.randomUUID().toString)
+ .build()
+ val sparkSqlResponse = stub.executePlan(sparkSqlRequest).next()
+ assert(sparkSqlResponse.hasSqlCommandResult)
+ assert(
+ sparkSqlResponse.getSqlCommandResult.getRelation ==
+ relation.getCommand.getSqlCommand.getInput)
}
}
}
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/TestPipelineDefinition.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/TestPipelineDefinition.scala
index dfb766b1df778..f3b63f7914218 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/TestPipelineDefinition.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/TestPipelineDefinition.scala
@@ -41,10 +41,12 @@ class TestPipelineDefinition(graphId: String) {
// TODO: Add support for specifiedSchema
// specifiedSchema: Option[StructType] = None,
partitionCols: Option[Seq[String]] = None,
+ clusterCols: Option[Seq[String]] = None,
properties: Map[String, String] = Map.empty): Unit = {
val tableDetails = sc.PipelineCommand.DefineOutput.TableDetails
.newBuilder()
.addAllPartitionCols(partitionCols.getOrElse(Seq()).asJava)
+ .addAllClusteringColumns(clusterCols.getOrElse(Seq()).asJava)
.putAllTableProperties(properties.asJava)
.build()
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
index 2989471d36a03..0e5488e312220 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala
@@ -65,6 +65,12 @@ class SparkConnectServiceSuite
with Logging
with SparkConnectPlanTest {
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ SparkConnectService.sessionManager.invalidateAllSessions()
+ SparkConnectService.sessionManager.initializeBaseSession(() => spark.newSession())
+ }
+
private def sparkSessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
private def DEFAULT_UUID = UUID.fromString("89ea6117-1f45-4c03-ae27-f47c6aded093")
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala
index 6cc5daadfddd7..1df8ba46286cb 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala
@@ -399,11 +399,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper {
handler.onNext(req)
}
assert(e.getStatus.getCode == Code.INTERNAL)
- val statusProto = StatusProto.fromThrowable(e)
- assert(statusProto.getDetailsCount == 1)
- val details = statusProto.getDetails(0)
- val info = details.unpack(classOf[ErrorInfo])
- assert(info.getReason.contains("java.lang.IllegalArgumentException"))
+ assert(e.getMessage.contains("INVALID_ARTIFACT_PATH"))
}
handler.onCompleted()
} finally {
@@ -422,11 +418,7 @@ class AddArtifactsHandlerSuite extends SharedSparkSession with ResourceHelper {
handler.onNext(req)
}
assert(e.getStatus.getCode == Code.INTERNAL)
- val statusProto = StatusProto.fromThrowable(e)
- assert(statusProto.getDetailsCount == 1)
- val details = statusProto.getDetails(0)
- val info = details.unpack(classOf[ErrorInfo])
- assert(info.getReason.contains("java.lang.IllegalArgumentException"))
+ assert(e.getMessage.contains("INVALID_ARTIFACT_PATH"))
}
handler.onCompleted()
} finally {
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala
index 7ce3ff46f5537..275808942d37d 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala
@@ -42,6 +42,12 @@ class ArtifactStatusesHandlerSuite extends SharedSparkSession with ResourceHelpe
val sessionId = UUID.randomUUID().toString
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ SparkConnectService.sessionManager.invalidateAllSessions()
+ SparkConnectService.sessionManager.initializeBaseSession(() => spark.newSession())
+ }
+
def getStatuses(names: Seq[String], exist: Set[String]): ArtifactStatusesResponse = {
val promise = Promise[ArtifactStatusesResponse]()
val handler = new SparkConnectArtifactStatusesHandler(new DummyStreamObserver(promise)) {
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectCloneSessionSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectCloneSessionSuite.scala
index 922c239526f31..044103a3e4f13 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectCloneSessionSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectCloneSessionSuite.scala
@@ -29,6 +29,7 @@ class SparkConnectCloneSessionSuite extends SharedSparkSession with BeforeAndAft
override def beforeEach(): Unit = {
super.beforeEach()
SparkConnectService.sessionManager.invalidateAllSessions()
+ SparkConnectService.sessionManager.initializeBaseSession(() => spark.newSession())
}
test("clone session with invalid target session ID format") {
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
index 0e18ff711c4c5..a433534b7511a 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectServiceE2ESuite.scala
@@ -16,12 +16,16 @@
*/
package org.apache.spark.sql.connect.service
+import java.io.ByteArrayOutputStream
import java.util.UUID
+import com.github.luben.zstd.{Zstd, ZstdOutputStreamNoFinalizer}
+import com.google.protobuf.ByteString
import org.scalatest.concurrent.Eventually
import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkException
+import org.apache.spark.connect.proto
import org.apache.spark.sql.connect.SparkConnectServerTest
import org.apache.spark.sql.connect.config.Connect
@@ -178,8 +182,8 @@ class SparkConnectServiceE2ESuite extends SparkConnectServerTest {
}
withClient(sessionId = sessionId, userId = userId) { client =>
// shall not be able to create a new session with the same id and user.
- val query = client.execute(buildPlan("SELECT 1"))
val queryError = intercept[SparkException] {
+ val query = client.execute(buildPlan("SELECT 1"))
while (query.hasNext) query.next()
}
assert(queryError.getMessage.contains("INVALID_HANDLE.SESSION_CLOSED"))
@@ -317,4 +321,210 @@ class SparkConnectServiceE2ESuite extends SparkConnectServerTest {
assert(error.getMessage.contains(fixedOperationId))
}
}
+
+ test("Interrupt with TAG type without operation_tag throws proper error class") {
+ withRawBlockingStub { stub =>
+ // Create an interrupt request with INTERRUPT_TYPE_TAG but no operation_tag
+ val request = org.apache.spark.connect.proto.InterruptRequest
+ .newBuilder()
+ .setSessionId(UUID.randomUUID().toString)
+ .setUserContext(org.apache.spark.connect.proto.UserContext
+ .newBuilder()
+ .setUserId(defaultUserId))
+ .setInterruptType(
+ org.apache.spark.connect.proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_TAG)
+ .build()
+
+ val error = intercept[io.grpc.StatusRuntimeException] {
+ stub.interrupt(request)
+ }
+
+ // Verify the error is INVALID_PARAMETER_VALUE.INTERRUPT_TYPE_TAG_REQUIRES_TAG
+ assert(error.getMessage.contains("INVALID_PARAMETER_VALUE.INTERRUPT_TYPE_TAG_REQUIRES_TAG"))
+ assert(error.getMessage.contains("operation_tag"))
+ }
+ }
+
+ test("Interrupt with OPERATION_ID type without operation_id throws proper error class") {
+ withRawBlockingStub { stub =>
+ // Create an interrupt request with INTERRUPT_TYPE_OPERATION_ID but no operation_id
+ val request = org.apache.spark.connect.proto.InterruptRequest
+ .newBuilder()
+ .setSessionId(UUID.randomUUID().toString)
+ .setUserContext(org.apache.spark.connect.proto.UserContext
+ .newBuilder()
+ .setUserId(defaultUserId))
+ .setInterruptType(
+ org.apache.spark.connect.proto.InterruptRequest.InterruptType.INTERRUPT_TYPE_OPERATION_ID)
+ .build()
+
+ val error = intercept[io.grpc.StatusRuntimeException] {
+ stub.interrupt(request)
+ }
+
+ // Verify the error is INVALID_PARAMETER_VALUE.INTERRUPT_TYPE_OPERATION_ID_REQUIRES_ID
+ assert(
+ error.getMessage.contains(
+ "INVALID_PARAMETER_VALUE.INTERRUPT_TYPE_OPERATION_ID_REQUIRES_ID"))
+ assert(error.getMessage.contains("operation_id"))
+ }
+ }
+
+ test("Relation as compressed plan works") {
+ withClient { client =>
+ val relation = buildPlan("SELECT 1").getRoot
+ val compressedRelation = Zstd.compress(relation.toByteArray)
+ val plan = proto.Plan
+ .newBuilder()
+ .setCompressedOperation(
+ proto.Plan.CompressedOperation
+ .newBuilder()
+ .setData(ByteString.copyFrom(compressedRelation))
+ .setOpType(proto.Plan.CompressedOperation.OpType.OP_TYPE_RELATION)
+ .setCompressionCodec(proto.CompressionCodec.COMPRESSION_CODEC_ZSTD)
+ .build())
+ .build()
+ val query = client.execute(plan)
+ while (query.hasNext) query.next()
+ }
+ }
+
+ test("Command as compressed plan works") {
+ withClient { client =>
+ val command = buildSqlCommandPlan("SET spark.sql.session.timeZone=Europe/Berlin").getCommand
+ val compressedCommand = Zstd.compress(command.toByteArray)
+ val plan = proto.Plan
+ .newBuilder()
+ .setCompressedOperation(
+ proto.Plan.CompressedOperation
+ .newBuilder()
+ .setData(ByteString.copyFrom(compressedCommand))
+ .setOpType(proto.Plan.CompressedOperation.OpType.OP_TYPE_COMMAND)
+ .setCompressionCodec(proto.CompressionCodec.COMPRESSION_CODEC_ZSTD)
+ .build())
+ .build()
+ val query = client.execute(plan)
+ while (query.hasNext) query.next()
+ }
+ }
+
+ private def compressInZstdStreamingMode(input: Array[Byte]): Array[Byte] = {
+ val outputStream = new ByteArrayOutputStream()
+ val zstdStream = new ZstdOutputStreamNoFinalizer(outputStream)
+ zstdStream.write(input)
+ zstdStream.flush()
+ zstdStream.close()
+ outputStream.toByteArray
+ }
+
+ test("Compressed plans generated in streaming mode also work correctly") {
+ withClient { client =>
+ val relation = buildPlan("SELECT 1").getRoot
+ val compressedRelation = compressInZstdStreamingMode(relation.toByteArray)
+ val plan = proto.Plan
+ .newBuilder()
+ .setCompressedOperation(
+ proto.Plan.CompressedOperation
+ .newBuilder()
+ .setData(ByteString.copyFrom(compressedRelation))
+ .setOpType(proto.Plan.CompressedOperation.OpType.OP_TYPE_RELATION)
+ .setCompressionCodec(proto.CompressionCodec.COMPRESSION_CODEC_ZSTD)
+ .build())
+ .build()
+ val query = client.execute(plan)
+ while (query.hasNext) query.next()
+ }
+ }
+
+ test("Invalid compressed bytes errors out") {
+ withClient { client =>
+ val invalidBytes = "invalidBytes".getBytes
+ val plan = proto.Plan
+ .newBuilder()
+ .setCompressedOperation(
+ proto.Plan.CompressedOperation
+ .newBuilder()
+ .setData(ByteString.copyFrom(invalidBytes))
+ .setOpType(proto.Plan.CompressedOperation.OpType.OP_TYPE_RELATION)
+ .setCompressionCodec(proto.CompressionCodec.COMPRESSION_CODEC_ZSTD)
+ .build())
+ .build()
+ val ex = intercept[SparkException] {
+ val query = client.execute(plan)
+ while (query.hasNext) query.next()
+ }
+ assert(ex.getMessage.contains("CONNECT_INVALID_PLAN.CANNOT_PARSE"))
+ }
+ }
+
+ test("Invalid compressed proto message errors out") {
+ withClient { client =>
+ val data = Zstd.compress("Apache Spark".getBytes)
+ val plan = proto.Plan
+ .newBuilder()
+ .setCompressedOperation(
+ proto.Plan.CompressedOperation
+ .newBuilder()
+ .setData(ByteString.copyFrom(data))
+ .setOpType(proto.Plan.CompressedOperation.OpType.OP_TYPE_RELATION)
+ .setCompressionCodec(proto.CompressionCodec.COMPRESSION_CODEC_ZSTD)
+ .build())
+ .build()
+ val ex = intercept[SparkException] {
+ val query = client.execute(plan)
+ while (query.hasNext) query.next()
+ }
+ assert(ex.getMessage.contains("CONNECT_INVALID_PLAN.CANNOT_PARSE"))
+ }
+ }
+
+ test("Large compressed plan errors out") {
+ withClient { client =>
+ withSparkEnvConfs(Connect.CONNECT_MAX_PLAN_SIZE.key -> "100") {
+ val relation = buildPlan("SELECT '" + "Apache Spark" * 100 + "'").getRoot
+ val compressedRelation = Zstd.compress(relation.toByteArray)
+
+ val plan = proto.Plan
+ .newBuilder()
+ .setCompressedOperation(
+ proto.Plan.CompressedOperation
+ .newBuilder()
+ .setData(ByteString.copyFrom(compressedRelation))
+ .setOpType(proto.Plan.CompressedOperation.OpType.OP_TYPE_RELATION)
+ .setCompressionCodec(proto.CompressionCodec.COMPRESSION_CODEC_ZSTD)
+ .build())
+ .build()
+ val ex = intercept[SparkException] {
+ val query = client.execute(plan)
+ while (query.hasNext) query.next()
+ }
+ assert(ex.getMessage.contains("CONNECT_INVALID_PLAN.PLAN_SIZE_LARGER_THAN_MAX"))
+ }
+ }
+ }
+
+ test("Large compressed plan generated in streaming mode also errors out") {
+ withClient { client =>
+ withSparkEnvConfs(Connect.CONNECT_MAX_PLAN_SIZE.key -> "100") {
+ val relation = buildPlan("SELECT '" + "Apache Spark" * 100 + "'").getRoot
+ val compressedRelation = compressInZstdStreamingMode(relation.toByteArray)
+
+ val plan = proto.Plan
+ .newBuilder()
+ .setCompressedOperation(
+ proto.Plan.CompressedOperation
+ .newBuilder()
+ .setData(ByteString.copyFrom(compressedRelation))
+ .setOpType(proto.Plan.CompressedOperation.OpType.OP_TYPE_RELATION)
+ .setCompressionCodec(proto.CompressionCodec.COMPRESSION_CODEC_ZSTD)
+ .build())
+ .build()
+ val ex = intercept[SparkException] {
+ val query = client.execute(plan)
+ while (query.hasNext) query.next()
+ }
+ assert(ex.getMessage.contains("CONNECT_INVALID_PLAN.PLAN_SIZE_LARGER_THAN_MAX"))
+ }
+ }
+ }
}
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
index 680755afdca21..1b747705e9ad7 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
@@ -437,7 +437,7 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
val pipelineUpdateContext = new PipelineUpdateContextImpl(
new DataflowGraph(Seq(), Seq(), Seq(), Seq()),
(_: PipelineEvent) => None,
- storageRoot = "test_storage_root")
+ storageRoot = "file:///test_storage_root")
sessionHolder.cachePipelineExecution(graphId, pipelineUpdateContext)
assert(
sessionHolder.getPipelineExecution(graphId).nonEmpty,
diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
index a3d851c1ce7b9..4029e4775f44b 100644
--- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
+++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala
@@ -23,6 +23,7 @@ import org.scalatest.BeforeAndAfterEach
import org.scalatest.time.SpanSugar._
import org.apache.spark.SparkSQLException
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl}
import org.apache.spark.sql.pipelines.logging.PipelineEvent
import org.apache.spark.sql.test.SharedSparkSession
@@ -32,6 +33,7 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA
override def beforeEach(): Unit = {
super.beforeEach()
SparkConnectService.sessionManager.invalidateAllSessions()
+ SparkConnectService.sessionManager.initializeBaseSession(() => spark.newSession())
}
test("sessionId needs to be an UUID") {
@@ -161,7 +163,7 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA
val pipelineUpdateContext = new PipelineUpdateContextImpl(
new DataflowGraph(Seq(), Seq(), Seq(), Seq()),
(_: PipelineEvent) => None,
- storageRoot = "test_storage_root")
+ storageRoot = "file:///test_storage_root")
sessionHolder.cachePipelineExecution(graphId, pipelineUpdateContext)
assert(
sessionHolder.getPipelineExecution(graphId).nonEmpty,
@@ -171,4 +173,51 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA
sessionHolder.getPipelineExecution(graphId).isEmpty,
"pipeline execution was not removed")
}
+
+ test("baseSession allows creating sessions after default session is cleared") {
+ // Create a new session manager to test initialization
+ val sessionManager = new SparkConnectSessionManager()
+
+ // Initialize the base session with the test SparkContext
+ sessionManager.initializeBaseSession(() => spark.newSession())
+
+ // Clear the default and active sessions to simulate the scenario where
+ // SparkSession.active or SparkSession.getDefaultSession would fail
+ SparkSession.clearDefaultSession()
+ SparkSession.clearActiveSession()
+
+ // Create an isolated session - this should still work because we have baseSession
+ val key = SessionKey("user", UUID.randomUUID().toString)
+ val sessionHolder = sessionManager.getOrCreateIsolatedSession(key, None)
+
+ // Verify the session was created successfully
+ assert(sessionHolder != null)
+ assert(sessionHolder.session != null)
+
+ // Clean up
+ sessionManager.closeSession(key)
+ }
+
+ test("initializeBaseSession is idempotent") {
+ // Create a new session manager to test initialization
+ val sessionManager = new SparkConnectSessionManager()
+
+ // Initialize the base session multiple times
+ sessionManager.initializeBaseSession(() => spark.newSession())
+ val key1 = SessionKey("user1", UUID.randomUUID().toString)
+ val sessionHolder1 = sessionManager.getOrCreateIsolatedSession(key1, None)
+ val baseSessionUUID1 = sessionHolder1.session.sessionUUID
+
+ // Initialize again - should not change the base session
+ sessionManager.initializeBaseSession(() => spark.newSession())
+ val key2 = SessionKey("user2", UUID.randomUUID().toString)
+ val sessionHolder2 = sessionManager.getOrCreateIsolatedSession(key2, None)
+
+ // Both sessions should be isolated from each other
+ assert(sessionHolder1.session.sessionUUID != sessionHolder2.session.sessionUUID)
+
+ // Clean up
+ sessionManager.closeSession(key1)
+ sessionManager.closeSession(key2)
+ }
}
diff --git a/sql/connect/shims/pom.xml b/sql/connect/shims/pom.xml
index 37e565bf25872..d4ee58e87c352 100644
--- a/sql/connect/shims/pom.xml
+++ b/sql/connect/shims/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../../pom.xml
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 39d8c39954410..285ea9ae4205c 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.13
- 4.1.0-SNAPSHOT
+ 4.1.2-SNAPSHOT
../../pom.xml
@@ -49,6 +49,13 @@
spark-sketch_${scala.binary.version}
${project.version}
+
+ org.apache.spark
+ spark-common-utils_${scala.binary.version}
+ ${project.version}
+ tests
+ test
+
org.apache.spark
spark-core_${scala.binary.version}
@@ -279,6 +286,10 @@
bcpkix-jdk18on
test
+
+ org.apache.arrow
+ arrow-compression
+
target/scala-${scala.binary.version}/classes
diff --git a/sql/core/src/main/buf.gen.yaml b/sql/core/src/main/buf.gen.yaml
index 01a34ed308444..5f87a840c6a49 100644
--- a/sql/core/src/main/buf.gen.yaml
+++ b/sql/core/src/main/buf.gen.yaml
@@ -17,7 +17,7 @@
version: v1
plugins:
# Building the Python build and building the mypy interfaces.
- - plugin: buf.build/protocolbuffers/python:v29.5
+ - plugin: buf.build/protocolbuffers/python:v33.0
out: gen/proto/python
- name: mypy
out: gen/proto/python
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java
index 37c936c84d5f7..002b7569a6e09 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java
@@ -80,7 +80,7 @@ final class ParquetColumnVector {
}
if (defaultValue == null) {
- vector.setAllNull();
+ vector.setMissing();
return;
}
// For Parquet tables whose columns have associated DEFAULT values, this reader must return
@@ -137,7 +137,7 @@ final class ParquetColumnVector {
// Only use levels from non-missing child, this can happen if only some but not all
// fields of a struct are missing.
- if (!childCv.vector.isAllNull()) {
+ if (!childCv.vector.isMissing()) {
allChildrenAreMissing = false;
this.repetitionLevels = childCv.repetitionLevels;
this.definitionLevels = childCv.definitionLevels;
@@ -147,7 +147,7 @@ final class ParquetColumnVector {
// This can happen if all the fields of a struct are missing, in which case we should mark
// the struct itself as a missing column
if (allChildrenAreMissing) {
- vector.setAllNull();
+ vector.setMissing();
}
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
index eb6c84b8113b8..4f90f878da86a 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
@@ -26,6 +26,7 @@
import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation;
import org.apache.parquet.schema.LogicalTypeAnnotation.TimeLogicalTypeAnnotation;
import org.apache.parquet.schema.LogicalTypeAnnotation.TimestampLogicalTypeAnnotation;
+import org.apache.parquet.schema.LogicalTypeAnnotation.UnknownLogicalTypeAnnotation;
import org.apache.parquet.schema.PrimitiveType;
import org.apache.spark.SparkUnsupportedOperationException;
@@ -70,7 +71,12 @@ public class ParquetVectorUpdaterFactory {
}
public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType sparkType) {
- PrimitiveType.PrimitiveTypeName typeName = descriptor.getPrimitiveType().getPrimitiveTypeName();
+ PrimitiveType type = descriptor.getPrimitiveType();
+ PrimitiveType.PrimitiveTypeName typeName = type.getPrimitiveTypeName();
+ boolean isUnknownType = type.getLogicalTypeAnnotation() instanceof UnknownLogicalTypeAnnotation;
+ if (isUnknownType && sparkType instanceof NullType) {
+ return new NullTypeUpdater();
+ }
switch (typeName) {
case BOOLEAN -> {
@@ -244,6 +250,42 @@ boolean isUnsignedIntTypeMatched(int bitWidth) {
!annotation.isSigned() && annotation.getBitWidth() == bitWidth;
}
+ /**
+ * Updater should not be called if all values are nulls, so all methods throw exception here.
+ */
+ private static class NullTypeUpdater implements ParquetVectorUpdater {
+ @Override
+ public void readValues(
+ int total,
+ int offset,
+ WritableColumnVector values,
+ VectorizedValuesReader valuesReader) {
+ throw SparkUnsupportedOperationException.apply();
+ }
+
+ @Override
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ throw SparkUnsupportedOperationException.apply();
+ }
+
+ @Override
+ public void readValue(
+ int offset,
+ WritableColumnVector values,
+ VectorizedValuesReader valuesReader) {
+ throw SparkUnsupportedOperationException.apply();
+ }
+
+ @Override
+ public void decodeSingleDictionaryId(
+ int offset,
+ WritableColumnVector values,
+ WritableColumnVector dictionaryIds,
+ Dictionary dictionary) {
+ throw SparkUnsupportedOperationException.apply();
+ }
+ }
+
private static class BooleanUpdater implements ParquetVectorUpdater {
@Override
public void readValues(
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
index 49c27f9775624..a46b5143eef6d 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java
@@ -96,6 +96,8 @@ public InternalRow copy() {
row.update(i, getArray(i).copy());
} else if (dt instanceof MapType) {
row.update(i, getMap(i).copy());
+ } else if (dt instanceof VariantType) {
+ row.update(i, getVariant(i));
} else {
throw new RuntimeException("Not implemented. " + dt);
}
@@ -217,6 +219,8 @@ public Object get(int ordinal, DataType dataType) {
return getStruct(ordinal, structType.fields().length);
} else if (dataType instanceof MapType) {
return getMap(ordinal);
+ } else if (dataType instanceof VariantType) {
+ return getVariant(ordinal);
} else {
throw new SparkUnsupportedOperationException(
"_LEGACY_ERROR_TEMP_3192", Map.of("dt", dataType.toString()));
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
index 2f64ffb42aa06..42454b283d098 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
@@ -111,12 +111,14 @@ public void putNotNull(int rowId) {
@Override
public void putNull(int rowId) {
+ if (isAllNull()) return; // Skip writing nulls to all-null vector.
Platform.putByte(null, nulls + rowId, (byte) 1);
++numNulls;
}
@Override
public void putNulls(int rowId, int count) {
+ if (isAllNull()) return; // Skip writing nulls to all-null vector.
long offset = nulls + rowId;
for (int i = 0; i < count; ++i, ++offset) {
Platform.putByte(null, offset, (byte) 1);
@@ -135,7 +137,7 @@ public void putNotNulls(int rowId, int count) {
@Override
public boolean isNullAt(int rowId) {
- return isAllNull || Platform.getByte(null, nulls + rowId) == 1;
+ return isAllNull() || Platform.getByte(null, nulls + rowId) == 1;
}
//
@@ -603,6 +605,8 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) {
// Split out the slow path.
@Override
protected void reserveInternal(int newCapacity) {
+ if (isAllNull()) return; // Skip allocation for all-null vector.
+
int oldCapacity = (nulls == 0L) ? 0 : capacity;
if (isArray() || type instanceof MapType) {
this.lengthData =
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index cd8d0b688bedb..401e499fee300 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -108,12 +108,14 @@ public void putNotNull(int rowId) {
@Override
public void putNull(int rowId) {
+ if (isAllNull()) return; // Skip writing nulls to all-null vector.
nulls[rowId] = (byte)1;
++numNulls;
}
@Override
public void putNulls(int rowId, int count) {
+ if (isAllNull()) return; // Skip writing nulls to all-null vector.
for (int i = 0; i < count; ++i) {
nulls[rowId + i] = (byte)1;
}
@@ -130,7 +132,7 @@ public void putNotNulls(int rowId, int count) {
@Override
public boolean isNullAt(int rowId) {
- return isAllNull || nulls[rowId] == 1;
+ return isAllNull() || nulls[rowId] == 1;
}
//
@@ -577,6 +579,8 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) {
// Spilt this function out since it is the slow path.
@Override
protected void reserveInternal(int newCapacity) {
+ if (isAllNull()) return; // Skip allocation for all-null vector.
+
if (isArray() || type instanceof MapType) {
int[] newLengths = new int[newCapacity];
int[] newOffsets = new int[newCapacity];
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
index 3f552679bb6f1..c4f06e07911d3 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
@@ -59,7 +59,7 @@ public abstract class WritableColumnVector extends ColumnVector {
* Resets this column for writing. The currently stored values are no longer accessible.
*/
public void reset() {
- if (isConstant || isAllNull) return;
+ if (isConstant || isAllNull()) return;
if (childColumns != null) {
for (WritableColumnVector c: childColumns) {
@@ -142,7 +142,7 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) {
@Override
public boolean hasNull() {
- return isAllNull || numNulls > 0;
+ return isAllNull() || numNulls > 0;
}
@Override
@@ -876,17 +876,24 @@ public final void setIsConstant() {
}
/**
- * Marks this column only contains null values.
+ * Marks this column missing from the file.
*/
- public final void setAllNull() {
- isAllNull = true;
+ public final void setMissing() {
+ isMissing = true;
+ }
+
+ /**
+ * Whether this column is missing from the file.
+ */
+ public final boolean isMissing() {
+ return isMissing;
}
/**
* Whether this column only contains null values.
*/
public final boolean isAllNull() {
- return isAllNull;
+ return isMissing || type instanceof NullType;
}
/**
@@ -921,10 +928,10 @@ public final boolean isAllNull() {
protected boolean isConstant;
/**
- * True if this column only contains nulls. This means the column values never change, even
- * across resets. Comparing to 'isConstant' above, this doesn't require any allocation of space.
+ * True if this column is missing from the file. This means the column values never change and are
+ * nulls, even across resets. This doesn't require any allocation of space.
*/
- protected boolean isAllNull;
+ protected boolean isMissing;
/**
* Default size of each array length value. This grows as necessary.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
index 5889fe581d4e0..0055d220a6764 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
@@ -183,25 +183,34 @@ class ArtifactManager(session: SparkSession) extends AutoCloseable with Logging
if (normalizedRemoteRelativePath.startsWith(s"cache${File.separator}")) {
val tmpFile = serverLocalStagingPath.toFile
Utils.tryWithSafeFinallyAndFailureCallbacks {
+ val hash = normalizedRemoteRelativePath.toString.stripPrefix(s"cache${File.separator}")
val blockManager = session.sparkContext.env.blockManager
val blockId = CacheId(
sessionUUID = session.sessionUUID,
- hash = normalizedRemoteRelativePath.toString.stripPrefix(s"cache${File.separator}"))
- val updater = blockManager.TempFileBasedBlockStoreUpdater(
- blockId = blockId,
- level = StorageLevel.MEMORY_AND_DISK_SER,
- classTag = implicitly[ClassTag[Array[Byte]]],
- tmpFile = tmpFile,
- blockSize = tmpFile.length(),
- tellMaster = false)
- updater.save()
- val oldBlock = hashToCachedIdMap.put(blockId.hash, new RefCountedCacheId(blockId))
- if (oldBlock != null) {
- logWarning(
- log"Replacing existing cache artifact with hash ${MDC(LogKeys.BLOCK_ID, blockId)} " +
- log"in session ${MDC(LogKeys.SESSION_ID, session.sessionUUID)}. " +
- log"This may indicate duplicate artifact addition.")
- oldBlock.release(blockManager)
+ hash = hash)
+ // If the exact same block (same CacheId) already exists, skip re-adding.
+ // This prevents incorrectly removing the existing block from BlockManager.
+ // Note: We only skip if the CacheId matches - if it's a different session's block
+ // (e.g., after clone), we should replace it.
+ val existingBlock = hashToCachedIdMap.get(hash)
+ if (existingBlock == null || existingBlock.id != blockId) {
+ val updater = blockManager.TempFileBasedBlockStoreUpdater(
+ blockId = blockId,
+ level = StorageLevel.MEMORY_AND_DISK_SER,
+ classTag = implicitly[ClassTag[Array[Byte]]],
+ tmpFile = tmpFile,
+ blockSize = tmpFile.length(),
+ tellMaster = false)
+ updater.save()
+ hashToCachedIdMap.put(blockId.hash, new RefCountedCacheId(blockId))
+ if (existingBlock != null) {
+ // Release the old block - this is a legitimate replacement (different CacheId,
+ // e.g., after session clone). The old block will be removed when its ref count
+ // reaches zero.
+ existingBlock.release(blockManager)
+ }
+ } else {
+ logWarning(s"Cache artifact with hash $hash already exists in this session, skipping.")
}
}(finallyBlock = { tmpFile.delete() })
} else if (normalizedRemoteRelativePath.startsWith(s"classes${File.separator}")) {
@@ -422,8 +431,7 @@ class ArtifactManager(session: SparkSession) extends AutoCloseable with Logging
artifactPath)
// Ensure that no reference to `this` is captured/help by the cleanup lambda
private def getCleanable: Cleaner.Cleanable = cleaner.register(
- this,
- () => ArtifactManager.cleanUpGlobalResources(cleanUpStateForGlobalResources)
+ this, new StateCleanupRunner(cleanUpStateForGlobalResources)
)
private var cleanable = getCleanable
@@ -450,7 +458,20 @@ class ArtifactManager(session: SparkSession) extends AutoCloseable with Logging
pythonIncludeList.clear()
sparkContextRelativePaths.clear()
- // Removed cached classloader
+ // Close and remove cached classloader
+ cachedClassLoader.foreach {
+ case urlClassLoader: URLClassLoader =>
+ try {
+ urlClassLoader.close()
+ logDebug(log"Closed URLClassLoader for session " +
+ log"${MDC(LogKeys.SESSION_ID, session.sessionUUID)}")
+ } catch {
+ case e: IOException =>
+ logWarning(log"Failed to close URLClassLoader for session " +
+ log"${MDC(LogKeys.SESSION_ID, session.sessionUUID)}", e)
+ }
+ case _ =>
+ }
cachedClassLoader = None
}
@@ -516,6 +537,12 @@ object ArtifactManager extends Logging {
val JAR, FILE, ARCHIVE = Value
}
+ private class StateCleanupRunner(cleanupState: ArtifactStateForCleanup) extends Runnable {
+ override def run(): Unit = {
+ ArtifactManager.cleanUpGlobalResources(cleanupState)
+ }
+ }
+
// Shared cleaner instance
private val cleaner: Cleaner = Cleaner.create()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
index 402bab666948d..f3c6e168d2f04 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala
@@ -291,12 +291,12 @@ private[sql] class AvroSerializer(
var i = 0
while (i < numFields) {
if (row.isNullAt(i)) {
- if (!isSchemaNullable(i)) {
+ if (!isSchemaNullable(avroIndices(i))) {
throw new SparkRuntimeException(
errorClass = "AVRO_CANNOT_WRITE_NULL_FIELD",
messageParameters = Map(
- "name" -> toSQLId(avroFields.get(i).name),
- "dataType" -> avroFields.get(i).schema().toString))
+ "name" -> toSQLId(avroFields.get(avroIndices(i)).name),
+ "dataType" -> avroFields.get(avroIndices(i)).schema().toString))
}
result.put(avroIndices(i), null)
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
index b664f52c8c1c5..a15ba27962029 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkException
import org.apache.spark.internal.LogKeys.CONFIG
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
-import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, CatalogUtils, ClusterBySpec}
+import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, CatalogUtils, ClusterBySpec, HiveTableRelation}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
@@ -31,7 +31,7 @@ import org.apache.spark.sql.connector.catalog.{CatalogExtension, CatalogManager,
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.command._
-import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1}
+import org.apache.spark.sql.execution.datasources.{CreateTable => CreateTableV1, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.internal.connector.V1Function
@@ -128,6 +128,25 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
case DropColumns(ResolvedV1TableIdentifier(ident), _, _) =>
throw QueryCompilationErrors.unsupportedTableOperationError(ident, "DROP COLUMN")
+ // V1 and hive tables do not support constraints
+ case AddConstraint(ResolvedV1TableIdentifier(ident), _) =>
+ throw QueryCompilationErrors.unsupportedTableOperationError(ident, "ADD CONSTRAINT")
+
+ case DropConstraint(ResolvedV1TableIdentifier(ident), _, _, _) =>
+ throw QueryCompilationErrors.unsupportedTableOperationError(ident, "DROP CONSTRAINT")
+
+ case a: AddCheckConstraint
+ if a.child.exists {
+ case _: LogicalRelation => true
+ case _: HiveTableRelation => true
+ case _ => false
+ } =>
+ val tableIdent = a.child.collectFirst {
+ case l: LogicalRelation => l.catalogTable.get.identifier
+ case h: HiveTableRelation => h.tableMeta.identifier
+ }.get
+ throw QueryCompilationErrors.unsupportedTableOperationError(tableIdent, "ADD CONSTRAINT")
+
case SetTableProperties(ResolvedV1TableIdentifier(ident), props) =>
AlterTableSetPropertiesCommand(ident, props, isView = false)
@@ -187,6 +206,10 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
c.tableSpec.provider, tableSpec.options, c.tableSpec.location, c.tableSpec.serde,
ctas = false)
if (!isV2Provider(provider)) {
+ if (tableSpec.constraints.nonEmpty) {
+ throw QueryCompilationErrors.unsupportedTableOperationError(
+ ident, "CONSTRAINT")
+ }
constructV1TableCmd(None, c.tableSpec, ident, c.tableSchema, c.partitioning,
c.ignoreIfExists, storageFormat, provider)
} else {
@@ -203,6 +226,10 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
ctas = true)
if (!isV2Provider(provider)) {
+ if (tableSpec.constraints.nonEmpty) {
+ throw QueryCompilationErrors.unsupportedTableOperationError(
+ ident, "CONSTRAINT")
+ }
constructV1TableCmd(Some(c.query), c.tableSpec, ident, new StructType, c.partitioning,
c.ignoreIfExists, storageFormat, provider)
} else {
@@ -439,13 +466,13 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
viewSchemaMode = viewSchemaMode)
case CreateView(ResolvedIdentifier(catalog, _), _, _, _, _, _, _, _, _, _) =>
- throw QueryCompilationErrors.missingCatalogAbilityError(catalog, "views")
+ throw QueryCompilationErrors.missingCatalogViewsAbilityError(catalog)
case ShowViews(ns: ResolvedNamespace, pattern, output) =>
ns match {
case ResolvedDatabaseInSessionCatalog(db) => ShowViewsCommand(db, pattern, output)
case _ =>
- throw QueryCompilationErrors.missingCatalogAbilityError(ns.catalog, "views")
+ throw QueryCompilationErrors.missingCatalogViewsAbilityError(ns.catalog)
}
// If target is view, force use v1 command
@@ -463,7 +490,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
if (isSessionCatalog(catalog)) {
DescribeFunctionCommand(func.asInstanceOf[V1Function].info, extended)
} else {
- throw QueryCompilationErrors.missingCatalogAbilityError(catalog, "functions")
+ throw QueryCompilationErrors.missingCatalogFunctionsAbilityError(catalog)
}
case ShowFunctions(
@@ -476,7 +503,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
identifier.asFunctionIdentifier)
DropFunctionCommand(funcIdentifier, ifExists, false)
} else {
- throw QueryCompilationErrors.missingCatalogAbilityError(catalog, "DROP FUNCTION")
+ throw QueryCompilationErrors.missingCatalogDropFunctionAbilityError(catalog)
}
case RefreshFunction(ResolvedPersistentFunc(catalog, identifier, _)) =>
@@ -485,7 +512,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
identifier.asFunctionIdentifier)
RefreshFunctionCommand(funcIdentifier.database, funcIdentifier.funcName)
} else {
- throw QueryCompilationErrors.missingCatalogAbilityError(catalog, "REFRESH FUNCTION")
+ throw QueryCompilationErrors.missingCatalogRefreshFunctionAbilityError(catalog)
}
case CreateFunction(
@@ -499,7 +526,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
replace)
case CreateFunction(ResolvedIdentifier(catalog, _), _, _, _, _) =>
- throw QueryCompilationErrors.missingCatalogAbilityError(catalog, "CREATE FUNCTION")
+ throw QueryCompilationErrors.missingCatalogCreateFunctionAbilityError(catalog)
case c @ CreateUserDefinedFunction(
ResolvedIdentifierInSessionCatalog(ident), _, _, _, _, _, _, _, _, _, _, _) =>
@@ -520,7 +547,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager)
case CreateUserDefinedFunction(
ResolvedIdentifier(catalog, _), _, _, _, _, _, _, _, _, _, _, _) =>
- throw QueryCompilationErrors.missingCatalogAbilityError(catalog, "CREATE FUNCTION")
+ throw QueryCompilationErrors.missingCatalogCreateFunctionAbilityError(catalog)
}
private def constructV1TableCmd(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala
index 2d3e4b84d9ae5..52012a8629421 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataFrameWriter.scala
@@ -438,17 +438,18 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
val session = df.sparkSession
- val canUseV2 = lookupV2Provider().isDefined || (hasCustomSessionCatalog &&
+ val v2ProviderOpt = lookupV2Provider()
+ val canUseV2 = v2ProviderOpt.isDefined || (hasCustomSessionCatalog &&
!df.sparkSession.sessionState.catalogManager.catalog(CatalogManager.SESSION_CATALOG_NAME)
.isInstanceOf[CatalogExtension])
session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match {
case nameParts @ NonSessionCatalogAndIdentifier(catalog, ident) =>
- saveAsTableCommand(catalog.asTableCatalog, ident, nameParts)
+ saveAsTableCommand(catalog.asTableCatalog, v2ProviderOpt, ident, nameParts)
case nameParts @ SessionCatalogAndIdentifier(catalog, ident)
if canUseV2 && ident.namespace().length <= 1 =>
- saveAsTableCommand(catalog.asTableCatalog, ident, nameParts)
+ saveAsTableCommand(catalog.asTableCatalog, v2ProviderOpt, ident, nameParts)
case AsTableIdentifier(tableIdentifier) =>
saveAsV1TableCommand(tableIdentifier)
@@ -459,7 +460,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram
}
private def saveAsTableCommand(
- catalog: TableCatalog, ident: Identifier, nameParts: Seq[String]): LogicalPlan = {
+ catalog: TableCatalog,
+ v2ProviderOpt: Option[TableProvider],
+ ident: Identifier,
+ nameParts: Seq[String]): LogicalPlan = {
val tableOpt = try Option(catalog.loadTable(ident, getWritePrivileges.toSet.asJava)) catch {
case _: NoSuchTableException => None
}
@@ -484,12 +488,19 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) extends sql.DataFram
serde = None,
external = false,
constraints = Seq.empty)
+ val writeOptions = v2ProviderOpt match {
+ case Some(p: SupportsV1OverwriteWithSaveAsTable)
+ if p.addV1OverwriteWithSaveAsTableOption() =>
+ extraOptions + (SupportsV1OverwriteWithSaveAsTable.OPTION_NAME -> "true")
+ case _ =>
+ extraOptions
+ }
ReplaceTableAsSelect(
UnresolvedIdentifier(nameParts),
partitioningAsV2,
df.queryExecution.analyzed,
tableSpec,
- writeOptions = extraOptions.toMap,
+ writeOptions = writeOptions.toMap,
orCreate = true) // Create the table if it doesn't exist
case (other, _) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala
index 471c5feadaabc..38483395ec8c5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/DataStreamWriter.scala
@@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Utils, FileDat
import org.apache.spark.sql.execution.datasources.v2.python.PythonDataSourceV2
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, Trigger}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.ArrayImplicits._
@@ -299,6 +300,14 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) extends streaming.D
recoverFromCheckpoint: Boolean = true,
catalogAndIdent: Option[(TableCatalog, Identifier)] = None,
catalogTable: Option[CatalogTable] = None): StreamingQuery = {
+ if (trigger.isInstanceOf[RealTimeTrigger]) {
+ RealTimeModeAllowlist.checkAllowedSink(
+ sink,
+ ds.sparkSession.sessionState.conf.getConf(
+ SQLConf.STREAMING_REAL_TIME_MODE_ALLOWLIST_CHECK)
+ )
+ }
+
val useTempCheckpointLocation = DataStreamWriter.SOURCES_ALLOW_ONE_TIME_QUERY.contains(source)
ds.sparkSession.sessionState.streamingQueryManager.startQuery(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/ObservationManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/ObservationManager.scala
index 3edd789b685f6..308651b449fd0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/ObservationManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/ObservationManager.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.classic
import java.util.concurrent.ConcurrentHashMap
-import org.apache.spark.sql.Observation
+import org.apache.spark.sql.{Observation, Row}
import org.apache.spark.sql.catalyst.plans.logical.CollectMetrics
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.util.QueryExecutionListener
@@ -56,10 +56,22 @@ private[sql] class ObservationManager(session: SparkSession) {
val allMetrics = qe.observedMetrics
qe.logical.foreach {
case c: CollectMetrics =>
- allMetrics.get(c.name).foreach { metrics =>
+ val keyExists = observations.containsKey((c.name, c.dataframeId))
+ val metrics = allMetrics.get(c.name)
+ if (keyExists && metrics.isEmpty) {
+ // If the key exists but no metrics were collected, it means for some reason the metrics
+ // could not be collected. This can happen e.g., if the CollectMetricsExec was optimized
+ // away.
val observation = observations.remove((c.name, c.dataframeId))
if (observation != null) {
- observation.setMetricsAndNotify(metrics)
+ observation.setMetricsAndNotify(Row.empty)
+ }
+ } else {
+ metrics.foreach { metrics =>
+ val observation = observations.remove((c.name, c.dataframeId))
+ if (observation != null) {
+ observation.setMetricsAndNotify(metrics)
+ }
}
}
case _ =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala
index f7876d9a023bd..c47e845416213 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/SparkSession.scala
@@ -45,11 +45,11 @@ import org.apache.spark.sql.catalyst.analysis.{GeneralParameterizedQuery, NamePa
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Expression, Literal}
import org.apache.spark.sql.catalyst.parser.{HybridParameterContext, NamedParameterContext, ParserInterface, PositionalParameterContext}
-import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, LocalRelation, OneRowRelation, Project, Range}
+import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, LocalRelation, LogicalPlan, OneRowRelation, Project, Range}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.classic.SparkSession.applyAndLoadExtensions
-import org.apache.spark.sql.errors.SqlScriptingErrors
+import org.apache.spark.sql.errors.{QueryCompilationErrors, SqlScriptingErrors}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.ExternalCommandExecutor
import org.apache.spark.sql.execution.datasources.LogicalRelation
@@ -501,26 +501,31 @@ class SparkSession private(
private[sql] def sql(sqlText: String, args: Array[_], tracker: QueryPlanningTracker): DataFrame =
withActive {
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
- val parsedPlan = if (args.nonEmpty) {
- // Resolve and validate parameters first
- val paramMap = args.zipWithIndex.map { case (arg, idx) =>
- s"_pos_$idx" -> lit(arg).expr
- }.toMap
- val resolvedParams = resolveAndValidateParameters(paramMap)
+ val parsedPlan = {
+ // Always parse with parameter context to detect unbound parameter markers.
+ // Even if args is empty, we need to detect and reject parameter markers in the SQL.
+ val (paramMap, resolvedParams) = if (args.nonEmpty) {
+ val pMap = args.zipWithIndex.map { case (arg, idx) =>
+ s"_pos_$idx" -> lit(arg).expr
+ }.toMap
+ (pMap, resolveAndValidateParameters(pMap))
+ } else {
+ (Map.empty[String, Expression], Map.empty[String, Expression])
+ }
+
val paramContext = PositionalParameterContext(resolvedParams.values.toSeq)
val parsed = sessionState.sqlParser.parsePlanWithParameters(sqlText, paramContext)
+
// Check for SQL scripting with positional parameters
- if (parsed.isInstanceOf[CompoundBody]) {
+ if (parsed.isInstanceOf[CompoundBody] && args.nonEmpty) {
throw SqlScriptingErrors.positionalParametersAreNotSupportedWithSqlScripting()
}
// In legacy mode, wrap with PosParameterizedQuery for analyzer binding
- if (sessionState.conf.legacyParameterSubstitutionConstantsOnly) {
+ if (args.nonEmpty && sessionState.conf.legacyParameterSubstitutionConstantsOnly) {
PosParameterizedQuery(parsed, paramContext.params)
} else {
parsed
}
- } else {
- sessionState.sqlParser.parsePlan(sqlText)
}
parsedPlan
}
@@ -554,30 +559,29 @@ class SparkSession private(
args: Map[String, Any],
tracker: QueryPlanningTracker): DataFrame =
withActive {
- // Always set parameter context if we have actual parameters
- if (args.nonEmpty) {
- // Resolve and validate parameters first
- val resolvedParams = resolveAndValidateParameters(args.transform((_, v) => lit(v).expr))
- val paramContext = NamedParameterContext(resolvedParams)
- val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
- val parsedPlan = sessionState.sqlParser.parsePlanWithParameters(sqlText, paramContext)
- // In legacy mode, wrap the parsed plan with NameParameterizedQuery
- // so that the BindParameters analyzer rule can bind the parameters
- if (sessionState.conf.legacyParameterSubstitutionConstantsOnly) {
- NameParameterizedQuery(parsedPlan, paramContext.params)
- } else {
- parsedPlan
- }
- }
-
- Dataset.ofRows(self, plan, tracker)
+ // Always parse with parameter context to detect unbound parameter markers.
+ // Even if args is empty, we need to detect and reject parameter markers in the SQL.
+ val resolvedParams = if (args.nonEmpty) {
+ resolveAndValidateParameters(args.transform((_, v) => lit(v).expr))
} else {
- // No parameters - parse normally without parameter context
- val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
- sessionState.sqlParser.parsePlan(sqlText)
+ Map.empty[String, Expression]
+ }
+ val paramContext = NamedParameterContext(resolvedParams)
+ val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
+ val parsedPlan = sessionState.sqlParser.parsePlanWithParameters(sqlText, paramContext)
+ val queryPlan = parsedPlan match {
+ case compoundBody: CompoundBody => compoundBody
+ case logicalPlan: LogicalPlan =>
+ // In legacy mode, wrap with NameParameterizedQuery for analyzer binding
+ if (args.nonEmpty && sessionState.conf.legacyParameterSubstitutionConstantsOnly) {
+ NameParameterizedQuery(logicalPlan, paramContext.params)
+ } else {
+ logicalPlan
+ }
}
- Dataset.ofRows(self, plan, tracker)
+ queryPlan
}
+ Dataset.ofRows(self, plan, tracker)
}
/** @inheritdoc */
@@ -610,6 +614,8 @@ class SparkSession private(
tracker: QueryPlanningTracker): DataFrame =
withActive {
val plan = tracker.measurePhase(QueryPlanningTracker.PARSING) {
+ // Always parse with parameter context to detect unbound parameter markers.
+ // Even if args is empty, we need to detect and reject parameter markers in the SQL.
val parsedPlan = if (args.nonEmpty) {
// Resolve and validate parameter arguments
val paramMap = args.zipWithIndex.map { case (arg, idx) =>
@@ -643,11 +649,6 @@ class SparkSession private(
val parsed = sessionState.sqlParser.parsePlanWithParameters(sqlText, paramContext)
- // Check for SQL scripting with positional parameters
- if (parsed.isInstanceOf[CompoundBody] && paramNames.isEmpty) {
- throw SqlScriptingErrors.positionalParametersAreNotSupportedWithSqlScripting()
- }
-
// In legacy mode, wrap with GeneralParameterizedQuery for analyzer binding
if (sessionState.conf.legacyParameterSubstitutionConstantsOnly) {
GeneralParameterizedQuery(
@@ -659,8 +660,16 @@ class SparkSession private(
parsed
}
} else {
- sessionState.sqlParser.parsePlan(sqlText)
+ // No arguments provided, but still need to detect parameter markers
+ val paramContext = HybridParameterContext(Seq.empty, Seq.empty)
+ sessionState.sqlParser.parsePlanWithParameters(sqlText, paramContext)
}
+
+ // Check for SQL scripts in EXECUTE IMMEDIATE (applies to both empty and non-empty args)
+ if (parsedPlan.isInstanceOf[CompoundBody]) {
+ throw QueryCompilationErrors.sqlScriptInExecuteImmediate(sqlText)
+ }
+
parsedPlan
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala
index bef09703025ef..72ae3b21d662a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/StreamingQueryManager.scala
@@ -213,7 +213,7 @@ class StreamingQueryManager private[sql] (
sink,
outputMode,
df.sparkSession.sessionState.newHadoopConf(),
- trigger.isInstanceOf[ContinuousTrigger],
+ trigger,
analyzedPlan,
catalogAndIdent,
catalogTable)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/TableValuedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/TableValuedFunction.scala
index d2034033fee7e..839faa85aed48 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/classic/TableValuedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/TableValuedFunction.scala
@@ -97,4 +97,8 @@ class TableValuedFunction(sparkSession: SparkSession)
/** @inheritdoc */
override def variant_explode_outer(input: Column): Dataset[Row] =
fn("variant_explode_outer", Seq(input))
+
+ /** @inheritdoc */
+ override def python_worker_logs(): Dataset[Row] =
+ fn("python_worker_logs", Seq.empty)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
index 671fcb765648d..a551762b83899 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala
@@ -17,24 +17,30 @@
package org.apache.spark.sql.execution
+import scala.util.control.NonFatal
+
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.internal.{Logging, MessageWithContext}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
+import org.apache.spark.sql.catalyst.analysis.Resolver
+import org.apache.spark.sql.catalyst.analysis.V2TableReference
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, SubqueryExpression}
import org.apache.spark.sql.catalyst.optimizer.EliminateResolvedHint
-import org.apache.spark.sql.catalyst.plans.logical.{IgnoreCachedData, LogicalPlan, ResolvedHint, View}
+import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, ResolvedHint, View}
import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.classic.{Dataset, SparkSession}
-import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper
+import org.apache.spark.sql.connector.catalog.CatalogPlugin
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{IdentifierHelper, MultipartIdentifierHelper}
+import org.apache.spark.sql.connector.catalog.Identifier
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.command.CommandUtils
import org.apache.spark.sql.execution.datasources.{FileIndex, HadoopFsRelation, LogicalRelation, LogicalRelationWithTable}
-import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2Table, FileTable}
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2CatalogAndIdentifier, ExtractV2Table, FileTable, V2TableRefreshUtil}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.storage.StorageLevel
import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK
@@ -132,7 +138,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
storageLevel: StorageLevel): Unit = {
if (storageLevel == StorageLevel.NONE) {
// Do nothing for StorageLevel.NONE since it will not actually cache any data.
- } else if (unnormalizedPlan.isInstanceOf[IgnoreCachedData]) {
+ } else if (unnormalizedPlan.isInstanceOf[Command]) {
logWarning(
log"Asked to cache a plan that is inapplicable for caching: " +
log"${MDC(LOGICAL_PLAN, unnormalizedPlan)}"
@@ -238,28 +244,51 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
name: Seq[String],
conf: SQLConf,
includeTimeTravel: Boolean): Boolean = {
- def isSameName(nameInCache: Seq[String]): Boolean = {
- nameInCache.length == name.length && nameInCache.zip(name).forall(conf.resolver.tupled)
- }
+ isMatchedTableOrView(plan, name, conf.resolver, includeTimeTravel)
+ }
+
+ private def isMatchedTableOrView(
+ plan: LogicalPlan,
+ name: Seq[String],
+ resolver: Resolver,
+ includeTimeTravel: Boolean): Boolean = {
EliminateSubqueryAliases(plan) match {
case LogicalRelationWithTable(_, Some(catalogTable)) =>
- isSameName(catalogTable.identifier.nameParts)
+ isSameName(name, catalogTable.identifier.nameParts, resolver)
case DataSourceV2Relation(_, _, Some(catalog), Some(v2Ident), _, timeTravelSpec) =>
val nameInCache = v2Ident.toQualifiedNameParts(catalog)
- isSameName(nameInCache) && (includeTimeTravel || timeTravelSpec.isEmpty)
+ isSameName(name, nameInCache, resolver) && (includeTimeTravel || timeTravelSpec.isEmpty)
+
+ case r: V2TableReference =>
+ isSameName(name, r.identifier.toQualifiedNameParts(r.catalog), resolver)
case v: View =>
- isSameName(v.desc.identifier.nameParts)
+ isSameName(name, v.desc.identifier.nameParts, resolver)
case HiveTableRelation(catalogTable, _, _, _, _) =>
- isSameName(catalogTable.identifier.nameParts)
+ isSameName(name, catalogTable.identifier.nameParts, resolver)
case _ => false
}
}
+ private def isSameName(
+ name: Seq[String],
+ catalog: CatalogPlugin,
+ ident: Identifier,
+ resolver: Resolver): Boolean = {
+ isSameName(name, ident.toQualifiedNameParts(catalog), resolver)
+ }
+
+ private def isSameName(
+ name: Seq[String],
+ nameInCache: Seq[String],
+ resolver: Resolver): Boolean = {
+ nameInCache.length == name.length && nameInCache.zip(name).forall(resolver.tupled)
+ }
+
private def uncacheByCondition(
spark: SparkSession,
isMatchedPlan: LogicalPlan => Boolean,
@@ -347,24 +376,97 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
}
needToRecache.foreach { cd =>
cd.cachedRepresentation.cacheBuilder.clearCache()
- val sessionWithConfigsOff = getOrCloneSessionWithConfigsOff(spark)
- val newCache = sessionWithConfigsOff.withActive {
- val qe = sessionWithConfigsOff.sessionState.executePlan(cd.plan)
- InMemoryRelation(cd.cachedRepresentation.cacheBuilder, qe)
- }
- val recomputedPlan = cd.copy(cachedRepresentation = newCache)
- this.synchronized {
- if (lookupCachedDataInternal(recomputedPlan.plan).nonEmpty) {
- logWarning("While recaching, data was already added to cache.")
- } else {
- cachedData = recomputedPlan +: cachedData
- CacheManager.logCacheOperation(log"Re-cached Dataframe cache entry:" +
- log"${MDC(DATAFRAME_CACHE_ENTRY, recomputedPlan)}")
+ tryRebuildCacheEntry(spark, cd).foreach { entry =>
+ this.synchronized {
+ if (lookupCachedDataInternal(entry.plan).nonEmpty) {
+ logWarning("While recaching, data was already added to cache.")
+ } else {
+ cachedData = entry +: cachedData
+ CacheManager.logCacheOperation(log"Re-cached Dataframe cache entry:" +
+ log"${MDC(DATAFRAME_CACHE_ENTRY, entry)}")
+ }
}
}
}
}
+ private def tryRebuildCacheEntry(spark: SparkSession, cd: CachedData): Option[CachedData] = {
+ val sessionWithConfigsOff = getOrCloneSessionWithConfigsOff(spark)
+ sessionWithConfigsOff.withActive {
+ tryRefreshPlan(sessionWithConfigsOff, cd.plan).map { refreshedPlan =>
+ val qe = QueryExecution.create(
+ sessionWithConfigsOff,
+ refreshedPlan,
+ refreshPhaseEnabled = false)
+ val newKey = qe.normalized
+ val newCache = InMemoryRelation(cd.cachedRepresentation.cacheBuilder, qe)
+ cd.copy(plan = newKey, cachedRepresentation = newCache)
+ }
+ }
+ }
+
+ /**
+ * Attempts to refresh table metadata loaded through the catalog.
+ *
+ * If the table state is cached (e.g., via `CACHE TABLE t`), the relation is replaced with
+ * updated metadata as long as the table ID still matches, ensuring that all schema changes
+ * are reflected. Otherwise, a new plan is produced using refreshed table metadata but
+ * retaining the original schema, provided the schema changes are still compatible with the
+ * query (e.g., adding new columns should be acceptable).
+ *
+ * Note this logic applies only to V2 tables at the moment.
+ *
+ * @return the refreshed plan if refresh succeeds, None otherwise
+ */
+ private def tryRefreshPlan(spark: SparkSession, plan: LogicalPlan): Option[LogicalPlan] = {
+ try {
+ EliminateSubqueryAliases(plan) match {
+ case r @ ExtractV2CatalogAndIdentifier(catalog, ident) if r.timeTravelSpec.isEmpty =>
+ val table = catalog.loadTable(ident)
+ if (r.table.id == table.id) {
+ Some(DataSourceV2Relation.create(table, Some(catalog), Some(ident)))
+ } else {
+ None
+ }
+ case _ =>
+ Some(V2TableRefreshUtil.refresh(spark, plan))
+ }
+ } catch {
+ case NonFatal(e) =>
+ logWarning(log"Failed to refresh plan while attempting to recache", e)
+ None
+ }
+ }
+
+ private[sql] def lookupCachedTable(
+ name: Seq[String],
+ resolver: Resolver): Option[LogicalPlan] = {
+ val cachedRelations = findCachedRelations(name, resolver)
+ cachedRelations match {
+ case cachedRelation +: _ =>
+ CacheManager.logCacheOperation(
+ log"Relation cache hit for table ${MDC(TABLE_NAME, name.quoted)}")
+ Some(cachedRelation)
+ case _ =>
+ None
+ }
+ }
+
+ private def findCachedRelations(
+ name: Seq[String],
+ resolver: Resolver): Seq[LogicalPlan] = {
+ cachedData.flatMap { cd =>
+ val plan = EliminateSubqueryAliases(cd.plan)
+ plan match {
+ case r @ ExtractV2CatalogAndIdentifier(catalog, ident)
+ if isSameName(name, catalog, ident, resolver) && r.timeTravelSpec.isEmpty =>
+ Some(r)
+ case _ =>
+ None
+ }
+ }
+ }
+
/**
* Optionally returns cached data for the given [[Dataset]]
*/
@@ -397,7 +499,7 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper {
*/
private[sql] def useCachedData(plan: LogicalPlan): LogicalPlan = {
val newPlan = plan transformDown {
- case command: IgnoreCachedData => command
+ case command: Command => command
case currentFragment =>
lookupCachedDataInternal(currentFragment).map { cached =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 6148fb30783e8..06085497de19a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -360,23 +360,25 @@ case class OneRowRelationExec() extends LeafExecNode
override val output: Seq[Attribute] = Nil
private val rdd: RDD[InternalRow] = {
- val numOutputRows = longMetric("numOutputRows")
session
.sparkContext
.parallelize(Seq(""), 1)
.mapPartitionsInternal { _ =>
val proj = UnsafeProjection.create(Seq.empty[Expression])
- Iterator(proj.apply(InternalRow.empty)).map { r =>
- numOutputRows += 1
- r
- }
+ Iterator(proj.apply(InternalRow.empty))
}
}
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
- protected override def doExecute(): RDD[InternalRow] = rdd
+ protected override def doExecute(): RDD[InternalRow] = {
+ val numOutputRows = longMetric("numOutputRows")
+ rdd.map { r =>
+ numOutputRows += 1
+ r
+ }
+ }
override def simpleString(maxFields: Int): String = s"$nodeName[]"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala
index 5a789179219ad..0a70021dc858c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyGroupedPartitionedScan.scala
@@ -34,7 +34,7 @@ trait KeyGroupedPartitionedScan[T] {
def getOutputKeyGroupedPartitioning(
basePartitioning: KeyGroupedPartitioning,
spjParams: StoragePartitionJoinParams): KeyGroupedPartitioning = {
- val expressions = spjParams.joinKeyPositions match {
+ val projectedExpressions = spjParams.joinKeyPositions match {
case Some(projectionPositions) =>
projectionPositions.map(i => basePartitioning.expressions(i))
case _ => basePartitioning.expressions
@@ -50,15 +50,16 @@ trait KeyGroupedPartitionedScan[T] {
case None =>
spjParams.joinKeyPositions match {
case Some(projectionPositions) => basePartitioning.partitionValues.map { r =>
- val projectedRow = KeyGroupedPartitioning.project(expressions,
+ val projectedRow = KeyGroupedPartitioning.project(basePartitioning.expressions,
projectionPositions, r)
- InternalRowComparableWrapper(projectedRow, expressions)
+ InternalRowComparableWrapper(projectedRow, projectedExpressions)
}.distinct.map(_.row)
case _ => basePartitioning.partitionValues
}
}
- basePartitioning.copy(expressions = expressions, numPartitions = newPartValues.length,
- partitionValues = newPartValues)
+ basePartitioning.copy(expressions = projectedExpressions, numPartitions = newPartValues.length,
+ partitionValues = newPartValues,
+ isPartiallyClustered = spjParams.applyPartialClustering)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 27d6eec46b69a..3e0aef962e719 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -42,6 +42,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.execution.adaptive.{AdaptiveExecutionContext, InsertAdaptiveSparkPlan}
import org.apache.spark.sql.execution.bucketing.{CoalesceBucketsInJoin, DisableUnnecessaryBucketedScan}
+import org.apache.spark.sql.execution.datasources.v2.V2TableRefreshUtil
import org.apache.spark.sql.execution.dynamicpruning.PlanDynamicPruningFilters
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery
@@ -65,7 +66,8 @@ class QueryExecution(
val logical: LogicalPlan,
val tracker: QueryPlanningTracker = new QueryPlanningTracker,
val mode: CommandExecutionMode.Value = CommandExecutionMode.ALL,
- val shuffleCleanupMode: ShuffleCleanupMode = DoNotCleanup) extends Logging {
+ val shuffleCleanupMode: ShuffleCleanupMode = DoNotCleanup,
+ val refreshPhaseEnabled: Boolean = true) extends Logging {
val id: Long = QueryExecution.nextExecutionId
@@ -177,7 +179,7 @@ class QueryExecution(
// for eagerly executed commands we mark this place as beginning of execution.
tracker.setReadyForExecution()
val qe = new QueryExecution(sparkSession, p, mode = mode,
- shuffleCleanupMode = shuffleCleanupMode)
+ shuffleCleanupMode = shuffleCleanupMode, refreshPhaseEnabled = refreshPhaseEnabled)
val result = QueryExecution.withInternalError(s"Eagerly executed $name failed.") {
SQLExecution.withNewExecutionId(qe, Some(name)) {
qe.executedPlan.executeCollect()
@@ -203,8 +205,20 @@ class QueryExecution(
}
}
+ // there may be delay between analysis and subsequent phases
+ // therefore, refresh captured table versions to reflect latest data
+ private val lazyTableVersionsRefreshed = LazyTry {
+ if (refreshPhaseEnabled) {
+ V2TableRefreshUtil.refresh(sparkSession, commandExecuted, versionedOnly = true)
+ } else {
+ commandExecuted
+ }
+ }
+
+ private[sql] def tableVersionsRefreshed: LogicalPlan = lazyTableVersionsRefreshed.get
+
private val lazyNormalized = LazyTry {
- QueryExecution.normalize(sparkSession, commandExecuted, Some(tracker))
+ QueryExecution.normalize(sparkSession, tableVersionsRefreshed, Some(tracker))
}
// The plan that has been normalized by custom rules, so that it's more likely to hit cache.
@@ -560,6 +574,18 @@ object QueryExecution {
private def nextExecutionId: Long = _nextExecutionId.getAndIncrement
+ private[execution] def create(
+ sparkSession: SparkSession,
+ logical: LogicalPlan,
+ refreshPhaseEnabled: Boolean = true): QueryExecution = {
+ new QueryExecution(
+ sparkSession,
+ logical,
+ mode = CommandExecutionMode.ALL,
+ shuffleCleanupMode = determineShuffleCleanupMode(sparkSession.sessionState.conf),
+ refreshPhaseEnabled = refreshPhaseEnabled)
+ }
+
/**
* Construct a sequence of rules that are used to prepare a planned [[SparkPlan]] for execution.
* These rules will make sure subqueries are planned, make sure the data partitioning and ordering
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
index 1cab0f8d35af5..19bafeb196122 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala
@@ -71,11 +71,12 @@ object SQLExecution extends Logging {
}
private def extractShuffleIds(plan: SparkPlan): Seq[Int] = {
- plan match {
+ val shuffleIdsOption = plan.collectFirst {
case ae: AdaptiveSparkPlanExec =>
ae.context.shuffleIds.asScala.keys.toSeq
- case nonAdaptivePlan =>
- nonAdaptivePlan.collect {
+ }
+ shuffleIdsOption.getOrElse {
+ plan.collect {
case exec: ShuffleExchangeLike => exec.shuffleId
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
index 58bffbed3e69e..d1e021d564c3b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala
@@ -118,14 +118,22 @@ class SparkSqlParser extends AbstractSqlParser {
// Step 2: Apply parameter substitution if a parameter context is provided.
val (paramSubstituted, positionMapper, hasParameters) = parameterContext match {
case Some(context) =>
+ // Check if the context actually contains parameters
+ val contextHasParams = context match {
+ case NamedParameterContext(params) => params.nonEmpty
+ case PositionalParameterContext(params) => params.nonEmpty
+ case HybridParameterContext(args, _) => args.nonEmpty
+ }
if (SQLConf.get.legacyParameterSubstitutionConstantsOnly) {
// Legacy mode: Parameters are detected but substitution is deferred to analysis phase.
- (variableSubstituted, PositionMapper.identity(variableSubstituted), true)
+ // Only set hasParameters if the context actually contains parameters.
+ (variableSubstituted, PositionMapper.identity(variableSubstituted), contextHasParams)
} else {
// Modern mode: Perform parameter substitution during parsing.
val (substituted, mapper) =
ParameterHandler.substituteParameters(variableSubstituted, context)
- (substituted, mapper, true)
+ // Only set hasParameters if the context actually contains parameters.
+ (substituted, mapper, contextHasParams)
}
case None =>
// No parameter context provided; skip parameter substitution.
@@ -179,7 +187,7 @@ class SparkSqlAstBuilder extends AstBuilder {
(ident, _) => builder(ident))
} else if (ctx.errorCapturingIdentifier() != null) {
// resolve immediately
- builder.apply(Seq(ctx.errorCapturingIdentifier().getText))
+ builder.apply(Seq(getIdentifierText(ctx.errorCapturingIdentifier())))
} else if (ctx.stringLit() != null) {
// resolve immediately
builder.apply(Seq(string(visitStringLit(ctx.stringLit()))))
@@ -559,7 +567,7 @@ class SparkSqlAstBuilder extends AstBuilder {
* - '/path/to/fileOrJar'
*/
override def visitManageResource(ctx: ManageResourceContext): LogicalPlan = withOrigin(ctx) {
- val rawArg = remainder(ctx.identifier).trim
+ val rawArg = remainder(ctx.simpleIdentifier).trim
val maybePaths = strLiteralDef.findAllIn(rawArg).toSeq.map {
case p if p.startsWith("\"") || p.startsWith("'") => unescapeSQLString(p)
case p => p
@@ -567,14 +575,14 @@ class SparkSqlAstBuilder extends AstBuilder {
ctx.op.getType match {
case SqlBaseParser.ADD =>
- ctx.identifier.getText.toLowerCase(Locale.ROOT) match {
+ ctx.simpleIdentifier.getText.toLowerCase(Locale.ROOT) match {
case "files" | "file" => AddFilesCommand(maybePaths)
case "jars" | "jar" => AddJarsCommand(maybePaths)
case "archives" | "archive" => AddArchivesCommand(maybePaths)
case other => operationNotAllowed(s"ADD with resource type '$other'", ctx)
}
case SqlBaseParser.LIST =>
- ctx.identifier.getText.toLowerCase(Locale.ROOT) match {
+ ctx.simpleIdentifier.getText.toLowerCase(Locale.ROOT) match {
case "files" | "file" =>
if (maybePaths.length > 0) {
ListFilesCommand(maybePaths)
@@ -629,7 +637,8 @@ class SparkSqlAstBuilder extends AstBuilder {
val userSpecifiedColumns = Option(ctx.identifierCommentList).toSeq.flatMap { icl =>
icl.identifierComment.asScala.map { ic =>
- ic.identifier.getText -> Option(ic.commentSpec()).map(visitCommentSpec)
+ // Use getIdentifierText to handle both regular identifiers and IDENTIFIER('literal')
+ getIdentifierText(ic.identifier) -> Option(ic.commentSpec()).map(visitCommentSpec)
}
}
@@ -726,7 +735,7 @@ class SparkSqlAstBuilder extends AstBuilder {
*/
override def visitCreateFunction(ctx: CreateFunctionContext): LogicalPlan = withOrigin(ctx) {
val resources = ctx.resource.asScala.map { resource =>
- val resourceType = resource.identifier.getText.toLowerCase(Locale.ROOT)
+ val resourceType = resource.simpleIdentifier.getText.toLowerCase(Locale.ROOT)
resourceType match {
case "jar" | "file" | "archive" =>
FunctionResource(FunctionResourceType.fromString(resourceType),
@@ -1299,7 +1308,7 @@ class SparkSqlAstBuilder extends AstBuilder {
} else {
DescribeColumn(
relation,
- UnresolvedAttribute(ctx.describeColName.nameParts.asScala.map(_.getText).toSeq),
+ UnresolvedAttribute(ctx.describeColName.nameParts.asScala.map(getIdentifierText).toSeq),
isExtended)
}
} else {
@@ -1376,7 +1385,7 @@ class SparkSqlAstBuilder extends AstBuilder {
if (colConstraints.nonEmpty) {
throw operationNotAllowed("Pipeline datasets do not currently support column constraints. " +
- "Please remove and CHECK, UNIQUE, PK, and FK constraints specified on the pipeline " +
+ "Please remove any CHECK, UNIQUE, PK, and FK constraints specified on the pipeline " +
"dataset.", ctx)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 3c36d3e2d4173..5efad83bcba78 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -23,6 +23,7 @@ import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{execution, AnalysisException}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, NamedRelation}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, JoinSelectionHelper, NormalizeFloatingNumbers}
@@ -31,12 +32,14 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.{SparkStrategy => Strategy}
import org.apache.spark.sql.execution.aggregate.AggUtils
import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec}
import org.apache.spark.sql.execution.command._
-import org.apache.spark.sql.execution.datasources.{WriteFiles, WriteFilesExec}
+import org.apache.spark.sql.execution.datasources.{LogicalRelation, WriteFiles, WriteFilesExec}
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.exchange.{REBALANCE_PARTITIONS_BY_COL, REBALANCE_PARTITIONS_BY_NONE, REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleExchangeExec}
import org.apache.spark.sql.execution.python._
import org.apache.spark.sql.execution.python.streaming.{FlatMapGroupsInPandasWithStateExec, TransformWithStateInPySparkExec}
@@ -1091,10 +1094,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case r: LogicalRDD =>
RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering,
r.stream) :: Nil
- case _: UpdateTable =>
- throw QueryExecutionErrors.ddlUnsupportedTemporarilyError("UPDATE TABLE")
- case _: MergeIntoTable =>
- throw QueryExecutionErrors.ddlUnsupportedTemporarilyError("MERGE INTO TABLE")
+ case u: UpdateTable =>
+ val tableName = extractTableNameForError(u.table)
+ throw QueryExecutionErrors.ddlUnsupportedTemporarilyError("UPDATE TABLE", tableName)
+ case m: MergeIntoTable =>
+ val tableName = extractTableNameForError(m.targetTable)
+ throw QueryExecutionErrors.ddlUnsupportedTemporarilyError("MERGE INTO TABLE", tableName)
case logical.CollectMetrics(name, metrics, child, _) =>
execution.CollectMetricsExec(name, metrics, planLater(child)) :: Nil
case WriteFiles(child, fileFormat, partitionColumns, bucket, options, staticPartitions) =>
@@ -1105,4 +1110,22 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case _ => Nil
}
}
+
+ /**
+ * Extracts a user-friendly table name from a logical plan for error messages.
+ */
+ private def extractTableNameForError(table: LogicalPlan): String = {
+ val unwrapped = EliminateSubqueryAliases(table)
+ unwrapped match {
+ // Check specific types before NamedRelation since they extend it
+ case DataSourceV2Relation(_, _, catalog, Some(ident), _, _) =>
+ (catalog.map(_.name()).toSeq ++ ident.asMultipartIdentifier).mkString(".")
+ case LogicalRelation(_, _, Some(catalogTable), _, _) =>
+ catalogTable.identifier.unquotedString
+ case r: NamedRelation =>
+ r.name
+ case _ =>
+ "unknown"
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
index 3072a12e3d587..3d435ac016ed1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala
@@ -22,13 +22,17 @@ import java.nio.channels.{Channels, ReadableByteChannel}
import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
+import scala.util.control.NonFatal
+import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec}
import org.apache.arrow.flatbuf.MessageHeader
import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector._
+import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec}
import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel}
import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, IpcOption, MessageSerializer}
+import org.apache.spark.SparkException
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
@@ -37,13 +41,13 @@ import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.util.{ArrowUtils, CloseableIterator, ConcatenatingArrowStreamReader, MessageIterator}
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
import org.apache.spark.util.{ByteBufferOutputStream, SizeEstimator, Utils}
import org.apache.spark.util.ArrayImplicits._
-
/**
* Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow stream format.
*/
@@ -92,8 +96,26 @@ private[sql] object ArrowConverters extends Logging {
ArrowUtils.rootAllocator.newChildAllocator(
s"to${this.getClass.getSimpleName}", 0, Long.MaxValue)
- private val root = VectorSchemaRoot.create(arrowSchema, allocator)
- protected val unloader = new VectorUnloader(root)
+ protected val root = VectorSchemaRoot.create(arrowSchema, allocator)
+
+ // Create compression codec based on config
+ private val compressionCodecName = SQLConf.get.arrowCompressionCodec
+ private val codec = compressionCodecName match {
+ case "none" => NoCompressionCodec.INSTANCE
+ case "zstd" =>
+ val compressionLevel = SQLConf.get.arrowZstdCompressionLevel
+ val factory = CompressionCodec.Factory.INSTANCE
+ val codecType = new ZstdCompressionCodec(compressionLevel).getCodecType()
+ factory.createCodec(codecType)
+ case "lz4" =>
+ val factory = CompressionCodec.Factory.INSTANCE
+ val codecType = new Lz4CompressionCodec().getCodecType()
+ factory.createCodec(codecType)
+ case other =>
+ throw SparkException.internalError(
+ s"Unsupported Arrow compression codec: $other. Supported values: none, zstd, lz4")
+ }
+ protected val unloader = new VectorUnloader(root, true, codec, true)
protected val arrowWriter = ArrowWriter.create(root)
Option(context).foreach {_.addTaskCompletionListener[Unit] { _ =>
@@ -275,50 +297,33 @@ private[sql] object ArrowConverters extends Logging {
* @param context Task Context for Spark
*/
private[sql] class InternalRowIteratorFromIPCStream(
- input: Array[Byte],
- context: TaskContext) extends Iterator[InternalRow] {
-
- // Keep all the resources we have opened in order, should be closed
- // in reverse order finally.
- private val resources = new ArrayBuffer[AutoCloseable]()
+ ipcStreams: Iterator[Array[Byte]],
+ context: TaskContext)
+ extends CloseableIterator[InternalRow] {
// Create an allocator used for all Arrow related memory.
protected val allocator: BufferAllocator = ArrowUtils.rootAllocator.newChildAllocator(
s"to${this.getClass.getSimpleName}",
0,
Long.MaxValue)
- resources.append(allocator)
- private val reader = try {
- new ArrowStreamReader(new ByteArrayInputStream(input), allocator)
- } catch {
- case e: Exception =>
- closeAll(resources.toSeq.reverse: _*)
- throw new IllegalArgumentException(
- s"Failed to create ArrowStreamReader: ${e.getMessage}", e)
- }
- resources.append(reader)
-
- private val root: VectorSchemaRoot = try {
- reader.getVectorSchemaRoot
- } catch {
- case e: Exception =>
- closeAll(resources.toSeq.reverse: _*)
- throw new IllegalArgumentException(
- s"Failed to read schema from IPC stream: ${e.getMessage}", e)
+ private val reader = {
+ val messages = ipcStreams.map { bytes =>
+ new MessageIterator(new ByteArrayInputStream(bytes), allocator)
+ }
+ new ConcatenatingArrowStreamReader(allocator, messages, destructive = true)
}
- resources.append(root)
- val schema: StructType = try {
- ArrowUtils.fromArrowSchema(root.getSchema)
+ lazy val schema: StructType = try {
+ ArrowUtils.fromArrowSchema(reader.getVectorSchemaRoot.getSchema)
} catch {
- case e: Exception =>
- closeAll(resources.toSeq.reverse: _*)
- throw new IllegalArgumentException(s"Failed to convert Arrow schema: ${e.getMessage}", e)
+ case NonFatal(e) =>
+ // Since this triggers a read (which involves allocating buffers) we have to clean-up.
+ close()
+ throw e
}
- // TODO: wrap in exception
- private var rowIterator: Iterator[InternalRow] = vectorSchemaRootToIter(root)
+ private var rowIterator: Iterator[InternalRow] = Iterator.empty
// Metrics to track batch processing
private var _batchesLoaded: Int = 0
@@ -326,36 +331,27 @@ private[sql] object ArrowConverters extends Logging {
if (context != null) {
context.addTaskCompletionListener[Unit] { _ =>
- closeAll(resources.toSeq.reverse: _*)
+ close()
}
}
// Public accessors for metrics
def batchesLoaded: Int = _batchesLoaded
def totalRowsProcessed: Long = _totalRowsProcessed
-
- // Loads the next batch from the Arrow reader and returns true or
- // false if the next batch could be loaded.
- private def loadNextBatch(): Boolean = {
- if (reader.loadNextBatch()) {
- rowIterator = vectorSchemaRootToIter(root)
- _batchesLoaded += 1
- true
- } else {
- false
- }
- }
+ def allocatedMemory: Long = allocator.getAllocatedMemory
+ def peakMemoryAllocation: Long = allocator.getPeakMemoryAllocation
override def hasNext: Boolean = {
- if (rowIterator.hasNext) {
- true
- } else {
- if (!loadNextBatch()) {
- false
+ while (!rowIterator.hasNext) {
+ if (reader.loadNextBatch()) {
+ rowIterator = vectorSchemaRootToIter(reader.getVectorSchemaRoot)
+ _batchesLoaded += 1
} else {
- hasNext
+ close()
+ return false
}
}
+ true
}
override def next(): InternalRow = {
@@ -365,6 +361,10 @@ private[sql] object ArrowConverters extends Logging {
_totalRowsProcessed += 1
rowIterator.next()
}
+
+ override def close(): Unit = {
+ closeAll(reader, allocator)
+ }
}
/**
@@ -490,15 +490,21 @@ private[sql] object ArrowConverters extends Logging {
* one schema and a varying number of record batches. Returns an iterator over the
* created InternalRow.
*/
- private[sql] def fromIPCStream(input: Array[Byte], context: TaskContext):
- (Iterator[InternalRow], StructType) = {
- fromIPCStreamWithIterator(input, context)
+ private[sql] def fromIPCStream(input: Array[Byte]):
+ (CloseableIterator[InternalRow], StructType) = {
+ fromIPCStream(Iterator.single(input))
+ }
+
+ private[sql] def fromIPCStream(inputs: Iterator[Array[Byte]]):
+ (CloseableIterator[InternalRow], StructType) = {
+ val iterator = new InternalRowIteratorFromIPCStream(inputs, null)
+ (iterator, iterator.schema)
}
// Overloaded method for tests to access the iterator with metrics
private[sql] def fromIPCStreamWithIterator(input: Array[Byte], context: TaskContext):
- (InternalRowIteratorFromIPCStream, StructType) = {
- val iterator = new InternalRowIteratorFromIPCStream(input, context)
+ (InternalRowIteratorFromIPCStream, StructType) = {
+ val iterator = new InternalRowIteratorFromIPCStream(Iterator.single(input), context)
(iterator, iterator.schema)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
index bf7491625fa03..56faf2032065d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala
@@ -406,11 +406,12 @@ object InMemoryRelation {
def apply(cacheBuilder: CachedRDDBuilder, qe: QueryExecution): InMemoryRelation = {
val optimizedPlan = qe.optimizedPlan
val serializer = cacheBuilder.serializer
- val newBuilder = if (serializer.supportsColumnarInput(optimizedPlan.output)) {
- cacheBuilder.copy(cachedPlan = serializer.convertToColumnarPlanIfPossible(qe.executedPlan))
+ val newCachedPlan = if (serializer.supportsColumnarInput(optimizedPlan.output)) {
+ serializer.convertToColumnarPlanIfPossible(qe.executedPlan)
} else {
- cacheBuilder.copy(cachedPlan = qe.executedPlan)
+ qe.executedPlan
}
+ val newBuilder = cacheBuilder.copy(cachedPlan = newCachedPlan, logicalPlan = qe.logical)
val relation = new InMemoryRelation(
newBuilder.cachedPlan.output, newBuilder, optimizedPlan.outputOrdering)
relation.statsOfPlanToCache = optimizedPlan.stats
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateUserDefinedFunctionCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateUserDefinedFunctionCommand.scala
index 78ff514bf9e51..a3780a8bff197 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateUserDefinedFunctionCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CreateUserDefinedFunctionCommand.scala
@@ -22,7 +22,6 @@ import java.util.Locale
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CapturesConfig, FunctionIdentifier}
import org.apache.spark.sql.catalyst.catalog.{LanguageSQL, RoutineLanguage, UserDefinedFunctionErrors}
-import org.apache.spark.sql.catalyst.plans.logical.IgnoreCachedData
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
@@ -30,7 +29,7 @@ import org.apache.spark.sql.types.StructType
* The base class for CreateUserDefinedFunctionCommand
*/
abstract class CreateUserDefinedFunctionCommand
- extends LeafRunnableCommand with IgnoreCachedData with CapturesConfig
+ extends LeafRunnableCommand with CapturesConfig
object CreateUserDefinedFunctionCommand {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala
index e31e7e8d704ca..e248f0eea96de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala
@@ -23,7 +23,6 @@ import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.analysis.VariableResolution
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.parser.ParseException
-import org.apache.spark.sql.catalyst.plans.logical.IgnoreCachedData
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.classic.ClassicConversions.castToImpl
import org.apache.spark.sql.errors.QueryCompilationErrors.toSQLId
@@ -218,7 +217,7 @@ object SetCommand {
* reset spark.sql.session.timeZone;
* }}}
*/
-case class ResetCommand(config: Option[String]) extends LeafRunnableCommand with IgnoreCachedData {
+case class ResetCommand(config: Option[String]) extends LeafRunnableCommand {
override def run(sparkSession: SparkSession): Seq[Row] = {
val globalInitialConfigs = sparkSession.sharedState.conf
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
index de5dbddbfa146..9c3ac9ef74191 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala
@@ -18,12 +18,11 @@
package org.apache.spark.sql.execution.command
import org.apache.spark.sql.{Row, SparkSession}
-import org.apache.spark.sql.catalyst.plans.logical.IgnoreCachedData
/**
* Clear all cached data from the in-memory cache.
*/
-case object ClearCacheCommand extends LeafRunnableCommand with IgnoreCachedData {
+case object ClearCacheCommand extends LeafRunnableCommand {
override def run(sparkSession: SparkSession): Seq[Row] = {
sparkSession.catalog.clearCache()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
index 514b64f6abed2..11ec17ca57fd9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
@@ -27,6 +27,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CapturesConfig, SQLConfHelper, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, GlobalTempView, LocalTempView, SchemaEvolution, SchemaUnsupported, ViewSchemaMode, ViewType}
+import org.apache.spark.sql.catalyst.analysis.V2TableReference
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, TemporaryViewRelation}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, SubqueryExpression, VariableReference}
import org.apache.spark.sql.catalyst.plans.logical.{AnalysisOnlyCommand, CreateTempView, CTEInChildren, CTERelationDef, LogicalPlan, Project, View, WithCTE}
@@ -34,6 +35,7 @@ import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.classic.ClassicConversions.castToImpl
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.NamespaceHelper
import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.StaticSQLConf
import org.apache.spark.sql.types.{MetadataBuilder, StructType}
import org.apache.spark.sql.util.SchemaUtils
@@ -733,7 +735,17 @@ object ViewHelper extends SQLConfHelper with Logging with CapturesConfig {
} else {
TemporaryViewRelation(
prepareTemporaryViewStoringAnalyzedPlan(name, aliasedPlan, defaultCollation),
- Some(aliasedPlan))
+ Some(prepareTemporaryViewPlan(name, aliasedPlan)))
+ }
+ }
+
+ private def prepareTemporaryViewPlan(
+ viewName: TableIdentifier,
+ plan: LogicalPlan): LogicalPlan = {
+ plan transform {
+ case r: DataSourceV2Relation
+ if r.catalog.isDefined && r.identifier.isDefined && r.timeTravelSpec.isEmpty =>
+ V2TableReference.createForTempView(r, viewName.nameParts)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
index d43c9eab0a5ba..3fd82573f001a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
@@ -30,7 +30,7 @@ import org.apache.spark.{SparkException, SparkUpgradeException}
import org.apache.spark.sql.{sources, SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY, SPARK_VERSION_METADATA_KEY}
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper}
-import org.apache.spark.sql.catalyst.util.RebaseDateTime
+import org.apache.spark.sql.catalyst.util.{RebaseDateTime, TypeUtils}
import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
@@ -93,6 +93,7 @@ object DataSourceUtils extends PredicateHelper {
* in a driver side.
*/
def verifySchema(format: FileFormat, schema: StructType, readOnly: Boolean = false): Unit = {
+ TypeUtils.failUnsupportedDataType(schema, SQLConf.get)
schema.foreach { field =>
val supported = if (readOnly) {
format.supportReadDataType(field.dataType)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
index 2cf1a5e9b8cdc..b0b20d08dccbe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PushVariantIntoScan.scala
@@ -26,10 +26,8 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project, Subquery}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
-import org.apache.spark.sql.connector.read.{SupportsPushDownVariants, VariantAccessInfo}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -225,8 +223,11 @@ class VariantInRelation {
case Some(variants) =>
variants.get(path) match {
case Some(fields) =>
+ // Accessing the full variant value
addField(fields, RequestedVariantField.fullVariant)
case _ =>
+ // Accessing the struct containing a variant.
+ // This variant is not eligible for push down.
// Remove non-eligible variants.
variants.filterInPlace { case (key, _) => !key.startsWith(path) }
}
@@ -281,11 +282,6 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
relation @ LogicalRelationWithTable(
hadoopFsRelation@HadoopFsRelation(_, _, _, _, _: ParquetFileFormat, _), _)) =>
rewritePlan(p, projectList, filters, relation, hadoopFsRelation)
-
- case p@PhysicalOperation(projectList, filters,
- scanRelation @ DataSourceV2ScanRelation(
- relation, scan: SupportsPushDownVariants, output, _, _)) =>
- rewritePlanV2(p, projectList, filters, scanRelation, scan)
}
}
@@ -333,112 +329,10 @@ object PushVariantIntoScan extends Rule[LogicalPlan] {
hadoopFsRelation.sparkSession)
val newRelation = relation.copy(relation = newHadoopFsRelation, output = newOutput.toIndexedSeq)
- buildFilterAndProject(newRelation, projectList, filters, variants, attributeMap)
- }
-
- // DataSource V2 rewrite method using SupportsPushDownVariants API
- // Key differences from V1 implementation:
- // 1. V2 uses DataSourceV2ScanRelation instead of LogicalRelation
- // 2. Uses SupportsPushDownVariants API instead of directly manipulating scan
- // 3. Schema is already resolved in scanRelation.output (no need for relation.resolve())
- // 4. Scan rebuilding is handled by the scan implementation via the API
- // Data sources like Delta and Iceberg can implement this API to support variant pushdown.
- private def rewritePlanV2(
- originalPlan: LogicalPlan,
- projectList: Seq[NamedExpression],
- filters: Seq[Expression],
- scanRelation: DataSourceV2ScanRelation,
- scan: SupportsPushDownVariants): LogicalPlan = {
- val variants = new VariantInRelation
-
- // Extract schema attributes from V2 scan relation
- val schemaAttributes = scanRelation.output
-
- // Construct schema for default value resolution
- val structSchema = StructType(schemaAttributes.map(a =>
- StructField(a.name, a.dataType, a.nullable, a.metadata)))
-
- val defaultValues = ResolveDefaultColumns.existenceDefaultValues(structSchema)
-
- // Add variant fields from the V2 scan schema
- for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) {
- variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil)
- }
- if (variants.mapping.isEmpty) return originalPlan
-
- // Collect requested fields from project list and filters
- projectList.foreach(variants.collectRequestedFields)
- filters.foreach(variants.collectRequestedFields)
-
- // If no variant columns remain after collection, return original plan
- if (variants.mapping.forall(_._2.isEmpty)) return originalPlan
-
- // Build VariantAccessInfo array for the API
- val variantAccessInfoArray = schemaAttributes.flatMap { attr =>
- variants.mapping.get(attr.exprId).flatMap(_.get(Nil)).map { fields =>
- // Build extracted schema for this variant column
- val extractedFields = fields.toArray.sortBy(_._2).map { case (field, ordinal) =>
- StructField(ordinal.toString, field.targetType, metadata = field.path.toMetadata)
- }
- val extractedSchema = if (extractedFields.isEmpty) {
- // Add placeholder field to avoid empty struct
- val placeholder = VariantMetadata("$.__placeholder_field__",
- failOnError = false, timeZoneId = "UTC")
- StructType(Array(StructField("0", BooleanType, metadata = placeholder.toMetadata)))
- } else {
- StructType(extractedFields)
- }
- new VariantAccessInfo(attr.name, extractedSchema)
- }
- }.toArray
-
- // Call the API to push down variant access
- if (variantAccessInfoArray.isEmpty) return originalPlan
-
- val pushed = scan.pushVariantAccess(variantAccessInfoArray)
- if (!pushed) return originalPlan
-
- // Get what was actually pushed
- val pushedVariantAccess = scan.pushedVariantAccess()
- if (pushedVariantAccess.isEmpty) return originalPlan
-
- // Build new attribute mapping based on pushed variant access
- val pushedColumnNames = pushedVariantAccess.map(_.columnName()).toSet
- val attributeMap = schemaAttributes.map { a =>
- if (pushedColumnNames.contains(a.name) && variants.mapping.get(a.exprId).exists(_.nonEmpty)) {
- val newType = variants.rewriteType(a.exprId, a.dataType, Nil)
- val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)(
- qualifier = a.qualifier)
- (a.exprId, newAttr)
- } else {
- (a.exprId, a)
- }
- }.toMap
-
- val newOutput = scanRelation.output.map(a => attributeMap.getOrElse(a.exprId, a))
-
- // The scan implementation should have updated its readSchema() based on the pushed info
- // We just need to create a new scan relation with the updated output
- val newScanRelation = scanRelation.copy(
- output = newOutput
- )
-
- buildFilterAndProject(newScanRelation, projectList, filters, variants, attributeMap)
- }
-
- /**
- * Build the final Project(Filter(relation)) plan with rewritten expressions.
- */
- private def buildFilterAndProject(
- relation: LogicalPlan,
- projectList: Seq[NamedExpression],
- filters: Seq[Expression],
- variants: VariantInRelation,
- attributeMap: Map[ExprId, AttributeReference]): LogicalPlan = {
val withFilter = if (filters.nonEmpty) {
- Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), relation)
+ Filter(filters.map(variants.rewriteExpr(_, attributeMap)).reduce(And), newRelation)
} else {
- relation
+ newRelation
}
val newProjectList = projectList.map { e =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
index 57e0efb993fb7..19b67d0c53900 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
@@ -83,6 +83,8 @@ case class BinaryFileFormat() extends FileFormat
override def shortName(): String = BINARY_FILE
+ override def toString: String = "BINARYFILE"
+
override protected def buildReader(
sparkSession: SparkSession,
dataSchema: StructType,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
index 4cc3fe61d22b4..08e545cb8c204 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
@@ -408,7 +408,7 @@ class ParquetFileFormat
}
override def supportDataType(dataType: DataType): Boolean = dataType match {
- case _: AtomicType => true
+ case _: AtomicType | _: NullType => true
case st: StructType => st.forall { f => supportDataType(f.dataType) }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala
index 25efd326f23a1..7ee5b4d224b34 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala
@@ -28,7 +28,7 @@ import org.apache.parquet.hadoop.api.{InitContext, ReadSupport}
import org.apache.parquet.hadoop.api.ReadSupport.ReadContext
import org.apache.parquet.io.api.RecordMaterializer
import org.apache.parquet.schema._
-import org.apache.parquet.schema.LogicalTypeAnnotation.{ListLogicalTypeAnnotation, MapKeyValueTypeAnnotation, MapLogicalTypeAnnotation}
+import org.apache.parquet.schema.LogicalTypeAnnotation.{ListLogicalTypeAnnotation, MapKeyValueTypeAnnotation, MapLogicalTypeAnnotation, UnknownLogicalTypeAnnotation}
import org.apache.parquet.schema.Type.Repetition
import org.apache.spark.internal.Logging
@@ -562,6 +562,8 @@ object ParquetReadSupport extends Logging {
}
case primitiveType: PrimitiveType =>
val cost = primitiveType.getPrimitiveTypeName match {
+ case _ if primitiveType.getLogicalTypeAnnotation
+ .isInstanceOf[UnknownLogicalTypeAnnotation] => 0 // NullType is always preferred
case PrimitiveType.PrimitiveTypeName.BOOLEAN => 1
case PrimitiveType.PrimitiveTypeName.INT32 => 4
case PrimitiveType.PrimitiveTypeName.INT64 => 8
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
index f9d50bf28ea85..271a1485dfd34 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
@@ -315,6 +315,13 @@ private[parquet] class ParquetRowConverter(
}
catalystType match {
+ case NullType
+ if parquetType.getLogicalTypeAnnotation.isInstanceOf[UnknownLogicalTypeAnnotation] =>
+ val parentUpdater = updater
+ // A converter that throws upon any add... call, as we don't expect any value for NullType.
+ new PrimitiveConverter with HasParentContainerUpdater {
+ override def updater: ParentContainerUpdater = parentUpdater
+ }
case LongType if isUnsignedIntTypeMatched(32) =>
new ParquetPrimitiveConverter(updater) {
override def addInt(value: Int): Unit =
@@ -869,7 +876,11 @@ private[parquet] class ParquetRowConverter(
}
}
- /** Parquet converter for unshredded Variant */
+ /**
+ * Parquet converter for unshredded Variant. We use this converter when the
+ * `spark.sql.variant.allowReadingShredded` config is set to false. This option just exists to
+ * fall back to legacy logic which will eventually be removed.
+ */
private final class ParquetUnshreddedVariantConverter(
parquetType: GroupType,
updater: ParentContainerUpdater)
@@ -883,29 +894,27 @@ private[parquet] class ParquetRowConverter(
// We may allow more than two children in the future, so consider this unsupported.
throw QueryCompilationErrors.invalidVariantWrongNumFieldsError()
}
- val valueAndMetadata = Seq("value", "metadata").map { colName =>
+ val Seq(value, metadata) = Seq("value", "metadata").map { colName =>
val idx = (0 until parquetType.getFieldCount())
- .find(parquetType.getFieldName(_) == colName)
- if (idx.isEmpty) {
- throw QueryCompilationErrors.invalidVariantMissingFieldError(colName)
- }
- val child = parquetType.getType(idx.get)
+ .find(parquetType.getFieldName(_) == colName)
+ .getOrElse(throw QueryCompilationErrors.invalidVariantMissingFieldError(colName))
+ val child = parquetType.getType(idx)
if (!child.isPrimitive || child.getRepetition != Type.Repetition.REQUIRED ||
- child.asPrimitiveType().getPrimitiveTypeName != BINARY) {
+ child.asPrimitiveType().getPrimitiveTypeName != BINARY) {
throw QueryCompilationErrors.invalidVariantNullableOrNotBinaryFieldError(colName)
}
- child
+ idx
}
- Array(
- // Converter for value
- newConverter(valueAndMetadata(0), BinaryType, new ParentContainerUpdater {
+ val result = new Array[Converter with HasParentContainerUpdater](2)
+ result(value) =
+ newConverter(parquetType.getType(value), BinaryType, new ParentContainerUpdater {
override def set(value: Any): Unit = currentValue = value
- }),
-
- // Converter for metadata
- newConverter(valueAndMetadata(1), BinaryType, new ParentContainerUpdater {
+ })
+ result(metadata) =
+ newConverter(parquetType.getType(metadata), BinaryType, new ParentContainerUpdater {
override def set(value: Any): Unit = currentMetadata = value
- }))
+ })
+ result
}
override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
index 947c021c1bd3a..9e6f4447ca792 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
@@ -58,7 +58,9 @@ class ParquetToSparkSchemaConverter(
caseSensitive: Boolean = SQLConf.CASE_SENSITIVE.defaultValue.get,
inferTimestampNTZ: Boolean = SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.defaultValue.get,
nanosAsLong: Boolean = SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.defaultValue.get,
- useFieldId: Boolean = SQLConf.PARQUET_FIELD_ID_READ_ENABLED.defaultValue.get) {
+ useFieldId: Boolean = SQLConf.PARQUET_FIELD_ID_READ_ENABLED.defaultValue.get,
+ val ignoreVariantAnnotation: Boolean =
+ SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.defaultValue.get) {
def this(conf: SQLConf) = this(
assumeBinaryIsString = conf.isParquetBinaryAsString,
@@ -66,7 +68,8 @@ class ParquetToSparkSchemaConverter(
caseSensitive = conf.caseSensitiveAnalysis,
inferTimestampNTZ = conf.parquetInferTimestampNTZEnabled,
nanosAsLong = conf.legacyParquetNanosAsLong,
- useFieldId = conf.parquetFieldIdReadEnabled)
+ useFieldId = conf.parquetFieldIdReadEnabled,
+ ignoreVariantAnnotation = conf.parquetIgnoreVariantAnnotation)
def this(conf: Configuration) = this(
assumeBinaryIsString = conf.get(SQLConf.PARQUET_BINARY_AS_STRING.key).toBoolean,
@@ -75,7 +78,9 @@ class ParquetToSparkSchemaConverter(
inferTimestampNTZ = conf.get(SQLConf.PARQUET_INFER_TIMESTAMP_NTZ_ENABLED.key).toBoolean,
nanosAsLong = conf.get(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key).toBoolean,
useFieldId = conf.getBoolean(SQLConf.PARQUET_FIELD_ID_READ_ENABLED.key,
- SQLConf.PARQUET_FIELD_ID_READ_ENABLED.defaultValue.get))
+ SQLConf.PARQUET_FIELD_ID_READ_ENABLED.defaultValue.get),
+ ignoreVariantAnnotation = conf.getBoolean(SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.key,
+ SQLConf.PARQUET_IGNORE_VARIANT_ANNOTATION.defaultValue.get))
/**
* Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]].
@@ -202,15 +207,17 @@ class ParquetToSparkSchemaConverter(
case primitiveColumn: PrimitiveColumnIO => convertPrimitiveField(primitiveColumn, targetType)
case groupColumn: GroupColumnIO if targetType.contains(VariantType) =>
if (SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_READING_SHREDDED)) {
- val col = convertGroupField(groupColumn)
+ // We need the underlying file type regardless of the config.
+ val col = convertGroupField(groupColumn, ignoreVariantAnnotation = true)
col.copy(sparkType = VariantType, variantFileType = Some(col))
} else {
convertVariantField(groupColumn)
}
case groupColumn: GroupColumnIO if targetType.exists(VariantMetadata.isVariantStruct) =>
- val col = convertGroupField(groupColumn)
+ val col = convertGroupField(groupColumn, ignoreVariantAnnotation = true)
col.copy(sparkType = targetType.get, variantFileType = Some(col))
- case groupColumn: GroupColumnIO => convertGroupField(groupColumn, targetType)
+ case groupColumn: GroupColumnIO =>
+ convertGroupField(groupColumn, ignoreVariantAnnotation, targetType)
}
}
@@ -246,7 +253,9 @@ class ParquetToSparkSchemaConverter(
DecimalType(precision, scale)
}
- val sparkType = sparkReadType.getOrElse(typeName match {
+ val isUnknownType = typeAnnotation.isInstanceOf[UnknownLogicalTypeAnnotation]
+ val nullTypeOpt = Option.when(isUnknownType)(NullType)
+ val sparkType = sparkReadType.orElse(nullTypeOpt).getOrElse(typeName match {
case BOOLEAN => BooleanType
case FLOAT => FloatType
@@ -347,6 +356,7 @@ class ParquetToSparkSchemaConverter(
private def convertGroupField(
groupColumn: GroupColumnIO,
+ ignoreVariantAnnotation: Boolean,
sparkReadType: Option[DataType] = None): ParquetColumn = {
val field = groupColumn.getType.asGroupType()
@@ -371,6 +381,22 @@ class ParquetToSparkSchemaConverter(
Option(field.getLogicalTypeAnnotation).fold(
convertInternal(groupColumn, sparkReadType.map(_.asInstanceOf[StructType]))) {
+ case v: VariantLogicalTypeAnnotation if v.getSpecVersion == 1 =>
+ if (ignoreVariantAnnotation) {
+ convertInternal(groupColumn)
+ } else {
+ ParquetSchemaConverter.checkConversionRequirement(
+ sparkReadType.forall(_.isInstanceOf[VariantType]),
+ s"Invalid Spark read type: expected $field to be variant type but found " +
+ s"${if (sparkReadType.isEmpty) { "None" } else {sparkReadType.get.sql} }")
+ if (SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_READING_SHREDDED)) {
+ val col = convertInternal(groupColumn)
+ col.copy(sparkType = VariantType, variantFileType = Some(col))
+ } else {
+ convertVariantField(groupColumn)
+ }
+ }
+
// A Parquet list is represented as a 3-level structure:
//
// group (LIST) {
@@ -550,7 +576,9 @@ class SparkToParquetSchemaConverter(
writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get,
outputTimestampType: SQLConf.ParquetOutputTimestampType.Value =
SQLConf.ParquetOutputTimestampType.INT96,
- useFieldId: Boolean = SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.defaultValue.get) {
+ useFieldId: Boolean = SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.defaultValue.get,
+ annotateVariantLogicalType: Boolean =
+ SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.defaultValue.get) {
def this(conf: SQLConf) = this(
writeLegacyParquetFormat = conf.writeLegacyParquetFormat,
@@ -561,7 +589,9 @@ class SparkToParquetSchemaConverter(
writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean,
outputTimestampType = SQLConf.ParquetOutputTimestampType.withName(
conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key)),
- useFieldId = conf.get(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key).toBoolean)
+ useFieldId = conf.get(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key).toBoolean,
+ annotateVariantLogicalType =
+ conf.get(SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.key).toBoolean)
/**
* Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]].
@@ -815,14 +845,22 @@ class SparkToParquetSchemaConverter(
// ===========
case VariantType =>
- Types.buildGroup(repetition)
+ (if (annotateVariantLogicalType) {
+ Types.buildGroup(repetition).as(LogicalTypeAnnotation.variantType(1))
+ } else {
+ Types.buildGroup(repetition)
+ })
.addField(convertField(StructField("value", BinaryType, nullable = false), inShredded))
.addField(convertField(StructField("metadata", BinaryType, nullable = false), inShredded))
.named(field.name)
case s: StructType if SparkShreddingUtils.isVariantShreddingStruct(s) =>
// Variant struct takes a Variant and writes to Parquet as a shredded schema.
- val group = Types.buildGroup(repetition)
+ val group = if (annotateVariantLogicalType) {
+ Types.buildGroup(repetition).as(LogicalTypeAnnotation.variantType(1))
+ } else {
+ Types.buildGroup(repetition)
+ }
s.fields.foreach { f =>
group.addField(convertField(f, inShredded = true))
}
@@ -836,6 +874,10 @@ class SparkToParquetSchemaConverter(
case udt: UserDefinedType[_] =>
convertField(field.copy(dataType = udt.sqlType), inShredded)
+ case NullType => // Selected primitive type here doesn't have significance.
+ Types.primitive(BOOLEAN, repetition).named(field.name)
+ .withLogicalTypeAnnotation(LogicalTypeAnnotation.unknownType())
+
case _ =>
throw QueryCompilationErrors.cannotConvertDataTypeToParquetTypeError(field)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
index 65a77322549f1..1f11a67b08fff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
@@ -45,7 +45,7 @@ import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, Outpu
import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.internal.SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED
-import org.apache.spark.sql.types.{ArrayType, AtomicType, DataType, MapType, StructField, StructType, UserDefinedType, VariantType}
+import org.apache.spark.sql.types.{ArrayType, AtomicType, DataType, MapType, NullType, StructField, StructType, UserDefinedType, VariantType}
import org.apache.spark.util.ArrayImplicits._
object ParquetUtils extends Logging {
@@ -209,6 +209,8 @@ object ParquetUtils extends Logging {
def isBatchReadSupported(sqlConf: SQLConf, dt: DataType): Boolean = dt match {
case _: AtomicType =>
true
+ case _: NullType =>
+ sqlConf.parquetVectorizedReaderNullTypeEnabled
case at: ArrayType =>
sqlConf.parquetVectorizedReaderNestedColumnEnabled &&
isBatchReadSupported(sqlConf, at.elementType)
@@ -521,6 +523,10 @@ object ParquetUtils extends Logging {
SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key,
sqlConf.legacyParquetNanosAsLong.toString)
+ conf.set(
+ SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.key,
+ sqlConf.parquetAnnotateVariantLogicalType.toString)
+
// Sets compression scheme
conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
index 2ab9fb64da43d..dcaf88fa8dfdb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala
@@ -29,7 +29,7 @@ import org.apache.parquet.hadoop.api.WriteSupport
import org.apache.parquet.hadoop.api.WriteSupport.WriteContext
import org.apache.parquet.io.api.{Binary, RecordConsumer}
-import org.apache.spark.{SPARK_VERSION_SHORT, SparkException}
+import org.apache.spark.{SPARK_VERSION_SHORT, SparkException, SparkUnsupportedOperationException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY, SPARK_VERSION_METADATA_KEY}
import org.apache.spark.sql.catalyst.InternalRow
@@ -192,6 +192,9 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging {
// schema. This affects how timestamp values are written.
private def makeWriter(dataType: DataType, inShredded: Boolean): ValueWriter = {
dataType match {
+ case NullType => // No values of NullType should ever be written, as all values are null.
+ (_: SpecializedGetters, _: Int) => throw SparkUnsupportedOperationException()
+
case BooleanType =>
(row: SpecializedGetters, ordinal: Int) =>
recordConsumer.addBoolean(row.getBoolean(ordinal))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
index 1132f074f29d1..ca2defffba913 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala
@@ -646,7 +646,9 @@ case object SparkShreddingUtils {
def parquetTypeToSparkType(parquetType: ParquetType): DataType = {
val messageType = ParquetTypes.buildMessage().addField(parquetType).named("foo")
val column = new ColumnIOFactory().getColumnIO(messageType)
- new ParquetToSparkSchemaConverter().convertField(column.getChild(0)).sparkType
+ // We need the underlying file type regardless of the ignoreVariantAnnotation config.
+ val converter = new ParquetToSparkSchemaConverter(ignoreVariantAnnotation = true)
+ converter.convertField(column.getChild(0)).sparkType
}
class SparkShreddedResult(schema: VariantSchema) extends VariantShreddingWriter.ShreddedResult {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
index c8cb5d7ce7c51..060d7fe72c0a4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala
@@ -65,12 +65,8 @@ class ResolveSQLOnFile(sparkSession: SparkSession) extends Rule[LogicalPlan] {
messageParameters = e.getMessageParameters.asScala.toMap)
case _: ClassNotFoundException => None
case e: Exception if !e.isInstanceOf[AnalysisException] =>
- // the provider is valid, but failed to create a logical plan
- u.failAnalysis(
- errorClass = "UNSUPPORTED_DATASOURCE_FOR_DIRECT_QUERY",
- messageParameters = Map("dataSourceType" -> u.multipartIdentifier.head),
- cause = e
- )
+ throw QueryCompilationErrors.failedToCreatePlanForDirectQueryError(
+ u.multipartIdentifier.head, e)
}
case _ =>
None
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
index 6a07d3c3931a1..19a057c72506b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources.v2
+import java.util.concurrent.ConcurrentHashMap
+
import scala.language.existentials
import org.apache.spark._
@@ -33,6 +35,19 @@ import org.apache.spark.util.ArrayImplicits._
class DataSourceRDDPartition(val index: Int, val inputPartitions: Seq[InputPartition])
extends Partition with Serializable
+/**
+ * Holds the state for a reader in a task, used by the completion listener to access the most
+ * recently created reader and iterator for final metrics updates and cleanup.
+ *
+ * When `compute()` is called multiple times for the same task (e.g., when DataSourceRDD is
+ * coalesced), this state is updated on each call to track the most recent reader. The task
+ * completion listener then uses this most recent reader for final cleanup and metrics reporting.
+ *
+ * @param reader The partition reader
+ * @param iterator The metrics iterator wrapping the reader
+ */
+private case class ReaderState(reader: PartitionReader[_], iterator: MetricsIterator[_])
+
// TODO: we should have 2 RDDs: an RDD[InternalRow] for row-based scan, an `RDD[ColumnarBatch]` for
// columnar scan.
class DataSourceRDD(
@@ -43,6 +58,11 @@ class DataSourceRDD(
customMetrics: Map[String, SQLMetric])
extends RDD[InternalRow](sc, Nil) {
+ // Map from task attempt ID to the most recently created ReaderState for that task.
+ // When compute() is called multiple times for the same task (due to coalescing), the map entry
+ // is updated each time so the completion listener always closes the last reader.
+ @transient private lazy val taskReaderStates = new ConcurrentHashMap[Long, ReaderState]()
+
override protected def getPartitions: Array[Partition] = {
inputPartitions.zipWithIndex.map {
case (inputPartitions, index) => new DataSourceRDDPartition(index, inputPartitions)
@@ -55,20 +75,34 @@ class DataSourceRDD(
}
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
+ val taskAttemptId = context.taskAttemptId()
+
+ // Add completion listener only once per task attempt. When compute() is called a second time
+ // for the same task (e.g., due to coalescing), the first call will have already put a
+ // ReaderState into taskReaderStates, so containsKey returns true and we skip this block.
+ if (!taskReaderStates.containsKey(taskAttemptId)) {
+ context.addTaskCompletionListener[Unit] { ctx =>
+ // In case of early stopping before consuming the entire iterator,
+ // we need to do one more metric update at the end of the task.
+ try {
+ val readerState = taskReaderStates.get(ctx.taskAttemptId())
+ if (readerState != null) {
+ CustomMetrics.updateMetrics(
+ readerState.reader.currentMetricsValues.toImmutableArraySeq, customMetrics)
+ readerState.iterator.forceUpdateMetrics()
+ readerState.reader.close()
+ }
+ } finally {
+ taskReaderStates.remove(ctx.taskAttemptId())
+ }
+ }
+ }
val iterator = new Iterator[Object] {
private val inputPartitions = castPartition(split).inputPartitions
private var currentIter: Option[Iterator[Object]] = None
private var currentIndex: Int = 0
- private val partitionMetricCallback = new PartitionMetricCallback(customMetrics)
-
- // In case of early stopping before consuming the entire iterator,
- // we need to do one more metric update at the end of the task.
- context.addTaskCompletionListener[Unit] { _ =>
- partitionMetricCallback.execute()
- }
-
override def hasNext: Boolean = currentIter.exists(_.hasNext) || advanceToNextIter()
override def next(): Object = {
@@ -96,8 +130,18 @@ class DataSourceRDD(
(iter, rowReader)
}
- // Once we advance to the next partition, update the metric callback for early finish
- partitionMetricCallback.advancePartition(iter, reader)
+ // Flush metrics and close the previous reader before advancing to the next one.
+ // Pass the accumulated metrics to the new reader so they carry forward correctly.
+ val prevState = taskReaderStates.get(taskAttemptId)
+ if (prevState != null) {
+ val metrics = prevState.reader.currentMetricsValues
+ CustomMetrics.updateMetrics(metrics.toImmutableArraySeq, customMetrics)
+ reader.initMetricsValues(metrics)
+ prevState.reader.close()
+ }
+
+ // Update the map so the completion listener always references the latest reader.
+ taskReaderStates.put(taskAttemptId, ReaderState(reader, iter))
currentIter = Some(iter)
hasNext
@@ -113,28 +157,6 @@ class DataSourceRDD(
}
}
-private class PartitionMetricCallback
- (customMetrics: Map[String, SQLMetric]) {
- private var iter: MetricsIterator[_] = null
- private var reader: PartitionReader[_] = null
-
- def advancePartition(iter: MetricsIterator[_], reader: PartitionReader[_]): Unit = {
- execute()
-
- this.iter = iter
- this.reader = reader
- }
-
- def execute(): Unit = {
- if (iter != null && reader != null) {
- CustomMetrics
- .updateMetrics(reader.currentMetricsValues.toImmutableArraySeq, customMetrics)
- iter.forceUpdateMetrics()
- reader.close()
- }
- }
-}
-
private class PartitionIterator[T](
reader: PartitionReader[T],
customMetrics: Map[String, SQLMetric]) extends Iterator[T] {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 9c624d951a76a..81bc1990404a9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -571,7 +571,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
val condition = a.checkConstraint.condition
val change = TableChange.addConstraint(
check.toV2Constraint,
- d.relation.table.currentVersion)
+ d.relation.table.version)
ResolveTableConstraints.validateCatalogForTableChange(Seq(change), catalog, ident)
AddCheckConstraintExec(catalog, ident, change, condition, planLater(a.child)) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala
index 92bbca3c02966..a3b5c5aeb7995 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala
@@ -115,7 +115,7 @@ private[sql] object DataSourceV2Utils extends Logging {
val finalOptions = sessionOptions.filter { case (k, _) => !optionsWithPath.contains(k) } ++
optionsWithPath.originalMap
val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava)
- val (table, catalog, ident) = provider match {
+ val (table, catalog, ident, timeTravelSpec) = provider match {
case _: SupportsCatalogOptions if userSpecifiedSchema.nonEmpty =>
throw new IllegalArgumentException(
s"$source does not support user specified schema. Please don't specify the schema.")
@@ -141,16 +141,17 @@ private[sql] object DataSourceV2Utils extends Logging {
}
val timeTravel = TimeTravelSpec.create(
timeTravelTimestamp, timeTravelVersion, conf.sessionLocalTimeZone)
- (CatalogV2Util.getTable(catalog, ident, timeTravel), Some(catalog), Some(ident))
+ val tbl = CatalogV2Util.getTable(catalog, ident, timeTravel)
+ (tbl, Some(catalog), Some(ident), timeTravel)
case _ =>
// TODO: Non-catalog paths for DSV2 are currently not well defined.
val tbl = DataSourceV2Utils.getTableFromProvider(provider, dsOptions, userSpecifiedSchema)
- (tbl, None, None)
+ (tbl, None, None, None)
}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
table match {
case _: SupportsRead if table.supports(BATCH_READ) =>
- Option(DataSourceV2Relation.create(table, catalog, ident, dsOptions))
+ Option(DataSourceV2Relation.create(table, catalog, ident, dsOptions, timeTravelSpec))
case _ => None
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RealTimeStreamScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RealTimeStreamScanExec.scala
index c4e072f184e6a..3432f28e12cc2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RealTimeStreamScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/RealTimeStreamScanExec.scala
@@ -83,7 +83,9 @@ case class LowLatencyReaderWrap(
reader.get()
}
- override def close(): Unit = {}
+ override def close(): Unit = {
+ reader.close()
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala
index 7ce95ced0d242..454a4041d36e8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceTableExec.scala
@@ -83,6 +83,7 @@ case class AtomicReplaceTableExec(
.withColumns(columns)
.withPartitions(partitioning.toArray)
.withProperties(tableProperties.asJava)
+ .withConstraints(tableSpec.constraints.toArray)
.build()
catalog.stageCreateOrReplace(identifier, tableInfo)
} else if (catalog.tableExists(identifier)) {
@@ -91,6 +92,7 @@ case class AtomicReplaceTableExec(
.withColumns(columns)
.withPartitions(partitioning.toArray)
.withProperties(tableProperties.asJava)
+ .withConstraints(tableSpec.constraints.toArray)
.build()
catalog.stageReplace(identifier, tableInfo)
} catch {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 31a98e1ff96cb..adfd5ceacd675 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -21,8 +21,9 @@ import java.util.Locale
import scala.collection.mutable
+import org.apache.spark.SparkException
import org.apache.spark.internal.LogKeys.{AGGREGATE_FUNCTIONS, COLUMN_NAMES, GROUP_BY_EXPRS, JOIN_CONDITION, JOIN_TYPE, POST_SCAN_FILTERS, PUSHED_FILTERS, RELATION_NAME, RELATION_OUTPUT}
-import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, ExprId, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
import org.apache.spark.sql.catalyst.planning.{PhysicalOperation, ScanOperation}
@@ -32,10 +33,11 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder}
import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.connector.expressions.filter.Predicate
-import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownJoin, V1Scan}
-import org.apache.spark.sql.execution.datasources.DataSourceStrategy
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownJoin, SupportsPushDownVariantExtractions, V1Scan, VariantExtraction}
+import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, VariantInRelation}
+import org.apache.spark.sql.internal.connector.VariantExtractionImpl
import org.apache.spark.sql.sources
-import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType, StructType}
+import org.apache.spark.sql.types.{DataType, DecimalType, IntegerType, StructField, StructType}
import org.apache.spark.sql.util.SchemaUtils._
import org.apache.spark.util.ArrayImplicits._
@@ -49,9 +51,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
pushDownFilters,
pushDownJoin,
pushDownAggregates,
+ pushDownVariants,
pushDownLimitAndOffset,
buildScanWithPushedAggregate,
buildScanWithPushedJoin,
+ buildScanWithPushedVariants,
pruneColumns)
pushdownRules.foldLeft(plan) { (newPlan, pushDownRule) =>
@@ -318,6 +322,139 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
case agg: Aggregate => rewriteAggregate(agg)
}
+ def pushDownVariants(plan: LogicalPlan): LogicalPlan = plan.transformDown {
+ case p@PhysicalOperation(projectList, filters, sHolder @ ScanBuilderHolder(_, _,
+ builder: SupportsPushDownVariantExtractions))
+ if conf.getConf(org.apache.spark.sql.internal.SQLConf.PUSH_VARIANT_INTO_SCAN) =>
+ pushVariantExtractions(p, projectList, filters, sHolder, builder)
+ }
+
+ /**
+ * Converts an ordinal path to a field name path.
+ *
+ * @param structType The top-level struct type
+ * @param ordinals The ordinal path (e.g., [1, 1] for nested.field)
+ * @return The field name path (e.g., ["nested", "field"])
+ */
+ private def getColumnName(structType: StructType, ordinals: Seq[Int]): Seq[String] = {
+ ordinals match {
+ case Seq() =>
+ // Base case: no more ordinals
+ Seq.empty
+ case ordinal +: rest =>
+ // Get the field at this ordinal
+ val field = structType.fields(ordinal)
+ if (rest.isEmpty) {
+ // Last ordinal in the path
+ Seq(field.name)
+ } else {
+ // Recurse into nested struct
+ field.dataType match {
+ case nestedStruct: StructType =>
+ field.name +: getColumnName(nestedStruct, rest)
+ case _ =>
+ throw SparkException.internalError(
+ s"Expected StructType at field '${field.name}' but got ${field.dataType}")
+ }
+ }
+ }
+ }
+
+ private def pushVariantExtractions(
+ originalPlan: LogicalPlan,
+ projectList: Seq[NamedExpression],
+ filters: Seq[Expression],
+ sHolder: ScanBuilderHolder,
+ builder: SupportsPushDownVariantExtractions): LogicalPlan = {
+ val variants = new VariantInRelation
+
+ // Extract schema attributes from scan builder holder
+ val schemaAttributes = sHolder.output
+
+ // Construct schema for default value resolution
+ val structSchema = StructType(schemaAttributes.map(a =>
+ StructField(a.name, a.dataType, a.nullable, a.metadata)))
+
+ val defaultValues = org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.
+ existenceDefaultValues(structSchema)
+
+ // Add variant fields from the V2 scan schema
+ for ((a, defaultValue) <- schemaAttributes.zip(defaultValues)) {
+ variants.addVariantFields(a.exprId, a.dataType, defaultValue, Nil)
+ }
+ if (variants.mapping.isEmpty) return originalPlan
+
+ // Collect requested fields from project list and filters
+ projectList.foreach(variants.collectRequestedFields)
+ filters.foreach(variants.collectRequestedFields)
+
+ // If no variant columns remain after collection, return original plan
+ if (variants.mapping.forall(_._2.isEmpty)) return originalPlan
+
+ // Build individual VariantExtraction for each field access
+ // Track which extraction corresponds to which (attr, field, ordinal)
+ val extractionInfo = schemaAttributes.flatMap { topAttr =>
+ val variantFields = variants.mapping.get(topAttr.exprId)
+ if (variantFields.isEmpty || variantFields.get.isEmpty) {
+ // No variant fields for this attribute
+ Seq.empty
+ } else {
+ variantFields.get.toSeq.flatMap { case (pathToVariant, fields) =>
+ val columnName = if (pathToVariant.isEmpty) {
+ Seq(topAttr.name)
+ } else {
+ Seq(topAttr.name) ++
+ getColumnName(topAttr.dataType.asInstanceOf[StructType], pathToVariant)
+ }
+ fields.toArray.sortBy(_._2).map { case (field, ordinal) =>
+ val extraction = new VariantExtractionImpl(
+ columnName.toArray,
+ field.path.toMetadata,
+ field.targetType
+ )
+ (extraction, topAttr, field, ordinal)
+ }
+ }
+ }
+ }
+
+ // Call the API to push down variant extractions
+ if (extractionInfo.isEmpty) return originalPlan
+
+ val extractions: Array[VariantExtraction] = extractionInfo.map(_._1).toArray
+ val pushedResults = builder.pushVariantExtractions(extractions)
+
+ // Filter to only the accepted extractions
+ val acceptedExtractions = extractionInfo.zip(pushedResults).filter(_._2).map(_._1)
+ if (acceptedExtractions.isEmpty) return originalPlan
+
+ // Group accepted extractions by attribute to rebuild the struct schemas
+ val extractionsByAttr = acceptedExtractions.groupBy(_._2)
+ val pushedColumnNames = extractionsByAttr.keys.map(_.name).toSet
+
+ // Build new attribute mapping based on pushed variant extractions
+ val attributeMap = schemaAttributes.map { a =>
+ if (pushedColumnNames.contains(a.name) && variants.mapping.get(a.exprId).exists(_.nonEmpty)) {
+ val newType = variants.rewriteType(a.exprId, a.dataType, Nil)
+ val newAttr = AttributeReference(a.name, newType, a.nullable, a.metadata)(
+ qualifier = a.qualifier)
+ (a.exprId, newAttr)
+ } else {
+ (a.exprId, a.asInstanceOf[AttributeReference])
+ }
+ }.toMap
+
+ val newOutput = sHolder.output.map(a => attributeMap.getOrElse(a.exprId, a))
+
+ // Store the transformation info on the holder for later use
+ sHolder.pushedVariants = Some(variants)
+ sHolder.pushedVariantAttributeMap = attributeMap
+ sHolder.output = newOutput
+
+ // Return the original plan unchanged - transformation happens in buildScanWithPushedVariants
+ originalPlan
+ }
+
private def rewriteAggregate(agg: Aggregate): LogicalPlan = agg.child match {
case PhysicalOperation(project, Nil, holder @ ScanBuilderHolder(_, _,
r: SupportsPushDownAggregates)) if CollapseProject.canCollapseExpressions(
@@ -589,6 +726,48 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
Project(projectList, scanRelation)
}
+ def buildScanWithPushedVariants(plan: LogicalPlan): LogicalPlan = plan.transform {
+ case p@PhysicalOperation(projectList, filters, holder: ScanBuilderHolder)
+ if holder.pushedVariants.isDefined =>
+ val variants = holder.pushedVariants.get
+ val attributeMap = holder.pushedVariantAttributeMap
+
+ // Build the scan
+ val scan = holder.builder.build()
+ val realOutput = toAttributes(scan.readSchema())
+ val wrappedScan = getWrappedScan(scan, holder)
+ val scanRelation = DataSourceV2ScanRelation(holder.relation, wrappedScan, realOutput)
+
+ // Create projection to map real output to expected output (with transformed types)
+ val outputProjection = realOutput.zip(holder.output).map { case (realAttr, expectedAttr) =>
+ Alias(realAttr, expectedAttr.name)(expectedAttr.exprId)
+ }
+
+ // Rewrite filter expressions using the variant transformation
+ val rewrittenFilters = if (filters.nonEmpty) {
+ val rewrittenFilterExprs = filters.map(variants.rewriteExpr(_, attributeMap))
+ Some(rewrittenFilterExprs.reduce(And))
+ } else {
+ None
+ }
+
+ // Rewrite project list expressions using the variant transformation
+ val rewrittenProjectList = projectList.map { e =>
+ val rewritten = variants.rewriteExpr(e, attributeMap)
+ rewritten match {
+ case n: NamedExpression => n
+ // When the variant column is directly selected, we replace the attribute
+ // reference with a struct access, which is not a NamedExpression. Wrap it with Alias.
+ case _ => Alias(rewritten, e.name)(e.exprId, e.qualifier)
+ }
+ }
+
+ // Build the plan: Project(outputProjection) -> [Filter?] -> scanRelation
+ val withProjection = Project(outputProjection, scanRelation)
+ val withFilter = rewrittenFilters.map(Filter(_, withProjection)).getOrElse(withProjection)
+ Project(rewrittenProjectList, withFilter)
+ }
+
def pruneColumns(plan: LogicalPlan): LogicalPlan = plan.transform {
case ScanOperation(project, filtersStayUp, filtersPushDown, sHolder: ScanBuilderHolder) =>
// column pruning
@@ -834,6 +1013,10 @@ case class ScanBuilderHolder(
var joinedRelationsPushedDownOperators: Seq[PushedDownOperators] = Seq.empty[PushedDownOperators]
var pushedJoinOutputMap: AttributeMap[Expression] = AttributeMap.empty[Expression]
+
+ var pushedVariantAttributeMap: Map[ExprId, AttributeReference] = Map.empty
+
+ var pushedVariants: Option[VariantInRelation] = None
}
// A wrapper for v1 scan to carry the translated filters and the handled ones, along with
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala
new file mode 100644
index 0000000000000..151329de9e6f2
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import scala.collection.mutable
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.SQLConfHelper
+import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan}
+import org.apache.spark.sql.classic.SparkSession
+import org.apache.spark.sql.connector.catalog.{Identifier, Table, TableCatalog, V2TableUtil}
+import org.apache.spark.sql.connector.catalog.CatalogV2Util
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.util.SchemaValidationMode
+import org.apache.spark.sql.util.SchemaValidationMode.ALLOW_NEW_FIELDS
+import org.apache.spark.sql.util.SchemaValidationMode.PROHIBIT_CHANGES
+
+private[sql] object V2TableRefreshUtil extends SQLConfHelper with Logging {
+ /**
+ * Refreshes table metadata for tables in the plan.
+ *
+ * This method reloads table metadata from the catalog and validates:
+ * - Table identity: Ensures table ID has not changed
+ * - Data columns: Verifies captured columns align with the current schema
+ * - Metadata columns: Checks metadata column consistency
+ *
+ * Tables with time travel specifications are skipped as they reference a specific point
+ * in time and don't have to be refreshed.
+ *
+ * Schema validation mode depends on the underlying plan. Commands, for instance,
+ * prohibit any schema changes while queries permit adding columns.
+ *
+ * @param spark the currently active Spark session
+ * @param plan the logical plan to refresh
+ * @param versionedOnly indicates whether to refresh only versioned tables
+ * @return plan with refreshed table metadata
+ */
+ def refresh(
+ spark: SparkSession,
+ plan: LogicalPlan,
+ versionedOnly: Boolean = false): LogicalPlan = {
+ refresh(spark, plan, versionedOnly, determineSchemaValidationMode(plan))
+ }
+
+ /**
+ * Refreshes table metadata for tables in the plan.
+ *
+ * This method reloads table metadata from the catalog and validates:
+ * - Table identity: Ensures table ID has not changed
+ * - Data columns: Verifies captured columns align with the current schema
+ * - Metadata columns: Checks metadata column consistency
+ *
+ * Tables with time travel specifications are skipped as they reference a specific point
+ * in time and don't have to be refreshed.
+ *
+ * @param spark the currently active Spark session
+ * @param plan the logical plan to refresh
+ * @param versionedOnly indicates whether to refresh only versioned tables
+ * @param schemaValidationMode schema validation mode to use
+ * @return plan with refreshed table metadata
+ */
+ def refresh(
+ spark: SparkSession,
+ plan: LogicalPlan,
+ versionedOnly: Boolean,
+ schemaValidationMode: SchemaValidationMode): LogicalPlan = {
+ val currentTables = mutable.HashMap.empty[(TableCatalog, Identifier), Table]
+ plan transformWithSubqueries {
+ case r @ ExtractV2CatalogAndIdentifier(catalog, ident)
+ if (r.isVersioned || !versionedOnly) && r.timeTravelSpec.isEmpty =>
+ val currentTable = currentTables.getOrElseUpdate((catalog, ident), {
+ val tableName = V2TableUtil.toQualifiedName(catalog, ident)
+ lookupCachedRelation(spark, catalog, ident, r.table) match {
+ case Some(cached) =>
+ logDebug(s"Refreshing table metadata for $tableName using shared relation cache")
+ cached.table
+ case None =>
+ logDebug(s"Refreshing table metadata for $tableName using catalog")
+ catalog.loadTable(ident)
+ }
+ })
+ validateTableIdentity(currentTable, r)
+ validateDataColumns(currentTable, r, schemaValidationMode)
+ validateMetadataColumns(currentTable, r, schemaValidationMode)
+ r.copy(table = currentTable)
+ }
+ }
+
+ private def lookupCachedRelation(
+ spark: SparkSession,
+ catalog: TableCatalog,
+ ident: Identifier,
+ table: Table): Option[DataSourceV2Relation] = {
+ CatalogV2Util.lookupCachedRelation(spark.sharedState.relationCache, catalog, ident, table, conf)
+ }
+
+ // it is not safe to allow any schema changes in commands (e.g. CTAS, RTAS, MERGE)
+ private def determineSchemaValidationMode(plan: LogicalPlan): SchemaValidationMode = {
+ if (containsCommand(plan)) PROHIBIT_CHANGES else ALLOW_NEW_FIELDS
+ }
+
+ private def containsCommand(plan: LogicalPlan): Boolean = {
+ plan.find(_.isInstanceOf[Command]).isDefined
+ }
+
+ private def validateTableIdentity(currentTable: Table, relation: DataSourceV2Relation): Unit = {
+ if (relation.table.id != null && relation.table.id != currentTable.id) {
+ throw QueryCompilationErrors.tableIdChangedAfterAnalysis(
+ relation.name,
+ capturedTableId = relation.table.id,
+ currentTableId = currentTable.id)
+ }
+ }
+
+ private def validateDataColumns(
+ currentTable: Table,
+ relation: DataSourceV2Relation,
+ mode: SchemaValidationMode): Unit = {
+ val errors = V2TableUtil.validateCapturedColumns(currentTable, relation, mode)
+ if (errors.nonEmpty) {
+ throw QueryCompilationErrors.columnsChangedAfterAnalysis(relation.name, errors)
+ }
+ }
+
+ private def validateMetadataColumns(
+ currentTable: Table,
+ relation: DataSourceV2Relation,
+ mode: SchemaValidationMode): Unit = {
+ val errors = V2TableUtil.validateCapturedMetadataColumns(currentTable, relation, mode)
+ if (errors.nonEmpty) {
+ throw QueryCompilationErrors.metadataColumnsChangedAfterAnalysis(relation.name, errors)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index 7a2795b729f5a..464f0d9658d1f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -33,10 +33,11 @@ import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, DeltaWrite, DeltaWriter, MergeSummaryImpl, PhysicalWriteInfoImpl, Write, WriterCommitMessage, WriteSummary}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
-import org.apache.spark.sql.execution.{SparkPlan, SQLExecution, UnaryExecNode}
+import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SQLExecution, UnaryExecNode}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric, SQLMetrics}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.SchemaValidationMode.PROHIBIT_CHANGES
import org.apache.spark.util.{LongAccumulator, Utils}
import org.apache.spark.util.ArrayImplicits._
@@ -167,6 +168,14 @@ case class ReplaceTableAsSelectExec(
// 1. Creating the new table fails,
// 2. Writing to the new table fails,
// 3. The table returned by catalog.createTable doesn't support writing.
+ //
+ // RTAS must refresh and pin versions in query to read from original table versions instead of
+ // newly created empty table that is meant to serve as target for append/overwrite
+ val refreshedQuery = V2TableRefreshUtil.refresh(
+ session,
+ query,
+ versionedOnly = true,
+ schemaValidationMode = PROHIBIT_CHANGES)
if (catalog.tableExists(ident)) {
invalidateCache(catalog, ident)
catalog.dropTable(ident)
@@ -174,13 +183,15 @@ case class ReplaceTableAsSelectExec(
throw QueryCompilationErrors.cannotReplaceMissingTableError(ident)
}
val tableInfo = new TableInfo.Builder()
- .withColumns(getV2Columns(query.schema, catalog.useNullableQuerySchema))
+ .withColumns(getV2Columns(refreshedQuery.schema, catalog.useNullableQuerySchema))
.withPartitions(partitioning.toArray)
.withProperties(properties.asJava)
.build()
val table = Option(catalog.createTable(ident, tableInfo))
.getOrElse(catalog.loadTable(ident, Set(TableWritePrivilege.INSERT).asJava))
- writeToTable(catalog, table, writeOptions, ident, query, overwrite = true)
+ writeToTable(
+ catalog, table, writeOptions, ident, refreshedQuery,
+ overwrite = true, refreshPhaseEnabled = false)
}
}
@@ -344,7 +355,7 @@ case class WriteToDataSourceV2Exec(
query: SparkPlan,
writeMetrics: Seq[CustomMetric]) extends V2TableWriteExec {
- override val stringArgs: Iterator[Any] = Iterator(batchWrite, query)
+ override def stringArgs: Iterator[Any] = Iterator(batchWrite, query)
override val customMetrics: Map[String, SQLMetric] = writeMetrics.map { customMetric =>
customMetric.name() -> SQLMetrics.createV2CustomMetric(sparkContext, customMetric)
@@ -364,7 +375,7 @@ trait V2ExistingTableWriteExec extends V2TableWriteExec {
def refreshCache: () => Unit
def write: Write
- override val stringArgs: Iterator[Any] = Iterator(query, write)
+ override def stringArgs: Iterator[Any] = Iterator(query, write)
override val customMetrics: Map[String, SQLMetric] =
write.supportedCustomMetrics().map { customMetric =>
@@ -717,7 +728,8 @@ private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec {
writeOptions: Map[String, String],
ident: Identifier,
query: LogicalPlan,
- overwrite: Boolean): Seq[InternalRow] = {
+ overwrite: Boolean,
+ refreshPhaseEnabled: Boolean = true): Seq[InternalRow] = {
Utils.tryWithSafeFinallyAndFailureCallbacks({
val relation = DataSourceV2Relation.create(table, Some(catalog), Some(ident))
val writeCommand = if (overwrite) {
@@ -725,7 +737,7 @@ private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec {
} else {
AppendData.byPosition(relation, query, writeOptions)
}
- val qe = session.sessionState.executePlan(writeCommand)
+ val qe = QueryExecution.create(session, writeCommand, refreshPhaseEnabled)
qe.assertCommandExecuted()
DataSourceV2Utils.commitStagedChanges(sparkContext, table, metrics)
Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
index d347cb04f0bcf..5a427aad5f895 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
@@ -25,13 +25,13 @@ import org.apache.parquet.hadoop.ParquetInputFormat
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
-import org.apache.spark.sql.connector.read.{PartitionReaderFactory, SupportsPushDownVariants, VariantAccessInfo}
-import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex}
+import org.apache.spark.sql.connector.read.{PartitionReaderFactory, VariantExtraction}
+import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex, VariantMetadata}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport}
import org.apache.spark.sql.execution.datasources.v2.FileScan
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{BooleanType, DataType, StructField, StructType, VariantType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.SerializableConfiguration
@@ -48,8 +48,7 @@ case class ParquetScan(
pushedAggregate: Option[Aggregation] = None,
partitionFilters: Seq[Expression] = Seq.empty,
dataFilters: Seq[Expression] = Seq.empty,
- pushedVariantAccessInfo: Array[VariantAccessInfo] = Array.empty) extends FileScan
- with SupportsPushDownVariants {
+ pushedVariantExtractions: Array[VariantExtraction] = Array.empty) extends FileScan {
override def isSplitable(path: Path): Boolean = {
// If aggregate is pushed down, only the file footer will be read once,
// so file should not be split across multiple tasks.
@@ -58,20 +57,57 @@ case class ParquetScan(
// Build transformed schema if variant pushdown is active
private def effectiveReadDataSchema: StructType = {
- if (_pushedVariantAccess.isEmpty) {
+ if (pushedVariantExtractions.isEmpty) {
readDataSchema
} else {
- // Build a mapping from column name to extracted schema
- val variantSchemaMap = _pushedVariantAccess.map(info =>
- info.columnName() -> info.extractedSchema()).toMap
-
- // Transform the read data schema by replacing variant columns with their extracted schemas
- StructType(readDataSchema.map { field =>
- variantSchemaMap.get(field.name) match {
- case Some(extractedSchema) => field.copy(dataType = extractedSchema)
- case None => field
+ rewriteVariantPushdownSchema(readDataSchema)
+ }
+ }
+
+ private def rewriteVariantPushdownSchema(schema: StructType): StructType = {
+ // Group extractions by column name and build extracted schemas
+ val variantSchemaMap: Map[Seq[String], StructType] = pushedVariantExtractions
+ .groupBy(e => e.columnName().toSeq)
+ .map { case (colName, extractions) =>
+ // Build struct schema with ordinal-named fields for each extraction
+ var fields = extractions.zipWithIndex.map { case (extraction, idx) =>
+ // Attach VariantMetadata so Parquet reader knows this is a variant extraction
+ StructField(idx.toString, extraction.expectedDataType(), nullable = true,
+ extraction.metadata())
+ }
+
+ // Avoid producing an empty struct of requested fields. This happens
+ // if the variant is not used, or only used in `IsNotNull/IsNull` expressions.
+ // The value of the placeholder field doesn't matter.
+ if (fields.size == 1 && fields.head.dataType.isInstanceOf[VariantType]) {
+ val placeholder = VariantMetadata("$.__placeholder_field__",
+ failOnError = false, timeZoneId = "UTC")
+ fields = Array(StructField("0", BooleanType,
+ metadata = placeholder.toMetadata))
}
- })
+
+ colName -> StructType(fields)
+ }.toMap
+
+ rewriteType(schema, Seq.empty, variantSchemaMap).asInstanceOf[StructType]
+ }
+
+ private def rewriteType(
+ dataType: DataType,
+ path: Seq[String],
+ mapping: Map[Seq[String], StructType]): DataType = {
+ dataType match {
+ case structType: StructType if !VariantMetadata.isVariantStruct(structType) =>
+ val fields = structType.fields.map { field =>
+ mapping.get(path :+ field.name) match {
+ case Some(extractedSchema) =>
+ field.copy(dataType = extractedSchema)
+ case None =>
+ field.copy(dataType = rewriteType(field.dataType, path :+ field.name, mapping))
+ }
+ }
+ StructType(fields)
+ case otherType => otherType
}
}
@@ -84,38 +120,14 @@ case class ParquetScan(
// super.readSchema() combines readDataSchema + readPartitionSchema
// Apply variant transformation if variant pushdown is active
val baseSchema = super.readSchema()
- if (_pushedVariantAccess.isEmpty) {
+ if (pushedVariantExtractions.isEmpty) {
baseSchema
} else {
- val variantSchemaMap = _pushedVariantAccess.map(info =>
- info.columnName() -> info.extractedSchema()).toMap
- StructType(baseSchema.map { field =>
- variantSchemaMap.get(field.name) match {
- case Some(extractedSchema) => field.copy(dataType = extractedSchema)
- case None => field
- }
- })
+ rewriteVariantPushdownSchema(baseSchema)
}
}
}
- // SupportsPushDownVariants API implementation
- private var _pushedVariantAccess: Array[VariantAccessInfo] = pushedVariantAccessInfo
-
- override def pushVariantAccess(variantAccessInfo: Array[VariantAccessInfo]): Boolean = {
- // Parquet supports variant pushdown for all variant accesses
- if (variantAccessInfo.nonEmpty) {
- _pushedVariantAccess = variantAccessInfo
- true
- } else {
- false
- }
- }
-
- override def pushedVariantAccess(): Array[VariantAccessInfo] = {
- _pushedVariantAccess
- }
-
override def createReaderFactory(): PartitionReaderFactory = {
val effectiveSchema = effectiveReadDataSchema
val readDataSchemaAsJson = effectiveSchema.json
@@ -171,8 +183,8 @@ case class ParquetScan(
pushedAggregate.isEmpty && p.pushedAggregate.isEmpty
}
val pushedVariantEqual =
- java.util.Arrays.equals(_pushedVariantAccess.asInstanceOf[Array[Object]],
- p._pushedVariantAccess.asInstanceOf[Array[Object]])
+ java.util.Arrays.equals(pushedVariantExtractions.asInstanceOf[Array[Object]],
+ p.pushedVariantExtractions.asInstanceOf[Array[Object]])
super.equals(p) && dataSchema == p.dataSchema && options == p.options &&
equivalentFilters(pushedFilters, p.pushedFilters) && pushedDownAggEqual &&
pushedVariantEqual
@@ -189,15 +201,17 @@ case class ParquetScan(
}
override def getMetaData(): Map[String, String] = {
- val variantAccessStr = if (_pushedVariantAccess.nonEmpty) {
- _pushedVariantAccess.map(info =>
- s"${info.columnName()}->${info.extractedSchema()}").mkString("[", ", ", "]")
+ val variantExtractionStr = if (pushedVariantExtractions.nonEmpty) {
+ pushedVariantExtractions.map { extraction =>
+ val colName = extraction.columnName().mkString(".")
+ s"$colName:${extraction.metadata()}:${extraction.expectedDataType()}"
+ }.mkString("[", ", ", "]")
} else {
"[]"
}
super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters.toImmutableArraySeq)) ++
Map("PushedAggregation" -> pushedAggregationsStr) ++
Map("PushedGroupBy" -> pushedGroupByStr) ++
- Map("PushedVariantAccess" -> variantAccessStr)
+ Map("PushedVariantExtractions" -> variantExtractionStr)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
index 01367675e65b9..94da53f229349 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
@@ -22,7 +22,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
-import org.apache.spark.sql.connector.read.SupportsPushDownAggregates
+import org.apache.spark.sql.connector.read.{SupportsPushDownAggregates, SupportsPushDownVariantExtractions, VariantExtraction}
import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter}
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
@@ -39,7 +39,8 @@ case class ParquetScanBuilder(
dataSchema: StructType,
options: CaseInsensitiveStringMap)
extends FileScanBuilder(sparkSession, fileIndex, dataSchema)
- with SupportsPushDownAggregates {
+ with SupportsPushDownAggregates
+ with SupportsPushDownVariantExtractions {
lazy val hadoopConf = {
val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
// Hadoop Configurations are case sensitive.
@@ -50,6 +51,8 @@ case class ParquetScanBuilder(
private var pushedAggregations = Option.empty[Aggregation]
+ private var pushedVariantExtractions = Array.empty[VariantExtraction]
+
override protected val supportsNestedSchemaPruning: Boolean = true
override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = {
@@ -99,6 +102,14 @@ case class ParquetScanBuilder(
}
}
+ // SupportsPushDownVariantExtractions API implementation
+ override def pushVariantExtractions(extractions: Array[VariantExtraction]): Array[Boolean] = {
+ // Parquet supports variant pushdown for all variant extractions
+ pushedVariantExtractions = extractions
+ // Return true for all extractions (Parquet can handle all of them)
+ Array.fill(extractions.length)(true)
+ }
+
override def build(): ParquetScan = {
// the `finalSchema` is either pruned in pushAggregation (if aggregates are
// pushed down), or pruned in readDataSchema() (in regular column pruning). These
@@ -108,6 +119,6 @@ case class ParquetScanBuilder(
}
ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema,
readPartitionSchema(), pushedDataFilters, options, pushedAggregations,
- partitionFilters, dataFilters)
+ partitionFilters, dataFilters, pushedVariantExtractions)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
index c147030037cd8..47e64a5b4041a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
@@ -172,7 +172,8 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
pythonRunnerConf,
metrics,
jobArtifactUUID,
- sessionUUID)
+ sessionUUID,
+ conf.pythonUDFProfiler)
}
def createPythonMetrics(): Array[CustomMetric] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
index b97d765afcf79..0d180bd336221 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -140,6 +140,13 @@ case class EnsureRequirements(
// Choose all the specs that can be used to shuffle other children
val candidateSpecs = specs
.filter(_._2.canCreatePartitioning)
+ .filter {
+ // To choose a KeyGroupedShuffleSpec, we must be able to push down SPJ parameters into
+ // the scan (for join key positions). If these parameters can't be pushed down, this
+ // spec can't be used to shuffle other children.
+ case (idx, _: KeyGroupedShuffleSpec) => canPushDownSPJParamsToScan(children(idx))
+ case _ => true
+ }
.filter(p => !shouldConsiderMinParallelism ||
children(p._1).outputPartitioning.numPartitions >= conf.defaultNumShufflePartitions)
val bestSpecOpt = if (candidateSpecs.isEmpty) {
@@ -300,7 +307,7 @@ case class EnsureRequirements(
private def ensureOrdering(plan: SparkPlan, distribution: Distribution) = {
(plan.outputPartitioning, distribution) match {
- case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _),
+ case (p @ KeyGroupedPartitioning(expressions, _, partitionValues, _, _),
d @ OrderedDistribution(ordering)) if p.satisfies(d) =>
val attrs = expressions.flatMap(_.collectLeaves()).map(_.asInstanceOf[Attribute])
val partitionOrdering: Ordering[InternalRow] = {
@@ -333,12 +340,12 @@ case class EnsureRequirements(
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, rightExpressions, rightKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, leftPartitioning, None))
- case (Some(KeyGroupedPartitioning(clustering, _, _, _)), _) =>
+ case (Some(KeyGroupedPartitioning(clustering, _, _, _, _)), _) =>
val leafExprs = clustering.flatMap(_.collectLeaves())
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, leftKeys)
.orElse(reorderJoinKeysRecursively(
leftKeys, rightKeys, None, rightPartitioning))
- case (_, Some(KeyGroupedPartitioning(clustering, _, _, _))) =>
+ case (_, Some(KeyGroupedPartitioning(clustering, _, _, _, _))) =>
val leafExprs = clustering.flatMap(_.collectLeaves())
reorder(leftKeys.toIndexedSeq, rightKeys.toIndexedSeq, leafExprs, rightKeys)
.orElse(reorderJoinKeysRecursively(
@@ -402,6 +409,24 @@ case class EnsureRequirements(
}
}
+ /**
+ * Whether SPJ params can be pushed down to the leaf nodes of a physical plan. For a plan to be
+ * eligible for SPJ parameter pushdown, all leaf nodes must be a KeyGroupedPartitioning-aware
+ * scan.
+ *
+ * Notably, if the leaf of `plan` is an [[RDDScanExec]] created by checkpointing a DSv2 scan, the
+ * reported partitioning will be a [[KeyGroupedPartitioning]], but this plan will _not_ be
+ * eligible for SPJ parameter pushdown (as the partitioning is static and can't be easily
+ * re-grouped or padded with empty partitions according to the partition values on the other side
+ * of the join).
+ */
+ private def canPushDownSPJParamsToScan(plan: SparkPlan): Boolean = {
+ plan.collectLeaves().forall {
+ case _: KeyGroupedPartitionedScan[_] => true
+ case _ => false
+ }
+ }
+
/**
* Checks whether two children, `left` and `right`, of a join operator have compatible
* `KeyGroupedPartitioning`, and can benefit from storage-partitioned join.
@@ -413,6 +438,12 @@ case class EnsureRequirements(
left: SparkPlan,
right: SparkPlan,
requiredChildDistribution: Seq[Distribution]): Option[Seq[SparkPlan]] = {
+ // If SPJ params can't be pushed down to either the left or right side, it's unsafe to do an
+ // SPJ.
+ if (!canPushDownSPJParamsToScan(left) || !canPushDownSPJParamsToScan(right)) {
+ return None
+ }
+
parent match {
case smj: SortMergeJoinExec =>
checkKeyGroupCompatible(left, right, smj.joinType, requiredChildDistribution)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index f052bd9068805..c59bb4d39b096 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -370,7 +370,7 @@ object ShuffleExchangeExec {
ascending = true,
samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition)
case SinglePartition => new ConstantPartitioner
- case k @ KeyGroupedPartitioning(expressions, n, _, _) =>
+ case k @ KeyGroupedPartitioning(expressions, n, _, _, _) =>
val valueMap = k.uniquePartitionValues.zipWithIndex.map {
case (partition, index) => (partition.toSeq(expressions.map(_.dataType)), index)
}.toMap
@@ -401,7 +401,7 @@ object ShuffleExchangeExec {
val projection = UnsafeProjection.create(sortingExpressions.map(_.child), outputAttributes)
row => projection(row)
case SinglePartition => identity
- case KeyGroupedPartitioning(expressions, _, _, _) =>
+ case KeyGroupedPartitioning(expressions, _, _, _, _) =>
row => bindReferences(expressions, outputAttributes).map(_.eval(row))
case s: ShufflePartitionIdPassThrough =>
// For ShufflePartitionIdPassThrough, the expression directly evaluates to the partition ID
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index 368534d05b1f0..944ee3b059092 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -25,7 +25,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
+import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, JoinSelectionHelper}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, HashPartitioningLike, Partitioning, PartitioningCollection, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan}
@@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
* broadcast relation. This data is then placed in a Spark broadcast variable. The streamed
* relation is not shuffled.
*/
-case class BroadcastHashJoinExec(
+case class BroadcastHashJoinExec private(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
@@ -129,7 +129,10 @@ case class BroadcastHashJoinExec(
val hashed = broadcastRelation.value.asReadOnlyCopy()
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize)
if (hashed == EmptyHashedRelation) {
- streamedIter
+ streamedIter.map { row =>
+ numOutputRows += 1
+ row
+ }
} else if (hashed == HashedRelationWithAllNullKeys) {
Iterator.empty
} else {
@@ -242,3 +245,27 @@ case class BroadcastHashJoinExec(
newLeft: SparkPlan, newRight: SparkPlan): BroadcastHashJoinExec =
copy(left = newLeft, right = newRight)
}
+
+object BroadcastHashJoinExec extends JoinSelectionHelper {
+ def apply(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ buildSide: BuildSide,
+ condition: Option[Expression],
+ left: SparkPlan,
+ right: SparkPlan,
+ isNullAwareAntiJoin: Boolean = false): BroadcastHashJoinExec = {
+ val (normalizedLeftKeys, normalizedRightKeys) = HashJoin.normalizeJoinKeys(leftKeys, rightKeys)
+
+ new BroadcastHashJoinExec(
+ normalizedLeftKeys,
+ normalizedRightKeys,
+ joinType,
+ buildSide,
+ condition,
+ left,
+ right,
+ isNullAwareAntiJoin)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index a1abb64e262df..fab14dba444dc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.catalyst.util.UnsafeRowUtils
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{CodegenSupport, ExplainUtils, RowIterator}
import org.apache.spark.sql.execution.metric.SQLMetric
@@ -41,6 +42,9 @@ private[joins] case class HashedRelationInfo(
isEmpty: Boolean)
trait HashJoin extends JoinCodegenSupport {
+ assert(leftKeys.forall(key => UnsafeRowUtils.isBinaryStable(key.dataType)))
+ assert(rightKeys.forall(key => UnsafeRowUtils.isBinaryStable(key.dataType)))
+
def buildSide: BuildSide
override def simpleStringWithNodeId(): String = {
@@ -724,6 +728,18 @@ trait HashJoin extends JoinCodegenSupport {
object HashJoin extends CastSupport with SQLConfHelper {
+ /**
+ * Normalize join keys by injecting `CollationKey` when the keys are collated.
+ */
+ def normalizeJoinKeys(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
+ (
+ leftKeys.map(CollationKey.injectCollationKey),
+ rightKeys.map(CollationKey.injectCollationKey)
+ )
+ }
+
private def canRewriteAsLongType(keys: Seq[Expression]): Boolean = {
// TODO: support BooleanType, DateType and TimestampType
keys.forall(_.dataType.isInstanceOf[IntegralType]) &&
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 85c1982905420..c67b55fd1d50c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -143,8 +143,8 @@ private[execution] object HashedRelation {
new TaskMemoryManager(
new UnifiedMemoryManager(
new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
- Long.MaxValue,
- Long.MaxValue / 2,
+ Runtime.getRuntime.maxMemory,
+ Runtime.getRuntime.maxMemory / 2,
1),
0)
}
@@ -401,8 +401,8 @@ private[joins] class UnsafeHashedRelation(
val taskMemoryManager = new TaskMemoryManager(
new UnifiedMemoryManager(
new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
- Long.MaxValue,
- Long.MaxValue / 2,
+ Runtime.getRuntime.maxMemory,
+ Runtime.getRuntime.maxMemory / 2,
1),
0)
@@ -576,8 +576,8 @@ private[execution] final class LongToUnsafeRowMap(
new TaskMemoryManager(
new UnifiedMemoryManager(
new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
- Long.MaxValue,
- Long.MaxValue / 2,
+ Runtime.getRuntime.maxMemory,
+ Runtime.getRuntime.maxMemory / 2,
1),
0),
0)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index 97ca74aee30c0..0f90f443ad41d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -35,7 +35,7 @@ import org.apache.spark.util.collection.{BitSet, OpenHashSet}
/**
* Performs a hash join of two child relations by first shuffling the data using the join keys.
*/
-case class ShuffledHashJoinExec(
+case class ShuffledHashJoinExec private (
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
@@ -659,3 +659,27 @@ case class ShuffledHashJoinExec(
newLeft: SparkPlan, newRight: SparkPlan): ShuffledHashJoinExec =
copy(left = newLeft, right = newRight)
}
+
+object ShuffledHashJoinExec {
+ def apply(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ joinType: JoinType,
+ buildSide: BuildSide,
+ condition: Option[Expression],
+ left: SparkPlan,
+ right: SparkPlan,
+ isSkewJoin: Boolean = false): ShuffledHashJoinExec = {
+ val (normalizedLeftKeys, normalizedRightKeys) = HashJoin.normalizeJoinKeys(leftKeys, rightKeys)
+
+ new ShuffledHashJoinExec(
+ normalizedLeftKeys,
+ normalizedRightKeys,
+ joinType,
+ buildSide,
+ condition,
+ left,
+ right,
+ isSkewJoin)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index b94e00bc11ef2..f5f968ee95228 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -60,6 +60,8 @@ abstract class BaseArrowPythonRunner[IN, OUT <: AnyRef](
override val killOnIdleTimeout: Boolean = SQLConf.get.pythonUDFWorkerKillOnIdleTimeout
override val tracebackDumpIntervalSeconds: Long =
SQLConf.get.pythonUDFWorkerTracebackDumpIntervalSeconds
+ override val killWorkerOnFlushFailure: Boolean =
+ SQLConf.get.pythonUDFDaemonKillWorkerOnFlushFailure
override val errorOnDuplicatedFieldNames: Boolean = true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
index 7b73818bf0ec1..1d5df9bad9247 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
@@ -83,6 +83,8 @@ class ArrowPythonUDTFRunner(
override val killOnIdleTimeout: Boolean = SQLConf.get.pythonUDFWorkerKillOnIdleTimeout
override val tracebackDumpIntervalSeconds: Long =
SQLConf.get.pythonUDFWorkerTracebackDumpIntervalSeconds
+ override val killWorkerOnFlushFailure: Boolean =
+ SQLConf.get.pythonUDFDaemonKillWorkerOnFlushFailure
override val errorOnDuplicatedFieldNames: Boolean = true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index 00eb9039d05cf..7f6efbae8881d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -20,7 +20,11 @@ package org.apache.spark.sql.execution.python
import java.io.DataOutputStream
import java.util
-import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec}
+import org.apache.arrow.vector.{VectorSchemaRoot, VectorUnloader}
+import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec}
+
+import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD, PythonWorker}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow.ArrowWriterWrapper
@@ -67,6 +71,8 @@ class CoGroupedArrowPythonRunner(
override val killOnIdleTimeout: Boolean = SQLConf.get.pythonUDFWorkerKillOnIdleTimeout
override val tracebackDumpIntervalSeconds: Long =
SQLConf.get.pythonUDFWorkerTracebackDumpIntervalSeconds
+ override val killWorkerOnFlushFailure: Boolean =
+ SQLConf.get.pythonUDFDaemonKillWorkerOnFlushFailure
override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
@@ -76,6 +82,27 @@ class CoGroupedArrowPythonRunner(
if (v > 0) v else Int.MaxValue
}
private val maxBytesPerBatch: Long = SQLConf.get.arrowMaxBytesPerBatch
+ private val compressionCodecName: String = SQLConf.get.arrowCompressionCodec
+
+ // Helper method to create VectorUnloader with compression
+ private def createUnloader(root: VectorSchemaRoot): VectorUnloader = {
+ val codec = compressionCodecName match {
+ case "none" => NoCompressionCodec.INSTANCE
+ case "zstd" =>
+ val compressionLevel = SQLConf.get.arrowZstdCompressionLevel
+ val factory = CompressionCodec.Factory.INSTANCE
+ val codecType = new ZstdCompressionCodec(compressionLevel).getCodecType()
+ factory.createCodec(codecType)
+ case "lz4" =>
+ val factory = CompressionCodec.Factory.INSTANCE
+ val codecType = new Lz4CompressionCodec().getCodecType()
+ factory.createCodec(codecType)
+ case other =>
+ throw SparkException.internalError(
+ s"Unsupported Arrow compression codec: $other. Supported values: none, zstd, lz4")
+ }
+ new VectorUnloader(root, true, codec, true)
+ }
protected def newWriter(
env: SparkEnv,
@@ -136,13 +163,17 @@ class CoGroupedArrowPythonRunner(
leftGroupArrowWriter = ArrowWriterWrapper.createAndStartArrowWriter(leftSchema,
timeZoneId, pythonExec + " (left)", errorOnDuplicatedFieldNames = true,
largeVarTypes, dataOut, context)
+ // Set the unloader with compression after creating the writer
+ leftGroupArrowWriter.unloader = createUnloader(leftGroupArrowWriter.root)
}
numRowsInBatch = BatchedPythonArrowInput.writeSizedBatch(
leftGroupArrowWriter.arrowWriter,
leftGroupArrowWriter.streamWriter,
nextBatchInLeftGroup,
maxBytesPerBatch,
- maxRecordsPerBatch)
+ maxRecordsPerBatch,
+ leftGroupArrowWriter.unloader,
+ dataOut)
if (!nextBatchInLeftGroup.hasNext) {
leftGroupArrowWriter.streamWriter.end()
@@ -155,13 +186,17 @@ class CoGroupedArrowPythonRunner(
rightGroupArrowWriter = ArrowWriterWrapper.createAndStartArrowWriter(rightSchema,
timeZoneId, pythonExec + " (right)", errorOnDuplicatedFieldNames = true,
largeVarTypes, dataOut, context)
+ // Set the unloader with compression after creating the writer
+ rightGroupArrowWriter.unloader = createUnloader(rightGroupArrowWriter.root)
}
numRowsInBatch = BatchedPythonArrowInput.writeSizedBatch(
rightGroupArrowWriter.arrowWriter,
rightGroupArrowWriter.streamWriter,
nextBatchInRightGroup,
maxBytesPerBatch,
- maxRecordsPerBatch)
+ maxRecordsPerBatch,
+ rightGroupArrowWriter.unloader,
+ dataOut)
if (!nextBatchInRightGroup.hasNext) {
rightGroupArrowWriter.streamWriter.end()
rightGroupArrowWriter.close()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index 212cc5db124ce..33622ca7349a9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -29,9 +29,9 @@ import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData, STUtils}
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
+import org.apache.spark.unsafe.types.{GeographyVal, GeometryVal, UTF8String, VariantVal}
object EvaluatePython {
@@ -43,7 +43,7 @@ object EvaluatePython {
def needConversionInPython(dt: DataType): Boolean = dt match {
case DateType | TimestampType | TimestampNTZType | VariantType | _: DayTimeIntervalType
- | _: TimeType => true
+ | _: TimeType | _: GeometryType | _: GeographyType => true
case _: StructType => true
case _: UserDefinedType[_] => true
case ArrayType(elementType, _) => needConversionInPython(elementType)
@@ -92,6 +92,10 @@ object EvaluatePython {
case (s: UTF8String, _: StringType) => s.toString
+ case (g: GeometryVal, gt: GeometryType) => STUtils.deserializeGeom(g, gt)
+
+ case (g: GeographyVal, gt: GeographyType) => STUtils.deserializeGeog(g, gt)
+
case (bytes: Array[Byte], BinaryType) =>
if (binaryAsBytes) {
new BytesWrapper(bytes)
@@ -228,6 +232,23 @@ object EvaluatePython {
)
}
+ case g: GeographyType => (obj: Any) => nullSafeConvert(obj) {
+ case s: java.util.HashMap[_, _] =>
+ val geographySrid = s.get("srid").asInstanceOf[Int]
+ g.assertSridAllowedForType(geographySrid)
+ STUtils.stGeogFromWKB(
+ s.get("wkb").asInstanceOf[Array[Byte]])
+ }
+
+ case g: GeometryType => (obj: Any) => nullSafeConvert(obj) {
+ case s: java.util.HashMap[_, _] =>
+ val geometrySrid = s.get("srid").asInstanceOf[Int]
+ g.assertSridAllowedForType(geometrySrid)
+ STUtils.stGeomFromWKB(
+ s.get("wkb").asInstanceOf[Array[Byte]],
+ geometrySrid)
+ }
+
case other => (obj: Any) => nullSafeConvert(obj)(PartialFunction.empty)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
index 4e78b3035a7ec..51909df26a567 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
@@ -41,7 +41,8 @@ class MapInBatchEvaluatorFactory(
pythonRunnerConf: Map[String, String],
val pythonMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String],
- sessionUUID: Option[String])
+ sessionUUID: Option[String],
+ profiler: Option[String])
extends PartitionEvaluatorFactory[InternalRow, InternalRow] {
override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] =
@@ -74,7 +75,7 @@ class MapInBatchEvaluatorFactory(
pythonMetrics,
jobArtifactUUID,
sessionUUID,
- None) with BatchedPythonArrowInput
+ profiler) with BatchedPythonArrowInput
val columnarBatchIter = pyRunner.compute(batchIter, context.partitionId(), context)
val unsafeProj = UnsafeProjection.create(output, output)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
index 1d03c0cf76037..c4f090674e7c5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
@@ -70,7 +70,8 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics {
pythonRunnerConf,
pythonMetrics,
jobArtifactUUID,
- sessionUUID)
+ sessionUUID,
+ conf.pythonUDFProfiler)
val rdd = if (isBarrier) {
val rddBarrier = child.execute().barrier()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index b2ec96c5b29f8..f77b0a9342b02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -17,11 +17,16 @@
package org.apache.spark.sql.execution.python
import java.io.DataOutputStream
+import java.nio.channels.Channels
-import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.compression.{Lz4CompressionCodec, ZstdCompressionCodec}
+import org.apache.arrow.vector.{VectorSchemaRoot, VectorUnloader}
+import org.apache.arrow.vector.compression.{CompressionCodec, NoCompressionCodec}
import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.apache.arrow.vector.ipc.WriteChannel
+import org.apache.arrow.vector.ipc.message.MessageSerializer
-import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.{SparkEnv, SparkException, TaskContext}
import org.apache.spark.api.python.{BasePythonRunner, PythonRDD, PythonWorker}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.arrow
@@ -70,6 +75,26 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] =>
protected val allocator =
ArrowUtils.rootAllocator.newChildAllocator(s"stdout writer for $pythonExec", 0, Long.MaxValue)
protected val root = VectorSchemaRoot.create(arrowSchema, allocator)
+
+ // Create compression codec based on config
+ private val compressionCodecName = SQLConf.get.arrowCompressionCodec
+ private val codec = compressionCodecName match {
+ case "none" => NoCompressionCodec.INSTANCE
+ case "zstd" =>
+ val compressionLevel = SQLConf.get.arrowZstdCompressionLevel
+ val factory = CompressionCodec.Factory.INSTANCE
+ val codecType = new ZstdCompressionCodec(compressionLevel).getCodecType()
+ factory.createCodec(codecType)
+ case "lz4" =>
+ val factory = CompressionCodec.Factory.INSTANCE
+ val codecType = new Lz4CompressionCodec().getCodecType()
+ factory.createCodec(codecType)
+ case other =>
+ throw SparkException.internalError(
+ s"Unsupported Arrow compression codec: $other. Supported values: none, zstd, lz4")
+ }
+ protected val unloader = new VectorUnloader(root, true, codec, true)
+
protected var writer: ArrowStreamWriter = _
protected def close(): Unit = {
@@ -137,7 +162,14 @@ private[python] trait BasicPythonArrowInput extends PythonArrowInput[Iterator[In
}
arrowWriter.finish()
- writer.writeBatch()
+ // Use unloader to get compressed batch and write it manually
+ val batch = unloader.getRecordBatch()
+ try {
+ val writeChannel = new WriteChannel(Channels.newChannel(dataOut))
+ MessageSerializer.serialize(writeChannel, batch)
+ } finally {
+ batch.close()
+ }
arrowWriter.reset()
val deltaData = dataOut.size() - startData
pythonMetrics("pythonDataSent") += deltaData
@@ -169,7 +201,8 @@ private[python] trait BatchedPythonArrowInput extends BasicPythonArrowInput {
val startData = dataOut.size()
val numRowsInBatch = BatchedPythonArrowInput.writeSizedBatch(
- arrowWriter, writer, nextBatchStart, maxBytesPerBatch, maxRecordsPerBatch)
+ arrowWriter, writer, nextBatchStart, maxBytesPerBatch, maxRecordsPerBatch, unloader,
+ dataOut)
assert(0 < numRowsInBatch && numRowsInBatch <= maxRecordsPerBatch, numRowsInBatch)
val deltaData = dataOut.size() - startData
@@ -209,7 +242,9 @@ private[python] object BatchedPythonArrowInput {
writer: ArrowStreamWriter,
rowIter: Iterator[InternalRow],
maxBytesPerBatch: Long,
- maxRecordsPerBatch: Int): Int = {
+ maxRecordsPerBatch: Int,
+ unloader: VectorUnloader,
+ dataOut: DataOutputStream): Int = {
var numRowsInBatch: Int = 0
def underBatchSizeLimit: Boolean =
@@ -221,7 +256,14 @@ private[python] object BatchedPythonArrowInput {
numRowsInBatch += 1
}
arrowWriter.finish()
- writer.writeBatch()
+ // Use unloader to get compressed batch and write it manually
+ val batch = unloader.getRecordBatch()
+ try {
+ val writeChannel = new WriteChannel(Channels.newChannel(dataOut))
+ MessageSerializer.serialize(writeChannel, batch)
+ } finally {
+ batch.close()
+ }
arrowWriter.reset()
numRowsInBatch
}
@@ -231,6 +273,26 @@ private[python] object BatchedPythonArrowInput {
* Enables an optimization that splits each group into the sized batches.
*/
private[python] trait GroupedPythonArrowInput { self: RowInputArrowPythonRunner =>
+
+ // Helper method to create VectorUnloader with compression for grouped operations
+ private def createUnloaderForGroup(root: VectorSchemaRoot): VectorUnloader = {
+ val codec = SQLConf.get.arrowCompressionCodec match {
+ case "none" => NoCompressionCodec.INSTANCE
+ case "zstd" =>
+ val compressionLevel = SQLConf.get.arrowZstdCompressionLevel
+ val factory = CompressionCodec.Factory.INSTANCE
+ val codecType = new ZstdCompressionCodec(compressionLevel).getCodecType()
+ factory.createCodec(codecType)
+ case "lz4" =>
+ val factory = CompressionCodec.Factory.INSTANCE
+ val codecType = new Lz4CompressionCodec().getCodecType()
+ factory.createCodec(codecType)
+ case other =>
+ throw SparkException.internalError(
+ s"Unsupported Arrow compression codec: $other. Supported values: none, zstd, lz4")
+ }
+ new VectorUnloader(root, true, codec, true)
+ }
protected override def newWriter(
env: SparkEnv,
worker: PythonWorker,
@@ -255,13 +317,16 @@ private[python] trait GroupedPythonArrowInput { self: RowInputArrowPythonRunner
writer = ArrowWriterWrapper.createAndStartArrowWriter(
schema, timeZoneId, pythonExec,
errorOnDuplicatedFieldNames, largeVarTypes, dataOut, context)
+ // Set the unloader with compression after creating the writer
+ writer.unloader = createUnloaderForGroup(writer.root)
nextBatchStart = inputIterator.next()
}
}
if (nextBatchStart.hasNext) {
val startData = dataOut.size()
val numRowsInBatch: Int = BatchedPythonArrowInput.writeSizedBatch(writer.arrowWriter,
- writer.streamWriter, nextBatchStart, maxBytesPerBatch, maxRecordsPerBatch)
+ writer.streamWriter, nextBatchStart, maxBytesPerBatch, maxRecordsPerBatch,
+ writer.unloader, dataOut)
if (!nextBatchStart.hasNext) {
writer.streamWriter.end()
// We don't need a try catch block here as the close() method is registered with
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
index 018619a5207df..1e8f4ebfd1fee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
@@ -75,6 +75,9 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[
private var processor: ArrowOutputProcessor = _
context.addTaskCompletionListener[Unit] { _ =>
+ if (processor != null) {
+ processor.close()
+ }
if (reader != null) {
reader.close(false)
}
@@ -241,6 +244,7 @@ abstract class BaseSliceArrowOutputProcessor(
prevVectors.foreach(_.close())
prevRoot.close()
}
+ super.close()
}
}
@@ -284,8 +288,8 @@ class SliceBytesArrowOutputProcessorImpl(
}
}
- private def getBatchBytes(root: VectorSchemaRoot): Int = {
- var batchBytes = 0
+ private def getBatchBytes(root: VectorSchemaRoot): Long = {
+ var batchBytes = 0L
root.getFieldVectors.asScala.foreach { vector =>
batchBytes += vector.getBufferSize
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
index 0f4ac4ddad719..92e99cdc11d97 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
@@ -58,6 +58,7 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) extends Logging {
val idleTimeoutSeconds: Long = SQLConf.get.pythonUDFWorkerIdleTimeoutSeconds
val killOnIdleTimeout: Boolean = SQLConf.get.pythonUDFWorkerKillOnIdleTimeout
val tracebackDumpIntervalSeconds: Long = SQLConf.get.pythonUDFWorkerTracebackDumpIntervalSeconds
+ val killWorkerOnFlushFailure: Boolean = SQLConf.get.pythonUDFDaemonKillWorkerOnFlushFailure
val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
val workerMemoryMb = SQLConf.get.pythonPlannerExecMemory
@@ -98,6 +99,9 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) extends Logging {
if (tracebackDumpIntervalSeconds > 0L) {
envVars.put("PYTHON_TRACEBACK_DUMP_INTERVAL_SECONDS", tracebackDumpIntervalSeconds.toString)
}
+ if (useDaemon && killWorkerOnFlushFailure) {
+ envVars.put("PYTHON_DAEMON_KILL_WORKER_ON_FLUSH_FAILURE", "1")
+ }
envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default"))
sessionUUID.foreach { uuid =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
index 61f493deeee49..759aa998832db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala
@@ -59,6 +59,8 @@ abstract class BasePythonUDFRunner(
override val killOnIdleTimeout: Boolean = SQLConf.get.pythonUDFWorkerKillOnIdleTimeout
override val tracebackDumpIntervalSeconds: Long =
SQLConf.get.pythonUDFWorkerTracebackDumpIntervalSeconds
+ override val killWorkerOnFlushFailure: Boolean =
+ SQLConf.get.pythonUDFDaemonKillWorkerOnFlushFailure
override val bufferSize: Int = SQLConf.get.getConf(SQLConf.PYTHON_UDF_BUFFER_SIZE)
override val batchSizeForPythonUDF: Int =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
index 51d9f6f523a23..14054ba89a948 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/ApplyInPandasWithStatePythonRunner.scala
@@ -79,6 +79,8 @@ class ApplyInPandasWithStatePythonRunner(
override val killOnIdleTimeout: Boolean = SQLConf.get.pythonUDFWorkerKillOnIdleTimeout
override val tracebackDumpIntervalSeconds: Long =
SQLConf.get.pythonUDFWorkerTracebackDumpIntervalSeconds
+ override val killWorkerOnFlushFailure: Boolean =
+ SQLConf.get.pythonUDFDaemonKillWorkerOnFlushFailure
private val sqlConf = SQLConf.get
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala
index 37716d2d8413b..cc7745210a4d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonForeachWriter.scala
@@ -106,6 +106,8 @@ class PythonForeachWriter(func: PythonFunction, schema: StructType)
override val killOnIdleTimeout: Boolean = SQLConf.get.pythonUDFWorkerKillOnIdleTimeout
override val tracebackDumpIntervalSeconds: Long =
SQLConf.get.pythonUDFWorkerTracebackDumpIntervalSeconds
+ override val killWorkerOnFlushFailure: Boolean =
+ SQLConf.get.pythonUDFDaemonKillWorkerOnFlushFailure
override val hideTraceback: Boolean = SQLConf.get.pysparkHideTraceback
override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/ChecksumCheckpointFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/ChecksumCheckpointFileManager.scala
index 7f801392c2f4a..637d11ad890bf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/ChecksumCheckpointFileManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/ChecksumCheckpointFileManager.scala
@@ -133,11 +133,21 @@ case class ChecksumFile(path: Path) {
* number of threads using file manager * 2.
* Setting this differently can lead to file operation being blocked waiting for
* a free thread.
+ * @param skipCreationIfFileMissingChecksum (ES-1629547): If true, when a file already exists
+ * but its checksum file does not exist, fall back to using the underlying
+ * file manager directly instead of creating with checksum. This is useful
+ * for compatibility with files created before checksums were enabled. Consider
+ * the case when a batch fails but state files are written. If on the next run,
+ * we try to upload both a new file and a checksum file, the file could fail to be
+ * uploaded but the checksum file is uploaded successfully. This would lead to a
+ * situation where the old file could be loaded and compared with the new file
+ * checksum, which would fail the checksum verification.
*/
class ChecksumCheckpointFileManager(
private val underlyingFileMgr: CheckpointFileManager,
val allowConcurrentDelete: Boolean = false,
- val numThreads: Int)
+ val numThreads: Int,
+ val skipCreationIfFileMissingChecksum: Boolean)
extends CheckpointFileManager with Logging {
assert(numThreads % 2 == 0, "numThreads must be a multiple of 2, we need 1 for the main file" +
"and another for the checksum file")
@@ -160,9 +170,18 @@ class ChecksumCheckpointFileManager(
underlyingFileMgr.mkdirs(path)
}
+ private def shouldSkipChecksumCreation(path: Path): Boolean = {
+ skipCreationIfFileMissingChecksum &&
+ underlyingFileMgr.exists(path) && !underlyingFileMgr.exists(getChecksumPath(path))
+ }
+
override def createAtomic(path: Path,
overwriteIfPossible: Boolean): CancellableFSDataOutputStream = {
- createWithChecksum(path, underlyingFileMgr.createAtomic(_, overwriteIfPossible))
+ if (shouldSkipChecksumCreation(path)) {
+ underlyingFileMgr.createAtomic(path, overwriteIfPossible)
+ } else {
+ createWithChecksum(path, underlyingFileMgr.createAtomic(_, overwriteIfPossible))
+ }
}
private def createWithChecksum(path: Path,
@@ -327,8 +346,11 @@ class ChecksumFSDataInputStream(
override def close(): Unit = {
if (!closed) {
// We verify the checksum only when the client is done reading.
- verifyChecksum()
- closeInternal()
+ try {
+ verifyChecksum()
+ } finally {
+ closeInternal()
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/FileStreamSource.scala
index d5503f1c247da..9847bd9d76448 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/FileStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/FileStreamSource.scala
@@ -149,8 +149,7 @@ class FileStreamSource(
var rSize = BigInt(0)
val lFiles = ArrayBuffer[NewFileEntry]()
val rFiles = ArrayBuffer[NewFileEntry]()
- for (i <- files.indices) {
- val file = files(i)
+ files.zipWithIndex.foreach { case (file, i) =>
val newSize = lSize + file.size
if (i == 0 || rFiles.isEmpty && newSize <= Long.MaxValue && newSize <= maxSize) {
lSize += file.size
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
index d4f9dc8cea93a..cf2fca3d3cd8b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/MicroBatchExecution.scala
@@ -41,7 +41,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, RealTimeStreamScanExec, StreamingDataSourceV2Relation, StreamingDataSourceV2ScanRelation, StreamWriterCommitProgress, WriteToDataSourceV2Exec}
-import org.apache.spark.sql.execution.streaming.{AvailableNowTrigger, Offset, OneTimeTrigger, ProcessingTimeTrigger, RealTimeTrigger, Sink, Source, StreamingQueryPlanTraverseHelper}
+import org.apache.spark.sql.execution.streaming.{AvailableNowTrigger, Offset, OneTimeTrigger, ProcessingTimeTrigger, RealTimeModeAllowlist, RealTimeTrigger, Sink, Source, StreamingQueryPlanTraverseHelper}
import org.apache.spark.sql.execution.streaming.checkpointing.{CheckpointFileManager, CommitMetadata, OffsetSeq, OffsetSeqMetadata}
import org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOperatorStateInfo, StatefulOpStateStoreCheckpointInfo, StateStoreWriter}
import org.apache.spark.sql.execution.streaming.runtime.AcceptsLatestSeenOffsetHandler
@@ -436,7 +436,10 @@ class MicroBatchExecution(
}
}
- if (containsStatefulOperator(analyzedPlan)) {
+ if (trigger.isInstanceOf[RealTimeTrigger]) {
+ logWarning(log"Disabling AQE since AQE is not supported for Real-time Mode.")
+ sparkSessionToRunBatches.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false")
+ } else if (containsStatefulOperator(analyzedPlan)) {
// SPARK-53941: We disable AQE for stateful workloads as of now.
logWarning(log"Disabling AQE since AQE is not supported in stateful workloads.")
sparkSessionToRunBatches.conf.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "false")
@@ -1042,6 +1045,14 @@ class MicroBatchExecution(
markMicroBatchExecutionStart(execCtx)
+ if (trigger.isInstanceOf[RealTimeTrigger]) {
+ RealTimeModeAllowlist.checkAllowedPhysicalOperator(
+ execCtx.executionPlan.executedPlan,
+ sparkSession.sessionState.conf.getConf(
+ SQLConf.STREAMING_REAL_TIME_MODE_ALLOWLIST_CHECK)
+ )
+ }
+
if (execCtx.previousContext.isEmpty) {
purgeStatefulMetadataAsync(execCtx.executionPlan.executedPlan)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/RealTimeModeAllowlist.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/RealTimeModeAllowlist.scala
new file mode 100644
index 0000000000000..443c7fa1a1cf6
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/RealTimeModeAllowlist.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.SparkIllegalArgumentException
+import org.apache.spark.internal.{Logging, LogKeys, MessageWithContext}
+import org.apache.spark.sql.connector.catalog.Table
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.datasources.v2.RealTimeStreamScanExec
+import org.apache.spark.sql.execution.streaming.operators.stateful._
+
+object RealTimeModeAllowlist extends Logging {
+ private val allowedSinks = Set(
+ "org.apache.spark.sql.execution.streaming.ConsoleTable$",
+ "org.apache.spark.sql.execution.streaming.sources.ContinuousMemorySink",
+ "org.apache.spark.sql.execution.streaming.sources.ForeachWriterTable",
+ "org.apache.spark.sql.kafka010.KafkaSourceProvider$KafkaTable"
+ )
+
+ private val allowedOperators = Set(
+ "org.apache.spark.sql.execution.AppendColumnsExec",
+ "org.apache.spark.sql.execution.CollectMetricsExec",
+ "org.apache.spark.sql.execution.ColumnarToRowExec",
+ "org.apache.spark.sql.execution.DeserializeToObjectExec",
+ "org.apache.spark.sql.execution.ExpandExec",
+ "org.apache.spark.sql.execution.FileSourceScanExec",
+ "org.apache.spark.sql.execution.FilterExec",
+ "org.apache.spark.sql.execution.GenerateExec",
+ "org.apache.spark.sql.execution.InputAdapter",
+ "org.apache.spark.sql.execution.LocalTableScanExec",
+ "org.apache.spark.sql.execution.MapElementsExec",
+ "org.apache.spark.sql.execution.MapPartitionsExec",
+ "org.apache.spark.sql.execution.PlanLater",
+ "org.apache.spark.sql.execution.ProjectExec",
+ "org.apache.spark.sql.execution.RangeExec",
+ "org.apache.spark.sql.execution.SerializeFromObjectExec",
+ "org.apache.spark.sql.execution.UnionExec",
+ "org.apache.spark.sql.execution.WholeStageCodegenExec",
+ "org.apache.spark.sql.execution.datasources.v2.RealTimeStreamScanExec",
+ "org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2Exec",
+ "org.apache.spark.sql.execution.exchange.BroadcastExchangeExec",
+ "org.apache.spark.sql.execution.exchange.ReusedExchangeExec",
+ "org.apache.spark.sql.execution.joins.BroadcastHashJoinExec",
+ classOf[EventTimeWatermarkExec].getName
+ )
+
+ private def classNamesString(classNames: Seq[String]): MessageWithContext = {
+ val sortedClassNames = classNames.sorted
+ var message = log"${MDC(LogKeys.CLASS_NAME, sortedClassNames.head)}"
+ sortedClassNames.tail.foreach(
+ name => message += log", ${MDC(LogKeys.CLASS_NAME, name)}"
+ )
+ if (sortedClassNames.size > 1) {
+ message + log" are"
+ } else {
+ message + log" is"
+ }
+ }
+
+ private def notInRTMAllowlistException(
+ errorType: String,
+ classNames: Seq[String]): SparkIllegalArgumentException = {
+ assert(classNames.nonEmpty)
+ new SparkIllegalArgumentException(
+ errorClass = "STREAMING_REAL_TIME_MODE.OPERATOR_OR_SINK_NOT_IN_ALLOWLIST",
+ messageParameters = Map(
+ "errorType" -> errorType,
+ "message" -> classNamesString(classNames).message
+ )
+ )
+ }
+
+ def checkAllowedSink(sink: Table, throwException: Boolean): Unit = {
+ if (!allowedSinks.contains(sink.getClass.getName)) {
+ if (throwException) {
+ throw notInRTMAllowlistException("sink", Seq(sink.getClass.getName))
+ } else {
+ logWarning(
+ log"The sink: " + classNamesString(Seq(sink.getClass.getName)) +
+ log" not in the sink allowlist for Real-Time Mode."
+ )
+ }
+ }
+ }
+
+ // Collect ALL nodes whose entire subtree contains RealTimeStreamScanExec.
+ private def collectRealtimeNodes(root: SparkPlan): Seq[SparkPlan] = {
+
+ def collectNodesWhoseSubtreeHasRTS(n: SparkPlan): (Boolean, List[SparkPlan]) = {
+ n match {
+ case _: RealTimeStreamScanExec =>
+ // Subtree has RTS, but we don't collect the RTS node itself.
+ (true, Nil)
+ case _ if n.children.isEmpty =>
+ (false, Nil)
+ case _ =>
+ val kidResults = n.children.map(collectNodesWhoseSubtreeHasRTS)
+ val anyChildHasRTS = kidResults.exists(_._1)
+ val collectedKids = kidResults.iterator.flatMap(_._2).toList
+ val collectedHere = if (anyChildHasRTS) n :: collectedKids else collectedKids
+ (anyChildHasRTS, collectedHere)
+ }
+ }
+
+ collectNodesWhoseSubtreeHasRTS(root)._2
+ }
+
+ def checkAllowedPhysicalOperator(operator: SparkPlan, throwException: Boolean): Unit = {
+ val nodesToCheck = collectRealtimeNodes(operator)
+ val violations = nodesToCheck
+ .collect {
+ case node =>
+ if (allowedOperators.contains(node.getClass.getName)) {
+ None
+ } else {
+ Some(node.getClass.getName)
+ }
+ }
+ .flatten
+ .distinct
+
+ if (violations.nonEmpty) {
+ if (throwException) {
+ throw notInRTMAllowlistException("operator", violations.toSet.toSeq)
+ } else {
+ logWarning(
+ log"The operator(s): " + classNamesString(violations.toSet.toSeq) +
+ log" not in the operator allowlist for Real-Time Mode."
+ )
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ResolveWriteToStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ResolveWriteToStream.scala
index ee7bf67eb9121..86d48b1e88c5c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ResolveWriteToStream.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/ResolveWriteToStream.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.streaming.{WriteToStream, WriteToStreamStatement}
import org.apache.spark.sql.connector.catalog.SupportsWrite
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
+import org.apache.spark.sql.execution.streaming.{ContinuousTrigger, RealTimeTrigger}
import org.apache.spark.sql.execution.streaming.checkpointing.CheckpointFileManager
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils
@@ -48,7 +49,12 @@ object ResolveWriteToStream extends Rule[LogicalPlan] {
}
if (conf.isUnsupportedOperationCheckEnabled) {
- if (s.sink.isInstanceOf[SupportsWrite] && s.isContinuousTrigger) {
+ if (s.trigger.isInstanceOf[RealTimeTrigger]) {
+ UnsupportedOperationChecker.
+ checkAdditionalRealTimeModeConstraints(s.inputQuery, s.outputMode)
+ }
+
+ if (s.sink.isInstanceOf[SupportsWrite] && s.trigger.isInstanceOf[ContinuousTrigger]) {
UnsupportedOperationChecker.checkForContinuous(s.inputQuery, s.outputMode)
} else {
UnsupportedOperationChecker.checkForStreaming(s.inputQuery, s.outputMode)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala
index bf67ed670ec81..c7556ed478599 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala
@@ -43,36 +43,51 @@ import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
-object MemoryStream extends LowPriorityMemoryStreamImplicits {
+object MemoryStream {
protected val currentBlockId = new AtomicInteger(0)
protected val memoryStreamId = new AtomicInteger(0)
- def apply[A : Encoder](implicit sparkSession: SparkSession): MemoryStream[A] =
- new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession)
-
- def apply[A : Encoder](numPartitions: Int)(implicit sparkSession: SparkSession): MemoryStream[A] =
- new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, Some(numPartitions))
-}
-
-/**
- * Provides lower-priority implicits for MemoryStream to prevent ambiguity when both
- * SparkSession and SQLContext are in scope. The implicits in the companion object,
- * which use SparkSession, take higher precedence.
- */
-trait LowPriorityMemoryStreamImplicits {
- this: MemoryStream.type =>
-
- // Deprecated: Used when an implicit SQLContext is in scope
- @deprecated("Use MemoryStream.apply with an implicit SparkSession instead of SQLContext", "4.1.0")
- def apply[A: Encoder]()(implicit sqlContext: SQLContext): MemoryStream[A] =
+ /**
+ * Creates a MemoryStream with an implicit SQLContext (backward compatible).
+ * Usage: `MemoryStream[Int]`
+ */
+ def apply[A: Encoder](implicit sqlContext: SQLContext): MemoryStream[A] =
new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession)
- @deprecated("Use MemoryStream.apply with an implicit SparkSession instead of SQLContext", "4.1.0")
- def apply[A: Encoder](numPartitions: Int)(implicit sqlContext: SQLContext): MemoryStream[A] =
+ /**
+ * Creates a MemoryStream with specified partitions using implicit SQLContext.
+ * Usage: `MemoryStream[Int](numPartitions)`
+ */
+ def apply[A: Encoder](numPartitions: Int)(
+ implicit sqlContext: SQLContext): MemoryStream[A] =
new MemoryStream[A](
memoryStreamId.getAndIncrement(),
sqlContext.sparkSession,
Some(numPartitions))
+
+ /**
+ * Creates a MemoryStream with explicit SparkSession.
+ * Usage: `MemoryStream[Int](spark)`
+ */
+ def apply[A: Encoder](sparkSession: SparkSession): MemoryStream[A] =
+ new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession)
+
+ /**
+ * Creates a MemoryStream with specified partitions using explicit SparkSession.
+ * Usage: `MemoryStream[Int](spark, numPartitions)`
+ */
+ def apply[A: Encoder](sparkSession: SparkSession, numPartitions: Int): MemoryStream[A] =
+ new MemoryStream[A](
+ memoryStreamId.getAndIncrement(),
+ sparkSession,
+ Some(numPartitions))
+
+ /**
+ * Creates a MemoryStream with explicit encoder and SparkSession.
+ * Usage: `MemoryStream(Encoders.scalaInt, spark)`
+ */
+ def apply[A](encoder: Encoder[A], sparkSession: SparkSession): MemoryStream[A] =
+ new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession)(encoder)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
index 8042cacf1374b..885f9ada22c9a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
@@ -112,47 +112,36 @@ class ContinuousMemoryStream[A : Encoder](
override def commit(end: Offset): Unit = {}
}
-object ContinuousMemoryStream extends LowPriorityContinuousMemoryStreamImplicits {
+object ContinuousMemoryStream {
protected val memoryStreamId = new AtomicInteger(0)
- def apply[A : Encoder](implicit sparkSession: SparkSession): ContinuousMemoryStream[A] =
- new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession)
-
- def apply[A : Encoder](numPartitions: Int)(implicit sparkSession: SparkSession):
- ContinuousMemoryStream[A] =
- new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, numPartitions)
-
- def singlePartition[A : Encoder](implicit sparkSession: SparkSession): ContinuousMemoryStream[A] =
- new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, 1)
-}
-
-/**
- * Provides lower-priority implicits for ContinuousMemoryStream to prevent ambiguity when both
- * SparkSession and SQLContext are in scope. The implicits in the companion object,
- * which use SparkSession, take higher precedence.
- */
-trait LowPriorityContinuousMemoryStreamImplicits {
- this: ContinuousMemoryStream.type =>
-
- // Deprecated: Used when an implicit SQLContext is in scope
- @deprecated("Use ContinuousMemoryStream with an implicit SparkSession " +
- "instead of SQLContext", "4.1.0")
- def apply[A: Encoder]()(implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
+ /** Creates a ContinuousMemoryStream with an implicit SQLContext (backward compatible). */
+ def apply[A: Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession)
- @deprecated("Use ContinuousMemoryStream with an implicit SparkSession " +
- "instead of SQLContext", "4.1.0")
- def apply[A: Encoder](numPartitions: Int)(implicit sqlContext: SQLContext):
- ContinuousMemoryStream[A] =
+ /** Creates a ContinuousMemoryStream with specified partitions (SQLContext). */
+ def apply[A: Encoder](numPartitions: Int)(
+ implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
new ContinuousMemoryStream[A](
memoryStreamId.getAndIncrement(),
sqlContext.sparkSession,
numPartitions)
- @deprecated("Use ContinuousMemoryStream.singlePartition with an implicit SparkSession " +
- "instead of SQLContext", "4.1.0")
- def singlePartition[A: Encoder]()(implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
+ /** Creates a ContinuousMemoryStream with explicit SparkSession. */
+ def apply[A: Encoder](sparkSession: SparkSession): ContinuousMemoryStream[A] =
+ new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession)
+
+ /** Creates a ContinuousMemoryStream with specified partitions (SparkSession). */
+ def apply[A: Encoder](sparkSession: SparkSession, numPartitions: Int): ContinuousMemoryStream[A] =
+ new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, numPartitions)
+
+ /** Creates a single partition ContinuousMemoryStream (SQLContext). */
+ def singlePartition[A: Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession, 1)
+
+ /** Creates a single partition ContinuousMemoryStream (SparkSession). */
+ def singlePartition[A: Encoder](sparkSession: SparkSession): ContinuousMemoryStream[A] =
+ new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, 1)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala
index d04f4b5d011ca..6dfeb0cc46032 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/LowLatencyMemoryStream.scala
@@ -172,53 +172,39 @@ class LowLatencyMemoryStream[A: Encoder](
}
}
-object LowLatencyMemoryStream extends LowPriorityLowLatencyMemoryStreamImplicits {
+object LowLatencyMemoryStream {
protected val memoryStreamId = new AtomicInteger(0)
- def apply[A: Encoder](implicit sparkSession: SparkSession): LowLatencyMemoryStream[A] =
- new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession)
+ /** Creates a LowLatencyMemoryStream with an implicit SQLContext (backward compatible). */
+ def apply[A: Encoder](implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] =
+ new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession)
+ /** Creates a LowLatencyMemoryStream with specified partitions (SQLContext). */
def apply[A: Encoder](numPartitions: Int)(
- implicit
- sparkSession: SparkSession): LowLatencyMemoryStream[A] =
+ implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] =
new LowLatencyMemoryStream[A](
memoryStreamId.getAndIncrement(),
- sparkSession,
- numPartitions = numPartitions
- )
-
- def singlePartition[A: Encoder](implicit sparkSession: SparkSession): LowLatencyMemoryStream[A] =
- new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, 1)
-}
-
-/**
- * Provides lower-priority implicits for LowLatencyMemoryStream to prevent ambiguity when both
- * SparkSession and SQLContext are in scope. The implicits in the companion object,
- * which use SparkSession, take higher precedence.
- */
-trait LowPriorityLowLatencyMemoryStreamImplicits {
- this: LowLatencyMemoryStream.type =>
+ sqlContext.sparkSession,
+ numPartitions = numPartitions)
- // Deprecated: Used when an implicit SQLContext is in scope
- @deprecated("Use LowLatencyMemoryStream with an implicit SparkSession " +
- "instead of SQLContext", "4.1.0")
- def apply[A: Encoder]()(implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] =
- new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession)
+ /** Creates a LowLatencyMemoryStream with explicit SparkSession. */
+ def apply[A: Encoder](sparkSession: SparkSession): LowLatencyMemoryStream[A] =
+ new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession)
- @deprecated("Use LowLatencyMemoryStream with an implicit SparkSession " +
- "instead of SQLContext", "4.1.0")
- def apply[A: Encoder](numPartitions: Int)(implicit sqlContext: SQLContext):
- LowLatencyMemoryStream[A] =
+ /** Creates a LowLatencyMemoryStream with specified partitions (SparkSession). */
+ def apply[A: Encoder](sparkSession: SparkSession, numPartitions: Int): LowLatencyMemoryStream[A] =
new LowLatencyMemoryStream[A](
memoryStreamId.getAndIncrement(),
- sqlContext.sparkSession,
- numPartitions = numPartitions
- )
+ sparkSession,
+ numPartitions = numPartitions)
- @deprecated("Use LowLatencyMemoryStream.singlePartition with an implicit SparkSession " +
- "instead of SQLContext", "4.1.0")
- def singlePartition[A: Encoder]()(implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] =
+ /** Creates a single partition LowLatencyMemoryStream (SQLContext). */
+ def singlePartition[A: Encoder](implicit sqlContext: SQLContext): LowLatencyMemoryStream[A] =
new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession, 1)
+
+ /** Creates a single partition LowLatencyMemoryStream (SparkSession). */
+ def singlePartition[A: Encoder](sparkSession: SparkSession): LowLatencyMemoryStream[A] =
+ new LowLatencyMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, 1)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
index aa4fa9bfaf627..9dc236e0ffd44 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala
@@ -466,7 +466,9 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
// (one for main file and another for checksum file).
// Since this fm is used by both query task and maintenance thread,
// then we need 2 * 2 = 4 threads.
- numThreads = 4)
+ numThreads = 4,
+ skipCreationIfFileMissingChecksum =
+ storeConf.checkpointFileChecksumSkipCreationIfFileMissingChecksum)
} else {
mgr
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index fb3ef606b8f34..b12cdb9bba95c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -69,7 +69,7 @@ case object StoreTaskCompletionListener extends RocksDBOpType("store_task_comple
class RocksDB(
dfsRootDir: String,
val conf: RocksDBConf,
- localRootDir: File = Utils.createTempDir(),
+ val localRootDir: File = Utils.createTempDir(),
hadoopConf: Configuration = new Configuration,
loggingId: String = "",
useColumnFamilies: Boolean = false,
@@ -150,20 +150,29 @@ class RocksDB(
localTempDir: File,
hadoopConf: Configuration,
codecName: String,
- loggingId: String): RocksDBFileManager = {
+ loggingId: String,
+ storeConf: StateStoreConf): RocksDBFileManager = {
new RocksDBFileManager(
dfsRootDir,
localTempDir,
hadoopConf,
codecName,
loggingId = loggingId,
+ storeConf,
fileChecksumEnabled = conf.fileChecksumEnabled,
fileChecksumThreadPoolSize = fileChecksumThreadPoolSize
)
}
- private[spark] val fileManager = createFileManager(dfsRootDir, createTempDir("fileManager"),
- hadoopConf, conf.compressionCodec, loggingId = loggingId)
+ private[spark] val fileManager = createFileManager(
+ dfsRootDir,
+ createTempDir("fileManager"),
+ hadoopConf,
+ conf.compressionCodec,
+ loggingId = loggingId,
+ storeConf = conf.stateStoreConf
+ )
+
private val byteArrayPair = new ByteArrayPair()
private val commitLatencyMs = new mutable.HashMap[String, Long]()
@@ -2067,7 +2076,8 @@ case class RocksDBConf(
compression: String,
reportSnapshotUploadLag: Boolean,
fileChecksumEnabled: Boolean,
- maxVersionsToDeletePerMaintenance: Int)
+ maxVersionsToDeletePerMaintenance: Int,
+ stateStoreConf: StateStoreConf)
object RocksDBConf {
/** Common prefix of all confs in SQLConf that affects RocksDB */
@@ -2267,7 +2277,8 @@ object RocksDBConf {
getStringConf(COMPRESSION_CONF),
storeConf.reportSnapshotUploadLag,
storeConf.checkpointFileChecksumEnabled,
- storeConf.maxVersionsToDeletePerMaintenance)
+ storeConf.maxVersionsToDeletePerMaintenance,
+ storeConf)
}
def apply(): RocksDBConf = apply(new StateStoreConf())
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
index 7bef692e264a2..ed34a46889d85 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
@@ -132,6 +132,7 @@ class RocksDBFileManager(
hadoopConf: Configuration,
codecName: String = CompressionCodec.ZSTD,
loggingId: String = "",
+ storeConf: StateStoreConf = StateStoreConf.empty,
fileChecksumEnabled: Boolean = false,
fileChecksumThreadPoolSize: Option[Int] = None)
extends Logging {
@@ -149,7 +150,9 @@ class RocksDBFileManager(
mgr,
// Allowing this for perf, since we do orphan checksum file cleanup in maintenance anyway
allowConcurrentDelete = true,
- numThreads = fileChecksumThreadPoolSize.get)
+ numThreads = fileChecksumThreadPoolSize.get,
+ skipCreationIfFileMissingChecksum
+ = storeConf.checkpointFileChecksumSkipCreationIfFileMissingChecksum)
} else {
mgr
}
@@ -1123,7 +1126,7 @@ object RocksDBCheckpointMetadata {
/** Used to convert between classes and JSON. */
lazy val mapper = {
val _mapper = new ObjectMapper with ClassTagExtensions
- _mapper.setSerializationInclusion(Include.NON_ABSENT)
+ _mapper.setDefaultPropertyInclusion(Include.NON_ABSENT)
_mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
_mapper.registerModule(DefaultScalaModule)
_mapper
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
index f49c79f96b9ce..f6fe4dbea576c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateEncoder.scala
@@ -628,7 +628,7 @@ class UnsafeRowDataEncoder(
override def decodeRemainingKey(bytes: Array[Byte]): UnsafeRow = {
keyStateEncoderSpec match {
case PrefixKeyScanStateEncoderSpec(_, numColsPrefixKey) =>
- decodeToUnsafeRow(bytes, numFields = numColsPrefixKey)
+ decodeToUnsafeRow(bytes, numFields = keySchema.length - numColsPrefixKey)
case RangeKeyScanStateEncoderSpec(_, orderingOrdinals) =>
decodeToUnsafeRow(bytes, keySchema.length - orderingOrdinals.length)
case _ => throw unsupportedOperationForKeyStateEncoder("decodeRemainingKey")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index 2cc4c8a870aea..1058c02c9304e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -837,6 +837,9 @@ private[sql] class RocksDBStateStoreProvider
@volatile private var stateSchemaProvider: Option[StateSchemaProvider] = _
@volatile private var rocksDBEventForwarder: Option[RocksDBEventForwarder] = _
@volatile private var stateStoreProviderId: StateStoreProviderId = _
+ // Exposed for testing
+ @volatile private[sql] var sparkConf: SparkConf = Option(SparkEnv.get).map(_.conf)
+ .getOrElse(new SparkConf)
protected def createRocksDB(
dfsRootDir: String,
@@ -867,8 +870,7 @@ private[sql] class RocksDBStateStoreProvider
val storeIdStr = s"StateStoreId(opId=${stateStoreId.operatorId}," +
s"partId=${stateStoreId.partitionId},name=${stateStoreId.storeName})"
val loggingId = stateStoreProviderId.toString
- val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
- val localRootDir = Utils.createTempDir(Utils.getLocalDir(sparkConf), storeIdStr)
+ val localRootDir = Utils.createExecutorLocalTempDir(sparkConf, storeIdStr)
createRocksDB(dfsRootDir, RocksDBConf(storeConf), localRootDir, hadoopConf, loggingId,
useColumnFamilies, storeConf.enableStateStoreCheckpointIds, stateStoreId.partitionId,
rocksDBEventForwarder,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
index 74904a37f4504..3e190eedc9f44 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala
@@ -114,6 +114,20 @@ class StateStoreConf(
val enableStateStoreCheckpointIds =
StatefulOperatorStateInfo.enableStateStoreCheckpointIds(sqlConf)
+ /**
+ * Whether to skip checksum creation if file missing checksum.
+ *
+ * Consider the case using STATE_STORE_CHECKPOINT_FORMAT_VERSION = 1 when a batch fails but state
+ * files are written. If on the next run, we try to upload both a new state file and a file
+ * checksum, the file could fail to be uploaded but the file checksum is uploaded successfully.
+ * This would lead to a situation where the old file could be loaded and compared with the new
+ * file checksum, which would fail the checksum verification. This issue does not happen when
+ * STATE_STORE_CHECKPOINT_FORMAT_VERSION = 2 since each batch run unique ids will be created.
+ */
+ val checkpointFileChecksumSkipCreationIfFileMissingChecksum: Boolean =
+ sqlConf.checkpointFileChecksumSkipCreationIfFileMissingChecksum &&
+ !enableStateStoreCheckpointIds
+
/**
* Whether the coordinator is reporting state stores trailing behind in snapshot uploads.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index ced4b6224c884..3ddf0b69e762c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -113,9 +113,12 @@ object SparkPlanGraph {
// Subquery should not be included in WholeStageCodegen
buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null, exchanges)
case "Subquery" if exchanges.contains(planInfo) =>
- // Point to the re-used subquery
val node = exchanges(planInfo)
- edges += SparkPlanGraphEdge(node.id, parent.id)
+ val newEdge = SparkPlanGraphEdge(node.id, parent.id)
+ if (!edges.contains(newEdge)) {
+ // Point to the re-used subquery
+ edges += newEdge
+ }
case "ReusedSubquery" =>
// Re-used subquery might appear before the original subquery, so skip this node and let
// the previous `case` make sure the re-used and the original point to the same node.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index c967497b660c7..a8c23c8e126f1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.{Analyzer, EvalSubqueriesForTimeTr
import org.apache.spark.sql.catalyst.analysis.resolver.ResolverExtension
import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExtractSemiStructuredFields}
+import org.apache.spark.sql.catalyst.normalizer.NormalizeCTEIds
import org.apache.spark.sql.catalyst.optimizer.Optimizer
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -176,6 +177,8 @@ abstract class BaseSessionStateBuilder(
protected lazy val catalogManager = new CatalogManager(v2SessionCatalog, catalog)
+ protected lazy val sharedRelationCache = session.sharedState.relationCache
+
/**
* Interface exposed to the user for registering user-defined functions.
*
@@ -197,7 +200,7 @@ abstract class BaseSessionStateBuilder(
*
* Note: this depends on the `conf` and `catalog` fields.
*/
- protected def analyzer: Analyzer = new Analyzer(catalogManager) {
+ protected def analyzer: Analyzer = new Analyzer(catalogManager, sharedRelationCache) {
override val hintResolutionRules: Seq[Rule[LogicalPlan]] =
customHintResolutionRules
@@ -401,6 +404,7 @@ abstract class BaseSessionStateBuilder(
}
protected def planNormalizationRules: Seq[Rule[LogicalPlan]] = {
+ NormalizeCTEIds +:
extensions.buildPlanNormalizationRules(session)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
index af1f38caab686..8e641294bf8cc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala
@@ -31,6 +31,7 @@ import org.apache.hadoop.fs.{FsUrlStreamHandlerFactory, Path}
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{CONFIG, CONFIG2, PATH, VALUE}
+import org.apache.spark.sql.catalyst.analysis.RelationCache
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.CacheManager
@@ -96,6 +97,13 @@ private[sql] class SharedState(
*/
val cacheManager: CacheManager = new CacheManager
+ /**
+ * A relation cache backed by the cache manager.
+ */
+ private[sql] val relationCache: RelationCache = {
+ (nameParts, resolver) => cacheManager.lookupCachedTable(nameParts, resolver)
+ }
+
/** A global lock for all streaming query lifecycle tracking and management. */
private[sql] val activeQueriesLock = new Object
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index ce4c347cad349..6b4e743ce989c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -772,7 +772,9 @@ abstract class JdbcDialect extends Serializable with Logging {
}
@Since("4.1.0")
- def isObjectNotFoundException(e: SQLException): Boolean = true
+ def isObjectNotFoundException(e: SQLException): Boolean = {
+ Option(e.getSQLState).exists(_.startsWith("42"))
+ }
/**
* Gets a dialect exception, classifies it and wraps it by `AnalysisException`.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala
index 2a849aa2d6040..1dee2ae80f508 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecution.scala
@@ -88,7 +88,8 @@ class SqlScriptingExecution(
/**
* Helper method to execute interrupts to ConditionalStatements.
- * This method should only interrupt when the statement that throws is a conditional statement.
+ * This method should only interrupt when the exception was thrown during evaluation of
+ * the conditional statement's condition.
* @param executionPlan Execution plan.
*/
private def interruptConditionalStatements(executionPlan: NonLeafStatementExec): Unit = {
@@ -101,8 +102,42 @@ class SqlScriptingExecution(
}
currExecPlan match {
- case exec: ConditionalStatementExec =>
- exec.interrupted = true
+ case exec: ConditionalStatementExec if exec.isInCondition =>
+ // Only interrupt the conditional if its condition/query was being evaluated when the
+ // exception occurred. This distinguishes between two scenarios:
+ // 1. Exception during condition evaluation -> interrupt (skip the conditional)
+ // 2. Exception before reaching the conditional -> don't interrupt (execute normally)
+ //
+ // Different conditional statements track evaluation state differently:
+ // - SimpleCaseStatementExec: hasStartedCaseVariableEvaluation flag is set when
+ // validateCache() begins evaluating the case variable expression.
+ // - ForStatementExec: hasStartedQueryEvaluation flag is set when cachedQueryResult()
+ // begins evaluating the FOR loop's query.
+ // - IF/ELSEIF, WHILE, REPEAT, SEARCHED CASE: curr.isExecuted flag is set by
+ // evaluateBooleanCondition() before evaluating each condition.
+ val shouldInterrupt =
+ exec match {
+ case simpleCaseStmt: SimpleCaseStatementExec
+ if simpleCaseStmt.hasStartedCaseVariableEvaluation =>
+ // Only interrupt if case variable evaluation was attempted.
+ true
+ case forStmt: ForStatementExec =>
+ // Only interrupt if query evaluation was attempted.
+ forStmt.hasStartedQueryEvaluation
+ case _ =>
+ // For IF, WHILE, REPEAT, SEARCHED/SIMPLE CASE: check if condition was executed.
+ // evaluateBooleanCondition sets isExecuted=true before evaluation, so if an
+ // exception occurs during evaluation, isExecuted will be true. If the exception
+ // happened before reaching the conditional, isExecuted will still be false.
+ exec.curr match {
+ case Some(stmt: SingleStatementExec) => stmt.isExecuted
+ case _ => false
+ }
+ }
+
+ if (shouldInterrupt) {
+ exec.interrupted = true
+ }
case _ =>
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
index aa2c2f405021a..953301ab8ed54 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/scripting/SqlScriptingExecutionNode.scala
@@ -120,6 +120,22 @@ trait ConditionalStatementExec extends NonLeafStatementExec {
* HANDLER.
*/
protected[scripting] var interrupted: Boolean = false
+
+ /**
+ * Returns true if the conditional statement is currently evaluating its condition,
+ * false if it's executing its body. This is used by CONTINUE HANDLER to determine
+ * whether to interrupt the conditional statement when an exception occurs.
+ *
+ * For loop statements (WHILE, REPEAT, FOR), this should return true when evaluating
+ * the loop condition and false when executing the loop body. This distinction is
+ * critical because:
+ * - Exception in condition: loop should be skipped (interrupted)
+ * - Exception in body: loop should continue to next iteration (not interrupted)
+ *
+ * For IF/CASE statements, this should return true when evaluating the condition
+ * expression and false when executing any branch body.
+ */
+ protected[scripting] def isInCondition: Boolean
}
/**
@@ -479,6 +495,8 @@ class IfElseStatementExec(
override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator
+ override protected[scripting] def isInCondition: Boolean = state == IfElseState.Condition
+
override def reset(): Unit = {
state = IfElseState.Condition
curr = Some(conditions.head)
@@ -536,6 +554,18 @@ class WhileStatementExec(
throw SparkException.internalError("Unexpected statement type in WHILE condition.")
}
case WhileState.Body =>
+ // Check if body has more statements before calling next(). When an exception in a
+ // conditional statement's condition is handled by a CONTINUE handler, the conditional
+ // is interrupted. If it's the last statement in the loop body, calling next() on the
+ // exhausted iterator would fail. Instead, we return NoOpStatementExec and transition
+ // back to the condition.
+ if (!body.getTreeIterator.hasNext) {
+ state = WhileState.Condition
+ curr = Some(condition)
+ condition.reset()
+ return new NoOpStatementExec
+ }
+
val retStmt = body.getTreeIterator.next()
// Handle LEAVE or ITERATE statement if it has been encountered.
@@ -565,6 +595,8 @@ class WhileStatementExec(
override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator
+ override protected[scripting] def isInCondition: Boolean = state == WhileState.Condition
+
override def reset(): Unit = {
state = WhileState.Condition
curr = Some(condition)
@@ -654,6 +686,8 @@ class SearchedCaseStatementExec(
override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator
+ override protected[scripting] def isInCondition: Boolean = state == CaseState.Condition
+
override def reset(): Unit = {
state = CaseState.Condition
curr = Some(conditions.head)
@@ -694,15 +728,23 @@ class SimpleCaseStatementExec(
private var conditionBodyTupleIterator: Iterator[(SingleStatementExec, CompoundBodyExec)] = _
private var caseVariableLiteral: Literal = _
+ // Flag to track if case variable evaluation has been attempted. Used by CONTINUE handler
+ // mechanism to determine if an exception occurred during case variable evaluation vs. before
+ // the CASE statement was reached.
+ protected[scripting] var hasStartedCaseVariableEvaluation: Boolean = false
+
private var isCacheValid = false
private def validateCache(): Unit = {
if (!isCacheValid) {
+ // Set flags before evaluation so CONTINUE handler can detect if exception happened here.
+ hasStartedCaseVariableEvaluation = true
val values = caseVariableExec.buildDataFrame(session).collect()
caseVariableExec.isExecuted = true
caseVariableLiteral = Literal(values.head.get(0))
conditionBodyTupleIterator = createConditionBodyIterator
isCacheValid = true
+ hasStartedCaseVariableEvaluation = false
}
}
@@ -793,6 +835,8 @@ class SimpleCaseStatementExec(
override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator
+ override protected[scripting] def isInCondition: Boolean = state == CaseState.Condition
+
override def reset(): Unit = {
state = CaseState.Condition
bodyExec = None
@@ -802,6 +846,7 @@ class SimpleCaseStatementExec(
caseVariableExec.reset()
conditionalBodies.foreach(b => b.reset())
elseBody.foreach(b => b.reset())
+ hasStartedCaseVariableEvaluation = false
}
}
@@ -852,6 +897,18 @@ class RepeatStatementExec(
throw SparkException.internalError("Unexpected statement type in REPEAT condition.")
}
case RepeatState.Body =>
+ // Check if body has more statements before calling next(). When an exception in a
+ // conditional statement's condition is handled by a CONTINUE handler, the conditional
+ // is interrupted. If it's the last statement in the loop body, calling next() on the
+ // exhausted iterator would fail. Instead, we return NoOpStatementExec and transition
+ // back to the condition.
+ if (!body.getTreeIterator.hasNext) {
+ state = RepeatState.Condition
+ curr = Some(condition)
+ condition.reset()
+ return new NoOpStatementExec
+ }
+
val retStmt = body.getTreeIterator.next()
retStmt match {
@@ -880,6 +937,8 @@ class RepeatStatementExec(
override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator
+ override protected[scripting] def isInCondition: Boolean = state == RepeatState.Condition
+
override def reset(): Unit = {
state = RepeatState.Body
curr = Some(body)
@@ -1017,17 +1076,24 @@ class ForStatementExec(
}
private var state = ForState.VariableAssignment
+ // Flag to track if FOR query evaluation has been attempted. Used by CONTINUE handler
+ // mechanism to determine if an exception occurred during query evaluation vs. before
+ // the FOR statement was reached.
+ protected[scripting] var hasStartedQueryEvaluation = false
private var queryResult: util.Iterator[Row] = _
private var queryColumnNameToDataType: Map[String, DataType] = _
private var isResultCacheValid = false
private def cachedQueryResult(): util.Iterator[Row] = {
if (!isResultCacheValid) {
+ // Set flag before evaluation so CONTINUE handler can detect if exception happened here.
+ hasStartedQueryEvaluation = true
val df = query.buildDataFrame(session)
queryResult = df.toLocalIterator()
queryColumnNameToDataType = df.schema.fields.map(f => f.name -> f.dataType).toMap
query.isExecuted = true
isResultCacheValid = true
+ hasStartedQueryEvaluation = false
}
queryResult
}
@@ -1215,6 +1281,8 @@ class ForStatementExec(
override def getTreeIterator: Iterator[CompoundStatementExec] = treeIterator
+ override protected[scripting] def isInCondition: Boolean = state == ForState.VariableAssignment
+
override def reset(): Unit = {
state = ForState.VariableAssignment
isResultCacheValid = false
@@ -1222,6 +1290,7 @@ class ForStatementExec(
curr = None
bodyWithVariables = None
firstIteration = true
+ hasStartedQueryEvaluation = false
}
}
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index 29d194f5715e6..a5c98675c9773 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -182,6 +182,21 @@
| org.apache.spark.sql.catalyst.expressions.JsonObjectKeys | json_object_keys | SELECT json_object_keys('{}') | struct> |
| org.apache.spark.sql.catalyst.expressions.JsonToStructs | from_json | SELECT from_json('{"a":1, "b":0.8}', 'a INT, b DOUBLE') | struct> |
| org.apache.spark.sql.catalyst.expressions.JsonTuple | json_tuple | SELECT json_tuple('{"a":1, "b":2}', 'a', 'b') | struct |
+| org.apache.spark.sql.catalyst.expressions.KllSketchGetNBigint | kll_sketch_get_n_bigint | SELECT kll_sketch_get_n_bigint(kll_sketch_agg_bigint(col)) FROM VALUES (1), (2), (3), (4), (5) tab(col) | struct |
+| org.apache.spark.sql.catalyst.expressions.KllSketchGetNDouble | kll_sketch_get_n_double | SELECT kll_sketch_get_n_double(kll_sketch_agg_double(col)) FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col) | struct |
+| org.apache.spark.sql.catalyst.expressions.KllSketchGetNFloat | kll_sketch_get_n_float | SELECT kll_sketch_get_n_float(kll_sketch_agg_float(col)) FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) tab(col) | struct |
+| org.apache.spark.sql.catalyst.expressions.KllSketchGetQuantileBigint | kll_sketch_get_quantile_bigint | SELECT kll_sketch_get_quantile_bigint(kll_sketch_agg_bigint(col), 0.5) > 1 FROM VALUES (1), (2), (3), (4), (5) tab(col) | struct<(kll_sketch_get_quantile_bigint(kll_sketch_agg_bigint(col), 0.5) > 1):boolean> |
+| org.apache.spark.sql.catalyst.expressions.KllSketchGetQuantileDouble | kll_sketch_get_quantile_double | SELECT kll_sketch_get_quantile_double(kll_sketch_agg_double(col), 0.5) > 1 FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col) | struct<(kll_sketch_get_quantile_double(kll_sketch_agg_double(col), 0.5) > 1):boolean> |
+| org.apache.spark.sql.catalyst.expressions.KllSketchGetQuantileFloat | kll_sketch_get_quantile_float | SELECT kll_sketch_get_quantile_float(kll_sketch_agg_float(col), 0.5) > 1 FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) tab(col) | struct<(kll_sketch_get_quantile_float(kll_sketch_agg_float(col), 0.5) > 1):boolean> |
+| org.apache.spark.sql.catalyst.expressions.KllSketchGetRankBigint | kll_sketch_get_rank_bigint | SELECT kll_sketch_get_rank_bigint(kll_sketch_agg_bigint(col), 3) > 0.3 FROM VALUES (1), (2), (3), (4), (5) tab(col) | struct<(kll_sketch_get_rank_bigint(kll_sketch_agg_bigint(col), 3) > 0.3):boolean> |
+| org.apache.spark.sql.catalyst.expressions.KllSketchGetRankDouble | kll_sketch_get_rank_double | SELECT kll_sketch_get_rank_double(kll_sketch_agg_double(col), 3.0) > 0.3 FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col) | struct<(kll_sketch_get_rank_double(kll_sketch_agg_double(col), 3.0) > 0.3):boolean> |
+| org.apache.spark.sql.catalyst.expressions.KllSketchGetRankFloat | kll_sketch_get_rank_float | SELECT kll_sketch_get_rank_float(kll_sketch_agg_float(col), 3.0) > 0.3 FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) tab(col) | struct<(kll_sketch_get_rank_float(kll_sketch_agg_float(col), 3.0) > 0.3):boolean> |
+| org.apache.spark.sql.catalyst.expressions.KllSketchMergeBigint | kll_sketch_merge_bigint | SELECT LENGTH(kll_sketch_to_string_bigint(kll_sketch_merge_bigint(kll_sketch_agg_bigint(col), kll_sketch_agg_bigint(col)))) > 0 FROM VALUES (1), (2), (3), (4), (5) tab(col) | struct<(length(kll_sketch_to_string_bigint(kll_sketch_merge_bigint(kll_sketch_agg_bigint(col), kll_sketch_agg_bigint(col)))) > 0):boolean> |
+| org.apache.spark.sql.catalyst.expressions.KllSketchMergeDouble | kll_sketch_merge_double | SELECT LENGTH(kll_sketch_to_string_double(kll_sketch_merge_double(kll_sketch_agg_double(col), kll_sketch_agg_double(col)))) > 0 FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col) | struct<(length(kll_sketch_to_string_double(kll_sketch_merge_double(kll_sketch_agg_double(col), kll_sketch_agg_double(col)))) > 0):boolean> |
+| org.apache.spark.sql.catalyst.expressions.KllSketchMergeFloat | kll_sketch_merge_float | SELECT LENGTH(kll_sketch_to_string_float(kll_sketch_merge_float(kll_sketch_agg_float(col), kll_sketch_agg_float(col)))) > 0 FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) tab(col) | struct<(length(kll_sketch_to_string_float(kll_sketch_merge_float(kll_sketch_agg_float(col), kll_sketch_agg_float(col)))) > 0):boolean> |
+| org.apache.spark.sql.catalyst.expressions.KllSketchToStringBigint | kll_sketch_to_string_bigint | SELECT LENGTH(kll_sketch_to_string_bigint(kll_sketch_agg_bigint(col))) > 0 FROM VALUES (1), (2), (3), (4), (5) tab(col) | struct<(length(kll_sketch_to_string_bigint(kll_sketch_agg_bigint(col))) > 0):boolean> |
+| org.apache.spark.sql.catalyst.expressions.KllSketchToStringDouble | kll_sketch_to_string_double | SELECT LENGTH(kll_sketch_to_string_double(kll_sketch_agg_double(col))) > 0 FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col) | struct<(length(kll_sketch_to_string_double(kll_sketch_agg_double(col))) > 0):boolean> |
+| org.apache.spark.sql.catalyst.expressions.KllSketchToStringFloat | kll_sketch_to_string_float | SELECT LENGTH(kll_sketch_to_string_float(kll_sketch_agg_float(col))) > 0 FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) tab(col) | struct<(length(kll_sketch_to_string_float(kll_sketch_agg_float(col))) > 0):boolean> |
| org.apache.spark.sql.catalyst.expressions.LPadExpressionBuilder | lpad | SELECT lpad('hi', 5, '??') | struct |
| org.apache.spark.sql.catalyst.expressions.Lag | lag | SELECT a, b, lag(b) OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct |
| org.apache.spark.sql.catalyst.expressions.LastDay | last_day | SELECT last_day('2009-01-12') | struct |
@@ -287,10 +302,6 @@
| org.apache.spark.sql.catalyst.expressions.Rint | rint | SELECT rint(12.3456) | struct |
| org.apache.spark.sql.catalyst.expressions.Round | round | SELECT round(2.5, 0) | struct |
| org.apache.spark.sql.catalyst.expressions.RowNumber | row_number | SELECT a, b, row_number() OVER (PARTITION BY a ORDER BY b) FROM VALUES ('A1', 2), ('A1', 1), ('A2', 3), ('A1', 1) tab(a, b) | struct |
-| org.apache.spark.sql.catalyst.expressions.ST_AsBinary | st_asbinary | SELECT hex(st_asbinary(st_geogfromwkb(X'0101000000000000000000F03F0000000000000040'))) | struct |
-| org.apache.spark.sql.catalyst.expressions.ST_GeogFromWKB | st_geogfromwkb | SELECT hex(st_asbinary(st_geogfromwkb(X'0101000000000000000000F03F0000000000000040'))) | struct |
-| org.apache.spark.sql.catalyst.expressions.ST_GeomFromWKB | st_geomfromwkb | SELECT hex(st_asbinary(st_geomfromwkb(X'0101000000000000000000F03F0000000000000040'))) | struct |
-| org.apache.spark.sql.catalyst.expressions.ST_Srid | st_srid | SELECT st_srid(st_geogfromwkb(X'0101000000000000000000F03F0000000000000040')) | struct |
| org.apache.spark.sql.catalyst.expressions.SchemaOfCsv | schema_of_csv | SELECT schema_of_csv('1,abc') | struct |
| org.apache.spark.sql.catalyst.expressions.SchemaOfJson | schema_of_json | SELECT schema_of_json('[{"col":0}]') | struct |
| org.apache.spark.sql.catalyst.expressions.SchemaOfXml | schema_of_xml | SELECT schema_of_xml('1
') | struct1
):string> |
@@ -444,6 +455,12 @@
| org.apache.spark.sql.catalyst.expressions.aggregate.HllSketchAgg | hll_sketch_agg | SELECT hll_sketch_estimate(hll_sketch_agg(col, 12)) FROM VALUES (1), (1), (2), (2), (3) tab(col) | struct |
| org.apache.spark.sql.catalyst.expressions.aggregate.HllUnionAgg | hll_union_agg | SELECT hll_sketch_estimate(hll_union_agg(sketch, true)) FROM (SELECT hll_sketch_agg(col) as sketch FROM VALUES (1) tab(col) UNION ALL SELECT hll_sketch_agg(col, 20) as sketch FROM VALUES (1) tab(col)) | struct |
| org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus | approx_count_distinct | SELECT approx_count_distinct(col1) FROM VALUES (1), (1), (2), (2), (3) tab(col1) | struct |
+| org.apache.spark.sql.catalyst.expressions.aggregate.KllMergeAggBigint | kll_merge_agg_bigint | SELECT kll_sketch_get_n_bigint(kll_merge_agg_bigint(sketch)) FROM (SELECT kll_sketch_agg_bigint(col) as sketch FROM VALUES (1), (2), (3) tab(col) UNION ALL SELECT kll_sketch_agg_bigint(col) as sketch FROM VALUES (4), (5), (6) tab(col)) t | struct |
+| org.apache.spark.sql.catalyst.expressions.aggregate.KllMergeAggDouble | kll_merge_agg_double | SELECT kll_sketch_get_n_double(kll_merge_agg_double(sketch)) FROM (SELECT kll_sketch_agg_double(col) as sketch FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)) tab(col) UNION ALL SELECT kll_sketch_agg_double(col) as sketch FROM VALUES (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)), (CAST(6.0 AS DOUBLE)) tab(col)) t | struct |
+| org.apache.spark.sql.catalyst.expressions.aggregate.KllMergeAggFloat | kll_merge_agg_float | SELECT kll_sketch_get_n_float(kll_merge_agg_float(sketch)) FROM (SELECT kll_sketch_agg_float(col) as sketch FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)) tab(col) UNION ALL SELECT kll_sketch_agg_float(col) as sketch FROM VALUES (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)), (CAST(6.0 AS FLOAT)) tab(col)) t | struct |
+| org.apache.spark.sql.catalyst.expressions.aggregate.KllSketchAggBigint | kll_sketch_agg_bigint | SELECT LENGTH(kll_sketch_to_string_bigint(kll_sketch_agg_bigint(col))) > 0 FROM VALUES (1), (2), (3), (4), (5) tab(col) | struct<(length(kll_sketch_to_string_bigint(kll_sketch_agg_bigint(col))) > 0):boolean> |
+| org.apache.spark.sql.catalyst.expressions.aggregate.KllSketchAggDouble | kll_sketch_agg_double | SELECT LENGTH(kll_sketch_to_string_double(kll_sketch_agg_double(col))) > 0 FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col) | struct<(length(kll_sketch_to_string_double(kll_sketch_agg_double(col))) > 0):boolean> |
+| org.apache.spark.sql.catalyst.expressions.aggregate.KllSketchAggFloat | kll_sketch_agg_float | SELECT LENGTH(kll_sketch_to_string_float(kll_sketch_agg_float(col))) > 0 FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) tab(col) | struct<(length(kll_sketch_to_string_float(kll_sketch_agg_float(col))) > 0):boolean> |
| org.apache.spark.sql.catalyst.expressions.aggregate.Kurtosis | kurtosis | SELECT kurtosis(col) FROM VALUES (-10), (-20), (100), (1000) AS tab(col) | struct |
| org.apache.spark.sql.catalyst.expressions.aggregate.Last | last | SELECT last(col) FROM VALUES (10), (5), (20) AS tab(col) | struct |
| org.apache.spark.sql.catalyst.expressions.aggregate.Last | last_value | SELECT last_value(col) FROM VALUES (10), (5), (20) AS tab(col) | struct |
@@ -481,6 +498,11 @@
| org.apache.spark.sql.catalyst.expressions.aggregate.VariancePop | var_pop | SELECT var_pop(col) FROM VALUES (1), (2), (3) AS tab(col) | struct |
| org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | var_samp | SELECT var_samp(col) FROM VALUES (1), (2), (3) AS tab(col) | struct |
| org.apache.spark.sql.catalyst.expressions.aggregate.VarianceSamp | variance | SELECT variance(col) FROM VALUES (1), (2), (3) AS tab(col) | struct |
+| org.apache.spark.sql.catalyst.expressions.st.ST_AsBinary | st_asbinary | SELECT hex(st_asbinary(st_geogfromwkb(X'0101000000000000000000F03F0000000000000040'))) | struct |
+| org.apache.spark.sql.catalyst.expressions.st.ST_GeogFromWKB | st_geogfromwkb | SELECT hex(st_asbinary(st_geogfromwkb(X'0101000000000000000000F03F0000000000000040'))) | struct |
+| org.apache.spark.sql.catalyst.expressions.st.ST_GeomFromWKB | st_geomfromwkb | SELECT hex(st_asbinary(st_geomfromwkb(X'0101000000000000000000F03F0000000000000040'))) | struct |
+| org.apache.spark.sql.catalyst.expressions.st.ST_SetSrid | st_setsrid | SELECT st_srid(st_setsrid(ST_GeogFromWKB(X'0101000000000000000000F03F0000000000000040'), 4326)) | struct |
+| org.apache.spark.sql.catalyst.expressions.st.ST_Srid | st_srid | SELECT st_srid(st_geogfromwkb(X'0101000000000000000000F03F0000000000000040')) | struct |
| org.apache.spark.sql.catalyst.expressions.variant.IsVariantNull | is_variant_null | SELECT is_variant_null(parse_json('null')) | struct |
| org.apache.spark.sql.catalyst.expressions.variant.ParseJsonExpressionBuilder | parse_json | SELECT parse_json('{"a":1,"b":0.8}') | struct |
| org.apache.spark.sql.catalyst.expressions.variant.SchemaOfVariant | schema_of_variant | SELECT schema_of_variant(parse_json('null')) | struct |
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/execute-immediate.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/execute-immediate.sql.out
index 1271f730d1e53..c874945badb1b 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/execute-immediate.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/execute-immediate.sql.out
@@ -1224,3 +1224,43 @@ EXECUTE IMMEDIATE 'SELECT typeof(:p) as type, :p as val' USING MAP(1, 'one', 2,
-- !query analysis
Project [typeof(map(1, one, 2, two)) AS type#x, map(1, one, 2, two) AS val#x]
+- OneRowRelation
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT :param'
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "UNBOUND_SQL_PARAMETER",
+ "sqlState" : "42P02",
+ "messageParameters" : {
+ "name" : "param"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 13,
+ "fragment" : ":param"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT ?'
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "UNBOUND_SQL_PARAMETER",
+ "sqlState" : "42P02",
+ "messageParameters" : {
+ "name" : "_7"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 8,
+ "fragment" : "?"
+ } ]
+}
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/hll.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/hll.sql.out
index 167c8f930d25d..291f071ef06c2 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/hll.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/hll.sql.out
@@ -233,6 +233,49 @@ Aggregate [hll_sketch_agg(col#x, 40, 0, 0) AS hll_sketch_agg(col, 40)#x]
+- LocalRelation [col#x]
+-- !query
+SELECT hll_sketch_agg(col, CAST(NULL AS INT)) AS k_is_null
+FROM VALUES (15), (16), (17) tab(col)
+-- !query analysis
+Aggregate [hll_sketch_agg(col#x, cast(null as int), 0, 0) AS k_is_null#x]
++- SubqueryAlias tab
+ +- LocalRelation [col#x]
+
+
+-- !query
+SELECT hll_sketch_agg(col, CAST(col AS INT)) AS k_non_constant
+FROM VALUES (15), (16), (17) tab(col)
+-- !query analysis
+Aggregate [hll_sketch_agg(col#x, cast(col#x as int), 0, 0) AS k_non_constant#x]
++- SubqueryAlias tab
+ +- LocalRelation [col#x]
+
+
+-- !query
+SELECT hll_sketch_agg(col, '15')
+FROM VALUES (50), (60), (60) tab(col)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"15\"",
+ "inputType" : "\"STRING\"",
+ "paramIndex" : "second",
+ "requiredType" : "\"INT\"",
+ "sqlExpr" : "\"hll_sketch_agg(col, 15)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 32,
+ "fragment" : "hll_sketch_agg(col, '15')"
+ } ]
+}
+
+
-- !query
SELECT hll_union(
hll_sketch_agg(col1, 12),
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out
new file mode 100644
index 0000000000000..94fff8f586972
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause-legacy.sql.out
@@ -0,0 +1,2596 @@
+-- Automatically generated by SQLQueryTestSuite
+-- !query
+SET hivevar:colname = 'c'
+-- !query analysis
+SetCommand (hivevar:colname,Some('c'))
+
+
+-- !query
+SELECT IDENTIFIER(${colname} || '_1') FROM VALUES(1) AS T(c_1)
+-- !query analysis
+Project [c_1#x]
++- SubqueryAlias T
+ +- LocalRelation [c_1#x]
+
+
+-- !query
+SELECT IDENTIFIER('c1') FROM VALUES(1) AS T(c1)
+-- !query analysis
+Project [c1#x]
++- SubqueryAlias T
+ +- LocalRelation [c1#x]
+
+
+-- !query
+SELECT IDENTIFIER('t.c1') FROM VALUES(1) AS T(c1)
+-- !query analysis
+Project [c1#x]
++- SubqueryAlias T
+ +- LocalRelation [c1#x]
+
+
+-- !query
+SELECT IDENTIFIER('`t`.c1') FROM VALUES(1) AS T(c1)
+-- !query analysis
+Project [c1#x]
++- SubqueryAlias T
+ +- LocalRelation [c1#x]
+
+
+-- !query
+SELECT IDENTIFIER('`c 1`') FROM VALUES(1) AS T(`c 1`)
+-- !query analysis
+Project [c 1#x]
++- SubqueryAlias T
+ +- LocalRelation [c 1#x]
+
+
+-- !query
+SELECT IDENTIFIER('``') FROM VALUES(1) AS T(``)
+-- !query analysis
+Project [#x]
++- SubqueryAlias T
+ +- LocalRelation [#x]
+
+
+-- !query
+SELECT IDENTIFIER('c' || '1') FROM VALUES(1) AS T(c1)
+-- !query analysis
+Project [c1#x]
++- SubqueryAlias T
+ +- LocalRelation [c1#x]
+
+
+-- !query
+CREATE SCHEMA IF NOT EXISTS s
+-- !query analysis
+CreateNamespace true
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [s]
+
+
+-- !query
+CREATE TABLE s.tab(c1 INT) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`s`.`tab`, false
+
+
+-- !query
+USE SCHEMA s
+-- !query analysis
+SetNamespaceCommand [s]
+
+
+-- !query
+INSERT INTO IDENTIFIER('ta' || 'b') VALUES(1)
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/s.db/tab, false, CSV, [path=file:[not included in comparison]/{warehouse_dir}/s.db/tab], Append, `spark_catalog`.`s`.`tab`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/s.db/tab), [c1]
++- Project [col1#x AS c1#x]
+ +- LocalRelation [col1#x]
+
+
+-- !query
+DELETE FROM IDENTIFIER('ta' || 'b') WHERE 1=0
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "UNSUPPORTED_FEATURE.TABLE_OPERATION",
+ "sqlState" : "0A000",
+ "messageParameters" : {
+ "operation" : "DELETE",
+ "tableName" : "`spark_catalog`.`s`.`tab`"
+ }
+}
+
+
+-- !query
+UPDATE IDENTIFIER('ta' || 'b') SET c1 = 2
+-- !query analysis
+org.apache.spark.SparkUnsupportedOperationException
+{
+ "errorClass" : "UNSUPPORTED_FEATURE.TABLE_OPERATION",
+ "sqlState" : "0A000",
+ "messageParameters" : {
+ "operation" : "UPDATE TABLE",
+ "tableName" : "`spark_catalog`.`s`.`tab`"
+ }
+}
+
+
+-- !query
+MERGE INTO IDENTIFIER('ta' || 'b') AS t USING IDENTIFIER('ta' || 'b') AS s ON s.c1 = t.c1
+ WHEN MATCHED THEN UPDATE SET c1 = 3
+-- !query analysis
+org.apache.spark.SparkUnsupportedOperationException
+{
+ "errorClass" : "UNSUPPORTED_FEATURE.TABLE_OPERATION",
+ "sqlState" : "0A000",
+ "messageParameters" : {
+ "operation" : "MERGE INTO TABLE",
+ "tableName" : "`spark_catalog`.`s`.`tab`"
+ }
+}
+
+
+-- !query
+SELECT * FROM IDENTIFIER('tab')
+-- !query analysis
+Project [c1#x]
++- SubqueryAlias spark_catalog.s.tab
+ +- Relation spark_catalog.s.tab[c1#x] csv
+
+
+-- !query
+SELECT * FROM IDENTIFIER('s.tab')
+-- !query analysis
+Project [c1#x]
++- SubqueryAlias spark_catalog.s.tab
+ +- Relation spark_catalog.s.tab[c1#x] csv
+
+
+-- !query
+SELECT * FROM IDENTIFIER('`s`.`tab`')
+-- !query analysis
+Project [c1#x]
++- SubqueryAlias spark_catalog.s.tab
+ +- Relation spark_catalog.s.tab[c1#x] csv
+
+
+-- !query
+SELECT * FROM IDENTIFIER('t' || 'a' || 'b')
+-- !query analysis
+Project [c1#x]
++- SubqueryAlias spark_catalog.s.tab
+ +- Relation spark_catalog.s.tab[c1#x] csv
+
+
+-- !query
+USE SCHEMA default
+-- !query analysis
+SetNamespaceCommand [default]
+
+
+-- !query
+DROP TABLE s.tab
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), s.tab
+
+
+-- !query
+DROP SCHEMA s
+-- !query analysis
+DropNamespace false, false
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [s]
+
+
+-- !query
+SELECT IDENTIFIER('COAL' || 'ESCE')(NULL, 1)
+-- !query analysis
+Project [coalesce(cast(null as int), 1) AS coalesce(NULL, 1)#x]
++- OneRowRelation
+
+
+-- !query
+SELECT IDENTIFIER('abs')(c1) FROM VALUES(-1) AS T(c1)
+-- !query analysis
+Project [abs(c1#x) AS abs(c1)#x]
++- SubqueryAlias T
+ +- LocalRelation [c1#x]
+
+
+-- !query
+SELECT * FROM IDENTIFIER('ra' || 'nge')(0, 1)
+-- !query analysis
+Project [id#xL]
++- Range (0, 1, step=1)
+
+
+-- !query
+CREATE TABLE IDENTIFIER('tab')(c1 INT) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`default`.`tab`, false
+
+
+-- !query
+DROP TABLE IF EXISTS IDENTIFIER('ta' || 'b')
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.tab
+
+
+-- !query
+CREATE SCHEMA identifier_clauses
+-- !query analysis
+CreateNamespace false
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [identifier_clauses]
+
+
+-- !query
+USE identifier_clauses
+-- !query analysis
+SetCatalogAndNamespace
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [identifier_clauses]
+
+
+-- !query
+CREATE TABLE IDENTIFIER('ta' || 'b')(c1 INT) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clauses`.`tab`, false
+
+
+-- !query
+DROP TABLE IF EXISTS IDENTIFIER('identifier_clauses.' || 'tab')
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clauses.tab
+
+
+-- !query
+CREATE TABLE IDENTIFIER('identifier_clauses.' || 'tab')(c1 INT) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clauses`.`tab`, false
+
+
+-- !query
+REPLACE TABLE IDENTIFIER('identifier_clauses.' || 'tab')(c1 INT) USING CSV
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "UNSUPPORTED_FEATURE.TABLE_OPERATION",
+ "sqlState" : "0A000",
+ "messageParameters" : {
+ "operation" : "REPLACE TABLE",
+ "tableName" : "`spark_catalog`.`identifier_clauses`.`tab`"
+ }
+}
+
+
+-- !query
+CACHE TABLE IDENTIFIER('ta' || 'b')
+-- !query analysis
+CacheTable [tab], false, true
+ +- SubqueryAlias spark_catalog.identifier_clauses.tab
+ +- Relation spark_catalog.identifier_clauses.tab[c1#x] csv
+
+
+-- !query
+UNCACHE TABLE IDENTIFIER('ta' || 'b')
+-- !query analysis
+UncacheTable false, true
+ +- SubqueryAlias spark_catalog.identifier_clauses.tab
+ +- Relation spark_catalog.identifier_clauses.tab[c1#x] csv
+
+
+-- !query
+DROP TABLE IF EXISTS IDENTIFIER('ta' || 'b')
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clauses.tab
+
+
+-- !query
+USE default
+-- !query analysis
+SetCatalogAndNamespace
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [default]
+
+
+-- !query
+DROP SCHEMA identifier_clauses
+-- !query analysis
+DropNamespace false, false
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [identifier_clauses]
+
+
+-- !query
+CREATE TABLE tab(c1 INT) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`default`.`tab`, false
+
+
+-- !query
+INSERT INTO tab VALUES (1)
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/tab, false, CSV, [path=file:[not included in comparison]/{warehouse_dir}/tab], Append, `spark_catalog`.`default`.`tab`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/tab), [c1]
++- Project [col1#x AS c1#x]
+ +- LocalRelation [col1#x]
+
+
+-- !query
+SELECT c1 FROM tab
+-- !query analysis
+Project [c1#x]
++- SubqueryAlias spark_catalog.default.tab
+ +- Relation spark_catalog.default.tab[c1#x] csv
+
+
+-- !query
+DESCRIBE IDENTIFIER('ta' || 'b')
+-- !query analysis
+DescribeTableCommand `spark_catalog`.`default`.`tab`, false, [col_name#x, data_type#x, comment#x]
+
+
+-- !query
+ANALYZE TABLE IDENTIFIER('ta' || 'b') COMPUTE STATISTICS
+-- !query analysis
+AnalyzeTableCommand `spark_catalog`.`default`.`tab`, false
+
+
+-- !query
+ALTER TABLE IDENTIFIER('ta' || 'b') ADD COLUMN c2 INT
+-- !query analysis
+AlterTableAddColumnsCommand `spark_catalog`.`default`.`tab`, [StructField(c2,IntegerType,true)]
+
+
+-- !query
+SHOW TBLPROPERTIES IDENTIFIER('ta' || 'b')
+-- !query analysis
+ShowTableProperties [key#x, value#x]
++- ResolvedTable V2SessionCatalog(spark_catalog), default.tab, V1Table(default.tab), [c1#x, c2#x]
+
+
+-- !query
+SHOW COLUMNS FROM IDENTIFIER('ta' || 'b')
+-- !query analysis
+ShowColumnsCommand `spark_catalog`.`default`.`tab`, [col_name#x]
+
+
+-- !query
+COMMENT ON TABLE IDENTIFIER('ta' || 'b') IS 'hello'
+-- !query analysis
+CommentOnTable hello
++- ResolvedTable V2SessionCatalog(spark_catalog), default.tab, V1Table(default.tab), [c1#x, c2#x]
+
+
+-- !query
+REFRESH TABLE IDENTIFIER('ta' || 'b')
+-- !query analysis
+RefreshTableCommand `spark_catalog`.`default`.`tab`
+
+
+-- !query
+REPAIR TABLE IDENTIFIER('ta' || 'b')
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_PARTITIONED_TABLE",
+ "sqlState" : "42809",
+ "messageParameters" : {
+ "operation" : "MSCK REPAIR TABLE",
+ "tableIdentWithDB" : "`spark_catalog`.`default`.`tab`"
+ }
+}
+
+
+-- !query
+TRUNCATE TABLE IDENTIFIER('ta' || 'b')
+-- !query analysis
+TruncateTableCommand `spark_catalog`.`default`.`tab`
+
+
+-- !query
+DROP TABLE IF EXISTS tab
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.tab
+
+
+-- !query
+CREATE OR REPLACE VIEW IDENTIFIER('v')(c1) AS VALUES(1)
+-- !query analysis
+CreateViewCommand `spark_catalog`.`default`.`v`, [(c1,None)], VALUES(1), false, true, PersistedView, COMPENSATION, true
+ +- LocalRelation [col1#x]
+
+
+-- !query
+SELECT * FROM v
+-- !query analysis
+Project [c1#x]
++- SubqueryAlias spark_catalog.default.v
+ +- View (`spark_catalog`.`default`.`v`, [c1#x])
+ +- Project [cast(col1#x as int) AS c1#x]
+ +- LocalRelation [col1#x]
+
+
+-- !query
+ALTER VIEW IDENTIFIER('v') AS VALUES(2)
+-- !query analysis
+AlterViewAsCommand `spark_catalog`.`default`.`v`, VALUES(2), true
+ +- LocalRelation [col1#x]
+
+
+-- !query
+DROP VIEW IDENTIFIER('v')
+-- !query analysis
+DropTableCommand `spark_catalog`.`default`.`v`, false, true, false
+
+
+-- !query
+CREATE TEMPORARY VIEW IDENTIFIER('v')(c1) AS VALUES(1)
+-- !query analysis
+CreateViewCommand `v`, [(c1,None)], VALUES(1), false, false, LocalTempView, UNSUPPORTED, true
+ +- LocalRelation [col1#x]
+
+
+-- !query
+DROP VIEW IDENTIFIER('v')
+-- !query analysis
+DropTempViewCommand v
+
+
+-- !query
+CREATE SCHEMA IDENTIFIER('id' || 'ent')
+-- !query analysis
+CreateNamespace false
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [ident]
+
+
+-- !query
+ALTER SCHEMA IDENTIFIER('id' || 'ent') SET PROPERTIES (somekey = 'somevalue')
+-- !query analysis
+SetNamespaceProperties [somekey=somevalue]
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [ident]
+
+
+-- !query
+ALTER SCHEMA IDENTIFIER('id' || 'ent') SET LOCATION 'someloc'
+-- !query analysis
+SetNamespaceLocation someloc
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [ident]
+
+
+-- !query
+COMMENT ON SCHEMA IDENTIFIER('id' || 'ent') IS 'some comment'
+-- !query analysis
+CommentOnNamespace some comment
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [ident]
+
+
+-- !query
+DESCRIBE SCHEMA IDENTIFIER('id' || 'ent')
+-- !query analysis
+DescribeNamespace false, [info_name#x, info_value#x]
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [ident]
+
+
+-- !query
+SHOW TABLES IN IDENTIFIER('id' || 'ent')
+-- !query analysis
+ShowTables [namespace#x, tableName#x, isTemporary#x]
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [ident]
+
+
+-- !query
+SHOW TABLE EXTENDED IN IDENTIFIER('id' || 'ent') LIKE 'hello'
+-- !query analysis
+ShowTablesCommand ident, hello, [namespace#x, tableName#x, isTemporary#x, information#x], true
+
+
+-- !query
+USE IDENTIFIER('id' || 'ent')
+-- !query analysis
+SetCatalogAndNamespace
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [ident]
+
+
+-- !query
+SHOW CURRENT SCHEMA
+-- !query analysis
+ShowCurrentNamespaceCommand
+
+
+-- !query
+USE SCHEMA IDENTIFIER('id' || 'ent')
+-- !query analysis
+SetNamespaceCommand [ident]
+
+
+-- !query
+USE SCHEMA default
+-- !query analysis
+SetNamespaceCommand [default]
+
+
+-- !query
+DROP SCHEMA IDENTIFIER('id' || 'ent')
+-- !query analysis
+DropNamespace false, false
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [ident]
+
+
+-- !query
+CREATE SCHEMA ident
+-- !query analysis
+CreateNamespace false
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [ident]
+
+
+-- !query
+CREATE FUNCTION IDENTIFIER('ident.' || 'myDoubleAvg') AS 'test.org.apache.spark.sql.MyDoubleAvg'
+-- !query analysis
+CreateFunctionCommand spark_catalog.ident.myDoubleAvg, test.org.apache.spark.sql.MyDoubleAvg, false, false, false
+
+
+-- !query
+DESCRIBE FUNCTION IDENTIFIER('ident.' || 'myDoubleAvg')
+-- !query analysis
+DescribeFunctionCommand org.apache.spark.sql.catalyst.expressions.ExpressionInfo@xxxxxxxx, false
+
+
+-- !query
+REFRESH FUNCTION IDENTIFIER('ident.' || 'myDoubleAvg')
+-- !query analysis
+RefreshFunctionCommand ident, mydoubleavg
+
+
+-- !query
+DROP FUNCTION IDENTIFIER('ident.' || 'myDoubleAvg')
+-- !query analysis
+DropFunctionCommand spark_catalog.ident.mydoubleavg, false, false
+
+
+-- !query
+DROP SCHEMA ident
+-- !query analysis
+DropNamespace false, false
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [ident]
+
+
+-- !query
+CREATE TEMPORARY FUNCTION IDENTIFIER('my' || 'DoubleAvg') AS 'test.org.apache.spark.sql.MyDoubleAvg'
+-- !query analysis
+CreateFunctionCommand myDoubleAvg, test.org.apache.spark.sql.MyDoubleAvg, true, false, false
+
+
+-- !query
+DROP TEMPORARY FUNCTION IDENTIFIER('my' || 'DoubleAvg')
+-- !query analysis
+DropFunctionCommand myDoubleAvg, false, true
+
+
+-- !query
+DECLARE var = 'sometable'
+-- !query analysis
+CreateVariable defaultvalueexpression(sometable, 'sometable'), false
++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.var
+
+
+-- !query
+CREATE TABLE IDENTIFIER(var)(c1 INT) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`default`.`sometable`, false
+
+
+-- !query
+SET VAR var = 'c1'
+-- !query analysis
+SetVariable [variablereference(system.session.var='sometable')]
++- Project [c1 AS var#x]
+ +- OneRowRelation
+
+
+-- !query
+SELECT IDENTIFIER(var) FROM VALUES(1) AS T(c1)
+-- !query analysis
+Project [c1#x]
++- SubqueryAlias T
+ +- LocalRelation [c1#x]
+
+
+-- !query
+SET VAR var = 'some'
+-- !query analysis
+SetVariable [variablereference(system.session.var='c1')]
++- Project [some AS var#x]
+ +- OneRowRelation
+
+
+-- !query
+DROP TABLE IDENTIFIER(var || 'table')
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.sometable
+
+
+-- !query
+SELECT IDENTIFIER('c 1') FROM VALUES(1) AS T(`c 1`)
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'1'",
+ "hint" : ": extra input '1'"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 24,
+ "fragment" : "IDENTIFIER('c 1')"
+ } ]
+}
+
+
+-- !query
+SELECT IDENTIFIER('') FROM VALUES(1) AS T(``)
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_EMPTY_STATEMENT",
+ "sqlState" : "42617",
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 21,
+ "fragment" : "IDENTIFIER('')"
+ } ]
+}
+
+
+-- !query
+VALUES(IDENTIFIER(CAST(NULL AS STRING)))
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NULL",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "CAST(NULL AS STRING)",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 19,
+ "stopIndex" : 38,
+ "fragment" : "CAST(NULL AS STRING)"
+ } ]
+}
+
+
+-- !query
+VALUES(IDENTIFIER(1))
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.WRONG_TYPE",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "dataType" : "int",
+ "expr" : "1",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 19,
+ "stopIndex" : 19,
+ "fragment" : "1"
+ } ]
+}
+
+
+-- !query
+VALUES(IDENTIFIER(SUBSTR('HELLO', 1, RAND() + 1)))
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NOT_CONSTANT",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "substr('HELLO', 1, CAST((rand() + CAST(1 AS DOUBLE)) AS INT))",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 19,
+ "stopIndex" : 48,
+ "fragment" : "SUBSTR('HELLO', 1, RAND() + 1)"
+ } ]
+}
+
+
+-- !query
+SELECT `IDENTIFIER`('abs')(c1) FROM VALUES(-1) AS T(c1)
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "UNRESOLVED_ROUTINE",
+ "sqlState" : "42883",
+ "messageParameters" : {
+ "routineName" : "`IDENTIFIER`",
+ "searchPath" : "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`default`]"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 26,
+ "fragment" : "`IDENTIFIER`('abs')"
+ } ]
+}
+
+
+-- !query
+CREATE TABLE t(col1 INT)
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`default`.`t`, false
+
+
+-- !query
+SELECT * FROM IDENTIFIER((SELECT 't'))
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NOT_CONSTANT",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "scalarsubquery()",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 26,
+ "stopIndex" : 37,
+ "fragment" : "(SELECT 't')"
+ } ]
+}
+
+
+-- !query
+SELECT * FROM (SELECT IDENTIFIER((SELECT 'col1')) FROM IDENTIFIER((SELECT 't')))
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NOT_CONSTANT",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "scalarsubquery()",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 67,
+ "stopIndex" : 78,
+ "fragment" : "(SELECT 't')"
+ } ]
+}
+
+
+-- !query
+SELECT IDENTIFIER((SELECT 'col1')) FROM VALUES(1)
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NOT_CONSTANT",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "scalarsubquery()",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 19,
+ "stopIndex" : 33,
+ "fragment" : "(SELECT 'col1')"
+ } ]
+}
+
+
+-- !query
+SELECT col1, IDENTIFIER((SELECT col1)) FROM VALUES(1)
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NOT_CONSTANT",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "scalarsubquery(col1)",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 25,
+ "stopIndex" : 37,
+ "fragment" : "(SELECT col1)"
+ } ]
+}
+
+
+-- !query
+SELECT IDENTIFIER((SELECT 'col1', 'col2')) FROM VALUES(1,2)
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "UNSUPPORTED_TYPED_LITERAL",
+ "sqlState" : "0A000",
+ "messageParameters" : {
+ "supportedTypes" : "\"DATE\", \"TIMESTAMP_NTZ\", \"TIMESTAMP_LTZ\", \"TIMESTAMP\", \"INTERVAL\", \"X\", \"TIME\"",
+ "unsupportedType" : "\"SELECT\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 20,
+ "stopIndex" : 32,
+ "fragment" : "SELECT 'col1'"
+ } ]
+}
+
+
+-- !query
+DROP TABLE t
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t
+
+
+-- !query
+CREATE TABLE IDENTIFIER(1)(c1 INT) USING csv
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.WRONG_TYPE",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "dataType" : "int",
+ "expr" : "1",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 25,
+ "stopIndex" : 25,
+ "fragment" : "1"
+ } ]
+}
+
+
+-- !query
+CREATE TABLE IDENTIFIER('a.b.c')(c1 INT) USING csv
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
+ "sqlState" : "42K05",
+ "messageParameters" : {
+ "namespace" : "`a`.`b`",
+ "sessionCatalog" : "spark_catalog"
+ }
+}
+
+
+-- !query
+CREATE VIEW IDENTIFIER('a.b.c')(c1) AS VALUES(1)
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
+ "sqlState" : "42K05",
+ "messageParameters" : {
+ "namespace" : "`a`.`b`",
+ "sessionCatalog" : "spark_catalog"
+ }
+}
+
+
+-- !query
+DROP TABLE IDENTIFIER('a.b.c')
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
+ "sqlState" : "42K05",
+ "messageParameters" : {
+ "namespace" : "`a`.`b`",
+ "sessionCatalog" : "spark_catalog"
+ }
+}
+
+
+-- !query
+DROP VIEW IDENTIFIER('a.b.c')
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
+ "sqlState" : "42K05",
+ "messageParameters" : {
+ "namespace" : "`a`.`b`",
+ "sessionCatalog" : "spark_catalog"
+ }
+}
+
+
+-- !query
+COMMENT ON TABLE IDENTIFIER('a.b.c.d') IS 'hello'
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
+ "sqlState" : "42K05",
+ "messageParameters" : {
+ "namespace" : "`a`.`b`.`c`",
+ "sessionCatalog" : "spark_catalog"
+ }
+}
+
+
+-- !query
+VALUES(IDENTIFIER(1)())
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.WRONG_TYPE",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "dataType" : "int",
+ "expr" : "1",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 19,
+ "stopIndex" : 19,
+ "fragment" : "1"
+ } ]
+}
+
+
+-- !query
+VALUES(IDENTIFIER('a.b.c.d')())
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "identifier" : "`a`.`b`.`c`.`d`",
+ "limit" : "2"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 30,
+ "fragment" : "IDENTIFIER('a.b.c.d')()"
+ } ]
+}
+
+
+-- !query
+CREATE TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') AS 'test.org.apache.spark.sql.MyDoubleAvg'
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_DATABASE",
+ "sqlState" : "42000",
+ "messageParameters" : {
+ "database" : "`default`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 108,
+ "fragment" : "CREATE TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') AS 'test.org.apache.spark.sql.MyDoubleAvg'"
+ } ]
+}
+
+
+-- !query
+DROP TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg')
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
+ "sqlState" : "42000",
+ "messageParameters" : {
+ "name" : "`default`.`myDoubleAvg`",
+ "statement" : "DROP TEMPORARY FUNCTION"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 63,
+ "fragment" : "DROP TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg')"
+ } ]
+}
+
+
+-- !query
+CREATE TEMPORARY VIEW IDENTIFIER('default.v')(c1) AS VALUES(1)
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "TEMP_VIEW_NAME_TOO_MANY_NAME_PARTS",
+ "sqlState" : "428EK",
+ "messageParameters" : {
+ "actualName" : "`default`.`v`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 62,
+ "fragment" : "CREATE TEMPORARY VIEW IDENTIFIER('default.v')(c1) AS VALUES(1)"
+ } ]
+}
+
+
+-- !query
+create temporary view identifier('v1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1)
+-- !query analysis
+CreateViewCommand `v1`, (select my_col from (values (1), (2), (1) as (my_col)) group by 1), false, false, LocalTempView, UNSUPPORTED, true
+ +- Aggregate [my_col#x], [my_col#x]
+ +- SubqueryAlias __auto_generated_subquery_name
+ +- SubqueryAlias as
+ +- LocalRelation [my_col#x]
+
+
+-- !query
+cache table identifier('t1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1)
+-- !query analysis
+CacheTableAsSelect t1, (select my_col from (values (1), (2), (1) as (my_col)) group by 1), false, true
+ +- Aggregate [my_col#x], [my_col#x]
+ +- SubqueryAlias __auto_generated_subquery_name
+ +- SubqueryAlias as
+ +- LocalRelation [my_col#x]
+
+
+-- !query
+create table identifier('t2') using csv as (select my_col from (values (1), (2), (1) as (my_col)) group by 1)
+-- !query analysis
+CreateDataSourceTableAsSelectCommand `spark_catalog`.`default`.`t2`, ErrorIfExists, [my_col]
+ +- Aggregate [my_col#x], [my_col#x]
+ +- SubqueryAlias __auto_generated_subquery_name
+ +- SubqueryAlias as
+ +- LocalRelation [my_col#x]
+
+
+-- !query
+insert into identifier('t2') select my_col from (values (3) as (my_col)) group by 1
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/t2, false, CSV, [path=file:[not included in comparison]/{warehouse_dir}/t2], Append, `spark_catalog`.`default`.`t2`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/t2), [my_col]
++- Project [my_col#x AS my_col#x]
+ +- Aggregate [my_col#x], [my_col#x]
+ +- SubqueryAlias __auto_generated_subquery_name
+ +- SubqueryAlias as
+ +- LocalRelation [my_col#x]
+
+
+-- !query
+drop view v1
+-- !query analysis
+DropTempViewCommand v1
+
+
+-- !query
+drop table t1
+-- !query analysis
+DropTempViewCommand t1
+
+
+-- !query
+drop table t2
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t2
+
+
+-- !query
+DECLARE agg = 'max'
+-- !query analysis
+CreateVariable defaultvalueexpression(max, 'max'), false
++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.agg
+
+
+-- !query
+DECLARE col = 'c1'
+-- !query analysis
+CreateVariable defaultvalueexpression(c1, 'c1'), false
++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.col
+
+
+-- !query
+DECLARE tab = 'T'
+-- !query analysis
+CreateVariable defaultvalueexpression(T, 'T'), false
++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.tab
+
+
+-- !query
+WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)),
+ T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd'))
+SELECT IDENTIFIER(agg)(IDENTIFIER(col)) FROM IDENTIFIER(tab)
+-- !query analysis
+WithCTE
+:- CTERelationDef xxxx, false
+: +- SubqueryAlias S
+: +- Project [col1#x AS c1#x, col2#x AS c2#x]
+: +- LocalRelation [col1#x, col2#x]
+:- CTERelationDef xxxx, false
+: +- SubqueryAlias T
+: +- Project [col1#x AS c1#x, col2#x AS c2#x]
+: +- LocalRelation [col1#x, col2#x]
++- Aggregate [max(c1#x) AS max(c1)#x]
+ +- SubqueryAlias T
+ +- CTERelationRef xxxx, true, [c1#x, c2#x], false, false, 2
+
+
+-- !query
+WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)),
+ T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd'))
+SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('T')
+-- !query analysis
+WithCTE
+:- CTERelationDef xxxx, false
+: +- SubqueryAlias S
+: +- Project [col1#x AS c1#x, col2#x AS c2#x]
+: +- LocalRelation [col1#x, col2#x]
+:- CTERelationDef xxxx, false
+: +- SubqueryAlias T
+: +- Project [col1#x AS c1#x, col2#x AS c2#x]
+: +- LocalRelation [col1#x, col2#x]
++- Aggregate [max(c1#x) AS max(c1)#x]
+ +- SubqueryAlias T
+ +- CTERelationRef xxxx, true, [c1#x, c2#x], false, false, 2
+
+
+-- !query
+WITH ABC(c1, c2) AS (VALUES(1, 2), (2, 3))
+SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('A' || 'BC')
+-- !query analysis
+WithCTE
+:- CTERelationDef xxxx, false
+: +- SubqueryAlias ABC
+: +- Project [col1#x AS c1#x, col2#x AS c2#x]
+: +- LocalRelation [col1#x, col2#x]
++- Aggregate [max(c1#x) AS max(c1)#x]
+ +- SubqueryAlias ABC
+ +- CTERelationRef xxxx, true, [c1#x, c2#x], false, false, 2
+
+
+-- !query
+SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1)
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "''x.win''",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT T1.c1 FROM VALUES(1) AS T1(c1) JOIN VALUES(1) AS T2(c1) USING (IDENTIFIER('c1'))
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT IDENTIFIER('t').c1 FROM VALUES(1) AS T(c1)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION",
+ "sqlState" : "42703",
+ "messageParameters" : {
+ "objectName" : "`t`",
+ "proposal" : "`c1`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 22,
+ "fragment" : "IDENTIFIER('t')"
+ } ]
+}
+
+
+-- !query
+SELECT map('a', 1).IDENTIFIER('a') FROM VALUES(1) AS T(c1)
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "''a''",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT named_struct('a', 1).IDENTIFIER('a') FROM VALUES(1) AS T(c1)
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "''a''",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT * FROM s.IDENTIFIER('tab')
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT * FROM IDENTIFIER('s').IDENTIFIER('tab')
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'IDENTIFIER'",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT * FROM IDENTIFIER('s').tab
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'IDENTIFIER'",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT row_number() OVER IDENTIFIER('win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1)
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "''win''",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT row_number() OVER win FROM VALUES(1) AS T(c1) WINDOW IDENTIFIER('win') AS (ORDER BY c1)
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ": missing 'AS'"
+ }
+}
+
+
+-- !query
+SELECT 1 AS IDENTIFIER('col1')
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT my_table.* FROM VALUES (1, 2) AS IDENTIFIER('my_table')(IDENTIFIER('c1'), IDENTIFIER('c2'))
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "''my_table''",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+WITH identifier('v')(identifier('c1')) AS (VALUES(1)) (SELECT c1 FROM v)
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "''v''",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+CREATE OR REPLACE VIEW v(IDENTIFIER('c1')) AS VALUES(1)
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT c1 FROM v
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "TABLE_OR_VIEW_NOT_FOUND",
+ "sqlState" : "42P01",
+ "messageParameters" : {
+ "relationName" : "`v`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 16,
+ "stopIndex" : 16,
+ "fragment" : "v"
+ } ]
+}
+
+
+-- !query
+DROP VIEW IF EXISTS v
+-- !query analysis
+DropTableCommand `spark_catalog`.`default`.`v`, true, true, false
+
+
+-- !query
+CREATE TABLE tab(IDENTIFIER('c1') INT) USING CSV
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+INSERT INTO tab(IDENTIFIER('c1')) VALUES(1)
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ": missing ')'"
+ }
+}
+
+
+-- !query
+SELECT c1 FROM tab
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "TABLE_OR_VIEW_NOT_FOUND",
+ "sqlState" : "42P01",
+ "messageParameters" : {
+ "relationName" : "`tab`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 16,
+ "stopIndex" : 18,
+ "fragment" : "tab"
+ } ]
+}
+
+
+-- !query
+ALTER TABLE IDENTIFIER('tab') RENAME COLUMN IDENTIFIER('c1') TO IDENTIFIER('col1')
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT col1 FROM tab
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "TABLE_OR_VIEW_NOT_FOUND",
+ "sqlState" : "42P01",
+ "messageParameters" : {
+ "relationName" : "`tab`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 18,
+ "stopIndex" : 20,
+ "fragment" : "tab"
+ } ]
+}
+
+
+-- !query
+ALTER TABLE IDENTIFIER('tab') ADD COLUMN IDENTIFIER('c2') INT
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT c2 FROM tab
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "TABLE_OR_VIEW_NOT_FOUND",
+ "sqlState" : "42P01",
+ "messageParameters" : {
+ "relationName" : "`tab`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 16,
+ "stopIndex" : 18,
+ "fragment" : "tab"
+ } ]
+}
+
+
+-- !query
+ALTER TABLE IDENTIFIER('tab') DROP COLUMN IDENTIFIER('c2')
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+ALTER TABLE IDENTIFIER('tab') RENAME TO IDENTIFIER('tab_renamed')
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT * FROM tab_renamed
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "TABLE_OR_VIEW_NOT_FOUND",
+ "sqlState" : "42P01",
+ "messageParameters" : {
+ "relationName" : "`tab_renamed`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 15,
+ "stopIndex" : 25,
+ "fragment" : "tab_renamed"
+ } ]
+}
+
+
+-- !query
+DROP TABLE IF EXISTS tab_renamed
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.tab_renamed
+
+
+-- !query
+DROP TABLE IF EXISTS tab
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.tab
+
+
+-- !query
+CREATE TABLE test_col_with_dot(IDENTIFIER('`col.with.dot`') INT) USING CSV
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+DROP TABLE IF EXISTS test_col_with_dot
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.test_col_with_dot
+
+
+-- !query
+SELECT * FROM VALUES (1, 2) AS IDENTIFIER('schema.table')(c1, c2)
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "''schema.table''",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT 1 AS IDENTIFIER('col1.col2')
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+CREATE SCHEMA identifier_clause_test_schema
+-- !query analysis
+CreateNamespace false
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [identifier_clause_test_schema]
+
+
+-- !query
+USE identifier_clause_test_schema
+-- !query analysis
+SetCatalogAndNamespace
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [identifier_clause_test_schema]
+
+
+-- !query
+CREATE TABLE test_show(c1 INT, c2 STRING) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_show`, false
+
+
+-- !query
+SHOW VIEWS IN IDENTIFIER('identifier_clause_test_schema')
+-- !query analysis
+ShowViewsCommand identifier_clause_test_schema, [namespace#x, viewName#x, isTemporary#x]
+
+
+-- !query
+SHOW PARTITIONS IDENTIFIER('test_show')
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "INVALID_PARTITION_OPERATION.PARTITION_SCHEMA_IS_EMPTY",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "name" : "`spark_catalog`.`identifier_clause_test_schema`.`test_show`"
+ }
+}
+
+
+-- !query
+SHOW CREATE TABLE IDENTIFIER('test_show')
+-- !query analysis
+ShowCreateTable false, [createtab_stmt#x]
++- ResolvedTable V2SessionCatalog(spark_catalog), identifier_clause_test_schema.test_show, V1Table(identifier_clause_test_schema.test_show), [c1#x, c2#x]
+
+
+-- !query
+DROP TABLE test_show
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.test_show
+
+
+-- !query
+CREATE TABLE test_desc(c1 INT) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_desc`, false
+
+
+-- !query
+DESCRIBE TABLE IDENTIFIER('test_desc')
+-- !query analysis
+DescribeTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_desc`, false, [col_name#x, data_type#x, comment#x]
+
+
+-- !query
+DESCRIBE FORMATTED IDENTIFIER('test_desc')
+-- !query analysis
+DescribeTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_desc`, true, [col_name#x, data_type#x, comment#x]
+
+
+-- !query
+DESCRIBE EXTENDED IDENTIFIER('test_desc')
+-- !query analysis
+DescribeTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_desc`, true, [col_name#x, data_type#x, comment#x]
+
+
+-- !query
+DESC IDENTIFIER('test_desc')
+-- !query analysis
+DescribeTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_desc`, false, [col_name#x, data_type#x, comment#x]
+
+
+-- !query
+DROP TABLE test_desc
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.test_desc
+
+
+-- !query
+CREATE TABLE test_comment(c1 INT, c2 STRING) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_comment`, false
+
+
+-- !query
+COMMENT ON TABLE IDENTIFIER('test_comment') IS 'table comment'
+-- !query analysis
+CommentOnTable table comment
++- ResolvedTable V2SessionCatalog(spark_catalog), identifier_clause_test_schema.test_comment, V1Table(identifier_clause_test_schema.test_comment), [c1#x, c2#x]
+
+
+-- !query
+ALTER TABLE test_comment ALTER COLUMN IDENTIFIER('c1') COMMENT 'column comment'
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+DROP TABLE test_comment
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.test_comment
+
+
+-- !query
+CREATE TABLE identifier_clause_test_schema.test_table(c1 INT) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_table`, false
+
+
+-- !query
+ANALYZE TABLE IDENTIFIER('identifier_clause_test_schema.test_table') COMPUTE STATISTICS
+-- !query analysis
+AnalyzeTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_table`, false
+
+
+-- !query
+REFRESH TABLE IDENTIFIER('identifier_clause_test_schema.test_table')
+-- !query analysis
+RefreshTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_table`
+
+
+-- !query
+DESCRIBE IDENTIFIER('identifier_clause_test_schema.test_table')
+-- !query analysis
+DescribeTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_table`, false, [col_name#x, data_type#x, comment#x]
+
+
+-- !query
+SHOW COLUMNS FROM IDENTIFIER('identifier_clause_test_schema.test_table')
+-- !query analysis
+ShowColumnsCommand `spark_catalog`.`identifier_clause_test_schema`.`test_table`, [col_name#x]
+
+
+-- !query
+DROP TABLE IDENTIFIER('identifier_clause_test_schema.test_table')
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.test_table
+
+
+-- !query
+DECLARE IDENTIFIER('my_var') = 'value'
+-- !query analysis
+CreateVariable defaultvalueexpression(value, 'value'), false
++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.my_var
+
+
+-- !query
+SET VAR IDENTIFIER('my_var') = 'new_value'
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ": missing EQ"
+ }
+}
+
+
+-- !query
+SELECT IDENTIFIER('my_var')
+-- !query analysis
+Project [variablereference(system.session.my_var='value') AS variablereference(system.session.my_var='value')#x]
++- OneRowRelation
+
+
+-- !query
+DROP TEMPORARY VARIABLE IDENTIFIER('my_var')
+-- !query analysis
+DropVariable false
++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.my_var
+
+
+-- !query
+CREATE TEMPORARY FUNCTION test_udf(IDENTIFIER('param1') INT, IDENTIFIER('param2') STRING)
+RETURNS INT
+RETURN IDENTIFIER('param1') + length(IDENTIFIER('param2'))
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT test_udf(5, 'hello')
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "UNRESOLVED_ROUTINE",
+ "sqlState" : "42883",
+ "messageParameters" : {
+ "routineName" : "`test_udf`",
+ "searchPath" : "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`identifier_clause_test_schema`]"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 27,
+ "fragment" : "test_udf(5, 'hello')"
+ } ]
+}
+
+
+-- !query
+DROP TEMPORARY FUNCTION test_udf
+-- !query analysis
+org.apache.spark.sql.catalyst.analysis.NoSuchTempFunctionException
+{
+ "errorClass" : "ROUTINE_NOT_FOUND",
+ "sqlState" : "42883",
+ "messageParameters" : {
+ "routineName" : "`test_udf`"
+ }
+}
+
+
+-- !query
+CREATE TEMPORARY FUNCTION test_table_udf(IDENTIFIER('input_val') INT)
+RETURNS TABLE(IDENTIFIER('col1') INT, IDENTIFIER('col2') STRING)
+RETURN SELECT IDENTIFIER('input_val'), 'result'
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT * FROM test_table_udf(42)
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "UNRESOLVABLE_TABLE_VALUED_FUNCTION",
+ "sqlState" : "42883",
+ "messageParameters" : {
+ "name" : "`test_table_udf`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 15,
+ "stopIndex" : 32,
+ "fragment" : "test_table_udf(42)"
+ } ]
+}
+
+
+-- !query
+DROP TEMPORARY FUNCTION test_table_udf
+-- !query analysis
+org.apache.spark.sql.catalyst.analysis.NoSuchTempFunctionException
+{
+ "errorClass" : "ROUTINE_NOT_FOUND",
+ "sqlState" : "42883",
+ "messageParameters" : {
+ "routineName" : "`test_table_udf`"
+ }
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:tab \'b\').c1 FROM VALUES(1) AS tab(c1)' USING 'ta' AS tab
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "INVALID_EXTRACT_BASE_FIELD_TYPE",
+ "sqlState" : "42000",
+ "messageParameters" : {
+ "base" : "\"variablereference(system.session.tab='T')\"",
+ "other" : "\"STRING\""
+ }
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:col1 ''.c2'') FROM VALUES(named_struct(''c2'', 42)) AS T(c1)'
+ USING 'c1' AS col1
+-- !query analysis
+Project [c1#x.c2 AS c1.c2#x]
++- SubqueryAlias T
+ +- LocalRelation [c1#x]
+
+
+-- !query
+CREATE TABLE integration_test(c1 INT, c2 STRING) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`integration_test`, false
+
+
+-- !query
+INSERT INTO integration_test VALUES (1, 'a'), (2, 'b')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/integration_test, false, CSV, [path=file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/integration_test], Append, `spark_catalog`.`identifier_clause_test_schema`.`integration_test`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/integration_test), [c1, c2]
++- Project [col1#x AS c1#x, col2#x AS c2#x]
+ +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT * FROM IDENTIFIER(:schema ''.'' :table) ORDER BY ALL'
+ USING 'identifier_clause_test_schema' AS schema, 'integration_test' AS table
+-- !query analysis
+Sort [c1#x ASC NULLS FIRST, c2#x ASC NULLS FIRST], true
++- Project [c1#x, c2#x]
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.integration_test
+ +- Relation spark_catalog.identifier_clause_test_schema.integration_test[c1#x,c2#x] csv
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:prefix ''1''), IDENTIFIER(:prefix ''2'') FROM integration_test ORDER BY ALL'
+ USING 'c' AS prefix
+-- !query analysis
+Sort [c1#x ASC NULLS FIRST, c2#x ASC NULLS FIRST], true
++- Project [c1#x, c2#x]
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.integration_test
+ +- Relation spark_catalog.identifier_clause_test_schema.integration_test[c1#x,c2#x] csv
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT * FROM integration_test WHERE IDENTIFIER(:col) = :val'
+ USING 'c1' AS col, 1 AS val
+-- !query analysis
+Project [c1#x, c2#x]
++- Filter (c1#x = 1)
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.integration_test
+ +- Relation spark_catalog.identifier_clause_test_schema.integration_test[c1#x,c2#x] csv
+
+
+-- !query
+CREATE TABLE integration_test2(c1 INT, c3 STRING) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`integration_test2`, false
+
+
+-- !query
+INSERT INTO integration_test2 VALUES (1, 'x'), (2, 'y')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/integration_test2, false, CSV, [path=file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/integration_test2], Append, `spark_catalog`.`identifier_clause_test_schema`.`integration_test2`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/integration_test2), [c1, c3]
++- Project [col1#x AS c1#x, col2#x AS c3#x]
+ +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT t1.*, t2.* FROM IDENTIFIER(:t1) t1 JOIN IDENTIFIER(:t2) t2 USING (IDENTIFIER(:col)) ORDER BY ALL'
+ USING 'integration_test' AS t1, 'integration_test2' AS t2, 'c1' AS col
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 103,
+ "fragment" : "SELECT t1.*, t2.* FROM IDENTIFIER(:t1) t1 JOIN IDENTIFIER(:t2) t2 USING (IDENTIFIER(:col)) ORDER BY ALL"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE
+ 'SELECT IDENTIFIER(:col1), IDENTIFIER(:col2), row_number() OVER (PARTITION BY IDENTIFIER(:part) ORDER BY IDENTIFIER(:ord)) as rn FROM integration_test'
+ USING 'c1' AS col1, 'c2' AS col2, 'c2' AS part, 'c1' AS ord
+-- !query analysis
+Project [c1#x, c2#x, rn#x]
++- Project [c1#x, c2#x, rn#x, rn#x]
+ +- Window [row_number() windowspecdefinition(c2#x, c1#x ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rn#x], [c2#x], [c1#x ASC NULLS FIRST]
+ +- Project [c1#x, c2#x]
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.integration_test
+ +- Relation spark_catalog.identifier_clause_test_schema.integration_test[c1#x,c2#x] csv
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:prefix ''2''), IDENTIFIER(:agg)(IDENTIFIER(:col)) FROM integration_test GROUP BY IDENTIFIER(:prefix ''2'') ORDER BY ALL'
+ USING 'c' AS prefix, 'count' AS agg, 'c1' AS col
+-- !query analysis
+Sort [c2#x ASC NULLS FIRST, count(c1)#xL ASC NULLS FIRST], true
++- Aggregate [c2#x], [c2#x, count(c1#x) AS count(c1)#xL]
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.integration_test
+ +- Relation spark_catalog.identifier_clause_test_schema.integration_test[c1#x,c2#x] csv
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT * FROM integration_test ORDER BY IDENTIFIER(:col1) DESC, IDENTIFIER(:col2)'
+ USING 'c1' AS col1, 'c2' AS col2
+-- !query analysis
+Sort [c1#x DESC NULLS LAST, c2#x ASC NULLS FIRST], true
++- Project [c1#x, c2#x]
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.integration_test
+ +- Relation spark_catalog.identifier_clause_test_schema.integration_test[c1#x,c2#x] csv
+
+
+-- !query
+EXECUTE IMMEDIATE 'INSERT INTO integration_test(IDENTIFIER(:col1), IDENTIFIER(:col2)) VALUES (:val1, :val2)'
+ USING 'c1' AS col1, 'c2' AS col2, 3 AS val1, 'c' AS val2
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ": missing ')'"
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 88,
+ "fragment" : "INSERT INTO integration_test(IDENTIFIER(:col1), IDENTIFIER(:col2)) VALUES (:val1, :val2)"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(concat(:schema, ''.'', :table, ''.c1'')) FROM VALUES(named_struct(''c1'', 100)) AS IDENTIFIER(:alias)(IDENTIFIER(:schema ''.'' :table))'
+ USING 'identifier_clause_test_schema' AS schema, 'my_table' AS table, 't' AS alias
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "':'",
+ "hint" : ": extra input ':'"
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 145,
+ "fragment" : "SELECT IDENTIFIER(concat(:schema, '.', :table, '.c1')) FROM VALUES(named_struct('c1', 100)) AS IDENTIFIER(:alias)(IDENTIFIER(:schema '.' :table))"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'WITH IDENTIFIER(:cte_name)(c1) AS (VALUES(1)) SELECT c1 FROM IDENTIFIER(:cte_name)'
+ USING 'my_cte' AS cte_name
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "':'",
+ "hint" : ""
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 82,
+ "fragment" : "WITH IDENTIFIER(:cte_name)(c1) AS (VALUES(1)) SELECT c1 FROM IDENTIFIER(:cte_name)"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'CREATE OR REPLACE TEMPORARY VIEW IDENTIFIER(:view_name)(IDENTIFIER(:col_name)) AS VALUES(1)'
+ USING 'test_view' AS view_name, 'test_col' AS col_name
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 91,
+ "fragment" : "CREATE OR REPLACE TEMPORARY VIEW IDENTIFIER(:view_name)(IDENTIFIER(:col_name)) AS VALUES(1)"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:col) FROM IDENTIFIER(:view)'
+ USING 'test_col' AS col, 'test_view' AS view
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "TABLE_OR_VIEW_NOT_FOUND",
+ "sqlState" : "42P01",
+ "messageParameters" : {
+ "relationName" : "`test_view`"
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 30,
+ "stopIndex" : 46,
+ "fragment" : "IDENTIFIER(:view)"
+ } ]
+}
+
+
+-- !query
+DROP VIEW test_view
+-- !query analysis
+org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+{
+ "errorClass" : "TABLE_OR_VIEW_NOT_FOUND",
+ "sqlState" : "42P01",
+ "messageParameters" : {
+ "relationName" : "`spark_catalog`.`identifier_clause_test_schema`.`test_view`"
+ }
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'ALTER TABLE IDENTIFIER(:tab) ADD COLUMN IDENTIFIER(:new_col) INT'
+ USING 'integration_test' AS tab, 'c4' AS new_col
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 64,
+ "fragment" : "ALTER TABLE IDENTIFIER(:tab) ADD COLUMN IDENTIFIER(:new_col) INT"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'ALTER TABLE IDENTIFIER(:tab) RENAME COLUMN IDENTIFIER(:old_col) TO IDENTIFIER(:new_col)'
+ USING 'integration_test' AS tab, 'c4' AS old_col, 'c5' AS new_col
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 87,
+ "fragment" : "ALTER TABLE IDENTIFIER(:tab) RENAME COLUMN IDENTIFIER(:old_col) TO IDENTIFIER(:new_col)"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT map(:key, :val).IDENTIFIER(:key) AS result'
+ USING 'mykey' AS key, 42 AS val
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "':'",
+ "hint" : ""
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 49,
+ "fragment" : "SELECT map(:key, :val).IDENTIFIER(:key) AS result"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:alias ''.c1'') FROM integration_test AS IDENTIFIER(:alias) ORDER BY ALL'
+ USING 't' AS alias
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "':'",
+ "hint" : ": extra input ':'"
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 88,
+ "fragment" : "SELECT IDENTIFIER(:alias '.c1') FROM integration_test AS IDENTIFIER(:alias) ORDER BY ALL"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE
+ 'SELECT IDENTIFIER(:col1), IDENTIFIER(:p ''2'') FROM IDENTIFIER(:schema ''.'' :tab) WHERE IDENTIFIER(:col1) > 0 ORDER BY IDENTIFIER(:p ''1'')'
+ USING 'c1' AS col1, 'c' AS p, 'identifier_clause_test_schema' AS schema, 'integration_test' AS tab
+-- !query analysis
+Sort [c1#x ASC NULLS FIRST], true
++- Project [c1#x, c2#x]
+ +- Filter (c1#x > 0)
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.integration_test
+ +- Relation spark_catalog.identifier_clause_test_schema.integration_test[c1#x,c2#x] csv
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT * FROM IDENTIFIER(:schema ''.'' :table) WHERE IDENTIFIER(concat(:tab_alias, ''.c1'')) > 0 ORDER BY ALL'
+ USING 'identifier_clause_test_schema' AS schema, 'integration_test' AS table, 'integration_test' AS tab_alias
+-- !query analysis
+Sort [c1#x ASC NULLS FIRST, c2#x ASC NULLS FIRST], true
++- Project [c1#x, c2#x]
+ +- Filter (c1#x > 0)
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.integration_test
+ +- Relation spark_catalog.identifier_clause_test_schema.integration_test[c1#x,c2#x] csv
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT 1 AS IDENTIFIER(:schema ''.'' :col)'
+ USING 'identifier_clause_test_schema' AS schema, 'col1' AS col
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 40,
+ "fragment" : "SELECT 1 AS IDENTIFIER(:schema '.' :col)"
+ } ]
+}
+
+
+-- !query
+DROP TABLE integration_test
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.integration_test
+
+
+-- !query
+DROP TABLE integration_test2
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.integration_test2
+
+
+-- !query
+CREATE TABLE lateral_test(arr ARRAY) USING PARQUET
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`lateral_test`, false
+
+
+-- !query
+INSERT INTO lateral_test VALUES (array(1, 2, 3))
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/lateral_test, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/lateral_test], Append, `spark_catalog`.`identifier_clause_test_schema`.`lateral_test`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/lateral_test), [arr]
++- Project [col1#x AS arr#x]
+ +- LocalRelation [col1#x]
+
+
+-- !query
+SELECT * FROM lateral_test LATERAL VIEW explode(arr) IDENTIFIER('tbl') AS IDENTIFIER('col') ORDER BY ALL
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT * FROM lateral_test LATERAL VIEW OUTER explode(arr) IDENTIFIER('my_table') AS IDENTIFIER('my_col') ORDER BY ALL
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+DROP TABLE lateral_test
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.lateral_test
+
+
+-- !query
+CREATE TABLE unpivot_test(id INT, a INT, b INT, c INT) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`unpivot_test`, false
+
+
+-- !query
+INSERT INTO unpivot_test VALUES (1, 10, 20, 30)
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/unpivot_test, false, CSV, [path=file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/unpivot_test], Append, `spark_catalog`.`identifier_clause_test_schema`.`unpivot_test`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/unpivot_test), [id, a, b, c]
++- Project [col1#x AS id#x, col2#x AS a#x, col3#x AS b#x, col4#x AS c#x]
+ +- LocalRelation [col1#x, col2#x, col3#x, col4#x]
+
+
+-- !query
+SELECT * FROM unpivot_test UNPIVOT (val FOR col IN (a AS IDENTIFIER('col_a'), b AS IDENTIFIER('col_b'))) ORDER BY ALL
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT * FROM unpivot_test UNPIVOT ((v1, v2) FOR col IN ((a, b) AS IDENTIFIER('cols_ab'), (b, c) AS IDENTIFIER('cols_bc'))) ORDER BY ALL
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+DROP TABLE unpivot_test
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.unpivot_test
+
+
+-- !query
+CREATE TABLE describe_col_test(c1 INT, c2 STRING, c3 DOUBLE) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`describe_col_test`, false
+
+
+-- !query
+DESCRIBE describe_col_test IDENTIFIER('c1')
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+DESCRIBE describe_col_test IDENTIFIER('c2')
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+DROP TABLE describe_col_test
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.describe_col_test
+
+
+-- !query
+SELECT :IDENTIFIER('param1') FROM VALUES(1) AS T(c1)
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "''param1''",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+CREATE TABLE hint_test(c1 INT, c2 INT) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`hint_test`, false
+
+
+-- !query
+INSERT INTO hint_test VALUES (1, 2), (3, 4)
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/hint_test, false, CSV, [path=file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/hint_test], Append, `spark_catalog`.`identifier_clause_test_schema`.`hint_test`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/hint_test), [c1, c2]
++- Project [col1#x AS c1#x, col2#x AS c2#x]
+ +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+SELECT /*+ IDENTIFIER('BROADCAST')(hint_test) */ * FROM hint_test
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT /*+ IDENTIFIER('MERGE')(hint_test) */ * FROM hint_test
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+DROP TABLE hint_test
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.hint_test
+
+
+-- !query
+SHOW IDENTIFIER('USER') FUNCTIONS
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT EXTRACT(IDENTIFIER('YEAR') FROM DATE'2024-01-15')
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT TIMESTAMPADD(IDENTIFIER('YEAR'), 1, DATE'2024-01-15')
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "UNRESOLVED_ROUTINE",
+ "sqlState" : "42883",
+ "messageParameters" : {
+ "routineName" : "`TIMESTAMPADD`",
+ "searchPath" : "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`identifier_clause_test_schema`]"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 60,
+ "fragment" : "TIMESTAMPADD(IDENTIFIER('YEAR'), 1, DATE'2024-01-15')"
+ } ]
+}
+
+
+-- !query
+DROP SCHEMA identifier_clause_test_schema
+-- !query analysis
+DropNamespace false, false
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [identifier_clause_test_schema]
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
index 13d911c988381..e6a406072c48b 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/identifier-clause.sql.out
@@ -107,9 +107,11 @@ UPDATE IDENTIFIER('ta' || 'b') SET c1 = 2
-- !query analysis
org.apache.spark.SparkUnsupportedOperationException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_2096",
+ "errorClass" : "UNSUPPORTED_FEATURE.TABLE_OPERATION",
+ "sqlState" : "0A000",
"messageParameters" : {
- "ddl" : "UPDATE TABLE"
+ "operation" : "UPDATE TABLE",
+ "tableName" : "`spark_catalog`.`s`.`tab`"
}
}
@@ -120,9 +122,11 @@ MERGE INTO IDENTIFIER('ta' || 'b') AS t USING IDENTIFIER('ta' || 'b') AS s ON s.
-- !query analysis
org.apache.spark.SparkUnsupportedOperationException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_2096",
+ "errorClass" : "UNSUPPORTED_FEATURE.TABLE_OPERATION",
+ "sqlState" : "0A000",
"messageParameters" : {
- "ddl" : "MERGE INTO TABLE"
+ "operation" : "MERGE INTO TABLE",
+ "tableName" : "`spark_catalog`.`s`.`tab`"
}
}
@@ -731,6 +735,124 @@ org.apache.spark.sql.AnalysisException
}
+-- !query
+CREATE TABLE t(col1 INT)
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`default`.`t`, false
+
+
+-- !query
+SELECT * FROM IDENTIFIER((SELECT 't'))
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NOT_CONSTANT",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "scalarsubquery()",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 26,
+ "stopIndex" : 37,
+ "fragment" : "(SELECT 't')"
+ } ]
+}
+
+
+-- !query
+SELECT * FROM (SELECT IDENTIFIER((SELECT 'col1')) FROM IDENTIFIER((SELECT 't')))
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NOT_CONSTANT",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "scalarsubquery()",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 67,
+ "stopIndex" : 78,
+ "fragment" : "(SELECT 't')"
+ } ]
+}
+
+
+-- !query
+SELECT IDENTIFIER((SELECT 'col1')) FROM VALUES(1)
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NOT_CONSTANT",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "scalarsubquery()",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 19,
+ "stopIndex" : 33,
+ "fragment" : "(SELECT 'col1')"
+ } ]
+}
+
+
+-- !query
+SELECT col1, IDENTIFIER((SELECT col1)) FROM VALUES(1)
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NOT_CONSTANT",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "scalarsubquery(col1)",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 25,
+ "stopIndex" : 37,
+ "fragment" : "(SELECT col1)"
+ } ]
+}
+
+
+-- !query
+SELECT IDENTIFIER((SELECT 'col1', 'col2')) FROM VALUES(1,2)
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "UNSUPPORTED_TYPED_LITERAL",
+ "sqlState" : "0A000",
+ "messageParameters" : {
+ "supportedTypes" : "\"DATE\", \"TIMESTAMP_NTZ\", \"TIMESTAMP_LTZ\", \"TIMESTAMP\", \"INTERVAL\", \"X\", \"TIME\"",
+ "unsupportedType" : "\"SELECT\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 20,
+ "stopIndex" : 32,
+ "fragment" : "SELECT 'col1'"
+ } ]
+}
+
+
+-- !query
+DROP TABLE t
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t
+
+
-- !query
CREATE TABLE IDENTIFIER(1)(c1 INT) USING csv
-- !query analysis
@@ -853,7 +975,8 @@ org.apache.spark.sql.AnalysisException
"errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
"sqlState" : "42601",
"messageParameters" : {
- "identifier" : "`a`.`b`.`c`.`d`"
+ "identifier" : "`a`.`b`.`c`.`d`",
+ "limit" : "2"
},
"queryContext" : [ {
"objectType" : "",
@@ -1064,27 +1187,32 @@ SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win
-- !query analysis
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "PARSE_SYNTAX_ERROR",
+ "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
"sqlState" : "42601",
"messageParameters" : {
- "error" : "''x.win''",
- "hint" : ""
- }
+ "identifier" : "`x`.`win`",
+ "limit" : "1"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 26,
+ "stopIndex" : 44,
+ "fragment" : "IDENTIFIER('x.win')"
+ } ]
}
-- !query
SELECT T1.c1 FROM VALUES(1) AS T1(c1) JOIN VALUES(1) AS T2(c1) USING (IDENTIFIER('c1'))
-- !query analysis
-org.apache.spark.sql.catalyst.parser.ParseException
-{
- "errorClass" : "PARSE_SYNTAX_ERROR",
- "sqlState" : "42601",
- "messageParameters" : {
- "error" : "'('",
- "hint" : ""
- }
-}
+Project [c1#x]
++- Project [c1#x]
+ +- Join Inner, (c1#x = c1#x)
+ :- SubqueryAlias T1
+ : +- LocalRelation [c1#x]
+ +- SubqueryAlias T2
+ +- LocalRelation [c1#x]
-- !query
@@ -1111,40 +1239,28 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
-- !query
SELECT map('a', 1).IDENTIFIER('a') FROM VALUES(1) AS T(c1)
-- !query analysis
-org.apache.spark.sql.catalyst.parser.ParseException
-{
- "errorClass" : "PARSE_SYNTAX_ERROR",
- "sqlState" : "42601",
- "messageParameters" : {
- "error" : "''a''",
- "hint" : ""
- }
-}
+Project [map(a, 1)[a] AS map(a, 1)[a]#x]
++- SubqueryAlias T
+ +- LocalRelation [c1#x]
-- !query
SELECT named_struct('a', 1).IDENTIFIER('a') FROM VALUES(1) AS T(c1)
-- !query analysis
-org.apache.spark.sql.catalyst.parser.ParseException
-{
- "errorClass" : "PARSE_SYNTAX_ERROR",
- "sqlState" : "42601",
- "messageParameters" : {
- "error" : "''a''",
- "hint" : ""
- }
-}
+Project [named_struct(a, 1).a AS named_struct(a, 1).a#x]
++- SubqueryAlias T
+ +- LocalRelation [c1#x]
-- !query
SELECT * FROM s.IDENTIFIER('tab')
-- !query analysis
-org.apache.spark.sql.catalyst.parser.ParseException
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
- "errorClass" : "INVALID_SQL_SYNTAX.INVALID_TABLE_VALUED_FUNC_NAME",
- "sqlState" : "42000",
+ "errorClass" : "TABLE_OR_VIEW_NOT_FOUND",
+ "sqlState" : "42P01",
"messageParameters" : {
- "funcName" : "`s`.`IDENTIFIER`"
+ "relationName" : "`s`.`tab`"
},
"queryContext" : [ {
"objectType" : "",
@@ -1159,110 +1275,1041 @@ org.apache.spark.sql.catalyst.parser.ParseException
-- !query
SELECT * FROM IDENTIFIER('s').IDENTIFIER('tab')
-- !query analysis
-org.apache.spark.sql.catalyst.parser.ParseException
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
- "errorClass" : "PARSE_SYNTAX_ERROR",
- "sqlState" : "42601",
+ "errorClass" : "TABLE_OR_VIEW_NOT_FOUND",
+ "sqlState" : "42P01",
"messageParameters" : {
- "error" : "'.'",
- "hint" : ""
- }
+ "relationName" : "`s`.`tab`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 15,
+ "stopIndex" : 47,
+ "fragment" : "IDENTIFIER('s').IDENTIFIER('tab')"
+ } ]
}
-- !query
SELECT * FROM IDENTIFIER('s').tab
-- !query analysis
-org.apache.spark.sql.catalyst.parser.ParseException
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
- "errorClass" : "PARSE_SYNTAX_ERROR",
- "sqlState" : "42601",
+ "errorClass" : "TABLE_OR_VIEW_NOT_FOUND",
+ "sqlState" : "42P01",
"messageParameters" : {
- "error" : "'.'",
- "hint" : ""
- }
+ "relationName" : "`s`.`tab`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 15,
+ "stopIndex" : 33,
+ "fragment" : "IDENTIFIER('s').tab"
+ } ]
}
-- !query
SELECT row_number() OVER IDENTIFIER('win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1)
-- !query analysis
-org.apache.spark.sql.catalyst.parser.ParseException
-{
- "errorClass" : "PARSE_SYNTAX_ERROR",
- "sqlState" : "42601",
- "messageParameters" : {
- "error" : "''win''",
- "hint" : ""
- }
-}
+Project [row_number() OVER (ORDER BY c1 ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x]
++- Project [c1#x, row_number() OVER (ORDER BY c1 ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x, row_number() OVER (ORDER BY c1 ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x]
+ +- Window [row_number() windowspecdefinition(c1#x ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS row_number() OVER (ORDER BY c1 ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x], [c1#x ASC NULLS FIRST]
+ +- Project [c1#x]
+ +- SubqueryAlias T
+ +- LocalRelation [c1#x]
-- !query
SELECT row_number() OVER win FROM VALUES(1) AS T(c1) WINDOW IDENTIFIER('win') AS (ORDER BY c1)
-- !query analysis
-org.apache.spark.sql.catalyst.parser.ParseException
-{
- "errorClass" : "PARSE_SYNTAX_ERROR",
- "sqlState" : "42601",
- "messageParameters" : {
- "error" : "'WINDOW'",
- "hint" : ""
- }
-}
+Project [row_number() OVER (ORDER BY c1 ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x]
++- Project [c1#x, row_number() OVER (ORDER BY c1 ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x, row_number() OVER (ORDER BY c1 ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x]
+ +- Window [row_number() windowspecdefinition(c1#x ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS row_number() OVER (ORDER BY c1 ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)#x], [c1#x ASC NULLS FIRST]
+ +- Project [c1#x]
+ +- SubqueryAlias T
+ +- LocalRelation [c1#x]
+
+
+-- !query
+SELECT 1 AS IDENTIFIER('col1')
+-- !query analysis
+Project [1 AS col1#x]
++- OneRowRelation
+
+
+-- !query
+SELECT my_table.* FROM VALUES (1, 2) AS IDENTIFIER('my_table')(IDENTIFIER('c1'), IDENTIFIER('c2'))
+-- !query analysis
+Project [c1#x, c2#x]
++- SubqueryAlias my_table
+ +- LocalRelation [c1#x, c2#x]
-- !query
WITH identifier('v')(identifier('c1')) AS (VALUES(1)) (SELECT c1 FROM v)
-- !query analysis
-org.apache.spark.sql.catalyst.parser.ParseException
-{
- "errorClass" : "PARSE_SYNTAX_ERROR",
- "sqlState" : "42601",
- "messageParameters" : {
- "error" : "''v''",
- "hint" : ""
- }
-}
+WithCTE
+:- CTERelationDef xxxx, false
+: +- SubqueryAlias v
+: +- Project [col1#x AS c1#x]
+: +- LocalRelation [col1#x]
++- Project [c1#x]
+ +- SubqueryAlias v
+ +- CTERelationRef xxxx, true, [c1#x], false, false, 1
+
+
+-- !query
+CREATE OR REPLACE VIEW v(IDENTIFIER('c1')) AS VALUES(1)
+-- !query analysis
+CreateViewCommand `spark_catalog`.`default`.`v`, [(c1,None)], VALUES(1), false, true, PersistedView, COMPENSATION, true
+ +- LocalRelation [col1#x]
+
+
+-- !query
+SELECT c1 FROM v
+-- !query analysis
+Project [c1#x]
++- SubqueryAlias spark_catalog.default.v
+ +- View (`spark_catalog`.`default`.`v`, [c1#x])
+ +- Project [cast(col1#x as int) AS c1#x]
+ +- LocalRelation [col1#x]
+
+
+-- !query
+DROP VIEW IF EXISTS v
+-- !query analysis
+DropTableCommand `spark_catalog`.`default`.`v`, true, true, false
+
+
+-- !query
+CREATE TABLE tab(IDENTIFIER('c1') INT) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`default`.`tab`, false
-- !query
INSERT INTO tab(IDENTIFIER('c1')) VALUES(1)
-- !query analysis
-org.apache.spark.sql.catalyst.parser.ParseException
-{
- "errorClass" : "PARSE_SYNTAX_ERROR",
- "sqlState" : "42601",
- "messageParameters" : {
- "error" : "'('",
- "hint" : ": missing ')'"
- }
-}
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/tab, false, CSV, [path=file:[not included in comparison]/{warehouse_dir}/tab], Append, `spark_catalog`.`default`.`tab`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/tab), [c1]
++- Project [c1#x AS c1#x]
+ +- Project [col1#x AS c1#x]
+ +- LocalRelation [col1#x]
-- !query
-CREATE OR REPLACE VIEW v(IDENTIFIER('c1')) AS VALUES(1)
+SELECT c1 FROM tab
-- !query analysis
-org.apache.spark.sql.catalyst.parser.ParseException
+Project [c1#x]
++- SubqueryAlias spark_catalog.default.tab
+ +- Relation spark_catalog.default.tab[c1#x] csv
+
+
+-- !query
+ALTER TABLE IDENTIFIER('tab') RENAME COLUMN IDENTIFIER('c1') TO IDENTIFIER('col1')
+-- !query analysis
+org.apache.spark.sql.AnalysisException
{
- "errorClass" : "PARSE_SYNTAX_ERROR",
- "sqlState" : "42601",
+ "errorClass" : "UNSUPPORTED_FEATURE.TABLE_OPERATION",
+ "sqlState" : "0A000",
"messageParameters" : {
- "error" : "'('",
- "hint" : ""
+ "operation" : "RENAME COLUMN",
+ "tableName" : "`spark_catalog`.`default`.`tab`"
}
}
-- !query
-CREATE TABLE tab(IDENTIFIER('c1') INT) USING CSV
+SELECT col1 FROM tab
-- !query analysis
-org.apache.spark.sql.catalyst.parser.ParseException
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
{
- "errorClass" : "PARSE_SYNTAX_ERROR",
- "sqlState" : "42601",
+ "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION",
+ "sqlState" : "42703",
+ "messageParameters" : {
+ "objectName" : "`col1`",
+ "proposal" : "`c1`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 11,
+ "fragment" : "col1"
+ } ]
+}
+
+
+-- !query
+ALTER TABLE IDENTIFIER('tab') ADD COLUMN IDENTIFIER('c2') INT
+-- !query analysis
+AlterTableAddColumnsCommand `spark_catalog`.`default`.`tab`, [StructField(c2,IntegerType,true)]
+
+
+-- !query
+SELECT c2 FROM tab
+-- !query analysis
+Project [c2#x]
++- SubqueryAlias spark_catalog.default.tab
+ +- Relation spark_catalog.default.tab[c1#x,c2#x] csv
+
+
+-- !query
+ALTER TABLE IDENTIFIER('tab') DROP COLUMN IDENTIFIER('c2')
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "UNSUPPORTED_FEATURE.TABLE_OPERATION",
+ "sqlState" : "0A000",
+ "messageParameters" : {
+ "operation" : "DROP COLUMN",
+ "tableName" : "`spark_catalog`.`default`.`tab`"
+ }
+}
+
+
+-- !query
+ALTER TABLE IDENTIFIER('tab') RENAME TO IDENTIFIER('tab_renamed')
+-- !query analysis
+AlterTableRenameCommand `spark_catalog`.`default`.`tab`, `tab_renamed`, false
+
+
+-- !query
+SELECT * FROM tab_renamed
+-- !query analysis
+Project [c1#x, c2#x]
++- SubqueryAlias spark_catalog.default.tab_renamed
+ +- Relation spark_catalog.default.tab_renamed[c1#x,c2#x] csv
+
+
+-- !query
+DROP TABLE IF EXISTS tab_renamed
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.tab_renamed
+
+
+-- !query
+DROP TABLE IF EXISTS tab
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.tab
+
+
+-- !query
+CREATE TABLE test_col_with_dot(IDENTIFIER('`col.with.dot`') INT) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`default`.`test_col_with_dot`, false
+
+
+-- !query
+DROP TABLE IF EXISTS test_col_with_dot
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.test_col_with_dot
+
+
+-- !query
+SELECT * FROM VALUES (1, 2) AS IDENTIFIER('schema.table')(c1, c2)
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "identifier" : "`schema`.`table`",
+ "limit" : "1"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 15,
+ "stopIndex" : 65,
+ "fragment" : "VALUES (1, 2) AS IDENTIFIER('schema.table')(c1, c2)"
+ } ]
+}
+
+
+-- !query
+SELECT 1 AS IDENTIFIER('col1.col2')
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "identifier" : "`col1`.`col2`",
+ "limit" : "1"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 35,
+ "fragment" : "1 AS IDENTIFIER('col1.col2')"
+ } ]
+}
+
+
+-- !query
+CREATE SCHEMA identifier_clause_test_schema
+-- !query analysis
+CreateNamespace false
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [identifier_clause_test_schema]
+
+
+-- !query
+USE identifier_clause_test_schema
+-- !query analysis
+SetCatalogAndNamespace
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [identifier_clause_test_schema]
+
+
+-- !query
+CREATE TABLE test_show(c1 INT, c2 STRING) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_show`, false
+
+
+-- !query
+SHOW VIEWS IN IDENTIFIER('identifier_clause_test_schema')
+-- !query analysis
+ShowViewsCommand identifier_clause_test_schema, [namespace#x, viewName#x, isTemporary#x]
+
+
+-- !query
+SHOW PARTITIONS IDENTIFIER('test_show')
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "INVALID_PARTITION_OPERATION.PARTITION_SCHEMA_IS_EMPTY",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "name" : "`spark_catalog`.`identifier_clause_test_schema`.`test_show`"
+ }
+}
+
+
+-- !query
+SHOW CREATE TABLE IDENTIFIER('test_show')
+-- !query analysis
+ShowCreateTable false, [createtab_stmt#x]
++- ResolvedTable V2SessionCatalog(spark_catalog), identifier_clause_test_schema.test_show, V1Table(identifier_clause_test_schema.test_show), [c1#x, c2#x]
+
+
+-- !query
+DROP TABLE test_show
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.test_show
+
+
+-- !query
+CREATE TABLE test_desc(c1 INT) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_desc`, false
+
+
+-- !query
+DESCRIBE TABLE IDENTIFIER('test_desc')
+-- !query analysis
+DescribeTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_desc`, false, [col_name#x, data_type#x, comment#x]
+
+
+-- !query
+DESCRIBE FORMATTED IDENTIFIER('test_desc')
+-- !query analysis
+DescribeTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_desc`, true, [col_name#x, data_type#x, comment#x]
+
+
+-- !query
+DESCRIBE EXTENDED IDENTIFIER('test_desc')
+-- !query analysis
+DescribeTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_desc`, true, [col_name#x, data_type#x, comment#x]
+
+
+-- !query
+DESC IDENTIFIER('test_desc')
+-- !query analysis
+DescribeTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_desc`, false, [col_name#x, data_type#x, comment#x]
+
+
+-- !query
+DROP TABLE test_desc
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.test_desc
+
+
+-- !query
+CREATE TABLE test_comment(c1 INT, c2 STRING) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_comment`, false
+
+
+-- !query
+COMMENT ON TABLE IDENTIFIER('test_comment') IS 'table comment'
+-- !query analysis
+CommentOnTable table comment
++- ResolvedTable V2SessionCatalog(spark_catalog), identifier_clause_test_schema.test_comment, V1Table(identifier_clause_test_schema.test_comment), [c1#x, c2#x]
+
+
+-- !query
+ALTER TABLE test_comment ALTER COLUMN IDENTIFIER('c1') COMMENT 'column comment'
+-- !query analysis
+AlterTableChangeColumnCommand `spark_catalog`.`identifier_clause_test_schema`.`test_comment`, c1, StructField(c1,IntegerType,true)
+
+
+-- !query
+DROP TABLE test_comment
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.test_comment
+
+
+-- !query
+CREATE TABLE identifier_clause_test_schema.test_table(c1 INT) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_table`, false
+
+
+-- !query
+ANALYZE TABLE IDENTIFIER('identifier_clause_test_schema.test_table') COMPUTE STATISTICS
+-- !query analysis
+AnalyzeTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_table`, false
+
+
+-- !query
+REFRESH TABLE IDENTIFIER('identifier_clause_test_schema.test_table')
+-- !query analysis
+RefreshTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_table`
+
+
+-- !query
+DESCRIBE IDENTIFIER('identifier_clause_test_schema.test_table')
+-- !query analysis
+DescribeTableCommand `spark_catalog`.`identifier_clause_test_schema`.`test_table`, false, [col_name#x, data_type#x, comment#x]
+
+
+-- !query
+SHOW COLUMNS FROM IDENTIFIER('identifier_clause_test_schema.test_table')
+-- !query analysis
+ShowColumnsCommand `spark_catalog`.`identifier_clause_test_schema`.`test_table`, [col_name#x]
+
+
+-- !query
+DROP TABLE IDENTIFIER('identifier_clause_test_schema.test_table')
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.test_table
+
+
+-- !query
+DECLARE IDENTIFIER('my_var') = 'value'
+-- !query analysis
+CreateVariable defaultvalueexpression(value, 'value'), false
++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.my_var
+
+
+-- !query
+SET VAR IDENTIFIER('my_var') = 'new_value'
+-- !query analysis
+SetVariable [variablereference(system.session.my_var='value')]
++- Project [new_value AS my_var#x]
+ +- OneRowRelation
+
+
+-- !query
+SELECT IDENTIFIER('my_var')
+-- !query analysis
+Project [variablereference(system.session.my_var='new_value') AS variablereference(system.session.my_var='new_value')#x]
++- OneRowRelation
+
+
+-- !query
+DROP TEMPORARY VARIABLE IDENTIFIER('my_var')
+-- !query analysis
+DropVariable false
++- ResolvedIdentifier org.apache.spark.sql.catalyst.analysis.FakeSystemCatalog$@xxxxxxxx, session.my_var
+
+
+-- !query
+CREATE TEMPORARY FUNCTION test_udf(IDENTIFIER('param1') INT, IDENTIFIER('param2') STRING)
+RETURNS INT
+RETURN IDENTIFIER('param1') + length(IDENTIFIER('param2'))
+-- !query analysis
+CreateSQLFunctionCommand test_udf, IDENTIFIER('param1') INT, IDENTIFIER('param2') STRING, INT, IDENTIFIER('param1') + length(IDENTIFIER('param2')), false, true, false, false
+
+
+-- !query
+SELECT test_udf(5, 'hello')
+-- !query analysis
+Project [test_udf(param1#x, param2#x) AS test_udf(5, hello)#x]
++- Project [cast(5 as int) AS param1#x, cast(hello as string) AS param2#x]
+ +- OneRowRelation
+
+
+-- !query
+DROP TEMPORARY FUNCTION test_udf
+-- !query analysis
+DropFunctionCommand test_udf, false, true
+
+
+-- !query
+CREATE TEMPORARY FUNCTION test_table_udf(IDENTIFIER('input_val') INT)
+RETURNS TABLE(IDENTIFIER('col1') INT, IDENTIFIER('col2') STRING)
+RETURN SELECT IDENTIFIER('input_val'), 'result'
+-- !query analysis
+CreateSQLFunctionCommand test_table_udf, IDENTIFIER('input_val') INT, IDENTIFIER('col1') INT, IDENTIFIER('col2') STRING, SELECT IDENTIFIER('input_val'), 'result', true, true, false, false
+
+
+-- !query
+SELECT * FROM test_table_udf(42)
+-- !query analysis
+Project [col1#x, col2#x]
++- SQLFunctionNode test_table_udf
+ +- SubqueryAlias test_table_udf
+ +- Project [cast(input_val#x as int) AS col1#x, cast(result#x as string) AS col2#x]
+ +- Project [cast(42 as int) AS input_val#x, result AS result#x]
+ +- OneRowRelation
+
+
+-- !query
+DROP TEMPORARY FUNCTION test_table_udf
+-- !query analysis
+DropFunctionCommand test_table_udf, false, true
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:tab \'b\').c1 FROM VALUES(1) AS tab(c1)' USING 'ta' AS tab
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "INVALID_EXTRACT_BASE_FIELD_TYPE",
+ "sqlState" : "42000",
+ "messageParameters" : {
+ "base" : "\"variablereference(system.session.tab='T')\"",
+ "other" : "\"STRING\""
+ }
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:col1 ''.c2'') FROM VALUES(named_struct(''c2'', 42)) AS T(c1)'
+ USING 'c1' AS col1
+-- !query analysis
+Project [c1#x.c2 AS c1.c2#x]
++- SubqueryAlias T
+ +- LocalRelation [c1#x]
+
+
+-- !query
+CREATE TABLE integration_test(c1 INT, c2 STRING) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`integration_test`, false
+
+
+-- !query
+INSERT INTO integration_test VALUES (1, 'a'), (2, 'b')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/integration_test, false, CSV, [path=file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/integration_test], Append, `spark_catalog`.`identifier_clause_test_schema`.`integration_test`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/integration_test), [c1, c2]
++- Project [col1#x AS c1#x, col2#x AS c2#x]
+ +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT * FROM IDENTIFIER(:schema ''.'' :table) ORDER BY ALL'
+ USING 'identifier_clause_test_schema' AS schema, 'integration_test' AS table
+-- !query analysis
+Sort [c1#x ASC NULLS FIRST, c2#x ASC NULLS FIRST], true
++- Project [c1#x, c2#x]
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.integration_test
+ +- Relation spark_catalog.identifier_clause_test_schema.integration_test[c1#x,c2#x] csv
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:prefix ''1''), IDENTIFIER(:prefix ''2'') FROM integration_test ORDER BY ALL'
+ USING 'c' AS prefix
+-- !query analysis
+Sort [c1#x ASC NULLS FIRST, c2#x ASC NULLS FIRST], true
++- Project [c1#x, c2#x]
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.integration_test
+ +- Relation spark_catalog.identifier_clause_test_schema.integration_test[c1#x,c2#x] csv
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT * FROM integration_test WHERE IDENTIFIER(:col) = :val'
+ USING 'c1' AS col, 1 AS val
+-- !query analysis
+Project [c1#x, c2#x]
++- Filter (c1#x = 1)
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.integration_test
+ +- Relation spark_catalog.identifier_clause_test_schema.integration_test[c1#x,c2#x] csv
+
+
+-- !query
+CREATE TABLE integration_test2(c1 INT, c3 STRING) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`integration_test2`, false
+
+
+-- !query
+INSERT INTO integration_test2 VALUES (1, 'x'), (2, 'y')
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/integration_test2, false, CSV, [path=file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/integration_test2], Append, `spark_catalog`.`identifier_clause_test_schema`.`integration_test2`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/integration_test2), [c1, c3]
++- Project [col1#x AS c1#x, col2#x AS c3#x]
+ +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT t1.*, t2.* FROM IDENTIFIER(:t1) t1 JOIN IDENTIFIER(:t2) t2 USING (IDENTIFIER(:col)) ORDER BY ALL'
+ USING 'integration_test' AS t1, 'integration_test2' AS t2, 'c1' AS col
+-- !query analysis
+Sort [c1#x ASC NULLS FIRST, c2#x ASC NULLS FIRST, c1#x ASC NULLS FIRST, c3#x ASC NULLS FIRST], true
++- Project [c1#x, c2#x, c1#x, c3#x]
+ +- Project [c1#x, c2#x, c3#x, c1#x]
+ +- Join Inner, (c1#x = c1#x)
+ :- SubqueryAlias t1
+ : +- SubqueryAlias spark_catalog.identifier_clause_test_schema.integration_test
+ : +- Relation spark_catalog.identifier_clause_test_schema.integration_test[c1#x,c2#x] csv
+ +- SubqueryAlias t2
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.integration_test2
+ +- Relation spark_catalog.identifier_clause_test_schema.integration_test2[c1#x,c3#x] csv
+
+
+-- !query
+EXECUTE IMMEDIATE
+ 'SELECT IDENTIFIER(:col1), IDENTIFIER(:col2), row_number() OVER (PARTITION BY IDENTIFIER(:part) ORDER BY IDENTIFIER(:ord)) as rn FROM integration_test'
+ USING 'c1' AS col1, 'c2' AS col2, 'c2' AS part, 'c1' AS ord
+-- !query analysis
+Project [c1#x, c2#x, rn#x]
++- Project [c1#x, c2#x, rn#x, rn#x]
+ +- Window [row_number() windowspecdefinition(c2#x, c1#x ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS rn#x], [c2#x], [c1#x ASC NULLS FIRST]
+ +- Project [c1#x, c2#x]
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.integration_test
+ +- Relation spark_catalog.identifier_clause_test_schema.integration_test[c1#x,c2#x] csv
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:prefix ''2''), IDENTIFIER(:agg)(IDENTIFIER(:col)) FROM integration_test GROUP BY IDENTIFIER(:prefix ''2'') ORDER BY ALL'
+ USING 'c' AS prefix, 'count' AS agg, 'c1' AS col
+-- !query analysis
+Sort [c2#x ASC NULLS FIRST, count(c1)#xL ASC NULLS FIRST], true
++- Aggregate [c2#x], [c2#x, count(c1#x) AS count(c1)#xL]
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.integration_test
+ +- Relation spark_catalog.identifier_clause_test_schema.integration_test[c1#x,c2#x] csv
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT * FROM integration_test ORDER BY IDENTIFIER(:col1) DESC, IDENTIFIER(:col2)'
+ USING 'c1' AS col1, 'c2' AS col2
+-- !query analysis
+Sort [c1#x DESC NULLS LAST, c2#x ASC NULLS FIRST], true
++- Project [c1#x, c2#x]
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.integration_test
+ +- Relation spark_catalog.identifier_clause_test_schema.integration_test[c1#x,c2#x] csv
+
+
+-- !query
+EXECUTE IMMEDIATE 'INSERT INTO integration_test(IDENTIFIER(:col1), IDENTIFIER(:col2)) VALUES (:val1, :val2)'
+ USING 'c1' AS col1, 'c2' AS col2, 3 AS val1, 'c' AS val2
+-- !query analysis
+CommandResult Execute InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/integration_test, false, CSV, [path=file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/integration_test], Append, `spark_catalog`.`identifier_clause_test_schema`.`integration_test`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/integration_test), [c1, c2]
+ +- InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/integration_test, false, CSV, [path=file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/integration_test], Append, `spark_catalog`.`identifier_clause_test_schema`.`integration_test`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/integration_test), [c1, c2]
+ +- Project [c1#x AS c1#x, c2#x AS c2#x]
+ +- Project [col1#x AS c1#x, col2#x AS c2#x]
+ +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(concat(:schema, ''.'', :table, ''.c1'')) FROM VALUES(named_struct(''c1'', 100)) AS IDENTIFIER(:alias)(IDENTIFIER(:schema ''.'' :table))'
+ USING 'identifier_clause_test_schema' AS schema, 'my_table' AS table, 't' AS alias
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "identifier" : "`identifier_clause_test_schema`.`my_table`",
+ "limit" : "1"
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 115,
+ "stopIndex" : 144,
+ "fragment" : "IDENTIFIER(:schema '.' :table)"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'WITH IDENTIFIER(:cte_name)(c1) AS (VALUES(1)) SELECT c1 FROM IDENTIFIER(:cte_name)'
+ USING 'my_cte' AS cte_name
+-- !query analysis
+WithCTE
+:- CTERelationDef xxxx, false
+: +- SubqueryAlias my_cte
+: +- Project [col1#x AS c1#x]
+: +- LocalRelation [col1#x]
++- Project [c1#x]
+ +- SubqueryAlias my_cte
+ +- CTERelationRef xxxx, true, [c1#x], false, false, 1
+
+
+-- !query
+EXECUTE IMMEDIATE 'CREATE OR REPLACE TEMPORARY VIEW IDENTIFIER(:view_name)(IDENTIFIER(:col_name)) AS VALUES(1)'
+ USING 'test_view' AS view_name, 'test_col' AS col_name
+-- !query analysis
+CommandResult Execute CreateViewCommand
+ +- CreateViewCommand `test_view`, [(test_col,None)], VALUES(1), false, true, LocalTempView, UNSUPPORTED, true
+ +- LocalRelation [col1#x]
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:col) FROM IDENTIFIER(:view)'
+ USING 'test_col' AS col, 'test_view' AS view
+-- !query analysis
+Project [test_col#x]
++- SubqueryAlias test_view
+ +- View (`test_view`, [test_col#x])
+ +- Project [cast(col1#x as int) AS test_col#x]
+ +- LocalRelation [col1#x]
+
+
+-- !query
+DROP VIEW test_view
+-- !query analysis
+DropTempViewCommand test_view
+
+
+-- !query
+EXECUTE IMMEDIATE 'ALTER TABLE IDENTIFIER(:tab) ADD COLUMN IDENTIFIER(:new_col) INT'
+ USING 'integration_test' AS tab, 'c4' AS new_col
+-- !query analysis
+CommandResult Execute AlterTableAddColumnsCommand
+ +- AlterTableAddColumnsCommand `spark_catalog`.`identifier_clause_test_schema`.`integration_test`, [StructField(c4,IntegerType,true)]
+
+
+-- !query
+EXECUTE IMMEDIATE 'ALTER TABLE IDENTIFIER(:tab) RENAME COLUMN IDENTIFIER(:old_col) TO IDENTIFIER(:new_col)'
+ USING 'integration_test' AS tab, 'c4' AS old_col, 'c5' AS new_col
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "UNSUPPORTED_FEATURE.TABLE_OPERATION",
+ "sqlState" : "0A000",
+ "messageParameters" : {
+ "operation" : "RENAME COLUMN",
+ "tableName" : "`spark_catalog`.`identifier_clause_test_schema`.`integration_test`"
+ }
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT map(:key, :val).IDENTIFIER(:key) AS result'
+ USING 'mykey' AS key, 42 AS val
+-- !query analysis
+Project [map(mykey, 42)[mykey] AS result#x]
++- OneRowRelation
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:alias ''.c1'') FROM integration_test AS IDENTIFIER(:alias) ORDER BY ALL'
+ USING 't' AS alias
+-- !query analysis
+Sort [c1#x ASC NULLS FIRST], true
++- Project [c1#x]
+ +- SubqueryAlias t
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.integration_test
+ +- Relation spark_catalog.identifier_clause_test_schema.integration_test[c1#x,c2#x,c4#x] csv
+
+
+-- !query
+EXECUTE IMMEDIATE
+ 'SELECT IDENTIFIER(:col1), IDENTIFIER(:p ''2'') FROM IDENTIFIER(:schema ''.'' :tab) WHERE IDENTIFIER(:col1) > 0 ORDER BY IDENTIFIER(:p ''1'')'
+ USING 'c1' AS col1, 'c' AS p, 'identifier_clause_test_schema' AS schema, 'integration_test' AS tab
+-- !query analysis
+Sort [c1#x ASC NULLS FIRST], true
++- Project [c1#x, c2#x]
+ +- Filter (c1#x > 0)
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.integration_test
+ +- Relation spark_catalog.identifier_clause_test_schema.integration_test[c1#x,c2#x,c4#x] csv
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT * FROM IDENTIFIER(:schema ''.'' :table) WHERE IDENTIFIER(concat(:tab_alias, ''.c1'')) > 0 ORDER BY ALL'
+ USING 'identifier_clause_test_schema' AS schema, 'integration_test' AS table, 'integration_test' AS tab_alias
+-- !query analysis
+Sort [c1#x ASC NULLS FIRST, c2#x ASC NULLS FIRST, c4#x ASC NULLS FIRST], true
++- Project [c1#x, c2#x, c4#x]
+ +- Filter (c1#x > 0)
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.integration_test
+ +- Relation spark_catalog.identifier_clause_test_schema.integration_test[c1#x,c2#x,c4#x] csv
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT 1 AS IDENTIFIER(:schema ''.'' :col)'
+ USING 'identifier_clause_test_schema' AS schema, 'col1' AS col
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "identifier" : "`identifier_clause_test_schema`.`col1`",
+ "limit" : "1"
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 40,
+ "fragment" : "1 AS IDENTIFIER(:schema '.' :col)"
+ } ]
+}
+
+
+-- !query
+DROP TABLE integration_test
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.integration_test
+
+
+-- !query
+DROP TABLE integration_test2
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.integration_test2
+
+
+-- !query
+CREATE TABLE lateral_test(arr ARRAY) USING PARQUET
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`lateral_test`, false
+
+
+-- !query
+INSERT INTO lateral_test VALUES (array(1, 2, 3))
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/lateral_test, false, Parquet, [path=file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/lateral_test], Append, `spark_catalog`.`identifier_clause_test_schema`.`lateral_test`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/lateral_test), [arr]
++- Project [col1#x AS arr#x]
+ +- LocalRelation [col1#x]
+
+
+-- !query
+SELECT * FROM lateral_test LATERAL VIEW explode(arr) IDENTIFIER('tbl') AS IDENTIFIER('col') ORDER BY ALL
+-- !query analysis
+Sort [arr#x ASC NULLS FIRST, col#x ASC NULLS FIRST], true
++- Project [arr#x, col#x]
+ +- Generate explode(arr#x), false, tbl, [col#x]
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.lateral_test
+ +- Relation spark_catalog.identifier_clause_test_schema.lateral_test[arr#x] parquet
+
+
+-- !query
+SELECT * FROM lateral_test LATERAL VIEW OUTER explode(arr) IDENTIFIER('my_table') AS IDENTIFIER('my_col') ORDER BY ALL
+-- !query analysis
+Sort [arr#x ASC NULLS FIRST, my_col#x ASC NULLS FIRST], true
++- Project [arr#x, my_col#x]
+ +- Generate explode(arr#x), true, my_table, [my_col#x]
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.lateral_test
+ +- Relation spark_catalog.identifier_clause_test_schema.lateral_test[arr#x] parquet
+
+
+-- !query
+DROP TABLE lateral_test
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.lateral_test
+
+
+-- !query
+CREATE TABLE unpivot_test(id INT, a INT, b INT, c INT) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`unpivot_test`, false
+
+
+-- !query
+INSERT INTO unpivot_test VALUES (1, 10, 20, 30)
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/unpivot_test, false, CSV, [path=file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/unpivot_test], Append, `spark_catalog`.`identifier_clause_test_schema`.`unpivot_test`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/unpivot_test), [id, a, b, c]
++- Project [col1#x AS id#x, col2#x AS a#x, col3#x AS b#x, col4#x AS c#x]
+ +- LocalRelation [col1#x, col2#x, col3#x, col4#x]
+
+
+-- !query
+SELECT * FROM unpivot_test UNPIVOT (val FOR col IN (a AS IDENTIFIER('col_a'), b AS IDENTIFIER('col_b'))) ORDER BY ALL
+-- !query analysis
+Sort [id#x ASC NULLS FIRST, c#x ASC NULLS FIRST, col#x ASC NULLS FIRST, val#x ASC NULLS FIRST], true
++- Project [id#x, c#x, col#x, val#x]
+ +- Filter isnotnull(coalesce(val#x))
+ +- Expand [[id#x, c#x, col_a, a#x], [id#x, c#x, col_b, b#x]], [id#x, c#x, col#x, val#x]
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.unpivot_test
+ +- Relation spark_catalog.identifier_clause_test_schema.unpivot_test[id#x,a#x,b#x,c#x] csv
+
+
+-- !query
+SELECT * FROM unpivot_test UNPIVOT ((v1, v2) FOR col IN ((a, b) AS IDENTIFIER('cols_ab'), (b, c) AS IDENTIFIER('cols_bc'))) ORDER BY ALL
+-- !query analysis
+Sort [id#x ASC NULLS FIRST, col#x ASC NULLS FIRST, v1#x ASC NULLS FIRST, v2#x ASC NULLS FIRST], true
++- Project [id#x, col#x, v1#x, v2#x]
+ +- Filter isnotnull(coalesce(v1#x, v2#x))
+ +- Expand [[id#x, cols_ab, a#x, b#x], [id#x, cols_bc, b#x, c#x]], [id#x, col#x, v1#x, v2#x]
+ +- SubqueryAlias spark_catalog.identifier_clause_test_schema.unpivot_test
+ +- Relation spark_catalog.identifier_clause_test_schema.unpivot_test[id#x,a#x,b#x,c#x] csv
+
+
+-- !query
+DROP TABLE unpivot_test
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.unpivot_test
+
+
+-- !query
+CREATE TABLE describe_col_test(c1 INT, c2 STRING, c3 DOUBLE) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`describe_col_test`, false
+
+
+-- !query
+DESCRIBE describe_col_test IDENTIFIER('c1')
+-- !query analysis
+DescribeColumnCommand `spark_catalog`.`identifier_clause_test_schema`.`describe_col_test`, [spark_catalog, identifier_clause_test_schema, describe_col_test, c1], false, [info_name#x, info_value#x]
+
+
+-- !query
+DESCRIBE describe_col_test IDENTIFIER('c2')
+-- !query analysis
+DescribeColumnCommand `spark_catalog`.`identifier_clause_test_schema`.`describe_col_test`, [spark_catalog, identifier_clause_test_schema, describe_col_test, c2], false, [info_name#x, info_value#x]
+
+
+-- !query
+DROP TABLE describe_col_test
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.describe_col_test
+
+
+-- !query
+SELECT :IDENTIFIER('param1') FROM VALUES(1) AS T(c1)
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "''param1''",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+CREATE TABLE hint_test(c1 INT, c2 INT) USING CSV
+-- !query analysis
+CreateDataSourceTableCommand `spark_catalog`.`identifier_clause_test_schema`.`hint_test`, false
+
+
+-- !query
+INSERT INTO hint_test VALUES (1, 2), (3, 4)
+-- !query analysis
+InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/hint_test, false, CSV, [path=file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/hint_test], Append, `spark_catalog`.`identifier_clause_test_schema`.`hint_test`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/hint_test), [c1, c2]
++- Project [col1#x AS c1#x, col2#x AS c2#x]
+ +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+SELECT /*+ IDENTIFIER('BROADCAST')(hint_test) */ * FROM hint_test
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
"messageParameters" : {
"error" : "'('",
"hint" : ""
}
}
+
+
+-- !query
+SELECT /*+ IDENTIFIER('MERGE')(hint_test) */ * FROM hint_test
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+DROP TABLE hint_test
+-- !query analysis
+DropTable false, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), identifier_clause_test_schema.hint_test
+
+
+-- !query
+SHOW IDENTIFIER('USER') FUNCTIONS
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT EXTRACT(IDENTIFIER('YEAR') FROM DATE'2024-01-15')
+-- !query analysis
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'FROM'",
+ "hint" : ": missing ')'"
+ }
+}
+
+
+-- !query
+SELECT TIMESTAMPADD(IDENTIFIER('YEAR'), 1, DATE'2024-01-15')
+-- !query analysis
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "UNRESOLVED_ROUTINE",
+ "sqlState" : "42883",
+ "messageParameters" : {
+ "routineName" : "`TIMESTAMPADD`",
+ "searchPath" : "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`identifier_clause_test_schema`]"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 60,
+ "fragment" : "TIMESTAMPADD(IDENTIFIER('YEAR'), 1, DATE'2024-01-15')"
+ } ]
+}
+
+
+-- !query
+DROP SCHEMA identifier_clause_test_schema
+-- !query analysis
+DropNamespace false, false
++- ResolvedNamespace V2SessionCatalog(spark_catalog), [identifier_clause_test_schema]
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out
index c023e3b56f117..88121e4dc3c21 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/interval.sql.out
@@ -61,15 +61,21 @@ Project [(INTERVAL '2147483647' MONTH / 0.5) AS (INTERVAL '2147483647' MONTH / 0
-- !query
select interval 2147483647 day * 2
-- !query analysis
-java.lang.ArithmeticException
-long overflow
+org.apache.spark.SparkArithmeticException
+{
+ "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION",
+ "sqlState" : "22015"
+}
-- !query
select interval 2147483647 day / 0.5
-- !query analysis
-java.lang.ArithmeticException
-long overflow
+org.apache.spark.SparkArithmeticException
+{
+ "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION",
+ "sqlState" : "22015"
+}
-- !query
@@ -3212,3 +3218,13 @@ SELECT width_bucket(INTERVAL '-59' MINUTE, INTERVAL -'1 01' DAY TO HOUR, INTERVA
-- !query analysis
Project [width_bucket(INTERVAL '-59' MINUTE, INTERVAL '-1 01' DAY TO HOUR, INTERVAL '1 02:03:04.001' DAY TO SECOND, cast(10 as bigint)) AS width_bucket(INTERVAL '-59' MINUTE, INTERVAL '-1 01' DAY TO HOUR, INTERVAL '1 02:03:04.001' DAY TO SECOND, 10)#xL]
+- OneRowRelation
+
+
+-- !query
+SELECT interval 106751991 day 4 hour 0 minute 54.776 second
+-- !query analysis
+org.apache.spark.SparkArithmeticException
+{
+ "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION",
+ "sqlState" : "22015"
+}
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/kllquantiles.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/kllquantiles.sql.out
new file mode 100644
index 0000000000000..8a2b50131627a
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/kllquantiles.sql.out
@@ -0,0 +1,1926 @@
+-- Automatically generated by SQLQueryTestSuite
+-- !query
+DROP TABLE IF EXISTS t_int_1_5_through_7_11
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t_int_1_5_through_7_11
+
+
+-- !query
+CREATE TABLE t_int_1_5_through_7_11 AS
+VALUES
+ (1, 5), (2, 6), (3, 7), (4, 8), (5, 9), (6, 10), (7, 11) AS tab(col1, col2)
+-- !query analysis
+CreateDataSourceTableAsSelectCommand `spark_catalog`.`default`.`t_int_1_5_through_7_11`, ErrorIfExists, [col1, col2]
+ +- SubqueryAlias tab
+ +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+DROP TABLE IF EXISTS t_long_1_5_through_7_11
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t_long_1_5_through_7_11
+
+
+-- !query
+CREATE TABLE t_long_1_5_through_7_11 AS
+VALUES
+ (1L, 5L), (2L, 6L), (3L, 7L), (4L, 8L), (5L, 9L), (6L, 10L), (7L, 11L) AS tab(col1, col2)
+-- !query analysis
+CreateDataSourceTableAsSelectCommand `spark_catalog`.`default`.`t_long_1_5_through_7_11`, ErrorIfExists, [col1, col2]
+ +- SubqueryAlias tab
+ +- LocalRelation [col1#xL, col2#xL]
+
+
+-- !query
+DROP TABLE IF EXISTS t_short_1_5_through_7_11
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t_short_1_5_through_7_11
+
+
+-- !query
+CREATE TABLE t_short_1_5_through_7_11 AS
+VALUES
+ (CAST(1 AS SMALLINT), CAST(5 AS SMALLINT)),
+ (CAST(2 AS SMALLINT), CAST(6 AS SMALLINT)),
+ (CAST(3 AS SMALLINT), CAST(7 AS SMALLINT)),
+ (CAST(4 AS SMALLINT), CAST(8 AS SMALLINT)),
+ (CAST(5 AS SMALLINT), CAST(9 AS SMALLINT)),
+ (CAST(6 AS SMALLINT), CAST(10 AS SMALLINT)),
+ (CAST(7 AS SMALLINT), CAST(11 AS SMALLINT))
+ AS tab(col1, col2)
+-- !query analysis
+CreateDataSourceTableAsSelectCommand `spark_catalog`.`default`.`t_short_1_5_through_7_11`, ErrorIfExists, [col1, col2]
+ +- SubqueryAlias tab
+ +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+DROP TABLE IF EXISTS t_byte_1_5_through_7_11
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t_byte_1_5_through_7_11
+
+
+-- !query
+CREATE TABLE t_byte_1_5_through_7_11 AS
+VALUES
+ (CAST(1 AS TINYINT), CAST(5 AS TINYINT)),
+ (CAST(2 AS TINYINT), CAST(6 AS TINYINT)),
+ (CAST(3 AS TINYINT), CAST(7 AS TINYINT)),
+ (CAST(4 AS TINYINT), CAST(8 AS TINYINT)),
+ (CAST(5 AS TINYINT), CAST(9 AS TINYINT)),
+ (CAST(6 AS TINYINT), CAST(10 AS TINYINT)),
+ (CAST(7 AS TINYINT), CAST(11 AS TINYINT))
+ AS tab(col1, col2)
+-- !query analysis
+CreateDataSourceTableAsSelectCommand `spark_catalog`.`default`.`t_byte_1_5_through_7_11`, ErrorIfExists, [col1, col2]
+ +- SubqueryAlias tab
+ +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+DROP TABLE IF EXISTS t_float_1_5_through_7_11
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t_float_1_5_through_7_11
+
+
+-- !query
+CREATE TABLE t_float_1_5_through_7_11 AS
+VALUES
+ (CAST(1 AS FLOAT), CAST(5 AS FLOAT)),
+ (CAST(2 AS FLOAT), CAST(6 AS FLOAT)),
+ (CAST(3 AS FLOAT), CAST(7 AS FLOAT)),
+ (CAST(4 AS FLOAT), CAST(8 AS FLOAT)),
+ (CAST(5 AS FLOAT), CAST(9 AS FLOAT)),
+ (CAST(6 AS FLOAT), CAST(10 AS FLOAT)),
+ (CAST(7 AS FLOAT), CAST(11 AS FLOAT)) AS tab(col1, col2)
+-- !query analysis
+CreateDataSourceTableAsSelectCommand `spark_catalog`.`default`.`t_float_1_5_through_7_11`, ErrorIfExists, [col1, col2]
+ +- SubqueryAlias tab
+ +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+DROP TABLE IF EXISTS t_double_1_5_through_7_11
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t_double_1_5_through_7_11
+
+
+-- !query
+CREATE TABLE t_double_1_5_through_7_11 AS
+VALUES
+ (CAST(1 AS DOUBLE), CAST(5 AS DOUBLE)),
+ (CAST(2 AS DOUBLE), CAST(6 AS DOUBLE)),
+ (CAST(3 AS DOUBLE), CAST(7 AS DOUBLE)),
+ (CAST(4 AS DOUBLE), CAST(8 AS DOUBLE)),
+ (CAST(5 AS DOUBLE), CAST(9 AS DOUBLE)),
+ (CAST(6 AS DOUBLE), CAST(10 AS DOUBLE)),
+ (CAST(7 AS DOUBLE), CAST(11 AS DOUBLE)) AS tab(col1, col2)
+-- !query analysis
+CreateDataSourceTableAsSelectCommand `spark_catalog`.`default`.`t_double_1_5_through_7_11`, ErrorIfExists, [col1, col2]
+ +- SubqueryAlias tab
+ +- LocalRelation [col1#x, col2#x]
+
+
+-- !query
+SELECT lower(kll_sketch_to_string_bigint(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_bigint(agg, 0.5) - 4) < 1 AS median_close_to_4,
+ abs(kll_sketch_get_rank_bigint(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_byte_1_5_through_7_11
+)
+-- !query analysis
+Project [lower(kll_sketch_to_string_bigint(agg#x)) LIKE %kll% AS str_contains_kll#x, (abs((kll_sketch_get_quantile_bigint(agg#x, cast(0.5 as double)) - cast(4 as bigint))) < cast(1 as bigint)) AS median_close_to_4#x, (abs((kll_sketch_get_rank_bigint(agg#x, cast(3 as bigint)) - cast(0.4 as double))) < cast(0.1 as double)) AS rank3_close_to_0_4#x]
++- SubqueryAlias __auto_generated_subquery_name
+ +- Aggregate [kll_sketch_agg_bigint(col1#x, None, 0, 0) AS agg#x]
+ +- SubqueryAlias spark_catalog.default.t_byte_1_5_through_7_11
+ +- Relation spark_catalog.default.t_byte_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT lower(kll_sketch_to_string_bigint(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_bigint(agg, 0.5) - 4) < 1 AS median_close_to_4,
+ abs(kll_sketch_get_rank_bigint(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_int_1_5_through_7_11
+)
+-- !query analysis
+Project [lower(kll_sketch_to_string_bigint(agg#x)) LIKE %kll% AS str_contains_kll#x, (abs((kll_sketch_get_quantile_bigint(agg#x, cast(0.5 as double)) - cast(4 as bigint))) < cast(1 as bigint)) AS median_close_to_4#x, (abs((kll_sketch_get_rank_bigint(agg#x, cast(3 as bigint)) - cast(0.4 as double))) < cast(0.1 as double)) AS rank3_close_to_0_4#x]
++- SubqueryAlias __auto_generated_subquery_name
+ +- Aggregate [kll_sketch_agg_bigint(col1#x, None, 0, 0) AS agg#x]
+ +- SubqueryAlias spark_catalog.default.t_int_1_5_through_7_11
+ +- Relation spark_catalog.default.t_int_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT lower(kll_sketch_to_string_bigint(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_bigint(agg, 0.5) - 4) < 1 AS median_close_to_4,
+ abs(kll_sketch_get_rank_bigint(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_long_1_5_through_7_11
+)
+-- !query analysis
+Project [lower(kll_sketch_to_string_bigint(agg#x)) LIKE %kll% AS str_contains_kll#x, (abs((kll_sketch_get_quantile_bigint(agg#x, cast(0.5 as double)) - cast(4 as bigint))) < cast(1 as bigint)) AS median_close_to_4#x, (abs((kll_sketch_get_rank_bigint(agg#x, cast(3 as bigint)) - cast(0.4 as double))) < cast(0.1 as double)) AS rank3_close_to_0_4#x]
++- SubqueryAlias __auto_generated_subquery_name
+ +- Aggregate [kll_sketch_agg_bigint(col1#xL, None, 0, 0) AS agg#x]
+ +- SubqueryAlias spark_catalog.default.t_long_1_5_through_7_11
+ +- Relation spark_catalog.default.t_long_1_5_through_7_11[col1#xL,col2#xL] parquet
+
+
+-- !query
+SELECT lower(kll_sketch_to_string_bigint(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_bigint(agg, 0.5) - 4) < 1 AS median_close_to_4,
+ abs(kll_sketch_get_rank_bigint(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_short_1_5_through_7_11
+)
+-- !query analysis
+Project [lower(kll_sketch_to_string_bigint(agg#x)) LIKE %kll% AS str_contains_kll#x, (abs((kll_sketch_get_quantile_bigint(agg#x, cast(0.5 as double)) - cast(4 as bigint))) < cast(1 as bigint)) AS median_close_to_4#x, (abs((kll_sketch_get_rank_bigint(agg#x, cast(3 as bigint)) - cast(0.4 as double))) < cast(0.1 as double)) AS rank3_close_to_0_4#x]
++- SubqueryAlias __auto_generated_subquery_name
+ +- Aggregate [kll_sketch_agg_bigint(col1#x, None, 0, 0) AS agg#x]
+ +- SubqueryAlias spark_catalog.default.t_short_1_5_through_7_11
+ +- Relation spark_catalog.default.t_short_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT lower(kll_sketch_to_string_float(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_float(agg, 0.5) - 4.0) < 0.5 AS median_close_to_4,
+ abs(kll_sketch_get_rank_float(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS agg
+ FROM t_float_1_5_through_7_11
+)
+-- !query analysis
+Project [lower(kll_sketch_to_string_float(agg#x)) LIKE %kll% AS str_contains_kll#x, (abs((cast(kll_sketch_get_quantile_float(agg#x, cast(0.5 as double)) as double) - cast(4.0 as double))) < cast(0.5 as double)) AS median_close_to_4#x, (abs((kll_sketch_get_rank_float(agg#x, cast(3 as float)) - cast(0.4 as double))) < cast(0.1 as double)) AS rank3_close_to_0_4#x]
++- SubqueryAlias __auto_generated_subquery_name
+ +- Aggregate [kll_sketch_agg_float(col1#x, None, 0, 0) AS agg#x]
+ +- SubqueryAlias spark_catalog.default.t_float_1_5_through_7_11
+ +- Relation spark_catalog.default.t_float_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT lower(kll_sketch_to_string_double(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_double(agg, 0.5) - 4.0) < 0.5 AS median_close_to_4,
+ abs(kll_sketch_get_rank_double(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_double(col1) AS agg
+ FROM t_double_1_5_through_7_11
+)
+-- !query analysis
+Project [lower(kll_sketch_to_string_double(agg#x)) LIKE %kll% AS str_contains_kll#x, (abs((kll_sketch_get_quantile_double(agg#x, cast(0.5 as double)) - cast(4.0 as double))) < cast(0.5 as double)) AS median_close_to_4#x, (abs((kll_sketch_get_rank_double(agg#x, cast(3 as double)) - cast(0.4 as double))) < cast(0.1 as double)) AS rank3_close_to_0_4#x]
++- SubqueryAlias __auto_generated_subquery_name
+ +- Aggregate [kll_sketch_agg_double(col1#x, None, 0, 0) AS agg#x]
+ +- SubqueryAlias spark_catalog.default.t_double_1_5_through_7_11
+ +- Relation spark_catalog.default.t_double_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT lower(kll_sketch_to_string_double(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_double(agg, 0.5) - 4.0) < 0.5 AS median_close_to_4,
+ abs(kll_sketch_get_rank_double(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_double(col1) AS agg
+ FROM t_float_1_5_through_7_11
+)
+-- !query analysis
+Project [lower(kll_sketch_to_string_double(agg#x)) LIKE %kll% AS str_contains_kll#x, (abs((kll_sketch_get_quantile_double(agg#x, cast(0.5 as double)) - cast(4.0 as double))) < cast(0.5 as double)) AS median_close_to_4#x, (abs((kll_sketch_get_rank_double(agg#x, cast(3 as double)) - cast(0.4 as double))) < cast(0.1 as double)) AS rank3_close_to_0_4#x]
++- SubqueryAlias __auto_generated_subquery_name
+ +- Aggregate [kll_sketch_agg_double(col1#x, None, 0, 0) AS agg#x]
+ +- SubqueryAlias spark_catalog.default.t_float_1_5_through_7_11
+ +- Relation spark_catalog.default.t_float_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT
+ split(
+ kll_sketch_to_string_bigint(
+ kll_sketch_merge_bigint(
+ kll_sketch_agg_bigint(col1),
+ kll_sketch_agg_bigint(col1)
+ )
+ ),
+ '\n'
+ )[1] AS result
+ FROM t_byte_1_5_through_7_11
+-- !query analysis
+Aggregate [split(kll_sketch_to_string_bigint(kll_sketch_merge_bigint(kll_sketch_agg_bigint(col1#x, None, 0, 0), kll_sketch_agg_bigint(col1#x, None, 0, 0))),
+, -1)[1] AS result#x]
++- SubqueryAlias spark_catalog.default.t_byte_1_5_through_7_11
+ +- Relation spark_catalog.default.t_byte_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT
+ split(
+ kll_sketch_to_string_float(
+ kll_sketch_merge_float(
+ kll_sketch_agg_float(col1),
+ kll_sketch_agg_float(col1)
+ )
+ ),
+ '\n'
+ )[1] AS result
+FROM t_byte_1_5_through_7_11
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col1\"",
+ "inputType" : "\"TINYINT\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"FLOAT\"",
+ "sqlExpr" : "\"kll_sketch_agg_float(col1)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 87,
+ "stopIndex" : 112,
+ "fragment" : "kll_sketch_agg_float(col1)"
+ } ]
+}
+
+
+-- !query
+SELECT
+ split(
+ kll_sketch_to_string_double(
+ kll_sketch_merge_double(
+ kll_sketch_agg_double(col1),
+ kll_sketch_agg_double(col1)
+ )
+ ),
+ '\n'
+ )[1] AS result
+FROM t_byte_1_5_through_7_11
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col1\"",
+ "inputType" : "\"TINYINT\"",
+ "paramIndex" : "first",
+ "requiredType" : "(\"FLOAT\" or \"DOUBLE\")",
+ "sqlExpr" : "\"kll_sketch_agg_double(col1)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 89,
+ "stopIndex" : 115,
+ "fragment" : "kll_sketch_agg_double(col1)"
+ } ]
+}
+
+
+-- !query
+SELECT
+ parity,
+ kll_sketch_get_n_bigint(kll_merge_agg_bigint(sketch_col)) AS total_count
+FROM (
+ SELECT
+ col1 % 2 AS parity,
+ kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_int_1_5_through_7_11
+ GROUP BY col1 % 2
+) grouped_sketches
+GROUP BY parity
+HAVING kll_sketch_get_n_bigint(kll_merge_agg_bigint(sketch_col)) > 3
+-- !query analysis
+Filter (total_count#xL > cast(3 as bigint))
++- Aggregate [parity#x], [parity#x, kll_sketch_get_n_bigint(kll_merge_agg_bigint(sketch_col#x, None, 0, 0)) AS total_count#xL]
+ +- SubqueryAlias grouped_sketches
+ +- Aggregate [(col1#x % 2)], [(col1#x % 2) AS parity#x, kll_sketch_agg_bigint(col1#x, None, 0, 0) AS sketch_col#x]
+ +- SubqueryAlias spark_catalog.default.t_int_1_5_through_7_11
+ +- Relation spark_catalog.default.t_int_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT kll_sketch_get_n_bigint(kll_merge_agg_bigint(sketch_col)) AS empty_merge_n
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_int_1_5_through_7_11
+ WHERE col1 > 1000
+) empty_sketches
+-- !query analysis
+Aggregate [kll_sketch_get_n_bigint(kll_merge_agg_bigint(sketch_col#x, None, 0, 0)) AS empty_merge_n#xL]
++- SubqueryAlias empty_sketches
+ +- Aggregate [kll_sketch_agg_bigint(col1#x, None, 0, 0) AS sketch_col#x]
+ +- Filter (col1#x > 1000)
+ +- SubqueryAlias spark_catalog.default.t_int_1_5_through_7_11
+ +- Relation spark_catalog.default.t_int_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT kll_sketch_get_n_float(kll_merge_agg_float(sketch_col)) AS empty_merge_n
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS sketch_col
+ FROM t_float_1_5_through_7_11
+ WHERE col1 > 1000.0
+) empty_sketches
+-- !query analysis
+Aggregate [kll_sketch_get_n_float(kll_merge_agg_float(sketch_col#x, None, 0, 0)) AS empty_merge_n#xL]
++- SubqueryAlias empty_sketches
+ +- Aggregate [kll_sketch_agg_float(col1#x, None, 0, 0) AS sketch_col#x]
+ +- Filter (cast(col1#x as double) > cast(1000.0 as double))
+ +- SubqueryAlias spark_catalog.default.t_float_1_5_through_7_11
+ +- Relation spark_catalog.default.t_float_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT kll_sketch_get_n_double(kll_merge_agg_double(sketch_col)) AS empty_merge_n
+FROM (
+ SELECT kll_sketch_agg_double(col1) AS sketch_col
+ FROM t_double_1_5_through_7_11
+ WHERE col1 > 1000.0
+) empty_sketches
+-- !query analysis
+Aggregate [kll_sketch_get_n_double(kll_merge_agg_double(sketch_col#x, None, 0, 0)) AS empty_merge_n#xL]
++- SubqueryAlias empty_sketches
+ +- Aggregate [kll_sketch_agg_double(col1#x, None, 0, 0) AS sketch_col#x]
+ +- Filter (col1#x > cast(1000.0 as double))
+ +- SubqueryAlias spark_catalog.default.t_double_1_5_through_7_11
+ +- Relation spark_catalog.default.t_double_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT lower(kll_sketch_to_string_bigint(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_bigint(agg, 0.5) - 4) < 1 AS median_close_to_4,
+ abs(kll_sketch_get_rank_bigint(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_merge_agg_bigint(sketch_col) AS agg
+ FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_int_1_5_through_7_11
+ UNION ALL
+ SELECT kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_short_1_5_through_7_11
+ ) sketches
+)
+-- !query analysis
+Project [lower(kll_sketch_to_string_bigint(agg#x)) LIKE %kll% AS str_contains_kll#x, (abs((kll_sketch_get_quantile_bigint(agg#x, cast(0.5 as double)) - cast(4 as bigint))) < cast(1 as bigint)) AS median_close_to_4#x, (abs((kll_sketch_get_rank_bigint(agg#x, cast(3 as bigint)) - cast(0.4 as double))) < cast(0.1 as double)) AS rank3_close_to_0_4#x]
++- SubqueryAlias __auto_generated_subquery_name
+ +- Aggregate [kll_merge_agg_bigint(sketch_col#x, None, 0, 0) AS agg#x]
+ +- SubqueryAlias sketches
+ +- Union false, false
+ :- Aggregate [kll_sketch_agg_bigint(col1#x, None, 0, 0) AS sketch_col#x]
+ : +- SubqueryAlias spark_catalog.default.t_int_1_5_through_7_11
+ : +- Relation spark_catalog.default.t_int_1_5_through_7_11[col1#x,col2#x] parquet
+ +- Aggregate [kll_sketch_agg_bigint(col1#x, None, 0, 0) AS sketch_col#x]
+ +- SubqueryAlias spark_catalog.default.t_short_1_5_through_7_11
+ +- Relation spark_catalog.default.t_short_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT lower(kll_sketch_to_string_float(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_float(agg, 0.5) - 5.5) < 1.0 AS median_close_to_5_5,
+ abs(kll_sketch_get_rank_float(agg, 5.0) - 0.35) < 0.15 AS rank5_close_to_0_35
+FROM (
+ SELECT kll_merge_agg_float(sketch_col) AS agg
+ FROM (
+ SELECT kll_sketch_agg_float(col1) AS sketch_col
+ FROM t_float_1_5_through_7_11
+ UNION ALL
+ SELECT kll_sketch_agg_float(col2) AS sketch_col
+ FROM t_float_1_5_through_7_11
+ ) sketches
+)
+-- !query analysis
+Project [lower(kll_sketch_to_string_float(agg#x)) LIKE %kll% AS str_contains_kll#x, (abs((cast(kll_sketch_get_quantile_float(agg#x, cast(0.5 as double)) as double) - cast(5.5 as double))) < cast(1.0 as double)) AS median_close_to_5_5#x, (abs((kll_sketch_get_rank_float(agg#x, cast(5.0 as float)) - cast(0.35 as double))) < cast(0.15 as double)) AS rank5_close_to_0_35#x]
++- SubqueryAlias __auto_generated_subquery_name
+ +- Aggregate [kll_merge_agg_float(sketch_col#x, None, 0, 0) AS agg#x]
+ +- SubqueryAlias sketches
+ +- Union false, false
+ :- Aggregate [kll_sketch_agg_float(col1#x, None, 0, 0) AS sketch_col#x]
+ : +- SubqueryAlias spark_catalog.default.t_float_1_5_through_7_11
+ : +- Relation spark_catalog.default.t_float_1_5_through_7_11[col1#x,col2#x] parquet
+ +- Aggregate [kll_sketch_agg_float(col2#x, None, 0, 0) AS sketch_col#x]
+ +- SubqueryAlias spark_catalog.default.t_float_1_5_through_7_11
+ +- Relation spark_catalog.default.t_float_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT lower(kll_sketch_to_string_double(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_double(agg, 0.5) - 6.0) < 1.0 AS median_close_to_6,
+ abs(kll_sketch_get_rank_double(agg, 5.0) - 0.35) < 0.15 AS rank5_close_to_0_35
+FROM (
+ SELECT kll_merge_agg_double(sketch_col) AS agg
+ FROM (
+ SELECT kll_sketch_agg_double(col1) AS sketch_col
+ FROM t_double_1_5_through_7_11
+ UNION ALL
+ SELECT kll_sketch_agg_double(col2) AS sketch_col
+ FROM t_float_1_5_through_7_11
+ ) sketches
+)
+-- !query analysis
+Project [lower(kll_sketch_to_string_double(agg#x)) LIKE %kll% AS str_contains_kll#x, (abs((kll_sketch_get_quantile_double(agg#x, cast(0.5 as double)) - cast(6.0 as double))) < cast(1.0 as double)) AS median_close_to_6#x, (abs((kll_sketch_get_rank_double(agg#x, cast(5.0 as double)) - cast(0.35 as double))) < cast(0.15 as double)) AS rank5_close_to_0_35#x]
++- SubqueryAlias __auto_generated_subquery_name
+ +- Aggregate [kll_merge_agg_double(sketch_col#x, None, 0, 0) AS agg#x]
+ +- SubqueryAlias sketches
+ +- Union false, false
+ :- Aggregate [kll_sketch_agg_double(col1#x, None, 0, 0) AS sketch_col#x]
+ : +- SubqueryAlias spark_catalog.default.t_double_1_5_through_7_11
+ : +- Relation spark_catalog.default.t_double_1_5_through_7_11[col1#x,col2#x] parquet
+ +- Aggregate [kll_sketch_agg_double(col2#x, None, 0, 0) AS sketch_col#x]
+ +- SubqueryAlias spark_catalog.default.t_float_1_5_through_7_11
+ +- Relation spark_catalog.default.t_float_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT LENGTH(kll_sketch_to_string_bigint(kll_merge_agg_bigint(sketch_col, 400))) > 0 AS merged_with_k
+FROM (
+ SELECT kll_sketch_agg_bigint(col1, 400) AS sketch_col
+ FROM t_long_1_5_through_7_11
+ UNION ALL
+ SELECT kll_sketch_agg_bigint(col2, 400) AS sketch_col
+ FROM t_byte_1_5_through_7_11
+) sketches
+-- !query analysis
+Aggregate [(length(kll_sketch_to_string_bigint(kll_merge_agg_bigint(sketch_col#x, Some(400), 0, 0))) > 0) AS merged_with_k#x]
++- SubqueryAlias sketches
+ +- Union false, false
+ :- Aggregate [kll_sketch_agg_bigint(col1#xL, Some(400), 0, 0) AS sketch_col#x]
+ : +- SubqueryAlias spark_catalog.default.t_long_1_5_through_7_11
+ : +- Relation spark_catalog.default.t_long_1_5_through_7_11[col1#xL,col2#xL] parquet
+ +- Aggregate [kll_sketch_agg_bigint(col2#x, Some(400), 0, 0) AS sketch_col#x]
+ +- SubqueryAlias spark_catalog.default.t_byte_1_5_through_7_11
+ +- Relation spark_catalog.default.t_byte_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT LENGTH(kll_sketch_to_string_float(kll_merge_agg_float(sketch_col, 300))) > 0 AS merged_with_k
+FROM (
+ SELECT kll_sketch_agg_float(col1, 300) AS sketch_col
+ FROM t_float_1_5_through_7_11
+) sketches
+-- !query analysis
+Aggregate [(length(kll_sketch_to_string_float(kll_merge_agg_float(sketch_col#x, Some(300), 0, 0))) > 0) AS merged_with_k#x]
++- SubqueryAlias sketches
+ +- Aggregate [kll_sketch_agg_float(col1#x, Some(300), 0, 0) AS sketch_col#x]
+ +- SubqueryAlias spark_catalog.default.t_float_1_5_through_7_11
+ +- Relation spark_catalog.default.t_float_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT LENGTH(kll_sketch_to_string_double(kll_merge_agg_double(sketch_col, 500))) > 0 AS merged_with_k
+FROM (
+ SELECT kll_sketch_agg_double(col1, 500) AS sketch_col
+ FROM t_double_1_5_through_7_11
+) sketches
+-- !query analysis
+Aggregate [(length(kll_sketch_to_string_double(kll_merge_agg_double(sketch_col#x, Some(500), 0, 0))) > 0) AS merged_with_k#x]
++- SubqueryAlias sketches
+ +- Aggregate [kll_sketch_agg_double(col1#x, Some(500), 0, 0) AS sketch_col#x]
+ +- SubqueryAlias spark_catalog.default.t_double_1_5_through_7_11
+ +- Relation spark_catalog.default.t_double_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT abs(kll_sketch_get_quantile_bigint(agg_with_nulls, 0.5) -
+ kll_sketch_get_quantile_bigint(agg_without_nulls, 0.5)) < 1 AS medians_match
+FROM (
+ SELECT kll_merge_agg_bigint(sketch_col) AS agg_with_nulls
+ FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_long_1_5_through_7_11
+ UNION ALL
+ SELECT CAST(NULL AS BINARY) AS sketch_col
+ UNION ALL
+ SELECT kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_byte_1_5_through_7_11
+ ) sketches_with_nulls
+) WITH_NULLS,
+(
+ SELECT kll_merge_agg_bigint(sketch_col) AS agg_without_nulls
+ FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_long_1_5_through_7_11
+ UNION ALL
+ SELECT kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_byte_1_5_through_7_11
+ ) sketches_without_nulls
+) WITHOUT_NULLS
+-- !query analysis
+Project [(abs((kll_sketch_get_quantile_bigint(agg_with_nulls#x, cast(0.5 as double)) - kll_sketch_get_quantile_bigint(agg_without_nulls#x, cast(0.5 as double)))) < cast(1 as bigint)) AS medians_match#x]
++- Join Inner
+ :- SubqueryAlias WITH_NULLS
+ : +- Aggregate [kll_merge_agg_bigint(sketch_col#x, None, 0, 0) AS agg_with_nulls#x]
+ : +- SubqueryAlias sketches_with_nulls
+ : +- Union false, false
+ : :- Union false, false
+ : : :- Aggregate [kll_sketch_agg_bigint(col1#xL, None, 0, 0) AS sketch_col#x]
+ : : : +- SubqueryAlias spark_catalog.default.t_long_1_5_through_7_11
+ : : : +- Relation spark_catalog.default.t_long_1_5_through_7_11[col1#xL,col2#xL] parquet
+ : : +- Project [cast(null as binary) AS sketch_col#x]
+ : : +- OneRowRelation
+ : +- Aggregate [kll_sketch_agg_bigint(col1#x, None, 0, 0) AS sketch_col#x]
+ : +- SubqueryAlias spark_catalog.default.t_byte_1_5_through_7_11
+ : +- Relation spark_catalog.default.t_byte_1_5_through_7_11[col1#x,col2#x] parquet
+ +- SubqueryAlias WITHOUT_NULLS
+ +- Aggregate [kll_merge_agg_bigint(sketch_col#x, None, 0, 0) AS agg_without_nulls#x]
+ +- SubqueryAlias sketches_without_nulls
+ +- Union false, false
+ :- Aggregate [kll_sketch_agg_bigint(col1#xL, None, 0, 0) AS sketch_col#x]
+ : +- SubqueryAlias spark_catalog.default.t_long_1_5_through_7_11
+ : +- Relation spark_catalog.default.t_long_1_5_through_7_11[col1#xL,col2#xL] parquet
+ +- Aggregate [kll_sketch_agg_bigint(col1#x, None, 0, 0) AS sketch_col#x]
+ +- SubqueryAlias spark_catalog.default.t_byte_1_5_through_7_11
+ +- Relation spark_catalog.default.t_byte_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT abs(kll_sketch_get_quantile_bigint(agg_with_nulls, 0.5) -
+ kll_sketch_get_quantile_bigint(agg_without_nulls, 0.5)) < 1 AS medians_match,
+ abs(kll_sketch_get_rank_bigint(agg_with_nulls, 4) -
+ kll_sketch_get_rank_bigint(agg_without_nulls, 4)) < 0.1 AS ranks_match
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg_with_nulls
+ FROM (VALUES (1L), (CAST(NULL AS BIGINT)), (3L), (5L), (CAST(NULL AS BIGINT)), (7L)) AS tab(col1)
+) WITH_NULLS,
+(
+ SELECT kll_sketch_agg_bigint(col1) AS agg_without_nulls
+ FROM (VALUES (1L), (3L), (5L), (7L)) AS tab(col1)
+) WITHOUT_NULLS
+-- !query analysis
+Project [(abs((kll_sketch_get_quantile_bigint(agg_with_nulls#x, cast(0.5 as double)) - kll_sketch_get_quantile_bigint(agg_without_nulls#x, cast(0.5 as double)))) < cast(1 as bigint)) AS medians_match#x, (abs((kll_sketch_get_rank_bigint(agg_with_nulls#x, cast(4 as bigint)) - kll_sketch_get_rank_bigint(agg_without_nulls#x, cast(4 as bigint)))) < cast(0.1 as double)) AS ranks_match#x]
++- Join Inner
+ :- SubqueryAlias WITH_NULLS
+ : +- Aggregate [kll_sketch_agg_bigint(col1#xL, None, 0, 0) AS agg_with_nulls#x]
+ : +- SubqueryAlias tab
+ : +- Project [col1#xL AS col1#xL]
+ : +- LocalRelation [col1#xL]
+ +- SubqueryAlias WITHOUT_NULLS
+ +- Aggregate [kll_sketch_agg_bigint(col1#xL, None, 0, 0) AS agg_without_nulls#x]
+ +- SubqueryAlias tab
+ +- Project [col1#xL AS col1#xL]
+ +- LocalRelation [col1#xL]
+
+
+-- !query
+SELECT abs(kll_sketch_get_quantile_float(agg_with_nulls, 0.5) -
+ kll_sketch_get_quantile_float(agg_without_nulls, 0.5)) < 0.5 AS medians_match,
+ abs(kll_sketch_get_rank_float(agg_with_nulls, 4.0) -
+ kll_sketch_get_rank_float(agg_without_nulls, 4.0)) < 0.1 AS ranks_match
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS agg_with_nulls
+ FROM (VALUES (1.0F), (CAST(NULL AS FLOAT)), (3.0F), (5.0F), (CAST(NULL AS FLOAT)), (7.0F)) AS tab(col1)
+) WITH_NULLS,
+(
+ SELECT kll_sketch_agg_float(col1) AS agg_without_nulls
+ FROM (VALUES (1.0F), (3.0F), (5.0F), (7.0F)) AS tab(col1)
+) WITHOUT_NULLS
+-- !query analysis
+Project [(cast(abs((kll_sketch_get_quantile_float(agg_with_nulls#x, cast(0.5 as double)) - kll_sketch_get_quantile_float(agg_without_nulls#x, cast(0.5 as double)))) as double) < cast(0.5 as double)) AS medians_match#x, (abs((kll_sketch_get_rank_float(agg_with_nulls#x, cast(4.0 as float)) - kll_sketch_get_rank_float(agg_without_nulls#x, cast(4.0 as float)))) < cast(0.1 as double)) AS ranks_match#x]
++- Join Inner
+ :- SubqueryAlias WITH_NULLS
+ : +- Aggregate [kll_sketch_agg_float(col1#x, None, 0, 0) AS agg_with_nulls#x]
+ : +- SubqueryAlias tab
+ : +- Project [col1#x AS col1#x]
+ : +- LocalRelation [col1#x]
+ +- SubqueryAlias WITHOUT_NULLS
+ +- Aggregate [kll_sketch_agg_float(col1#x, None, 0, 0) AS agg_without_nulls#x]
+ +- SubqueryAlias tab
+ +- Project [col1#x AS col1#x]
+ +- LocalRelation [col1#x]
+
+
+-- !query
+SELECT abs(kll_sketch_get_quantile_double(agg_with_nulls, 0.5) -
+ kll_sketch_get_quantile_double(agg_without_nulls, 0.5)) < 0.5 AS medians_match,
+ abs(kll_sketch_get_rank_double(agg_with_nulls, 4.0) -
+ kll_sketch_get_rank_double(agg_without_nulls, 4.0)) < 0.1 AS ranks_match
+FROM (
+ SELECT kll_sketch_agg_double(col1) AS agg_with_nulls
+ FROM (VALUES (1.0D), (CAST(NULL AS DOUBLE)), (3.0D), (5.0D), (CAST(NULL AS DOUBLE)), (7.0D)) AS tab(col1)
+) WITH_NULLS,
+(
+ SELECT kll_sketch_agg_double(col1) AS agg_without_nulls
+ FROM (VALUES (1.0D), (3.0D), (5.0D), (7.0D)) AS tab(col1)
+) WITHOUT_NULLS
+-- !query analysis
+Project [(abs((kll_sketch_get_quantile_double(agg_with_nulls#x, cast(0.5 as double)) - kll_sketch_get_quantile_double(agg_without_nulls#x, cast(0.5 as double)))) < cast(0.5 as double)) AS medians_match#x, (abs((kll_sketch_get_rank_double(agg_with_nulls#x, cast(4.0 as double)) - kll_sketch_get_rank_double(agg_without_nulls#x, cast(4.0 as double)))) < cast(0.1 as double)) AS ranks_match#x]
++- Join Inner
+ :- SubqueryAlias WITH_NULLS
+ : +- Aggregate [kll_sketch_agg_double(col1#x, None, 0, 0) AS agg_with_nulls#x]
+ : +- SubqueryAlias tab
+ : +- Project [col1#x AS col1#x]
+ : +- LocalRelation [col1#x]
+ +- SubqueryAlias WITHOUT_NULLS
+ +- Aggregate [kll_sketch_agg_double(col1#x, None, 0, 0) AS agg_without_nulls#x]
+ +- SubqueryAlias tab
+ +- Project [col1#x AS col1#x]
+ +- LocalRelation [col1#x]
+
+
+-- !query
+SELECT kll_sketch_get_quantile_bigint(CAST(NULL AS BINARY), 0.5) AS null_sketch
+-- !query analysis
+Project [kll_sketch_get_quantile_bigint(cast(null as binary), cast(0.5 as double)) AS null_sketch#xL]
++- OneRowRelation
+
+
+-- !query
+SELECT kll_sketch_get_rank_float(CAST(NULL AS BINARY), 5.0) AS null_sketch
+-- !query analysis
+Project [kll_sketch_get_rank_float(cast(null as binary), cast(5.0 as float)) AS null_sketch#x]
++- OneRowRelation
+
+
+-- !query
+SELECT LENGTH(kll_sketch_to_string_bigint(kll_sketch_agg_bigint(col1, 8))) > 0 AS k_min_value
+FROM t_long_1_5_through_7_11
+-- !query analysis
+Aggregate [(length(kll_sketch_to_string_bigint(kll_sketch_agg_bigint(col1#xL, Some(8), 0, 0))) > 0) AS k_min_value#x]
++- SubqueryAlias spark_catalog.default.t_long_1_5_through_7_11
+ +- Relation spark_catalog.default.t_long_1_5_through_7_11[col1#xL,col2#xL] parquet
+
+
+-- !query
+SELECT LENGTH(kll_sketch_to_string_bigint(kll_sketch_agg_bigint(col1, 200))) > 0 AS k_default_value
+FROM t_long_1_5_through_7_11
+-- !query analysis
+Aggregate [(length(kll_sketch_to_string_bigint(kll_sketch_agg_bigint(col1#xL, Some(200), 0, 0))) > 0) AS k_default_value#x]
++- SubqueryAlias spark_catalog.default.t_long_1_5_through_7_11
+ +- Relation spark_catalog.default.t_long_1_5_through_7_11[col1#xL,col2#xL] parquet
+
+
+-- !query
+SELECT LENGTH(kll_sketch_to_string_bigint(kll_sketch_agg_bigint(col1, 400))) > 0 AS k_custom_value
+FROM t_long_1_5_through_7_11
+-- !query analysis
+Aggregate [(length(kll_sketch_to_string_bigint(kll_sketch_agg_bigint(col1#xL, Some(400), 0, 0))) > 0) AS k_custom_value#x]
++- SubqueryAlias spark_catalog.default.t_long_1_5_through_7_11
+ +- Relation spark_catalog.default.t_long_1_5_through_7_11[col1#xL,col2#xL] parquet
+
+
+-- !query
+SELECT LENGTH(kll_sketch_to_string_bigint(kll_sketch_agg_bigint(col1, 65535))) > 0 AS k_max_value
+FROM t_long_1_5_through_7_11
+-- !query analysis
+Aggregate [(length(kll_sketch_to_string_bigint(kll_sketch_agg_bigint(col1#xL, Some(65535), 0, 0))) > 0) AS k_max_value#x]
++- SubqueryAlias spark_catalog.default.t_long_1_5_through_7_11
+ +- Relation spark_catalog.default.t_long_1_5_through_7_11[col1#xL,col2#xL] parquet
+
+
+-- !query
+SELECT LENGTH(kll_sketch_to_string_float(kll_sketch_agg_float(col1, 100))) > 0 AS k_float_sketch
+FROM t_float_1_5_through_7_11
+-- !query analysis
+Aggregate [(length(kll_sketch_to_string_float(kll_sketch_agg_float(col1#x, Some(100), 0, 0))) > 0) AS k_float_sketch#x]
++- SubqueryAlias spark_catalog.default.t_float_1_5_through_7_11
+ +- Relation spark_catalog.default.t_float_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT LENGTH(kll_sketch_to_string_double(kll_sketch_agg_double(col1, 300))) > 0 AS k_double_sketch
+FROM t_double_1_5_through_7_11
+-- !query analysis
+Aggregate [(length(kll_sketch_to_string_double(kll_sketch_agg_double(col1#x, Some(300), 0, 0))) > 0) AS k_double_sketch#x]
++- SubqueryAlias spark_catalog.default.t_double_1_5_through_7_11
+ +- Relation spark_catalog.default.t_double_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT kll_sketch_get_n_bigint(kll_sketch_agg_bigint(col1)) AS n_bigint
+FROM t_long_1_5_through_7_11
+-- !query analysis
+Aggregate [kll_sketch_get_n_bigint(kll_sketch_agg_bigint(col1#xL, None, 0, 0)) AS n_bigint#xL]
++- SubqueryAlias spark_catalog.default.t_long_1_5_through_7_11
+ +- Relation spark_catalog.default.t_long_1_5_through_7_11[col1#xL,col2#xL] parquet
+
+
+-- !query
+SELECT kll_sketch_get_n_bigint(kll_sketch_agg_bigint(col1)) AS n_byte
+FROM t_byte_1_5_through_7_11
+-- !query analysis
+Aggregate [kll_sketch_get_n_bigint(kll_sketch_agg_bigint(col1#x, None, 0, 0)) AS n_byte#xL]
++- SubqueryAlias spark_catalog.default.t_byte_1_5_through_7_11
+ +- Relation spark_catalog.default.t_byte_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT kll_sketch_get_n_bigint(kll_sketch_agg_bigint(col1)) AS n_short
+FROM t_short_1_5_through_7_11
+-- !query analysis
+Aggregate [kll_sketch_get_n_bigint(kll_sketch_agg_bigint(col1#x, None, 0, 0)) AS n_short#xL]
++- SubqueryAlias spark_catalog.default.t_short_1_5_through_7_11
+ +- Relation spark_catalog.default.t_short_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT kll_sketch_get_n_bigint(kll_sketch_agg_bigint(col1)) AS n_int
+FROM t_int_1_5_through_7_11
+-- !query analysis
+Aggregate [kll_sketch_get_n_bigint(kll_sketch_agg_bigint(col1#x, None, 0, 0)) AS n_int#xL]
++- SubqueryAlias spark_catalog.default.t_int_1_5_through_7_11
+ +- Relation spark_catalog.default.t_int_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT kll_sketch_get_n_float(kll_sketch_agg_float(col1)) AS n_float
+FROM t_float_1_5_through_7_11
+-- !query analysis
+Aggregate [kll_sketch_get_n_float(kll_sketch_agg_float(col1#x, None, 0, 0)) AS n_float#xL]
++- SubqueryAlias spark_catalog.default.t_float_1_5_through_7_11
+ +- Relation spark_catalog.default.t_float_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT kll_sketch_get_n_double(kll_sketch_agg_double(col1)) AS n_double
+FROM t_double_1_5_through_7_11
+-- !query analysis
+Aggregate [kll_sketch_get_n_double(kll_sketch_agg_double(col1#x, None, 0, 0)) AS n_double#xL]
++- SubqueryAlias spark_catalog.default.t_double_1_5_through_7_11
+ +- Relation spark_catalog.default.t_double_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT kll_sketch_get_n_bigint(kll_sketch_agg_bigint(col1, 100)) AS n_k_100
+FROM t_long_1_5_through_7_11
+-- !query analysis
+Aggregate [kll_sketch_get_n_bigint(kll_sketch_agg_bigint(col1#xL, Some(100), 0, 0)) AS n_k_100#xL]
++- SubqueryAlias spark_catalog.default.t_long_1_5_through_7_11
+ +- Relation spark_catalog.default.t_long_1_5_through_7_11[col1#xL,col2#xL] parquet
+
+
+-- !query
+SELECT lower(kll_sketch_to_string_bigint(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_bigint(agg, 0.5) - 4) < 1 AS median_close_to_4,
+ abs(kll_sketch_get_rank_bigint(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_double_1_5_through_7_11
+)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col1\"",
+ "inputType" : "\"DOUBLE\"",
+ "paramIndex" : "first",
+ "requiredType" : "(\"TINYINT\" or \"INT\" or \"BIGINT\" or \"SMALLINT\")",
+ "sqlExpr" : "\"kll_sketch_agg_bigint(col1)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 264,
+ "stopIndex" : 290,
+ "fragment" : "kll_sketch_agg_bigint(col1)"
+ } ]
+}
+
+
+-- !query
+SELECT lower(kll_sketch_to_string_bigint(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_bigint(agg, 0.5) - 4) < 1 AS median_close_to_4,
+ abs(kll_sketch_get_rank_bigint(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_float_1_5_through_7_11
+)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col1\"",
+ "inputType" : "\"FLOAT\"",
+ "paramIndex" : "first",
+ "requiredType" : "(\"TINYINT\" or \"INT\" or \"BIGINT\" or \"SMALLINT\")",
+ "sqlExpr" : "\"kll_sketch_agg_bigint(col1)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 264,
+ "stopIndex" : 290,
+ "fragment" : "kll_sketch_agg_bigint(col1)"
+ } ]
+}
+
+
+-- !query
+SELECT lower(kll_sketch_to_string_float(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_float(agg, 0.5) - 4.0) < 0.5 AS median_close_to_4,
+ abs(kll_sketch_get_rank_float(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS agg
+ FROM t_double_1_5_through_7_11
+)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col1\"",
+ "inputType" : "\"DOUBLE\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"FLOAT\"",
+ "sqlExpr" : "\"kll_sketch_agg_float(col1)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 265,
+ "stopIndex" : 290,
+ "fragment" : "kll_sketch_agg_float(col1)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_agg_float(col1) AS invalid_float_bigint
+FROM t_long_1_5_through_7_11
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col1\"",
+ "inputType" : "\"BIGINT\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"FLOAT\"",
+ "sqlExpr" : "\"kll_sketch_agg_float(col1)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 33,
+ "fragment" : "kll_sketch_agg_float(col1)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_agg_float(col1) AS invalid_float_int
+FROM t_int_1_5_through_7_11
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col1\"",
+ "inputType" : "\"INT\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"FLOAT\"",
+ "sqlExpr" : "\"kll_sketch_agg_float(col1)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 33,
+ "fragment" : "kll_sketch_agg_float(col1)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_agg_float(col1) AS invalid_float_short
+FROM t_short_1_5_through_7_11
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col1\"",
+ "inputType" : "\"SMALLINT\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"FLOAT\"",
+ "sqlExpr" : "\"kll_sketch_agg_float(col1)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 33,
+ "fragment" : "kll_sketch_agg_float(col1)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_agg_float(col1) AS invalid_float_byte
+FROM t_byte_1_5_through_7_11
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col1\"",
+ "inputType" : "\"TINYINT\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"FLOAT\"",
+ "sqlExpr" : "\"kll_sketch_agg_float(col1)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 33,
+ "fragment" : "kll_sketch_agg_float(col1)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_agg_double(col1) AS invalid_double_bigint
+FROM t_long_1_5_through_7_11
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col1\"",
+ "inputType" : "\"BIGINT\"",
+ "paramIndex" : "first",
+ "requiredType" : "(\"FLOAT\" or \"DOUBLE\")",
+ "sqlExpr" : "\"kll_sketch_agg_double(col1)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 34,
+ "fragment" : "kll_sketch_agg_double(col1)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_agg_double(col1) AS invalid_double_int
+FROM t_int_1_5_through_7_11
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col1\"",
+ "inputType" : "\"INT\"",
+ "paramIndex" : "first",
+ "requiredType" : "(\"FLOAT\" or \"DOUBLE\")",
+ "sqlExpr" : "\"kll_sketch_agg_double(col1)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 34,
+ "fragment" : "kll_sketch_agg_double(col1)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_agg_double(col1) AS invalid_double_short
+FROM t_short_1_5_through_7_11
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col1\"",
+ "inputType" : "\"SMALLINT\"",
+ "paramIndex" : "first",
+ "requiredType" : "(\"FLOAT\" or \"DOUBLE\")",
+ "sqlExpr" : "\"kll_sketch_agg_double(col1)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 34,
+ "fragment" : "kll_sketch_agg_double(col1)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_agg_double(col1) AS invalid_double_byte
+FROM t_byte_1_5_through_7_11
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col1\"",
+ "inputType" : "\"TINYINT\"",
+ "paramIndex" : "first",
+ "requiredType" : "(\"FLOAT\" or \"DOUBLE\")",
+ "sqlExpr" : "\"kll_sketch_agg_double(col1)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 34,
+ "fragment" : "kll_sketch_agg_double(col1)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_get_quantile_bigint(agg, -0.5) AS invalid_quantile
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_long_1_5_through_7_11
+)
+-- !query analysis
+Project [kll_sketch_get_quantile_bigint(agg#x, cast(-0.5 as double)) AS invalid_quantile#xL]
++- SubqueryAlias __auto_generated_subquery_name
+ +- Aggregate [kll_sketch_agg_bigint(col1#xL, None, 0, 0) AS agg#x]
+ +- SubqueryAlias spark_catalog.default.t_long_1_5_through_7_11
+ +- Relation spark_catalog.default.t_long_1_5_through_7_11[col1#xL,col2#xL] parquet
+
+
+-- !query
+SELECT kll_sketch_get_quantile_bigint(agg, 1.5) AS invalid_quantile
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_long_1_5_through_7_11
+)
+-- !query analysis
+Project [kll_sketch_get_quantile_bigint(agg#x, cast(1.5 as double)) AS invalid_quantile#xL]
++- SubqueryAlias __auto_generated_subquery_name
+ +- Aggregate [kll_sketch_agg_bigint(col1#xL, None, 0, 0) AS agg#x]
+ +- SubqueryAlias spark_catalog.default.t_long_1_5_through_7_11
+ +- Relation spark_catalog.default.t_long_1_5_through_7_11[col1#xL,col2#xL] parquet
+
+
+-- !query
+SELECT kll_sketch_get_quantile_float(agg, array(-0.1, 0.5, 1.5)) AS invalid_quantiles
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS agg
+ FROM t_float_1_5_through_7_11
+)
+-- !query analysis
+Project [kll_sketch_get_quantile_float(agg#x, cast(array(cast(-0.1 as decimal(2,1)), cast(0.5 as decimal(2,1)), 1.5) as array)) AS invalid_quantiles#x]
++- SubqueryAlias __auto_generated_subquery_name
+ +- Aggregate [kll_sketch_agg_float(col1#x, None, 0, 0) AS agg#x]
+ +- SubqueryAlias spark_catalog.default.t_float_1_5_through_7_11
+ +- Relation spark_catalog.default.t_float_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT kll_sketch_get_rank_bigint(agg, 5) AS wrong_type
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS agg
+ FROM t_float_1_5_through_7_11
+)
+-- !query analysis
+Project [kll_sketch_get_rank_bigint(agg#x, cast(5 as bigint)) AS wrong_type#x]
++- SubqueryAlias __auto_generated_subquery_name
+ +- Aggregate [kll_sketch_agg_float(col1#x, None, 0, 0) AS agg#x]
+ +- SubqueryAlias spark_catalog.default.t_float_1_5_through_7_11
+ +- Relation spark_catalog.default.t_float_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT kll_sketch_merge_bigint(agg1, agg2) AS incompatible_merge
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg1,
+ kll_sketch_agg_float(CAST(col1 AS FLOAT)) AS agg2
+ FROM t_long_1_5_through_7_11
+)
+-- !query analysis
+Project [kll_sketch_merge_bigint(agg1#x, agg2#x) AS incompatible_merge#x]
++- SubqueryAlias __auto_generated_subquery_name
+ +- Aggregate [kll_sketch_agg_bigint(col1#xL, None, 0, 0) AS agg1#x, kll_sketch_agg_float(cast(col1#xL as float), None, 0, 0) AS agg2#x]
+ +- SubqueryAlias spark_catalog.default.t_long_1_5_through_7_11
+ +- Relation spark_catalog.default.t_long_1_5_through_7_11[col1#xL,col2#xL] parquet
+
+
+-- !query
+SELECT kll_sketch_get_quantile_bigint(CAST('not_a_sketch' AS BINARY), 0.5) AS invalid_binary
+-- !query analysis
+Project [kll_sketch_get_quantile_bigint(cast(not_a_sketch as binary), cast(0.5 as double)) AS invalid_binary#xL]
++- OneRowRelation
+
+
+-- !query
+SELECT kll_sketch_get_quantile_float(agg, 0.5) IS NOT NULL AS returns_value
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_long_1_5_through_7_11
+)
+-- !query analysis
+Project [isnotnull(kll_sketch_get_quantile_float(agg#x, cast(0.5 as double))) AS returns_value#x]
++- SubqueryAlias __auto_generated_subquery_name
+ +- Aggregate [kll_sketch_agg_bigint(col1#xL, None, 0, 0) AS agg#x]
+ +- SubqueryAlias spark_catalog.default.t_long_1_5_through_7_11
+ +- Relation spark_catalog.default.t_long_1_5_through_7_11[col1#xL,col2#xL] parquet
+
+
+-- !query
+SELECT lower(kll_sketch_to_string_double(agg)) LIKE '%kll%' AS contains_kll_header
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_long_1_5_through_7_11
+)
+-- !query analysis
+Project [lower(kll_sketch_to_string_double(agg#x)) LIKE %kll% AS contains_kll_header#x]
++- SubqueryAlias __auto_generated_subquery_name
+ +- Aggregate [kll_sketch_agg_bigint(col1#xL, None, 0, 0) AS agg#x]
+ +- SubqueryAlias spark_catalog.default.t_long_1_5_through_7_11
+ +- Relation spark_catalog.default.t_long_1_5_through_7_11[col1#xL,col2#xL] parquet
+
+
+-- !query
+SELECT kll_sketch_agg_bigint(col1, 7) AS k_too_small
+FROM t_long_1_5_through_7_11
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "hint" : "",
+ "msg" : "[KLL_SKETCH_K_OUT_OF_RANGE] For function `kll_sketch_agg_bigint`, the k parameter must be between 8 and 65535 (inclusive), but got 7. SQLSTATE: 22003",
+ "sqlExpr" : "\"kll_sketch_agg_bigint(col1, 7)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 37,
+ "fragment" : "kll_sketch_agg_bigint(col1, 7)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_agg_bigint(col1, 65536) AS k_too_large
+FROM t_long_1_5_through_7_11
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "hint" : "",
+ "msg" : "[KLL_SKETCH_K_OUT_OF_RANGE] For function `kll_sketch_agg_bigint`, the k parameter must be between 8 and 65535 (inclusive), but got 65536. SQLSTATE: 22003",
+ "sqlExpr" : "\"kll_sketch_agg_bigint(col1, 65536)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 41,
+ "fragment" : "kll_sketch_agg_bigint(col1, 65536)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_agg_float(col1, CAST(NULL AS INT)) AS k_is_null
+FROM t_float_1_5_through_7_11
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_NULL",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "exprName" : "k",
+ "sqlExpr" : "\"kll_sketch_agg_float(col1, CAST(NULL AS INT))\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 52,
+ "fragment" : "kll_sketch_agg_float(col1, CAST(NULL AS INT))"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_agg_double(col1, CAST(col1 AS INT)) AS k_non_constant
+FROM t_double_1_5_through_7_11
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputExpr" : "CAST(spark_catalog.default.t_double_1_5_through_7_11.col1 AS INT)",
+ "inputName" : "k",
+ "inputType" : "int",
+ "sqlExpr" : "\"kll_sketch_agg_double(col1, CAST(col1 AS INT))\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 53,
+ "fragment" : "kll_sketch_agg_double(col1, CAST(col1 AS INT))"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_agg_bigint(col1, '100') AS k_wrong_type
+FROM t_long_1_5_through_7_11
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"100\"",
+ "inputType" : "\"STRING\"",
+ "paramIndex" : "second",
+ "requiredType" : "\"INT\"",
+ "sqlExpr" : "\"kll_sketch_agg_bigint(col1, 100)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 41,
+ "fragment" : "kll_sketch_agg_bigint(col1, '100')"
+ } ]
+}
+
+
+-- !query
+SELECT kll_merge_agg_bigint(sketch_col) AS wrong_type_merge
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS sketch_col
+ FROM t_float_1_5_through_7_11
+) float_sketches
+-- !query analysis
+Aggregate [kll_merge_agg_bigint(sketch_col#x, None, 0, 0) AS wrong_type_merge#x]
++- SubqueryAlias float_sketches
+ +- Aggregate [kll_sketch_agg_float(col1#x, None, 0, 0) AS sketch_col#x]
+ +- SubqueryAlias spark_catalog.default.t_float_1_5_through_7_11
+ +- Relation spark_catalog.default.t_float_1_5_through_7_11[col1#x,col2#x] parquet
+
+
+-- !query
+SELECT kll_merge_agg_bigint(col1) AS merge_wrong_type
+FROM t_long_1_5_through_7_11
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col1\"",
+ "inputType" : "\"BIGINT\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"BINARY\"",
+ "sqlExpr" : "\"kll_merge_agg_bigint(col1)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 33,
+ "fragment" : "kll_merge_agg_bigint(col1)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_merge_agg_float(col1) AS merge_wrong_type
+FROM t_float_1_5_through_7_11
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col1\"",
+ "inputType" : "\"FLOAT\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"BINARY\"",
+ "sqlExpr" : "\"kll_merge_agg_float(col1)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 32,
+ "fragment" : "kll_merge_agg_float(col1)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_merge_agg_double(col1) AS merge_wrong_type
+FROM t_double_1_5_through_7_11
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"col1\"",
+ "inputType" : "\"DOUBLE\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"BINARY\"",
+ "sqlExpr" : "\"kll_merge_agg_double(col1)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 33,
+ "fragment" : "kll_merge_agg_double(col1)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_merge_agg_bigint(sketch_col) AS invalid_merge
+FROM (
+ SELECT CAST('not_a_sketch' AS BINARY) AS sketch_col
+) invalid_data
+-- !query analysis
+Aggregate [kll_merge_agg_bigint(sketch_col#x, None, 0, 0) AS invalid_merge#x]
++- SubqueryAlias invalid_data
+ +- Project [cast(not_a_sketch as binary) AS sketch_col#x]
+ +- OneRowRelation
+
+
+-- !query
+SELECT kll_merge_agg_float(sketch_col) AS invalid_merge
+FROM (
+ SELECT X'deadbeef' AS sketch_col
+) invalid_data
+-- !query analysis
+Aggregate [kll_merge_agg_float(sketch_col#x, None, 0, 0) AS invalid_merge#x]
++- SubqueryAlias invalid_data
+ +- Project [0xDEADBEEF AS sketch_col#x]
+ +- OneRowRelation
+
+
+-- !query
+SELECT kll_merge_agg_double(sketch_col) AS invalid_merge
+FROM (
+ SELECT X'cafebabe' AS sketch_col
+) invalid_data
+-- !query analysis
+Aggregate [kll_merge_agg_double(sketch_col#x, None, 0, 0) AS invalid_merge#x]
++- SubqueryAlias invalid_data
+ +- Project [0xCAFEBABE AS sketch_col#x]
+ +- OneRowRelation
+
+
+-- !query
+SELECT kll_merge_agg_bigint(sketch_col, 7) AS k_too_small
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_long_1_5_through_7_11
+) sketches
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "hint" : "",
+ "msg" : "[KLL_SKETCH_K_OUT_OF_RANGE] For function `kll_merge_agg_bigint`, the k parameter must be between 8 and 65535 (inclusive), but got 7. SQLSTATE: 22003",
+ "sqlExpr" : "\"kll_merge_agg_bigint(sketch_col, 7)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 42,
+ "fragment" : "kll_merge_agg_bigint(sketch_col, 7)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_merge_agg_float(sketch_col, 65536) AS k_too_large
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS sketch_col
+ FROM t_float_1_5_through_7_11
+) sketches
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "hint" : "",
+ "msg" : "[KLL_SKETCH_K_OUT_OF_RANGE] For function `kll_merge_agg_float`, the k parameter must be between 8 and 65535 (inclusive), but got 65536. SQLSTATE: 22003",
+ "sqlExpr" : "\"kll_merge_agg_float(sketch_col, 65536)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 45,
+ "fragment" : "kll_merge_agg_float(sketch_col, 65536)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_merge_agg_double(sketch_col, CAST(NULL AS INT)) AS k_is_null
+FROM (
+ SELECT kll_sketch_agg_double(col1) AS sketch_col
+ FROM t_double_1_5_through_7_11
+) sketches
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_NULL",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "exprName" : "k",
+ "sqlExpr" : "\"kll_merge_agg_double(sketch_col, CAST(NULL AS INT))\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 58,
+ "fragment" : "kll_merge_agg_double(sketch_col, CAST(NULL AS INT))"
+ } ]
+}
+
+
+-- !query
+SELECT kll_merge_agg_bigint(sketch_col, CAST(RAND() * 100 AS INT) + 200) AS k_non_constant
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_long_1_5_through_7_11
+) sketches
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputExpr" : "(CAST((rand() * CAST(100 AS DOUBLE)) AS INT) + 200)",
+ "inputName" : "k",
+ "inputType" : "int",
+ "sqlExpr" : "\"kll_merge_agg_bigint(sketch_col, (CAST((rand() * 100) AS INT) + 200))\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 72,
+ "fragment" : "kll_merge_agg_bigint(sketch_col, CAST(RAND() * 100 AS INT) + 200)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_get_n_bigint(X'deadbeef') AS invalid_binary_bigint
+-- !query analysis
+Project [kll_sketch_get_n_bigint(0xDEADBEEF) AS invalid_binary_bigint#xL]
++- OneRowRelation
+
+
+-- !query
+SELECT kll_sketch_get_n_float(X'cafebabe') AS invalid_binary_float
+-- !query analysis
+Project [kll_sketch_get_n_float(0xCAFEBABE) AS invalid_binary_float#xL]
++- OneRowRelation
+
+
+-- !query
+SELECT kll_sketch_get_n_double(X'12345678') AS invalid_binary_double
+-- !query analysis
+Project [kll_sketch_get_n_double(0x12345678) AS invalid_binary_double#xL]
++- OneRowRelation
+
+
+-- !query
+SELECT kll_sketch_get_n_bigint(42) AS wrong_argument_type
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"42\"",
+ "inputType" : "\"INT\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"BINARY\"",
+ "sqlExpr" : "\"kll_sketch_get_n_bigint(42)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 34,
+ "fragment" : "kll_sketch_get_n_bigint(42)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_get_n_float(42.0) AS wrong_argument_type
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"42.0\"",
+ "inputType" : "\"DECIMAL(3,1)\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"BINARY\"",
+ "sqlExpr" : "\"kll_sketch_get_n_float(42.0)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 35,
+ "fragment" : "kll_sketch_get_n_float(42.0)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_get_n_double(42.0D) AS wrong_argument_type
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"42.0\"",
+ "inputType" : "\"DOUBLE\"",
+ "paramIndex" : "first",
+ "requiredType" : "\"BINARY\"",
+ "sqlExpr" : "\"kll_sketch_get_n_double(42.0)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 37,
+ "fragment" : "kll_sketch_get_n_double(42.0D)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_get_quantile_bigint(agg, 'invalid') AS quantile_string
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_long_1_5_through_7_11
+)
+-- !query analysis
+Project [kll_sketch_get_quantile_bigint(agg#x, cast(invalid as double)) AS quantile_string#xL]
++- SubqueryAlias __auto_generated_subquery_name
+ +- Aggregate [kll_sketch_agg_bigint(col1#xL, None, 0, 0) AS agg#x]
+ +- SubqueryAlias spark_catalog.default.t_long_1_5_through_7_11
+ +- Relation spark_catalog.default.t_long_1_5_through_7_11[col1#xL,col2#xL] parquet
+
+
+-- !query
+SELECT kll_sketch_get_quantile_float(agg, X'deadbeef') AS quantile_binary
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS agg
+ FROM t_float_1_5_through_7_11
+)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"X'DEADBEEF'\"",
+ "inputType" : "\"BINARY\"",
+ "paramIndex" : "second",
+ "requiredType" : "(\"DOUBLE\" or \"ARRAY\")",
+ "sqlExpr" : "\"kll_sketch_get_quantile_float(agg, X'DEADBEEF')\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 54,
+ "fragment" : "kll_sketch_get_quantile_float(agg, X'deadbeef')"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_get_quantile_double(agg, true) AS quantile_boolean
+FROM (
+ SELECT kll_sketch_agg_double(col1) AS agg
+ FROM t_double_1_5_through_7_11
+)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"true\"",
+ "inputType" : "\"BOOLEAN\"",
+ "paramIndex" : "second",
+ "requiredType" : "(\"DOUBLE\" or \"ARRAY\")",
+ "sqlExpr" : "\"kll_sketch_get_quantile_double(agg, true)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 48,
+ "fragment" : "kll_sketch_get_quantile_double(agg, true)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_get_rank_bigint(agg, 'invalid') AS rank_string
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_long_1_5_through_7_11
+)
+-- !query analysis
+Project [kll_sketch_get_rank_bigint(agg#x, cast(invalid as bigint)) AS rank_string#x]
++- SubqueryAlias __auto_generated_subquery_name
+ +- Aggregate [kll_sketch_agg_bigint(col1#xL, None, 0, 0) AS agg#x]
+ +- SubqueryAlias spark_catalog.default.t_long_1_5_through_7_11
+ +- Relation spark_catalog.default.t_long_1_5_through_7_11[col1#xL,col2#xL] parquet
+
+
+-- !query
+SELECT kll_sketch_get_rank_float(agg, X'cafebabe') AS rank_binary
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS agg
+ FROM t_float_1_5_through_7_11
+)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"X'CAFEBABE'\"",
+ "inputType" : "\"BINARY\"",
+ "paramIndex" : "second",
+ "requiredType" : "(\"FLOAT\" or \"ARRAY\")",
+ "sqlExpr" : "\"kll_sketch_get_rank_float(agg, X'CAFEBABE')\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 50,
+ "fragment" : "kll_sketch_get_rank_float(agg, X'cafebabe')"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_get_rank_double(agg, false) AS rank_boolean
+FROM (
+ SELECT kll_sketch_agg_double(col1) AS agg
+ FROM t_double_1_5_through_7_11
+)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"false\"",
+ "inputType" : "\"BOOLEAN\"",
+ "paramIndex" : "second",
+ "requiredType" : "(\"DOUBLE\" or \"ARRAY\")",
+ "sqlExpr" : "\"kll_sketch_get_rank_double(agg, false)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 45,
+ "fragment" : "kll_sketch_get_rank_double(agg, false)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_get_quantile_bigint(agg, CAST(col1 AS DOUBLE) / 10.0) AS non_foldable_scalar_rank
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+ FROM t_long_1_5_through_7_11
+ GROUP BY col1
+)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputExpr" : "\"(CAST(col1 AS DOUBLE) / 10.0)\"",
+ "inputName" : "`rank`",
+ "inputType" : "\"DOUBLE\"",
+ "sqlExpr" : "\"kll_sketch_get_quantile_bigint(agg, (CAST(col1 AS DOUBLE) / 10.0))\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 71,
+ "fragment" : "kll_sketch_get_quantile_bigint(agg, CAST(col1 AS DOUBLE) / 10.0)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_get_quantile_bigint(agg, array(0.25, CAST(col1 AS DOUBLE) / 10.0, 0.75)) AS non_foldable_array_rank
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+ FROM t_long_1_5_through_7_11
+ GROUP BY col1
+)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputExpr" : "\"array(0.25, (CAST(col1 AS DOUBLE) / 10.0), 0.75)\"",
+ "inputName" : "`rank`",
+ "inputType" : "\"ARRAY\"",
+ "sqlExpr" : "\"kll_sketch_get_quantile_bigint(agg, array(0.25, (CAST(col1 AS DOUBLE) / 10.0), 0.75))\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 90,
+ "fragment" : "kll_sketch_get_quantile_bigint(agg, array(0.25, CAST(col1 AS DOUBLE) / 10.0, 0.75))"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_get_rank_bigint(agg, col1) AS non_foldable_scalar_quantile
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+ FROM t_long_1_5_through_7_11
+ GROUP BY col1
+)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputExpr" : "\"col1\"",
+ "inputName" : "`quantile`",
+ "inputType" : "\"BIGINT\"",
+ "sqlExpr" : "\"kll_sketch_get_rank_bigint(agg, col1)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 44,
+ "fragment" : "kll_sketch_get_rank_bigint(agg, col1)"
+ } ]
+}
+
+
+-- !query
+SELECT kll_sketch_get_rank_bigint(agg, array(1L, col1, 5L)) AS non_foldable_array_quantile
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+ FROM t_long_1_5_through_7_11
+ GROUP BY col1
+)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputExpr" : "\"array(1, col1, 5)\"",
+ "inputName" : "`quantile`",
+ "inputType" : "\"ARRAY\"",
+ "sqlExpr" : "\"kll_sketch_get_rank_bigint(agg, array(1, col1, 5))\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 59,
+ "fragment" : "kll_sketch_get_rank_bigint(agg, array(1L, col1, 5L))"
+ } ]
+}
+
+
+-- !query
+DROP TABLE IF EXISTS t_int_1_5_through_7_11
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t_int_1_5_through_7_11
+
+
+-- !query
+DROP TABLE IF EXISTS t_long_1_5_through_7_11
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t_long_1_5_through_7_11
+
+
+-- !query
+DROP TABLE IF EXISTS t_short_1_5_through_7_11
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t_short_1_5_through_7_11
+
+
+-- !query
+DROP TABLE IF EXISTS t_byte_1_5_through_7_11
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t_byte_1_5_through_7_11
+
+
+-- !query
+DROP TABLE IF EXISTS t_float_1_5_through_7_11
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t_float_1_5_through_7_11
+
+
+-- !query
+DROP TABLE IF EXISTS t_double_1_5_through_7_11
+-- !query analysis
+DropTable true, false
++- ResolvedIdentifier V2SessionCatalog(spark_catalog), default.t_double_1_5_through_7_11
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/nonansi/interval.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/nonansi/interval.sql.out
index c0196bbe118ef..259cb2bff5ef8 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/nonansi/interval.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/nonansi/interval.sql.out
@@ -61,15 +61,21 @@ Project [(INTERVAL '2147483647' MONTH / 0.5) AS (INTERVAL '2147483647' MONTH / 0
-- !query
select interval 2147483647 day * 2
-- !query analysis
-java.lang.ArithmeticException
-long overflow
+org.apache.spark.SparkArithmeticException
+{
+ "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION",
+ "sqlState" : "22015"
+}
-- !query
select interval 2147483647 day / 0.5
-- !query analysis
-java.lang.ArithmeticException
-long overflow
+org.apache.spark.SparkArithmeticException
+{
+ "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION",
+ "sqlState" : "22015"
+}
-- !query
@@ -3212,3 +3218,13 @@ SELECT width_bucket(INTERVAL '-59' MINUTE, INTERVAL -'1 01' DAY TO HOUR, INTERVA
-- !query analysis
Project [width_bucket(INTERVAL '-59' MINUTE, INTERVAL '-1 01' DAY TO HOUR, INTERVAL '1 02:03:04.001' DAY TO SECOND, cast(10 as bigint)) AS width_bucket(INTERVAL '-59' MINUTE, INTERVAL '-1 01' DAY TO HOUR, INTERVAL '1 02:03:04.001' DAY TO SECOND, 10)#xL]
+- OneRowRelation
+
+
+-- !query
+SELECT interval 106751991 day 4 hour 0 minute 54.776 second
+-- !query analysis
+org.apache.spark.SparkArithmeticException
+{
+ "errorClass" : "INTERVAL_ARITHMETIC_OVERFLOW.WITHOUT_SUGGESTION",
+ "sqlState" : "22015"
+}
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/nonansi/st-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/nonansi/st-functions.sql.out
index fe2dda3f1967b..c86d2454d759c 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/nonansi/st-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/nonansi/st-functions.sql.out
@@ -66,6 +66,237 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
}
+-- !query
+SELECT hex(ST_AsBinary(CAST(ST_GeogFromWKB(X'0101000000000000000000f03f0000000000000040') AS GEOGRAPHY(ANY)))) AS result
+-- !query analysis
+Project [hex(st_asbinary(cast(st_geogfromwkb(0x0101000000000000000000F03F0000000000000040) as geography(any)))) AS result#x]
++- OneRowRelation
+
+
+-- !query
+SELECT CAST(ST_GeogFromWKB(X'0101000000000000000000f03f0000000000000040')::GEOGRAPHY(ANY) AS GEOGRAPHY(4326)) AS result
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.CAST_WITHOUT_SUGGESTION",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "sqlExpr" : "\"CAST(CAST(st_geogfromwkb(X'0101000000000000000000F03F0000000000000040') AS GEOGRAPHY(ANY)) AS GEOGRAPHY(4326))\"",
+ "srcType" : "\"GEOGRAPHY(ANY)\"",
+ "targetType" : "\"GEOGRAPHY(4326)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 109,
+ "fragment" : "CAST(ST_GeogFromWKB(X'0101000000000000000000f03f0000000000000040')::GEOGRAPHY(ANY) AS GEOGRAPHY(4326))"
+ } ]
+}
+
+
+-- !query
+SELECT hex(ST_AsBinary(CAST(ST_GeogFromWKB(X'0101000000000000000000f03f0000000000000040') AS GEOMETRY(4326)))) AS result
+-- !query analysis
+Project [hex(st_asbinary(cast(st_geogfromwkb(0x0101000000000000000000F03F0000000000000040) as geometry(4326)))) AS result#x]
++- OneRowRelation
+
+
+-- !query
+SELECT CAST(ST_GeogFromWKB(X'0101000000000000000000f03f0000000000000040') AS GEOMETRY(ANY)) AS result
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.CAST_WITHOUT_SUGGESTION",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "sqlExpr" : "\"CAST(st_geogfromwkb(X'0101000000000000000000F03F0000000000000040') AS GEOMETRY(ANY))\"",
+ "srcType" : "\"GEOGRAPHY(4326)\"",
+ "targetType" : "\"GEOMETRY(ANY)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 91,
+ "fragment" : "CAST(ST_GeogFromWKB(X'0101000000000000000000f03f0000000000000040') AS GEOMETRY(ANY))"
+ } ]
+}
+
+
+-- !query
+SELECT hex(ST_AsBinary(CAST(ST_GeomFromWKB(X'0101000000000000000000f03f0000000000000040') AS GEOMETRY(ANY)))) AS result
+-- !query analysis
+Project [hex(st_asbinary(cast(st_geomfromwkb(0x0101000000000000000000F03F0000000000000040) as geometry(any)))) AS result#x]
++- OneRowRelation
+
+
+-- !query
+SELECT CAST(ST_GeomFromWKB(X'0101000000000000000000f03f0000000000000040')::GEOMETRY(ANY) AS GEOMETRY(4326)) AS result
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.CAST_WITHOUT_SUGGESTION",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "sqlExpr" : "\"CAST(CAST(st_geomfromwkb(X'0101000000000000000000F03F0000000000000040') AS GEOMETRY(ANY)) AS GEOMETRY(4326))\"",
+ "srcType" : "\"GEOMETRY(ANY)\"",
+ "targetType" : "\"GEOMETRY(4326)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 107,
+ "fragment" : "CAST(ST_GeomFromWKB(X'0101000000000000000000f03f0000000000000040')::GEOMETRY(ANY) AS GEOMETRY(4326))"
+ } ]
+}
+
+
+-- !query
+SELECT typeof(array(ST_GeogFromWKB(wkb), ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY))) FROM geodata
+-- !query analysis
+Project [typeof(array(cast(st_geogfromwkb(wkb#x) as geography(any)), cast(st_geogfromwkb(wkb#x) as geography(any)))) AS typeof(array(st_geogfromwkb(wkb), CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(array(ST_GeomFromWKB(wkb), ST_GeomFromWKB(wkb)::GEOMETRY(ANY))) FROM geodata
+-- !query analysis
+Project [typeof(array(cast(st_geomfromwkb(wkb#x) as geometry(any)), cast(st_geomfromwkb(wkb#x) as geometry(any)))) AS typeof(array(st_geomfromwkb(wkb), CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(map('a', ST_GeogFromWKB(wkb), 'b', ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY))) FROM geodata
+-- !query analysis
+Project [typeof(map(a, cast(st_geogfromwkb(wkb#x) as geography(any)), b, cast(st_geogfromwkb(wkb#x) as geography(any)))) AS typeof(map(a, st_geogfromwkb(wkb), b, CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(map('a', ST_GeomFromWKB(wkb), 'b', ST_GeomFromWKB(wkb)::GEOMETRY(ANY))) FROM geodata
+-- !query analysis
+Project [typeof(map(a, cast(st_geomfromwkb(wkb#x) as geometry(any)), b, cast(st_geomfromwkb(wkb#x) as geometry(any)))) AS typeof(map(a, st_geomfromwkb(wkb), b, CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(array(named_struct('g1', ST_GeogFromWKB(wkb), 'g2', ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY)), named_struct('g1', ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY), 'g2', ST_GeogFromWKB(wkb)))) FROM geodata
+-- !query analysis
+Project [typeof(array(cast(named_struct(g1, st_geogfromwkb(wkb#x), g2, cast(st_geogfromwkb(wkb#x) as geography(any))) as struct), cast(named_struct(g1, cast(st_geogfromwkb(wkb#x) as geography(any)), g2, st_geogfromwkb(wkb#x)) as struct))) AS typeof(array(named_struct(g1, st_geogfromwkb(wkb), g2, CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY))), named_struct(g1, CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY)), g2, st_geogfromwkb(wkb))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(array(named_struct('g1', ST_GeomFromWKB(wkb), 'g2', ST_GeomFromWKB(wkb)::GEOMETRY(ANY)), named_struct('g1', ST_GeomFromWKB(wkb)::GEOMETRY(ANY), 'g2', ST_GeomFromWKB(wkb)))) FROM geodata
+-- !query analysis
+Project [typeof(array(cast(named_struct(g1, st_geomfromwkb(wkb#x), g2, cast(st_geomfromwkb(wkb#x) as geometry(any))) as struct), cast(named_struct(g1, cast(st_geomfromwkb(wkb#x) as geometry(any)), g2, st_geomfromwkb(wkb#x)) as struct))) AS typeof(array(named_struct(g1, st_geomfromwkb(wkb), g2, CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY))), named_struct(g1, CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY)), g2, st_geomfromwkb(wkb))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(named_struct('a', array(ST_GeogFromWKB(wkb), ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY)), 'b', map('g', ST_GeogFromWKB(wkb), 'h', ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY)))) FROM geodata
+-- !query analysis
+Project [typeof(named_struct(a, array(cast(st_geogfromwkb(wkb#x) as geography(any)), cast(st_geogfromwkb(wkb#x) as geography(any))), b, map(g, cast(st_geogfromwkb(wkb#x) as geography(any)), h, cast(st_geogfromwkb(wkb#x) as geography(any))))) AS typeof(named_struct(a, array(st_geogfromwkb(wkb), CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY))), b, map(g, st_geogfromwkb(wkb), h, CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY)))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(named_struct('a', array(ST_GeomFromWKB(wkb), ST_GeomFromWKB(wkb)::GEOMETRY(ANY)), 'b', map('g', ST_GeomFromWKB(wkb), 'h', ST_GeomFromWKB(wkb)::GEOMETRY(ANY)))) FROM geodata
+-- !query analysis
+Project [typeof(named_struct(a, array(cast(st_geomfromwkb(wkb#x) as geometry(any)), cast(st_geomfromwkb(wkb#x) as geometry(any))), b, map(g, cast(st_geomfromwkb(wkb#x) as geometry(any)), h, cast(st_geomfromwkb(wkb#x) as geometry(any))))) AS typeof(named_struct(a, array(st_geomfromwkb(wkb), CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY))), b, map(g, st_geomfromwkb(wkb), h, CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY)))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(nvl(ST_GeogFromWKB(wkb), ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY))) FROM geodata
+-- !query analysis
+Project [typeof(nvl(st_geogfromwkb(wkb#x), cast(st_geogfromwkb(wkb#x) as geography(any)))) AS typeof(nvl(st_geogfromwkb(wkb), CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(nvl(ST_GeomFromWKB(wkb), ST_GeomFromWKB(wkb)::GEOMETRY(ANY))) FROM geodata
+-- !query analysis
+Project [typeof(nvl(st_geomfromwkb(wkb#x), cast(st_geomfromwkb(wkb#x) as geometry(any)))) AS typeof(nvl(st_geomfromwkb(wkb), CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(nvl2(ST_GeogFromWKB(wkb), ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY), ST_GeogFromWKB(wkb))) FROM geodata
+-- !query analysis
+Project [typeof(nvl2(st_geogfromwkb(wkb#x), cast(st_geogfromwkb(wkb#x) as geography(any)), st_geogfromwkb(wkb#x))) AS typeof(nvl2(st_geogfromwkb(wkb), CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY)), st_geogfromwkb(wkb)))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(nvl2(ST_GeomFromWKB(wkb), ST_GeomFromWKB(wkb)::GEOMETRY(ANY), ST_GeomFromWKB(wkb))) FROM geodata
+-- !query analysis
+Project [typeof(nvl2(st_geomfromwkb(wkb#x), cast(st_geomfromwkb(wkb#x) as geometry(any)), st_geomfromwkb(wkb#x))) AS typeof(nvl2(st_geomfromwkb(wkb), CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY)), st_geomfromwkb(wkb)))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(CASE WHEN wkb IS NOT NULL THEN ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY) ELSE ST_GeogFromWKB(wkb) END) FROM geodata
+-- !query analysis
+Project [typeof(CASE WHEN isnotnull(wkb#x) THEN cast(st_geogfromwkb(wkb#x) as geography(any)) ELSE cast(st_geogfromwkb(wkb#x) as geography(any)) END) AS typeof(CASE WHEN (wkb IS NOT NULL) THEN CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY)) ELSE st_geogfromwkb(wkb) END)#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(CASE WHEN wkb IS NOT NULL THEN ST_GeomFromWKB(wkb)::GEOMETRY(ANY) ELSE ST_GeomFromWKB(wkb) END) FROM geodata
+-- !query analysis
+Project [typeof(CASE WHEN isnotnull(wkb#x) THEN cast(st_geomfromwkb(wkb#x) as geometry(any)) ELSE cast(st_geomfromwkb(wkb#x) as geometry(any)) END) AS typeof(CASE WHEN (wkb IS NOT NULL) THEN CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY)) ELSE st_geomfromwkb(wkb) END)#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(coalesce(ST_GeogFromWKB(wkb), ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY))) FROM geodata
+-- !query analysis
+Project [typeof(coalesce(cast(st_geogfromwkb(wkb#x) as geography(any)), cast(st_geogfromwkb(wkb#x) as geography(any)))) AS typeof(coalesce(st_geogfromwkb(wkb), CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(coalesce(ST_GeomFromWKB(wkb), ST_GeomFromWKB(wkb)::GEOMETRY(ANY))) FROM geodata
+-- !query analysis
+Project [typeof(coalesce(cast(st_geomfromwkb(wkb#x) as geometry(any)), cast(st_geomfromwkb(wkb#x) as geometry(any)))) AS typeof(coalesce(st_geomfromwkb(wkb), CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(IF(wkb IS NOT NULL, ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY), ST_GeogFromWKB(wkb))) FROM geodata
+-- !query analysis
+Project [typeof(if (isnotnull(wkb#x)) cast(st_geogfromwkb(wkb#x) as geography(any)) else cast(st_geogfromwkb(wkb#x) as geography(any))) AS typeof((IF((wkb IS NOT NULL), CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY)), st_geogfromwkb(wkb))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(IF(wkb IS NOT NULL, ST_GeomFromWKB(wkb)::GEOMETRY(ANY), ST_GeomFromWKB(wkb))) FROM geodata
+-- !query analysis
+Project [typeof(if (isnotnull(wkb#x)) cast(st_geomfromwkb(wkb#x) as geometry(any)) else cast(st_geomfromwkb(wkb#x) as geometry(any))) AS typeof((IF((wkb IS NOT NULL), CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY)), st_geomfromwkb(wkb))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
-- !query
SELECT hex(ST_AsBinary(ST_GeogFromWKB(X'0101000000000000000000f03f0000000000000040'))) AS result
-- !query analysis
@@ -119,6 +350,90 @@ Aggregate [count(1) AS count(1)#xL]
+- Relation spark_catalog.default.geodata[wkb#x] parquet
+-- !query
+SELECT ST_Srid(ST_SetSrid(ST_GeogFromWKB(X'0101000000000000000000F03F0000000000000040'), 4326))
+-- !query analysis
+Project [st_srid(st_setsrid(st_geogfromwkb(0x0101000000000000000000F03F0000000000000040), 4326)) AS st_srid(st_setsrid(st_geogfromwkb(X'0101000000000000000000F03F0000000000000040'), 4326))#x]
++- OneRowRelation
+
+
+-- !query
+SELECT ST_Srid(ST_SetSrid(ST_GeomFromWKB(X'0101000000000000000000F03F0000000000000040'), 3857))
+-- !query analysis
+Project [st_srid(st_setsrid(st_geomfromwkb(0x0101000000000000000000F03F0000000000000040), 3857)) AS st_srid(st_setsrid(st_geomfromwkb(X'0101000000000000000000F03F0000000000000040'), 3857))#x]
++- OneRowRelation
+
+
+-- !query
+SELECT ST_Srid(ST_SetSrid(ST_GeogFromWKB(X'0101000000000000000000F03F0000000000000040'), 3857))
+-- !query analysis
+org.apache.spark.SparkIllegalArgumentException
+{
+ "errorClass" : "ST_INVALID_SRID_VALUE",
+ "sqlState" : "22023",
+ "messageParameters" : {
+ "srid" : "3857"
+ }
+}
+
+
+-- !query
+SELECT ST_Srid(ST_SetSrid(ST_GeomFromWKB(X'0101000000000000000000F03F0000000000000040'), 9999))
+-- !query analysis
+org.apache.spark.SparkIllegalArgumentException
+{
+ "errorClass" : "ST_INVALID_SRID_VALUE",
+ "sqlState" : "22023",
+ "messageParameters" : {
+ "srid" : "9999"
+ }
+}
+
+
+-- !query
+SELECT COUNT(*) FROM geodata WHERE ST_Srid(ST_SetSrid(ST_GeogFromWKB(wkb), 4326)) <> 4326
+-- !query analysis
+Aggregate [count(1) AS count(1)#xL]
++- Filter NOT (st_srid(st_setsrid(st_geogfromwkb(wkb#x), 4326)) = 4326)
+ +- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT COUNT(*) FROM geodata WHERE ST_Srid(ST_SetSrid(ST_GeomFromWKB(wkb), 3857)) <> 3857
+-- !query analysis
+Aggregate [count(1) AS count(1)#xL]
++- Filter NOT (st_srid(st_setsrid(st_geomfromwkb(wkb#x), 3857)) = 3857)
+ +- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT COUNT(*) FROM geodata WHERE ST_Srid(ST_SetSrid(ST_GeogFromWKB(wkb), 3857)) IS NOT NULL
+-- !query analysis
+org.apache.spark.SparkIllegalArgumentException
+{
+ "errorClass" : "ST_INVALID_SRID_VALUE",
+ "sqlState" : "22023",
+ "messageParameters" : {
+ "srid" : "3857"
+ }
+}
+
+
+-- !query
+SELECT COUNT(*) FROM geodata WHERE ST_Srid(ST_SetSrid(ST_GeomFromWKB(wkb), 9999)) IS NOT NULL
+-- !query analysis
+org.apache.spark.SparkIllegalArgumentException
+{
+ "errorClass" : "ST_INVALID_SRID_VALUE",
+ "sqlState" : "22023",
+ "messageParameters" : {
+ "srid" : "9999"
+ }
+}
+
+
-- !query
DROP TABLE geodata
-- !query analysis
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/st-functions.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/st-functions.sql.out
index fe2dda3f1967b..c86d2454d759c 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/st-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/st-functions.sql.out
@@ -66,6 +66,237 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
}
+-- !query
+SELECT hex(ST_AsBinary(CAST(ST_GeogFromWKB(X'0101000000000000000000f03f0000000000000040') AS GEOGRAPHY(ANY)))) AS result
+-- !query analysis
+Project [hex(st_asbinary(cast(st_geogfromwkb(0x0101000000000000000000F03F0000000000000040) as geography(any)))) AS result#x]
++- OneRowRelation
+
+
+-- !query
+SELECT CAST(ST_GeogFromWKB(X'0101000000000000000000f03f0000000000000040')::GEOGRAPHY(ANY) AS GEOGRAPHY(4326)) AS result
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.CAST_WITHOUT_SUGGESTION",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "sqlExpr" : "\"CAST(CAST(st_geogfromwkb(X'0101000000000000000000F03F0000000000000040') AS GEOGRAPHY(ANY)) AS GEOGRAPHY(4326))\"",
+ "srcType" : "\"GEOGRAPHY(ANY)\"",
+ "targetType" : "\"GEOGRAPHY(4326)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 109,
+ "fragment" : "CAST(ST_GeogFromWKB(X'0101000000000000000000f03f0000000000000040')::GEOGRAPHY(ANY) AS GEOGRAPHY(4326))"
+ } ]
+}
+
+
+-- !query
+SELECT hex(ST_AsBinary(CAST(ST_GeogFromWKB(X'0101000000000000000000f03f0000000000000040') AS GEOMETRY(4326)))) AS result
+-- !query analysis
+Project [hex(st_asbinary(cast(st_geogfromwkb(0x0101000000000000000000F03F0000000000000040) as geometry(4326)))) AS result#x]
++- OneRowRelation
+
+
+-- !query
+SELECT CAST(ST_GeogFromWKB(X'0101000000000000000000f03f0000000000000040') AS GEOMETRY(ANY)) AS result
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.CAST_WITHOUT_SUGGESTION",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "sqlExpr" : "\"CAST(st_geogfromwkb(X'0101000000000000000000F03F0000000000000040') AS GEOMETRY(ANY))\"",
+ "srcType" : "\"GEOGRAPHY(4326)\"",
+ "targetType" : "\"GEOMETRY(ANY)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 91,
+ "fragment" : "CAST(ST_GeogFromWKB(X'0101000000000000000000f03f0000000000000040') AS GEOMETRY(ANY))"
+ } ]
+}
+
+
+-- !query
+SELECT hex(ST_AsBinary(CAST(ST_GeomFromWKB(X'0101000000000000000000f03f0000000000000040') AS GEOMETRY(ANY)))) AS result
+-- !query analysis
+Project [hex(st_asbinary(cast(st_geomfromwkb(0x0101000000000000000000F03F0000000000000040) as geometry(any)))) AS result#x]
++- OneRowRelation
+
+
+-- !query
+SELECT CAST(ST_GeomFromWKB(X'0101000000000000000000f03f0000000000000040')::GEOMETRY(ANY) AS GEOMETRY(4326)) AS result
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.CAST_WITHOUT_SUGGESTION",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "sqlExpr" : "\"CAST(CAST(st_geomfromwkb(X'0101000000000000000000F03F0000000000000040') AS GEOMETRY(ANY)) AS GEOMETRY(4326))\"",
+ "srcType" : "\"GEOMETRY(ANY)\"",
+ "targetType" : "\"GEOMETRY(4326)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 107,
+ "fragment" : "CAST(ST_GeomFromWKB(X'0101000000000000000000f03f0000000000000040')::GEOMETRY(ANY) AS GEOMETRY(4326))"
+ } ]
+}
+
+
+-- !query
+SELECT typeof(array(ST_GeogFromWKB(wkb), ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY))) FROM geodata
+-- !query analysis
+Project [typeof(array(cast(st_geogfromwkb(wkb#x) as geography(any)), cast(st_geogfromwkb(wkb#x) as geography(any)))) AS typeof(array(st_geogfromwkb(wkb), CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(array(ST_GeomFromWKB(wkb), ST_GeomFromWKB(wkb)::GEOMETRY(ANY))) FROM geodata
+-- !query analysis
+Project [typeof(array(cast(st_geomfromwkb(wkb#x) as geometry(any)), cast(st_geomfromwkb(wkb#x) as geometry(any)))) AS typeof(array(st_geomfromwkb(wkb), CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(map('a', ST_GeogFromWKB(wkb), 'b', ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY))) FROM geodata
+-- !query analysis
+Project [typeof(map(a, cast(st_geogfromwkb(wkb#x) as geography(any)), b, cast(st_geogfromwkb(wkb#x) as geography(any)))) AS typeof(map(a, st_geogfromwkb(wkb), b, CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(map('a', ST_GeomFromWKB(wkb), 'b', ST_GeomFromWKB(wkb)::GEOMETRY(ANY))) FROM geodata
+-- !query analysis
+Project [typeof(map(a, cast(st_geomfromwkb(wkb#x) as geometry(any)), b, cast(st_geomfromwkb(wkb#x) as geometry(any)))) AS typeof(map(a, st_geomfromwkb(wkb), b, CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(array(named_struct('g1', ST_GeogFromWKB(wkb), 'g2', ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY)), named_struct('g1', ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY), 'g2', ST_GeogFromWKB(wkb)))) FROM geodata
+-- !query analysis
+Project [typeof(array(cast(named_struct(g1, st_geogfromwkb(wkb#x), g2, cast(st_geogfromwkb(wkb#x) as geography(any))) as struct), cast(named_struct(g1, cast(st_geogfromwkb(wkb#x) as geography(any)), g2, st_geogfromwkb(wkb#x)) as struct))) AS typeof(array(named_struct(g1, st_geogfromwkb(wkb), g2, CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY))), named_struct(g1, CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY)), g2, st_geogfromwkb(wkb))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(array(named_struct('g1', ST_GeomFromWKB(wkb), 'g2', ST_GeomFromWKB(wkb)::GEOMETRY(ANY)), named_struct('g1', ST_GeomFromWKB(wkb)::GEOMETRY(ANY), 'g2', ST_GeomFromWKB(wkb)))) FROM geodata
+-- !query analysis
+Project [typeof(array(cast(named_struct(g1, st_geomfromwkb(wkb#x), g2, cast(st_geomfromwkb(wkb#x) as geometry(any))) as struct), cast(named_struct(g1, cast(st_geomfromwkb(wkb#x) as geometry(any)), g2, st_geomfromwkb(wkb#x)) as struct))) AS typeof(array(named_struct(g1, st_geomfromwkb(wkb), g2, CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY))), named_struct(g1, CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY)), g2, st_geomfromwkb(wkb))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(named_struct('a', array(ST_GeogFromWKB(wkb), ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY)), 'b', map('g', ST_GeogFromWKB(wkb), 'h', ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY)))) FROM geodata
+-- !query analysis
+Project [typeof(named_struct(a, array(cast(st_geogfromwkb(wkb#x) as geography(any)), cast(st_geogfromwkb(wkb#x) as geography(any))), b, map(g, cast(st_geogfromwkb(wkb#x) as geography(any)), h, cast(st_geogfromwkb(wkb#x) as geography(any))))) AS typeof(named_struct(a, array(st_geogfromwkb(wkb), CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY))), b, map(g, st_geogfromwkb(wkb), h, CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY)))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(named_struct('a', array(ST_GeomFromWKB(wkb), ST_GeomFromWKB(wkb)::GEOMETRY(ANY)), 'b', map('g', ST_GeomFromWKB(wkb), 'h', ST_GeomFromWKB(wkb)::GEOMETRY(ANY)))) FROM geodata
+-- !query analysis
+Project [typeof(named_struct(a, array(cast(st_geomfromwkb(wkb#x) as geometry(any)), cast(st_geomfromwkb(wkb#x) as geometry(any))), b, map(g, cast(st_geomfromwkb(wkb#x) as geometry(any)), h, cast(st_geomfromwkb(wkb#x) as geometry(any))))) AS typeof(named_struct(a, array(st_geomfromwkb(wkb), CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY))), b, map(g, st_geomfromwkb(wkb), h, CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY)))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(nvl(ST_GeogFromWKB(wkb), ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY))) FROM geodata
+-- !query analysis
+Project [typeof(nvl(st_geogfromwkb(wkb#x), cast(st_geogfromwkb(wkb#x) as geography(any)))) AS typeof(nvl(st_geogfromwkb(wkb), CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(nvl(ST_GeomFromWKB(wkb), ST_GeomFromWKB(wkb)::GEOMETRY(ANY))) FROM geodata
+-- !query analysis
+Project [typeof(nvl(st_geomfromwkb(wkb#x), cast(st_geomfromwkb(wkb#x) as geometry(any)))) AS typeof(nvl(st_geomfromwkb(wkb), CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(nvl2(ST_GeogFromWKB(wkb), ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY), ST_GeogFromWKB(wkb))) FROM geodata
+-- !query analysis
+Project [typeof(nvl2(st_geogfromwkb(wkb#x), cast(st_geogfromwkb(wkb#x) as geography(any)), st_geogfromwkb(wkb#x))) AS typeof(nvl2(st_geogfromwkb(wkb), CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY)), st_geogfromwkb(wkb)))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(nvl2(ST_GeomFromWKB(wkb), ST_GeomFromWKB(wkb)::GEOMETRY(ANY), ST_GeomFromWKB(wkb))) FROM geodata
+-- !query analysis
+Project [typeof(nvl2(st_geomfromwkb(wkb#x), cast(st_geomfromwkb(wkb#x) as geometry(any)), st_geomfromwkb(wkb#x))) AS typeof(nvl2(st_geomfromwkb(wkb), CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY)), st_geomfromwkb(wkb)))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(CASE WHEN wkb IS NOT NULL THEN ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY) ELSE ST_GeogFromWKB(wkb) END) FROM geodata
+-- !query analysis
+Project [typeof(CASE WHEN isnotnull(wkb#x) THEN cast(st_geogfromwkb(wkb#x) as geography(any)) ELSE cast(st_geogfromwkb(wkb#x) as geography(any)) END) AS typeof(CASE WHEN (wkb IS NOT NULL) THEN CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY)) ELSE st_geogfromwkb(wkb) END)#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(CASE WHEN wkb IS NOT NULL THEN ST_GeomFromWKB(wkb)::GEOMETRY(ANY) ELSE ST_GeomFromWKB(wkb) END) FROM geodata
+-- !query analysis
+Project [typeof(CASE WHEN isnotnull(wkb#x) THEN cast(st_geomfromwkb(wkb#x) as geometry(any)) ELSE cast(st_geomfromwkb(wkb#x) as geometry(any)) END) AS typeof(CASE WHEN (wkb IS NOT NULL) THEN CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY)) ELSE st_geomfromwkb(wkb) END)#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(coalesce(ST_GeogFromWKB(wkb), ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY))) FROM geodata
+-- !query analysis
+Project [typeof(coalesce(cast(st_geogfromwkb(wkb#x) as geography(any)), cast(st_geogfromwkb(wkb#x) as geography(any)))) AS typeof(coalesce(st_geogfromwkb(wkb), CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(coalesce(ST_GeomFromWKB(wkb), ST_GeomFromWKB(wkb)::GEOMETRY(ANY))) FROM geodata
+-- !query analysis
+Project [typeof(coalesce(cast(st_geomfromwkb(wkb#x) as geometry(any)), cast(st_geomfromwkb(wkb#x) as geometry(any)))) AS typeof(coalesce(st_geomfromwkb(wkb), CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(IF(wkb IS NOT NULL, ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY), ST_GeogFromWKB(wkb))) FROM geodata
+-- !query analysis
+Project [typeof(if (isnotnull(wkb#x)) cast(st_geogfromwkb(wkb#x) as geography(any)) else cast(st_geogfromwkb(wkb#x) as geography(any))) AS typeof((IF((wkb IS NOT NULL), CAST(st_geogfromwkb(wkb) AS GEOGRAPHY(ANY)), st_geogfromwkb(wkb))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT typeof(IF(wkb IS NOT NULL, ST_GeomFromWKB(wkb)::GEOMETRY(ANY), ST_GeomFromWKB(wkb))) FROM geodata
+-- !query analysis
+Project [typeof(if (isnotnull(wkb#x)) cast(st_geomfromwkb(wkb#x) as geometry(any)) else cast(st_geomfromwkb(wkb#x) as geometry(any))) AS typeof((IF((wkb IS NOT NULL), CAST(st_geomfromwkb(wkb) AS GEOMETRY(ANY)), st_geomfromwkb(wkb))))#x]
++- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
-- !query
SELECT hex(ST_AsBinary(ST_GeogFromWKB(X'0101000000000000000000f03f0000000000000040'))) AS result
-- !query analysis
@@ -119,6 +350,90 @@ Aggregate [count(1) AS count(1)#xL]
+- Relation spark_catalog.default.geodata[wkb#x] parquet
+-- !query
+SELECT ST_Srid(ST_SetSrid(ST_GeogFromWKB(X'0101000000000000000000F03F0000000000000040'), 4326))
+-- !query analysis
+Project [st_srid(st_setsrid(st_geogfromwkb(0x0101000000000000000000F03F0000000000000040), 4326)) AS st_srid(st_setsrid(st_geogfromwkb(X'0101000000000000000000F03F0000000000000040'), 4326))#x]
++- OneRowRelation
+
+
+-- !query
+SELECT ST_Srid(ST_SetSrid(ST_GeomFromWKB(X'0101000000000000000000F03F0000000000000040'), 3857))
+-- !query analysis
+Project [st_srid(st_setsrid(st_geomfromwkb(0x0101000000000000000000F03F0000000000000040), 3857)) AS st_srid(st_setsrid(st_geomfromwkb(X'0101000000000000000000F03F0000000000000040'), 3857))#x]
++- OneRowRelation
+
+
+-- !query
+SELECT ST_Srid(ST_SetSrid(ST_GeogFromWKB(X'0101000000000000000000F03F0000000000000040'), 3857))
+-- !query analysis
+org.apache.spark.SparkIllegalArgumentException
+{
+ "errorClass" : "ST_INVALID_SRID_VALUE",
+ "sqlState" : "22023",
+ "messageParameters" : {
+ "srid" : "3857"
+ }
+}
+
+
+-- !query
+SELECT ST_Srid(ST_SetSrid(ST_GeomFromWKB(X'0101000000000000000000F03F0000000000000040'), 9999))
+-- !query analysis
+org.apache.spark.SparkIllegalArgumentException
+{
+ "errorClass" : "ST_INVALID_SRID_VALUE",
+ "sqlState" : "22023",
+ "messageParameters" : {
+ "srid" : "9999"
+ }
+}
+
+
+-- !query
+SELECT COUNT(*) FROM geodata WHERE ST_Srid(ST_SetSrid(ST_GeogFromWKB(wkb), 4326)) <> 4326
+-- !query analysis
+Aggregate [count(1) AS count(1)#xL]
++- Filter NOT (st_srid(st_setsrid(st_geogfromwkb(wkb#x), 4326)) = 4326)
+ +- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT COUNT(*) FROM geodata WHERE ST_Srid(ST_SetSrid(ST_GeomFromWKB(wkb), 3857)) <> 3857
+-- !query analysis
+Aggregate [count(1) AS count(1)#xL]
++- Filter NOT (st_srid(st_setsrid(st_geomfromwkb(wkb#x), 3857)) = 3857)
+ +- SubqueryAlias spark_catalog.default.geodata
+ +- Relation spark_catalog.default.geodata[wkb#x] parquet
+
+
+-- !query
+SELECT COUNT(*) FROM geodata WHERE ST_Srid(ST_SetSrid(ST_GeogFromWKB(wkb), 3857)) IS NOT NULL
+-- !query analysis
+org.apache.spark.SparkIllegalArgumentException
+{
+ "errorClass" : "ST_INVALID_SRID_VALUE",
+ "sqlState" : "22023",
+ "messageParameters" : {
+ "srid" : "3857"
+ }
+}
+
+
+-- !query
+SELECT COUNT(*) FROM geodata WHERE ST_Srid(ST_SetSrid(ST_GeomFromWKB(wkb), 9999)) IS NOT NULL
+-- !query analysis
+org.apache.spark.SparkIllegalArgumentException
+{
+ "errorClass" : "ST_INVALID_SRID_VALUE",
+ "sqlState" : "22023",
+ "messageParameters" : {
+ "srid" : "9999"
+ }
+}
+
+
-- !query
DROP TABLE geodata
-- !query analysis
diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/thetasketch.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/thetasketch.sql.out
index 323084223d4bc..84fb8086151d1 100644
--- a/sql/core/src/test/resources/sql-tests/analyzer-results/thetasketch.sql.out
+++ b/sql/core/src/test/resources/sql-tests/analyzer-results/thetasketch.sql.out
@@ -1068,6 +1068,49 @@ Aggregate [theta_union_agg(sketch#x, 27, 0, 0) AS theta_union_agg(sketch, 27)#x]
+- LocalRelation [col#x]
+-- !query
+SELECT theta_sketch_agg(col, CAST(NULL AS INT)) AS lg_nom_entries_is_null
+FROM VALUES (15), (16), (17) tab(col)
+-- !query analysis
+Aggregate [theta_sketch_agg(col#x, cast(null as int), 0, 0) AS lg_nom_entries_is_null#x]
++- SubqueryAlias tab
+ +- LocalRelation [col#x]
+
+
+-- !query
+SELECT theta_sketch_agg(col, CAST(col AS INT)) AS lg_nom_entries_non_constant
+FROM VALUES (15), (16), (17) tab(col)
+-- !query analysis
+Aggregate [theta_sketch_agg(col#x, cast(col#x as int), 0, 0) AS lg_nom_entries_non_constant#x]
++- SubqueryAlias tab
+ +- LocalRelation [col#x]
+
+
+-- !query
+SELECT theta_sketch_agg(col, '15')
+FROM VALUES (50), (60), (60) tab(col)
+-- !query analysis
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"15\"",
+ "inputType" : "\"STRING\"",
+ "paramIndex" : "second",
+ "requiredType" : "\"INT\"",
+ "sqlExpr" : "\"theta_sketch_agg(col, 15)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 34,
+ "fragment" : "theta_sketch_agg(col, '15')"
+ } ]
+}
+
+
-- !query
SELECT theta_union(1, 2)
FROM VALUES
diff --git a/sql/core/src/test/resources/sql-tests/inputs/execute-immediate.sql b/sql/core/src/test/resources/sql-tests/inputs/execute-immediate.sql
index 17fa47be4eec3..16e1850d5e59f 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/execute-immediate.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/execute-immediate.sql
@@ -289,3 +289,11 @@ EXECUTE IMMEDIATE 'SELECT typeof(:p) as type, :p as val' USING MAP('key1', 'valu
-- !query
EXECUTE IMMEDIATE 'SELECT typeof(:p) as type, :p as val' USING MAP(1, 'one', 2, 'two') AS p;
+-- !query
+-- Test unbound parameter markers without USING clause
+-- named parameter without USING clause should fail
+EXECUTE IMMEDIATE 'SELECT :param';
+
+-- !query
+-- positional parameter without USING clause should fail
+EXECUTE IMMEDIATE 'SELECT ?';
diff --git a/sql/core/src/test/resources/sql-tests/inputs/hll.sql b/sql/core/src/test/resources/sql-tests/inputs/hll.sql
index fbd82b936b776..35128da97fd61 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/hll.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/hll.sql
@@ -69,6 +69,15 @@ FROM VALUES (50), (60), (60) tab(col);
SELECT hll_sketch_agg(col, 40)
FROM VALUES (50), (60), (60) tab(col);
+SELECT hll_sketch_agg(col, CAST(NULL AS INT)) AS k_is_null
+FROM VALUES (15), (16), (17) tab(col);
+
+SELECT hll_sketch_agg(col, CAST(col AS INT)) AS k_non_constant
+FROM VALUES (15), (16), (17) tab(col);
+
+SELECT hll_sketch_agg(col, '15')
+FROM VALUES (50), (60), (60) tab(col);
+
SELECT hll_union(
hll_sketch_agg(col1, 12),
hll_sketch_agg(col2, 13))
diff --git a/sql/core/src/test/resources/sql-tests/inputs/identifier-clause-legacy.sql b/sql/core/src/test/resources/sql-tests/inputs/identifier-clause-legacy.sql
new file mode 100644
index 0000000000000..ae1f10f1af1f2
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/identifier-clause-legacy.sql
@@ -0,0 +1,2 @@
+--SET spark.sql.legacy.identifierClause = true
+--IMPORT identifier-clause.sql
diff --git a/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql b/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql
index 4aa8019097fdf..d9bafe7cc607e 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/identifier-clause.sql
@@ -119,6 +119,14 @@ VALUES(IDENTIFIER(1));
VALUES(IDENTIFIER(SUBSTR('HELLO', 1, RAND() + 1)));
SELECT `IDENTIFIER`('abs')(c1) FROM VALUES(-1) AS T(c1);
+CREATE TABLE t(col1 INT);
+SELECT * FROM IDENTIFIER((SELECT 't'));
+SELECT * FROM (SELECT IDENTIFIER((SELECT 'col1')) FROM IDENTIFIER((SELECT 't')));
+SELECT IDENTIFIER((SELECT 'col1')) FROM VALUES(1);
+SELECT col1, IDENTIFIER((SELECT col1)) FROM VALUES(1);
+SELECT IDENTIFIER((SELECT 'col1', 'col2')) FROM VALUES(1,2);
+DROP TABLE t;
+
CREATE TABLE IDENTIFIER(1)(c1 INT) USING csv;
CREATE TABLE IDENTIFIER('a.b.c')(c1 INT) USING csv;
CREATE VIEW IDENTIFIER('a.b.c')(c1) AS VALUES(1);
@@ -157,7 +165,6 @@ SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('T');
WITH ABC(c1, c2) AS (VALUES(1, 2), (2, 3))
SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('A' || 'BC');
--- Not supported
SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1);
SELECT T1.c1 FROM VALUES(1) AS T1(c1) JOIN VALUES(1) AS T2(c1) USING (IDENTIFIER('c1'));
SELECT IDENTIFIER('t').c1 FROM VALUES(1) AS T(c1);
@@ -168,10 +175,223 @@ SELECT * FROM IDENTIFIER('s').IDENTIFIER('tab');
SELECT * FROM IDENTIFIER('s').tab;
SELECT row_number() OVER IDENTIFIER('win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1);
SELECT row_number() OVER win FROM VALUES(1) AS T(c1) WINDOW IDENTIFIER('win') AS (ORDER BY c1);
+SELECT 1 AS IDENTIFIER('col1');
+SELECT my_table.* FROM VALUES (1, 2) AS IDENTIFIER('my_table')(IDENTIFIER('c1'), IDENTIFIER('c2'));
WITH identifier('v')(identifier('c1')) AS (VALUES(1)) (SELECT c1 FROM v);
-INSERT INTO tab(IDENTIFIER('c1')) VALUES(1);
CREATE OR REPLACE VIEW v(IDENTIFIER('c1')) AS VALUES(1);
+SELECT c1 FROM v;
+DROP VIEW IF EXISTS v;
CREATE TABLE tab(IDENTIFIER('c1') INT) USING CSV;
+INSERT INTO tab(IDENTIFIER('c1')) VALUES(1);
+SELECT c1 FROM tab;
+ALTER TABLE IDENTIFIER('tab') RENAME COLUMN IDENTIFIER('c1') TO IDENTIFIER('col1');
+SELECT col1 FROM tab;
+ALTER TABLE IDENTIFIER('tab') ADD COLUMN IDENTIFIER('c2') INT;
+SELECT c2 FROM tab;
+ALTER TABLE IDENTIFIER('tab') DROP COLUMN IDENTIFIER('c2');
+ALTER TABLE IDENTIFIER('tab') RENAME TO IDENTIFIER('tab_renamed');
+SELECT * FROM tab_renamed;
+DROP TABLE IF EXISTS tab_renamed;
+DROP TABLE IF EXISTS tab;
+
+-- Error because qualified names are not allowed
+CREATE TABLE test_col_with_dot(IDENTIFIER('`col.with.dot`') INT) USING CSV;
+DROP TABLE IF EXISTS test_col_with_dot;
+-- Identifier-lite: table alias with qualified name should error (table alias must be single)
+SELECT * FROM VALUES (1, 2) AS IDENTIFIER('schema.table')(c1, c2);
+-- Identifier-lite: column alias with qualified name should error (column alias must be single)
+SELECT 1 AS IDENTIFIER('col1.col2');
+
+-- Additional coverage: SHOW commands with identifier-lite
+CREATE SCHEMA identifier_clause_test_schema;
+USE identifier_clause_test_schema;
+CREATE TABLE test_show(c1 INT, c2 STRING) USING CSV;
+SHOW VIEWS IN IDENTIFIER('identifier_clause_test_schema');
+SHOW PARTITIONS IDENTIFIER('test_show');
+SHOW CREATE TABLE IDENTIFIER('test_show');
+DROP TABLE test_show;
+
+-- SET CATALOG with identifier-lite
+-- SET CATALOG IDENTIFIER('spark_catalog');
+
+-- DESCRIBE with different forms
+CREATE TABLE test_desc(c1 INT) USING CSV;
+DESCRIBE TABLE IDENTIFIER('test_desc');
+DESCRIBE FORMATTED IDENTIFIER('test_desc');
+DESCRIBE EXTENDED IDENTIFIER('test_desc');
+DESC IDENTIFIER('test_desc');
+DROP TABLE test_desc;
+
+-- COMMENT ON COLUMN with identifier-lite
+CREATE TABLE test_comment(c1 INT, c2 STRING) USING CSV;
+COMMENT ON TABLE IDENTIFIER('test_comment') IS 'table comment';
+ALTER TABLE test_comment ALTER COLUMN IDENTIFIER('c1') COMMENT 'column comment';
+DROP TABLE test_comment;
+
+-- Additional identifier tests with qualified table names in various commands
+CREATE TABLE identifier_clause_test_schema.test_table(c1 INT) USING CSV;
+ANALYZE TABLE IDENTIFIER('identifier_clause_test_schema.test_table') COMPUTE STATISTICS;
+REFRESH TABLE IDENTIFIER('identifier_clause_test_schema.test_table');
+DESCRIBE IDENTIFIER('identifier_clause_test_schema.test_table');
+SHOW COLUMNS FROM IDENTIFIER('identifier_clause_test_schema.test_table');
+DROP TABLE IDENTIFIER('identifier_clause_test_schema.test_table');
+
+-- Session variables with identifier-lite
+DECLARE IDENTIFIER('my_var') = 'value';
+SET VAR IDENTIFIER('my_var') = 'new_value';
+SELECT IDENTIFIER('my_var');
+DROP TEMPORARY VARIABLE IDENTIFIER('my_var');
+
+-- SQL UDF with identifier-lite in parameter names and return statement
+CREATE TEMPORARY FUNCTION test_udf(IDENTIFIER('param1') INT, IDENTIFIER('param2') STRING)
+RETURNS INT
+RETURN IDENTIFIER('param1') + length(IDENTIFIER('param2'));
+
+SELECT test_udf(5, 'hello');
+DROP TEMPORARY FUNCTION test_udf;
+
+-- SQL UDF with table return type using identifier-lite
+CREATE TEMPORARY FUNCTION test_table_udf(IDENTIFIER('input_val') INT)
+RETURNS TABLE(IDENTIFIER('col1') INT, IDENTIFIER('col2') STRING)
+RETURN SELECT IDENTIFIER('input_val'), 'result';
+
+SELECT * FROM test_table_udf(42);
+DROP TEMPORARY FUNCTION test_table_udf;
+
+-- Integration tests: Combining parameter markers, string coalescing, and IDENTIFIER
+-- These tests demonstrate the power of combining IDENTIFIER with parameters
+
+-- Test 1: IDENTIFIER with parameter marker for table name
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:tab \'b\').c1 FROM VALUES(1) AS tab(c1)' USING 'ta' AS tab;
+
+-- Test 2: IDENTIFIER with string coalescing for column name
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:col1 ''.c2'') FROM VALUES(named_struct(''c2'', 42)) AS T(c1)'
+ USING 'c1' AS col1;
+
+-- Test 3: IDENTIFIER with parameter and string literal coalescing for qualified table name
+CREATE TABLE integration_test(c1 INT, c2 STRING) USING CSV;
+INSERT INTO integration_test VALUES (1, 'a'), (2, 'b');
+EXECUTE IMMEDIATE 'SELECT * FROM IDENTIFIER(:schema ''.'' :table) ORDER BY ALL'
+ USING 'identifier_clause_test_schema' AS schema, 'integration_test' AS table;
+
+-- Test 4: IDENTIFIER in column reference with parameter and string coalescing
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:prefix ''1''), IDENTIFIER(:prefix ''2'') FROM integration_test ORDER BY ALL'
+ USING 'c' AS prefix;
+
+-- Test 5: IDENTIFIER in WHERE clause with parameters
+EXECUTE IMMEDIATE 'SELECT * FROM integration_test WHERE IDENTIFIER(:col) = :val'
+ USING 'c1' AS col, 1 AS val;
+
+-- Test 6: IDENTIFIER in JOIN with parameters for table and column names
+CREATE TABLE integration_test2(c1 INT, c3 STRING) USING CSV;
+INSERT INTO integration_test2 VALUES (1, 'x'), (2, 'y');
+EXECUTE IMMEDIATE 'SELECT t1.*, t2.* FROM IDENTIFIER(:t1) t1 JOIN IDENTIFIER(:t2) t2 USING (IDENTIFIER(:col)) ORDER BY ALL'
+ USING 'integration_test' AS t1, 'integration_test2' AS t2, 'c1' AS col;
+
+-- Test 7: IDENTIFIER in window function with parameter for partition column
+EXECUTE IMMEDIATE
+ 'SELECT IDENTIFIER(:col1), IDENTIFIER(:col2), row_number() OVER (PARTITION BY IDENTIFIER(:part) ORDER BY IDENTIFIER(:ord)) as rn FROM integration_test'
+ USING 'c1' AS col1, 'c2' AS col2, 'c2' AS part, 'c1' AS ord;
+
+-- Test 8: IDENTIFIER in aggregate function with string coalescing
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:prefix ''2''), IDENTIFIER(:agg)(IDENTIFIER(:col)) FROM integration_test GROUP BY IDENTIFIER(:prefix ''2'') ORDER BY ALL'
+ USING 'c' AS prefix, 'count' AS agg, 'c1' AS col;
+
+-- Test 9: IDENTIFIER in ORDER BY with multiple parameters
+EXECUTE IMMEDIATE 'SELECT * FROM integration_test ORDER BY IDENTIFIER(:col1) DESC, IDENTIFIER(:col2)'
+ USING 'c1' AS col1, 'c2' AS col2;
+
+-- Test 10: IDENTIFIER in INSERT with parameter for column name
+EXECUTE IMMEDIATE 'INSERT INTO integration_test(IDENTIFIER(:col1), IDENTIFIER(:col2)) VALUES (:val1, :val2)'
+ USING 'c1' AS col1, 'c2' AS col2, 3 AS val1, 'c' AS val2;
+
+-- Test 11: Complex - IDENTIFIER with nested string operations
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(concat(:schema, ''.'', :table, ''.c1'')) FROM VALUES(named_struct(''c1'', 100)) AS IDENTIFIER(:alias)(IDENTIFIER(:schema ''.'' :table))'
+ USING 'identifier_clause_test_schema' AS schema, 'my_table' AS table, 't' AS alias;
+
+-- Test 12: IDENTIFIER in CTE name with parameter
+EXECUTE IMMEDIATE 'WITH IDENTIFIER(:cte_name)(c1) AS (VALUES(1)) SELECT c1 FROM IDENTIFIER(:cte_name)'
+ USING 'my_cte' AS cte_name;
+
+-- Test 13: IDENTIFIER in view name with parameter
+EXECUTE IMMEDIATE 'CREATE OR REPLACE TEMPORARY VIEW IDENTIFIER(:view_name)(IDENTIFIER(:col_name)) AS VALUES(1)'
+ USING 'test_view' AS view_name, 'test_col' AS col_name;
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:col) FROM IDENTIFIER(:view)'
+ USING 'test_col' AS col, 'test_view' AS view;
+DROP VIEW test_view;
+
+-- Test 14: IDENTIFIER in ALTER TABLE with parameters
+EXECUTE IMMEDIATE 'ALTER TABLE IDENTIFIER(:tab) ADD COLUMN IDENTIFIER(:new_col) INT'
+ USING 'integration_test' AS tab, 'c4' AS new_col;
+EXECUTE IMMEDIATE 'ALTER TABLE IDENTIFIER(:tab) RENAME COLUMN IDENTIFIER(:old_col) TO IDENTIFIER(:new_col)'
+ USING 'integration_test' AS tab, 'c4' AS old_col, 'c5' AS new_col;
+
+-- Test 15: IDENTIFIER with dereference using parameters
+EXECUTE IMMEDIATE 'SELECT map(:key, :val).IDENTIFIER(:key) AS result'
+ USING 'mykey' AS key, 42 AS val;
+
+-- Test 16: IDENTIFIER in table alias with string coalescing
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:alias ''.c1'') FROM integration_test AS IDENTIFIER(:alias) ORDER BY ALL'
+ USING 't' AS alias;
+
+-- Test 17: Multiple IDENTIFIER clauses with different parameter combinations
+EXECUTE IMMEDIATE
+ 'SELECT IDENTIFIER(:col1), IDENTIFIER(:p ''2'') FROM IDENTIFIER(:schema ''.'' :tab) WHERE IDENTIFIER(:col1) > 0 ORDER BY IDENTIFIER(:p ''1'')'
+ USING 'c1' AS col1, 'c' AS p, 'identifier_clause_test_schema' AS schema, 'integration_test' AS tab;
+
+-- Test 19: IDENTIFIER with qualified name coalescing for schema.table.column pattern
+-- This should work for multi-part identifiers
+EXECUTE IMMEDIATE 'SELECT * FROM IDENTIFIER(:schema ''.'' :table) WHERE IDENTIFIER(concat(:tab_alias, ''.c1'')) > 0 ORDER BY ALL'
+ USING 'identifier_clause_test_schema' AS schema, 'integration_test' AS table, 'integration_test' AS tab_alias;
+
+-- Test 20: Error case - IDENTIFIER with too many parts from parameter coalescing
+-- This should error as column alias must be single identifier
+EXECUTE IMMEDIATE 'SELECT 1 AS IDENTIFIER(:schema ''.'' :col)'
+ USING 'identifier_clause_test_schema' AS schema, 'col1' AS col;
+
+-- Cleanup
+DROP TABLE integration_test;
+DROP TABLE integration_test2;
+
+-- LATERAL VIEW with IDENTIFIER() for table and column names
+CREATE TABLE lateral_test(arr ARRAY) USING PARQUET;
+INSERT INTO lateral_test VALUES (array(1, 2, 3));
+SELECT * FROM lateral_test LATERAL VIEW explode(arr) IDENTIFIER('tbl') AS IDENTIFIER('col') ORDER BY ALL;
+SELECT * FROM lateral_test LATERAL VIEW OUTER explode(arr) IDENTIFIER('my_table') AS IDENTIFIER('my_col') ORDER BY ALL;
+DROP TABLE lateral_test;
+
+-- UNPIVOT with IDENTIFIER() for value column alias
+CREATE TABLE unpivot_test(id INT, a INT, b INT, c INT) USING CSV;
+INSERT INTO unpivot_test VALUES (1, 10, 20, 30);
+SELECT * FROM unpivot_test UNPIVOT (val FOR col IN (a AS IDENTIFIER('col_a'), b AS IDENTIFIER('col_b'))) ORDER BY ALL;
+SELECT * FROM unpivot_test UNPIVOT ((v1, v2) FOR col IN ((a, b) AS IDENTIFIER('cols_ab'), (b, c) AS IDENTIFIER('cols_bc'))) ORDER BY ALL;
+DROP TABLE unpivot_test;
+
+-- DESCRIBE column with IDENTIFIER()
+CREATE TABLE describe_col_test(c1 INT, c2 STRING, c3 DOUBLE) USING CSV;
+DESCRIBE describe_col_test IDENTIFIER('c1');
+DESCRIBE describe_col_test IDENTIFIER('c2');
+DROP TABLE describe_col_test;
+
+-- All the following tests fail because they are not about "true" identifiers
+
+-- This should fail - named parameters don't support IDENTIFIER()
+SELECT :IDENTIFIER('param1') FROM VALUES(1) AS T(c1);
+
+-- Hint names use simpleIdentifier - these should fail
+CREATE TABLE hint_test(c1 INT, c2 INT) USING CSV;
+INSERT INTO hint_test VALUES (1, 2), (3, 4);
+SELECT /*+ IDENTIFIER('BROADCAST')(hint_test) */ * FROM hint_test;
+SELECT /*+ IDENTIFIER('MERGE')(hint_test) */ * FROM hint_test;
+DROP TABLE hint_test;
+
+-- These should fail - function scope doesn't support IDENTIFIER()
+SHOW IDENTIFIER('USER') FUNCTIONS;
+-- EXTRACT field name uses simpleIdentifier - should fail
+SELECT EXTRACT(IDENTIFIER('YEAR') FROM DATE'2024-01-15');
+-- TIMESTAMPADD unit is a token, not identifier - should fail
+SELECT TIMESTAMPADD(IDENTIFIER('YEAR'), 1, DATE'2024-01-15');
+DROP SCHEMA identifier_clause_test_schema;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/interval.sql b/sql/core/src/test/resources/sql-tests/inputs/interval.sql
index e4da28c2e7588..e8e10d089dec6 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/interval.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/interval.sql
@@ -385,3 +385,8 @@ SELECT width_bucket(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10
SELECT width_bucket(INTERVAL '-1' YEAR, INTERVAL -'1-2' YEAR TO MONTH, INTERVAL '1-2' YEAR TO MONTH, 10);
SELECT width_bucket(INTERVAL '0' DAY, INTERVAL '0' DAY, INTERVAL '10' DAY, 10);
SELECT width_bucket(INTERVAL '-59' MINUTE, INTERVAL -'1 01' DAY TO HOUR, INTERVAL '1 2:3:4.001' DAY TO SECOND, 10);
+
+-- interval overflow with large day values (SPARK-50072)
+-- This should throw INTERVAL_ARITHMETIC_OVERFLOW error
+SELECT interval 106751991 day 4 hour 0 minute 54.776 second;
+
diff --git a/sql/core/src/test/resources/sql-tests/inputs/kllquantiles.sql b/sql/core/src/test/resources/sql-tests/inputs/kllquantiles.sql
new file mode 100644
index 0000000000000..9300754d204a5
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/inputs/kllquantiles.sql
@@ -0,0 +1,712 @@
+-- Positive test cases
+-- Create tables with two columns for each data type
+
+-- Integer table
+DROP TABLE IF EXISTS t_int_1_5_through_7_11;
+CREATE TABLE t_int_1_5_through_7_11 AS
+VALUES
+ (1, 5), (2, 6), (3, 7), (4, 8), (5, 9), (6, 10), (7, 11) AS tab(col1, col2);
+
+-- Long table
+DROP TABLE IF EXISTS t_long_1_5_through_7_11;
+CREATE TABLE t_long_1_5_through_7_11 AS
+VALUES
+ (1L, 5L), (2L, 6L), (3L, 7L), (4L, 8L), (5L, 9L), (6L, 10L), (7L, 11L) AS tab(col1, col2);
+
+-- SMALLINT (ShortType) table
+DROP TABLE IF EXISTS t_short_1_5_through_7_11;
+CREATE TABLE t_short_1_5_through_7_11 AS
+VALUES
+ (CAST(1 AS SMALLINT), CAST(5 AS SMALLINT)),
+ (CAST(2 AS SMALLINT), CAST(6 AS SMALLINT)),
+ (CAST(3 AS SMALLINT), CAST(7 AS SMALLINT)),
+ (CAST(4 AS SMALLINT), CAST(8 AS SMALLINT)),
+ (CAST(5 AS SMALLINT), CAST(9 AS SMALLINT)),
+ (CAST(6 AS SMALLINT), CAST(10 AS SMALLINT)),
+ (CAST(7 AS SMALLINT), CAST(11 AS SMALLINT))
+ AS tab(col1, col2);
+
+-- TINYINT (ByteType) table
+DROP TABLE IF EXISTS t_byte_1_5_through_7_11;
+CREATE TABLE t_byte_1_5_through_7_11 AS
+VALUES
+ (CAST(1 AS TINYINT), CAST(5 AS TINYINT)),
+ (CAST(2 AS TINYINT), CAST(6 AS TINYINT)),
+ (CAST(3 AS TINYINT), CAST(7 AS TINYINT)),
+ (CAST(4 AS TINYINT), CAST(8 AS TINYINT)),
+ (CAST(5 AS TINYINT), CAST(9 AS TINYINT)),
+ (CAST(6 AS TINYINT), CAST(10 AS TINYINT)),
+ (CAST(7 AS TINYINT), CAST(11 AS TINYINT))
+ AS tab(col1, col2);
+
+-- Float table
+DROP TABLE IF EXISTS t_float_1_5_through_7_11;
+CREATE TABLE t_float_1_5_through_7_11 AS
+VALUES
+ (CAST(1 AS FLOAT), CAST(5 AS FLOAT)),
+ (CAST(2 AS FLOAT), CAST(6 AS FLOAT)),
+ (CAST(3 AS FLOAT), CAST(7 AS FLOAT)),
+ (CAST(4 AS FLOAT), CAST(8 AS FLOAT)),
+ (CAST(5 AS FLOAT), CAST(9 AS FLOAT)),
+ (CAST(6 AS FLOAT), CAST(10 AS FLOAT)),
+ (CAST(7 AS FLOAT), CAST(11 AS FLOAT)) AS tab(col1, col2);
+
+-- Double table
+DROP TABLE IF EXISTS t_double_1_5_through_7_11;
+CREATE TABLE t_double_1_5_through_7_11 AS
+VALUES
+ (CAST(1 AS DOUBLE), CAST(5 AS DOUBLE)),
+ (CAST(2 AS DOUBLE), CAST(6 AS DOUBLE)),
+ (CAST(3 AS DOUBLE), CAST(7 AS DOUBLE)),
+ (CAST(4 AS DOUBLE), CAST(8 AS DOUBLE)),
+ (CAST(5 AS DOUBLE), CAST(9 AS DOUBLE)),
+ (CAST(6 AS DOUBLE), CAST(10 AS DOUBLE)),
+ (CAST(7 AS DOUBLE), CAST(11 AS DOUBLE)) AS tab(col1, col2);
+
+-- BIGINT sketches
+SELECT lower(kll_sketch_to_string_bigint(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_bigint(agg, 0.5) - 4) < 1 AS median_close_to_4,
+ abs(kll_sketch_get_rank_bigint(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_byte_1_5_through_7_11
+);
+
+SELECT lower(kll_sketch_to_string_bigint(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_bigint(agg, 0.5) - 4) < 1 AS median_close_to_4,
+ abs(kll_sketch_get_rank_bigint(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_int_1_5_through_7_11
+);
+
+SELECT lower(kll_sketch_to_string_bigint(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_bigint(agg, 0.5) - 4) < 1 AS median_close_to_4,
+ abs(kll_sketch_get_rank_bigint(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_long_1_5_through_7_11
+);
+
+SELECT lower(kll_sketch_to_string_bigint(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_bigint(agg, 0.5) - 4) < 1 AS median_close_to_4,
+ abs(kll_sketch_get_rank_bigint(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_short_1_5_through_7_11
+);
+
+-- FLOAT sketches (only accepts float types to avoid precision loss)
+SELECT lower(kll_sketch_to_string_float(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_float(agg, 0.5) - 4.0) < 0.5 AS median_close_to_4,
+ abs(kll_sketch_get_rank_float(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS agg
+ FROM t_float_1_5_through_7_11
+);
+
+-- DOUBLE sketches (accepts float and double types to avoid precision loss from integer conversion)
+SELECT lower(kll_sketch_to_string_double(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_double(agg, 0.5) - 4.0) < 0.5 AS median_close_to_4,
+ abs(kll_sketch_get_rank_double(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_double(col1) AS agg
+ FROM t_double_1_5_through_7_11
+);
+
+-- Test float column with double sketch (valid type promotion)
+SELECT lower(kll_sketch_to_string_double(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_double(agg, 0.5) - 4.0) < 0.5 AS median_close_to_4,
+ abs(kll_sketch_get_rank_double(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_double(col1) AS agg
+ FROM t_float_1_5_through_7_11
+);
+
+-- Merging sketches and converting them to strings (scalar merge functions)
+SELECT
+ split(
+ kll_sketch_to_string_bigint(
+ kll_sketch_merge_bigint(
+ kll_sketch_agg_bigint(col1),
+ kll_sketch_agg_bigint(col1)
+ )
+ ),
+ '\n'
+ )[1] AS result
+ FROM t_byte_1_5_through_7_11;
+
+SELECT
+ split(
+ kll_sketch_to_string_float(
+ kll_sketch_merge_float(
+ kll_sketch_agg_float(col1),
+ kll_sketch_agg_float(col1)
+ )
+ ),
+ '\n'
+ )[1] AS result
+FROM t_byte_1_5_through_7_11;
+
+SELECT
+ split(
+ kll_sketch_to_string_double(
+ kll_sketch_merge_double(
+ kll_sketch_agg_double(col1),
+ kll_sketch_agg_double(col1)
+ )
+ ),
+ '\n'
+ )[1] AS result
+FROM t_byte_1_5_through_7_11;
+
+-- Tests for KllMergeAgg* aggregate functions
+-- These functions merge multiple binary sketch representations
+
+-- Test GROUP BY with kll_merge_agg_bigint and HAVING clause
+SELECT
+ parity,
+ kll_sketch_get_n_bigint(kll_merge_agg_bigint(sketch_col)) AS total_count
+FROM (
+ SELECT
+ col1 % 2 AS parity,
+ kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_int_1_5_through_7_11
+ GROUP BY col1 % 2
+) grouped_sketches
+GROUP BY parity
+HAVING kll_sketch_get_n_bigint(kll_merge_agg_bigint(sketch_col)) > 3;
+
+-- Test empty aggregation: zero rows input for kll_merge_agg_bigint
+SELECT kll_sketch_get_n_bigint(kll_merge_agg_bigint(sketch_col)) AS empty_merge_n
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_int_1_5_through_7_11
+ WHERE col1 > 1000
+) empty_sketches;
+
+-- Test empty aggregation: zero rows input for kll_merge_agg_float
+SELECT kll_sketch_get_n_float(kll_merge_agg_float(sketch_col)) AS empty_merge_n
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS sketch_col
+ FROM t_float_1_5_through_7_11
+ WHERE col1 > 1000.0
+) empty_sketches;
+
+-- Test empty aggregation: zero rows input for kll_merge_agg_double
+SELECT kll_sketch_get_n_double(kll_merge_agg_double(sketch_col)) AS empty_merge_n
+FROM (
+ SELECT kll_sketch_agg_double(col1) AS sketch_col
+ FROM t_double_1_5_through_7_11
+ WHERE col1 > 1000.0
+) empty_sketches;
+
+-- Test kll_merge_agg_bigint: merge bigint sketches from multiple rows
+SELECT lower(kll_sketch_to_string_bigint(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_bigint(agg, 0.5) - 4) < 1 AS median_close_to_4,
+ abs(kll_sketch_get_rank_bigint(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_merge_agg_bigint(sketch_col) AS agg
+ FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_int_1_5_through_7_11
+ UNION ALL
+ SELECT kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_short_1_5_through_7_11
+ ) sketches
+);
+
+-- Test kll_merge_agg_float: merge float sketches from multiple rows
+-- Merging col1 (1-7) and col2 (5-11) gives combined data with median ~5.5
+SELECT lower(kll_sketch_to_string_float(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_float(agg, 0.5) - 5.5) < 1.0 AS median_close_to_5_5,
+ abs(kll_sketch_get_rank_float(agg, 5.0) - 0.35) < 0.15 AS rank5_close_to_0_35
+FROM (
+ SELECT kll_merge_agg_float(sketch_col) AS agg
+ FROM (
+ SELECT kll_sketch_agg_float(col1) AS sketch_col
+ FROM t_float_1_5_through_7_11
+ UNION ALL
+ SELECT kll_sketch_agg_float(col2) AS sketch_col
+ FROM t_float_1_5_through_7_11
+ ) sketches
+);
+
+-- Test kll_merge_agg_double: merge double sketches from multiple rows
+SELECT lower(kll_sketch_to_string_double(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_double(agg, 0.5) - 6.0) < 1.0 AS median_close_to_6,
+ abs(kll_sketch_get_rank_double(agg, 5.0) - 0.35) < 0.15 AS rank5_close_to_0_35
+FROM (
+ SELECT kll_merge_agg_double(sketch_col) AS agg
+ FROM (
+ SELECT kll_sketch_agg_double(col1) AS sketch_col
+ FROM t_double_1_5_through_7_11
+ UNION ALL
+ SELECT kll_sketch_agg_double(col2) AS sketch_col
+ FROM t_float_1_5_through_7_11
+ ) sketches
+);
+
+-- Test kll_merge_agg_bigint with custom k parameter
+SELECT LENGTH(kll_sketch_to_string_bigint(kll_merge_agg_bigint(sketch_col, 400))) > 0 AS merged_with_k
+FROM (
+ SELECT kll_sketch_agg_bigint(col1, 400) AS sketch_col
+ FROM t_long_1_5_through_7_11
+ UNION ALL
+ SELECT kll_sketch_agg_bigint(col2, 400) AS sketch_col
+ FROM t_byte_1_5_through_7_11
+) sketches;
+
+-- Test kll_merge_agg_float with custom k parameter
+SELECT LENGTH(kll_sketch_to_string_float(kll_merge_agg_float(sketch_col, 300))) > 0 AS merged_with_k
+FROM (
+ SELECT kll_sketch_agg_float(col1, 300) AS sketch_col
+ FROM t_float_1_5_through_7_11
+) sketches;
+
+-- Test kll_merge_agg_double with custom k parameter
+SELECT LENGTH(kll_sketch_to_string_double(kll_merge_agg_double(sketch_col, 500))) > 0 AS merged_with_k
+FROM (
+ SELECT kll_sketch_agg_double(col1, 500) AS sketch_col
+ FROM t_double_1_5_through_7_11
+) sketches;
+
+-- Test that kll_merge_agg functions ignore NULL sketch values
+SELECT abs(kll_sketch_get_quantile_bigint(agg_with_nulls, 0.5) -
+ kll_sketch_get_quantile_bigint(agg_without_nulls, 0.5)) < 1 AS medians_match
+FROM (
+ SELECT kll_merge_agg_bigint(sketch_col) AS agg_with_nulls
+ FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_long_1_5_through_7_11
+ UNION ALL
+ SELECT CAST(NULL AS BINARY) AS sketch_col
+ UNION ALL
+ SELECT kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_byte_1_5_through_7_11
+ ) sketches_with_nulls
+) WITH_NULLS,
+(
+ SELECT kll_merge_agg_bigint(sketch_col) AS agg_without_nulls
+ FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_long_1_5_through_7_11
+ UNION ALL
+ SELECT kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_byte_1_5_through_7_11
+ ) sketches_without_nulls
+) WITHOUT_NULLS;
+
+-- Tests verifying that NULL input values are ignored by aggregate functions
+
+-- Test BIGINT aggregate ignores NULL values
+-- Verify that the sketch computed with NULLs matches the sketch without NULLs
+-- Both should compute median of [1, 3, 5, 7] which is 4
+-- Input data: 1, NULL, 3, 5, NULL, 7
+SELECT abs(kll_sketch_get_quantile_bigint(agg_with_nulls, 0.5) -
+ kll_sketch_get_quantile_bigint(agg_without_nulls, 0.5)) < 1 AS medians_match,
+ abs(kll_sketch_get_rank_bigint(agg_with_nulls, 4) -
+ kll_sketch_get_rank_bigint(agg_without_nulls, 4)) < 0.1 AS ranks_match
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg_with_nulls
+ FROM (VALUES (1L), (CAST(NULL AS BIGINT)), (3L), (5L), (CAST(NULL AS BIGINT)), (7L)) AS tab(col1)
+) WITH_NULLS,
+(
+ SELECT kll_sketch_agg_bigint(col1) AS agg_without_nulls
+ FROM (VALUES (1L), (3L), (5L), (7L)) AS tab(col1)
+) WITHOUT_NULLS;
+
+-- Test FLOAT aggregate ignores NULL values
+-- Verify that the sketch computed with NULLs matches the sketch without NULLs
+-- Input data: 1.0, NULL, 3.0, 5.0, NULL, 7.0
+SELECT abs(kll_sketch_get_quantile_float(agg_with_nulls, 0.5) -
+ kll_sketch_get_quantile_float(agg_without_nulls, 0.5)) < 0.5 AS medians_match,
+ abs(kll_sketch_get_rank_float(agg_with_nulls, 4.0) -
+ kll_sketch_get_rank_float(agg_without_nulls, 4.0)) < 0.1 AS ranks_match
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS agg_with_nulls
+ FROM (VALUES (1.0F), (CAST(NULL AS FLOAT)), (3.0F), (5.0F), (CAST(NULL AS FLOAT)), (7.0F)) AS tab(col1)
+) WITH_NULLS,
+(
+ SELECT kll_sketch_agg_float(col1) AS agg_without_nulls
+ FROM (VALUES (1.0F), (3.0F), (5.0F), (7.0F)) AS tab(col1)
+) WITHOUT_NULLS;
+
+-- Test DOUBLE aggregate ignores NULL values
+-- Verify that the sketch computed with NULLs matches the sketch without NULLs
+-- Input data: 1.0, NULL, 3.0, 5.0, NULL, 7.0
+SELECT abs(kll_sketch_get_quantile_double(agg_with_nulls, 0.5) -
+ kll_sketch_get_quantile_double(agg_without_nulls, 0.5)) < 0.5 AS medians_match,
+ abs(kll_sketch_get_rank_double(agg_with_nulls, 4.0) -
+ kll_sketch_get_rank_double(agg_without_nulls, 4.0)) < 0.1 AS ranks_match
+FROM (
+ SELECT kll_sketch_agg_double(col1) AS agg_with_nulls
+ FROM (VALUES (1.0D), (CAST(NULL AS DOUBLE)), (3.0D), (5.0D), (CAST(NULL AS DOUBLE)), (7.0D)) AS tab(col1)
+) WITH_NULLS,
+(
+ SELECT kll_sketch_agg_double(col1) AS agg_without_nulls
+ FROM (VALUES (1.0D), (3.0D), (5.0D), (7.0D)) AS tab(col1)
+) WITHOUT_NULLS;
+
+-- Tests covering NULLs
+-- NULL sketch to get_quantile
+SELECT kll_sketch_get_quantile_bigint(CAST(NULL AS BINARY), 0.5) AS null_sketch;
+
+-- NULL sketch to get_rank
+SELECT kll_sketch_get_rank_float(CAST(NULL AS BINARY), 5.0) AS null_sketch;
+
+-- Tests for the optional k parameter
+-- Positive tests with valid k values
+SELECT LENGTH(kll_sketch_to_string_bigint(kll_sketch_agg_bigint(col1, 8))) > 0 AS k_min_value
+FROM t_long_1_5_through_7_11;
+
+SELECT LENGTH(kll_sketch_to_string_bigint(kll_sketch_agg_bigint(col1, 200))) > 0 AS k_default_value
+FROM t_long_1_5_through_7_11;
+
+SELECT LENGTH(kll_sketch_to_string_bigint(kll_sketch_agg_bigint(col1, 400))) > 0 AS k_custom_value
+FROM t_long_1_5_through_7_11;
+
+SELECT LENGTH(kll_sketch_to_string_bigint(kll_sketch_agg_bigint(col1, 65535))) > 0 AS k_max_value
+FROM t_long_1_5_through_7_11;
+
+SELECT LENGTH(kll_sketch_to_string_float(kll_sketch_agg_float(col1, 100))) > 0 AS k_float_sketch
+FROM t_float_1_5_through_7_11;
+
+SELECT LENGTH(kll_sketch_to_string_double(kll_sketch_agg_double(col1, 300))) > 0 AS k_double_sketch
+FROM t_double_1_5_through_7_11;
+
+-- Tests for kll_sketch_get_n functions
+-- BIGINT sketches
+SELECT kll_sketch_get_n_bigint(kll_sketch_agg_bigint(col1)) AS n_bigint
+FROM t_long_1_5_through_7_11;
+
+SELECT kll_sketch_get_n_bigint(kll_sketch_agg_bigint(col1)) AS n_byte
+FROM t_byte_1_5_through_7_11;
+
+SELECT kll_sketch_get_n_bigint(kll_sketch_agg_bigint(col1)) AS n_short
+FROM t_short_1_5_through_7_11;
+
+SELECT kll_sketch_get_n_bigint(kll_sketch_agg_bigint(col1)) AS n_int
+FROM t_int_1_5_through_7_11;
+
+-- FLOAT sketches
+SELECT kll_sketch_get_n_float(kll_sketch_agg_float(col1)) AS n_float
+FROM t_float_1_5_through_7_11;
+
+-- DOUBLE sketches
+SELECT kll_sketch_get_n_double(kll_sketch_agg_double(col1)) AS n_double
+FROM t_double_1_5_through_7_11;
+
+-- Test with different k values
+SELECT kll_sketch_get_n_bigint(kll_sketch_agg_bigint(col1, 100)) AS n_k_100
+FROM t_long_1_5_through_7_11;
+
+-- Negative tests
+-- These queries should fail with type mismatch or validation errors
+
+-- Type mismatch: BIGINT sketch does not accept DOUBLE columns
+SELECT lower(kll_sketch_to_string_bigint(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_bigint(agg, 0.5) - 4) < 1 AS median_close_to_4,
+ abs(kll_sketch_get_rank_bigint(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_double_1_5_through_7_11
+);
+
+-- Type mismatch: BIGINT sketch does not accept FLOAT columns
+SELECT lower(kll_sketch_to_string_bigint(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_bigint(agg, 0.5) - 4) < 1 AS median_close_to_4,
+ abs(kll_sketch_get_rank_bigint(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_float_1_5_through_7_11
+);
+
+-- Type mismatch: FLOAT sketch does not accept DOUBLE columns
+SELECT lower(kll_sketch_to_string_float(agg)) LIKE '%kll%' AS str_contains_kll,
+ abs(kll_sketch_get_quantile_float(agg, 0.5) - 4.0) < 0.5 AS median_close_to_4,
+ abs(kll_sketch_get_rank_float(agg, 3) - 0.4) < 0.1 AS rank3_close_to_0_4
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS agg
+ FROM t_double_1_5_through_7_11
+);
+
+-- Type mismatch: FLOAT sketch does not accept integer types (BIGINT) to avoid precision loss
+SELECT kll_sketch_agg_float(col1) AS invalid_float_bigint
+FROM t_long_1_5_through_7_11;
+
+-- Type mismatch: FLOAT sketch does not accept integer types (INT) to avoid precision loss
+SELECT kll_sketch_agg_float(col1) AS invalid_float_int
+FROM t_int_1_5_through_7_11;
+
+-- Type mismatch: FLOAT sketch does not accept integer types (SMALLINT) to avoid precision loss
+SELECT kll_sketch_agg_float(col1) AS invalid_float_short
+FROM t_short_1_5_through_7_11;
+
+-- Type mismatch: FLOAT sketch does not accept integer types (TINYINT) to avoid precision loss
+SELECT kll_sketch_agg_float(col1) AS invalid_float_byte
+FROM t_byte_1_5_through_7_11;
+
+-- Type mismatch: DOUBLE sketch does not accept integer types (BIGINT) to avoid precision loss
+SELECT kll_sketch_agg_double(col1) AS invalid_double_bigint
+FROM t_long_1_5_through_7_11;
+
+-- Type mismatch: DOUBLE sketch does not accept integer types (INT) to avoid precision loss
+SELECT kll_sketch_agg_double(col1) AS invalid_double_int
+FROM t_int_1_5_through_7_11;
+
+-- Type mismatch: DOUBLE sketch does not accept integer types (SMALLINT) to avoid precision loss
+SELECT kll_sketch_agg_double(col1) AS invalid_double_short
+FROM t_short_1_5_through_7_11;
+
+-- Type mismatch: DOUBLE sketch does not accept integer types (TINYINT) to avoid precision loss
+SELECT kll_sketch_agg_double(col1) AS invalid_double_byte
+FROM t_byte_1_5_through_7_11;
+
+-- Invalid quantile: quantile value must be between 0 and 1 (negative value)
+SELECT kll_sketch_get_quantile_bigint(agg, -0.5) AS invalid_quantile
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_long_1_5_through_7_11
+);
+
+-- Invalid quantile: quantile value must be between 0 and 1 (value > 1)
+SELECT kll_sketch_get_quantile_bigint(agg, 1.5) AS invalid_quantile
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_long_1_5_through_7_11
+);
+
+-- Invalid quantile: quantile array with out of range values
+SELECT kll_sketch_get_quantile_float(agg, array(-0.1, 0.5, 1.5)) AS invalid_quantiles
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS agg
+ FROM t_float_1_5_through_7_11
+);
+
+-- Type mismatch: wrong sketch type for get_rank function
+SELECT kll_sketch_get_rank_bigint(agg, 5) AS wrong_type
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS agg
+ FROM t_float_1_5_through_7_11
+);
+
+-- Type mismatch: incompatible sketches in merge (BIGINT and FLOAT)
+SELECT kll_sketch_merge_bigint(agg1, agg2) AS incompatible_merge
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg1,
+ kll_sketch_agg_float(CAST(col1 AS FLOAT)) AS agg2
+ FROM t_long_1_5_through_7_11
+);
+
+-- Invalid input: non-sketch binary data to get_quantile
+SELECT kll_sketch_get_quantile_bigint(CAST('not_a_sketch' AS BINARY), 0.5) AS invalid_binary;
+
+-- Note: get_quantile functions cannot detect sketch type mismatches at the binary level.
+-- This query succeeds even though we're using a FLOAT get_quantile on a BIGINT sketch,
+-- but it returns garbage values because it interprets the BIGINT binary data as FLOAT data.
+SELECT kll_sketch_get_quantile_float(agg, 0.5) IS NOT NULL AS returns_value
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_long_1_5_through_7_11
+);
+
+-- Note: to_string functions cannot detect sketch type mismatches because they just
+-- interpret the binary data. This query succeeds even though we're using a DOUBLE
+-- to_string function on a BIGINT sketch. The function reads the binary representation
+-- and produces output, but the numeric values will be incorrectly interpreted.
+SELECT lower(kll_sketch_to_string_double(agg)) LIKE '%kll%' AS contains_kll_header
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_long_1_5_through_7_11
+);
+
+-- Negative tests for k parameter
+-- k parameter too small (minimum is 8)
+SELECT kll_sketch_agg_bigint(col1, 7) AS k_too_small
+FROM t_long_1_5_through_7_11;
+
+-- k parameter too large (maximum is 65535)
+SELECT kll_sketch_agg_bigint(col1, 65536) AS k_too_large
+FROM t_long_1_5_through_7_11;
+
+-- k parameter is NULL
+SELECT kll_sketch_agg_float(col1, CAST(NULL AS INT)) AS k_is_null
+FROM t_float_1_5_through_7_11;
+
+-- k parameter is not foldable (non-constant)
+SELECT kll_sketch_agg_double(col1, CAST(col1 AS INT)) AS k_non_constant
+FROM t_double_1_5_through_7_11;
+
+-- k parameter has wrong type (STRING instead of INT)
+SELECT kll_sketch_agg_bigint(col1, '100') AS k_wrong_type
+FROM t_long_1_5_through_7_11;
+
+-- Negative tests for kll_merge_agg functions
+
+-- Test wrong sketch type: float sketch passed to kll_merge_agg_bigint (should fail)
+SELECT kll_merge_agg_bigint(sketch_col) AS wrong_type_merge
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS sketch_col
+ FROM t_float_1_5_through_7_11
+) float_sketches;
+
+-- Type mismatch: kll_merge_agg_bigint does not accept integer columns (needs binary)
+SELECT kll_merge_agg_bigint(col1) AS merge_wrong_type
+FROM t_long_1_5_through_7_11;
+
+-- Type mismatch: kll_merge_agg_float does not accept float columns (needs binary)
+SELECT kll_merge_agg_float(col1) AS merge_wrong_type
+FROM t_float_1_5_through_7_11;
+
+-- Type mismatch: kll_merge_agg_double does not accept double columns (needs binary)
+SELECT kll_merge_agg_double(col1) AS merge_wrong_type
+FROM t_double_1_5_through_7_11;
+
+-- Invalid binary data for kll_merge_agg_bigint
+SELECT kll_merge_agg_bigint(sketch_col) AS invalid_merge
+FROM (
+ SELECT CAST('not_a_sketch' AS BINARY) AS sketch_col
+) invalid_data;
+
+-- Invalid binary data for kll_merge_agg_float
+SELECT kll_merge_agg_float(sketch_col) AS invalid_merge
+FROM (
+ SELECT X'deadbeef' AS sketch_col
+) invalid_data;
+
+-- Invalid binary data for kll_merge_agg_double
+SELECT kll_merge_agg_double(sketch_col) AS invalid_merge
+FROM (
+ SELECT X'cafebabe' AS sketch_col
+) invalid_data;
+
+-- k parameter too small for kll_merge_agg_bigint
+SELECT kll_merge_agg_bigint(sketch_col, 7) AS k_too_small
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_long_1_5_through_7_11
+) sketches;
+
+-- k parameter too large for kll_merge_agg_float
+SELECT kll_merge_agg_float(sketch_col, 65536) AS k_too_large
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS sketch_col
+ FROM t_float_1_5_through_7_11
+) sketches;
+
+-- k parameter is NULL for kll_merge_agg_double
+SELECT kll_merge_agg_double(sketch_col, CAST(NULL AS INT)) AS k_is_null
+FROM (
+ SELECT kll_sketch_agg_double(col1) AS sketch_col
+ FROM t_double_1_5_through_7_11
+) sketches;
+
+-- k parameter is not foldable for kll_merge_agg_bigint (using a non-constant expression)
+SELECT kll_merge_agg_bigint(sketch_col, CAST(RAND() * 100 AS INT) + 200) AS k_non_constant
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS sketch_col
+ FROM t_long_1_5_through_7_11
+) sketches;
+
+-- Negative tests for kll_sketch_get_n functions
+-- Invalid binary data
+SELECT kll_sketch_get_n_bigint(X'deadbeef') AS invalid_binary_bigint;
+
+SELECT kll_sketch_get_n_float(X'cafebabe') AS invalid_binary_float;
+
+SELECT kll_sketch_get_n_double(X'12345678') AS invalid_binary_double;
+
+-- Wrong argument types
+SELECT kll_sketch_get_n_bigint(42) AS wrong_argument_type;
+
+SELECT kll_sketch_get_n_float(42.0) AS wrong_argument_type;
+
+SELECT kll_sketch_get_n_double(42.0D) AS wrong_argument_type;
+
+-- Negative tests for kll_sketch_get_quantile functions with invalid second argument types
+-- Invalid type: STRING instead of DOUBLE for quantile parameter
+SELECT kll_sketch_get_quantile_bigint(agg, 'invalid') AS quantile_string
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_long_1_5_through_7_11
+);
+
+-- Invalid type: BINARY instead of DOUBLE for quantile parameter
+SELECT kll_sketch_get_quantile_float(agg, X'deadbeef') AS quantile_binary
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS agg
+ FROM t_float_1_5_through_7_11
+);
+
+-- Invalid type: BOOLEAN instead of DOUBLE for quantile parameter
+SELECT kll_sketch_get_quantile_double(agg, true) AS quantile_boolean
+FROM (
+ SELECT kll_sketch_agg_double(col1) AS agg
+ FROM t_double_1_5_through_7_11
+);
+
+-- Negative tests for kll_sketch_get_rank functions with invalid second argument types
+-- Invalid type: STRING instead of BIGINT for rank value parameter
+SELECT kll_sketch_get_rank_bigint(agg, 'invalid') AS rank_string
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg
+ FROM t_long_1_5_through_7_11
+);
+
+-- Invalid type: BINARY instead of FLOAT for rank value parameter
+SELECT kll_sketch_get_rank_float(agg, X'cafebabe') AS rank_binary
+FROM (
+ SELECT kll_sketch_agg_float(col1) AS agg
+ FROM t_float_1_5_through_7_11
+);
+
+-- Invalid type: BOOLEAN instead of DOUBLE for rank value parameter
+SELECT kll_sketch_get_rank_double(agg, false) AS rank_boolean
+FROM (
+ SELECT kll_sketch_agg_double(col1) AS agg
+ FROM t_double_1_5_through_7_11
+);
+
+-- Negative tests for non-foldable (non-constant) rank/quantile arguments
+-- These tests verify that get_quantile and get_rank functions require compile-time constant arguments
+
+-- Non-foldable scalar rank argument to get_quantile (column reference)
+SELECT kll_sketch_get_quantile_bigint(agg, CAST(col1 AS DOUBLE) / 10.0) AS non_foldable_scalar_rank
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+ FROM t_long_1_5_through_7_11
+ GROUP BY col1
+);
+
+-- Non-foldable array rank argument to get_quantile (array containing column reference)
+SELECT kll_sketch_get_quantile_bigint(agg, array(0.25, CAST(col1 AS DOUBLE) / 10.0, 0.75)) AS non_foldable_array_rank
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+ FROM t_long_1_5_through_7_11
+ GROUP BY col1
+);
+
+-- Non-foldable scalar quantile argument to get_rank (column reference)
+SELECT kll_sketch_get_rank_bigint(agg, col1) AS non_foldable_scalar_quantile
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+ FROM t_long_1_5_through_7_11
+ GROUP BY col1
+);
+
+-- Non-foldable array quantile argument to get_rank (array containing column reference)
+SELECT kll_sketch_get_rank_bigint(agg, array(1L, col1, 5L)) AS non_foldable_array_quantile
+FROM (
+ SELECT kll_sketch_agg_bigint(col1) AS agg, col1
+ FROM t_long_1_5_through_7_11
+ GROUP BY col1
+);
+
+-- Clean up
+DROP TABLE IF EXISTS t_int_1_5_through_7_11;
+DROP TABLE IF EXISTS t_long_1_5_through_7_11;
+DROP TABLE IF EXISTS t_short_1_5_through_7_11;
+DROP TABLE IF EXISTS t_byte_1_5_through_7_11;
+DROP TABLE IF EXISTS t_float_1_5_through_7_11;
+DROP TABLE IF EXISTS t_double_1_5_through_7_11;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/st-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/st-functions.sql
index dc688e4a89941..70fbdef533303 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/st-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/st-functions.sql
@@ -13,6 +13,52 @@ INSERT INTO geodata VALUES
SELECT CAST(ST_GeogFromWKB(X'0101000000000000000000f03f0000000000000040') AS STRING) AS result;
SELECT CAST(X'0101000000000000000000f03f0000000000000040' AS GEOMETRY(4326)) AS result;
+-- Casting GEOGRAPHY() to GEOGRAPHY(ANY) is allowed.
+SELECT hex(ST_AsBinary(CAST(ST_GeogFromWKB(X'0101000000000000000000f03f0000000000000040') AS GEOGRAPHY(ANY)))) AS result;
+-- Casting GEOGRAPHY(ANY) to GEOGRAPHY() is not allowed.
+SELECT CAST(ST_GeogFromWKB(X'0101000000000000000000f03f0000000000000040')::GEOGRAPHY(ANY) AS GEOGRAPHY(4326)) AS result;
+
+-- Casting GEOGRAPHY to GEOMETRY is allowed only if SRIDs match.
+SELECT hex(ST_AsBinary(CAST(ST_GeogFromWKB(X'0101000000000000000000f03f0000000000000040') AS GEOMETRY(4326)))) AS result;
+-- Error handling: mismatched SRIDs.
+SELECT CAST(ST_GeogFromWKB(X'0101000000000000000000f03f0000000000000040') AS GEOMETRY(ANY)) AS result;
+
+-- Casting GEOMETRY() to GEOMETRY(ANY) is allowed.
+SELECT hex(ST_AsBinary(CAST(ST_GeomFromWKB(X'0101000000000000000000f03f0000000000000040') AS GEOMETRY(ANY)))) AS result;
+-- Casting GEOMETRY(ANY) to GEOMETRY() is not allowed.
+SELECT CAST(ST_GeomFromWKB(X'0101000000000000000000f03f0000000000000040')::GEOMETRY(ANY) AS GEOMETRY(4326)) AS result;
+
+---- Geospatial type coercion
+
+-- Array
+SELECT typeof(array(ST_GeogFromWKB(wkb), ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY))) FROM geodata;
+SELECT typeof(array(ST_GeomFromWKB(wkb), ST_GeomFromWKB(wkb)::GEOMETRY(ANY))) FROM geodata;
+-- Map
+SELECT typeof(map('a', ST_GeogFromWKB(wkb), 'b', ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY))) FROM geodata;
+SELECT typeof(map('a', ST_GeomFromWKB(wkb), 'b', ST_GeomFromWKB(wkb)::GEOMETRY(ANY))) FROM geodata;
+-- Struct
+SELECT typeof(array(named_struct('g1', ST_GeogFromWKB(wkb), 'g2', ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY)), named_struct('g1', ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY), 'g2', ST_GeogFromWKB(wkb)))) FROM geodata;
+SELECT typeof(array(named_struct('g1', ST_GeomFromWKB(wkb), 'g2', ST_GeomFromWKB(wkb)::GEOMETRY(ANY)), named_struct('g1', ST_GeomFromWKB(wkb)::GEOMETRY(ANY), 'g2', ST_GeomFromWKB(wkb)))) FROM geodata;
+-- Nested
+SELECT typeof(named_struct('a', array(ST_GeogFromWKB(wkb), ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY)), 'b', map('g', ST_GeogFromWKB(wkb), 'h', ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY)))) FROM geodata;
+SELECT typeof(named_struct('a', array(ST_GeomFromWKB(wkb), ST_GeomFromWKB(wkb)::GEOMETRY(ANY)), 'b', map('g', ST_GeomFromWKB(wkb), 'h', ST_GeomFromWKB(wkb)::GEOMETRY(ANY)))) FROM geodata;
+
+-- NVL
+SELECT typeof(nvl(ST_GeogFromWKB(wkb), ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY))) FROM geodata;
+SELECT typeof(nvl(ST_GeomFromWKB(wkb), ST_GeomFromWKB(wkb)::GEOMETRY(ANY))) FROM geodata;
+-- NVL2
+SELECT typeof(nvl2(ST_GeogFromWKB(wkb), ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY), ST_GeogFromWKB(wkb))) FROM geodata;
+SELECT typeof(nvl2(ST_GeomFromWKB(wkb), ST_GeomFromWKB(wkb)::GEOMETRY(ANY), ST_GeomFromWKB(wkb))) FROM geodata;
+-- CASE WHEN
+SELECT typeof(CASE WHEN wkb IS NOT NULL THEN ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY) ELSE ST_GeogFromWKB(wkb) END) FROM geodata;
+SELECT typeof(CASE WHEN wkb IS NOT NULL THEN ST_GeomFromWKB(wkb)::GEOMETRY(ANY) ELSE ST_GeomFromWKB(wkb) END) FROM geodata;
+-- COALESCE
+SELECT typeof(coalesce(ST_GeogFromWKB(wkb), ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY))) FROM geodata;
+SELECT typeof(coalesce(ST_GeomFromWKB(wkb), ST_GeomFromWKB(wkb)::GEOMETRY(ANY))) FROM geodata;
+-- IF
+SELECT typeof(IF(wkb IS NOT NULL, ST_GeogFromWKB(wkb)::GEOGRAPHY(ANY), ST_GeogFromWKB(wkb))) FROM geodata;
+SELECT typeof(IF(wkb IS NOT NULL, ST_GeomFromWKB(wkb)::GEOMETRY(ANY), ST_GeomFromWKB(wkb))) FROM geodata;
+
---- ST reader/writer expressions
-- WKB (Well-Known Binary) round-trip tests for GEOGRAPHY and GEOMETRY types.
@@ -32,5 +78,23 @@ SELECT ST_Srid(ST_GeomFromWKB(X'0101000000000000000000F03F0000000000000040'));
SELECT COUNT(*) FROM geodata WHERE ST_Srid(ST_GeogFromWKB(wkb)) <> 4326;
SELECT COUNT(*) FROM geodata WHERE ST_Srid(ST_GeomFromWKB(wkb)) <> 0;
+------ ST modifier expressions
+
+---- ST_SetSrid
+
+-- 1. Driver-level queries.
+SELECT ST_Srid(ST_SetSrid(ST_GeogFromWKB(X'0101000000000000000000F03F0000000000000040'), 4326));
+SELECT ST_Srid(ST_SetSrid(ST_GeomFromWKB(X'0101000000000000000000F03F0000000000000040'), 3857));
+-- Error handling: invalid SRID.
+SELECT ST_Srid(ST_SetSrid(ST_GeogFromWKB(X'0101000000000000000000F03F0000000000000040'), 3857));
+SELECT ST_Srid(ST_SetSrid(ST_GeomFromWKB(X'0101000000000000000000F03F0000000000000040'), 9999));
+
+-- 2. Table-level queries.
+SELECT COUNT(*) FROM geodata WHERE ST_Srid(ST_SetSrid(ST_GeogFromWKB(wkb), 4326)) <> 4326;
+SELECT COUNT(*) FROM geodata WHERE ST_Srid(ST_SetSrid(ST_GeomFromWKB(wkb), 3857)) <> 3857;
+-- Error handling: invalid SRID.
+SELECT COUNT(*) FROM geodata WHERE ST_Srid(ST_SetSrid(ST_GeogFromWKB(wkb), 3857)) IS NOT NULL;
+SELECT COUNT(*) FROM geodata WHERE ST_Srid(ST_SetSrid(ST_GeomFromWKB(wkb), 9999)) IS NOT NULL;
+
-- Drop the test table.
DROP TABLE geodata;
diff --git a/sql/core/src/test/resources/sql-tests/inputs/thetasketch.sql b/sql/core/src/test/resources/sql-tests/inputs/thetasketch.sql
index d270442b50499..4782d2017f2a6 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/thetasketch.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/thetasketch.sql
@@ -457,6 +457,18 @@ FROM (SELECT theta_sketch_agg(col, 12) as sketch
SELECT theta_sketch_agg(col, 20) as sketch
FROM VALUES (1) AS tab(col));
+-- lgNomEntries parameter is NULL
+SELECT theta_sketch_agg(col, CAST(NULL AS INT)) AS lg_nom_entries_is_null
+FROM VALUES (15), (16), (17) tab(col);
+
+-- lgNomEntries parameter is not foldable (non-constant)
+SELECT theta_sketch_agg(col, CAST(col AS INT)) AS lg_nom_entries_non_constant
+FROM VALUES (15), (16), (17) tab(col);
+
+-- lgNomEntries parameter has wrong type (STRING instead of INT)
+SELECT theta_sketch_agg(col, '15')
+FROM VALUES (50), (60), (60) tab(col);
+
-- Test theta_union with integers (1, 2) instead of binary sketch data - should fail
SELECT theta_union(1, 2)
FROM VALUES
diff --git a/sql/core/src/test/resources/sql-tests/results/execute-immediate.sql.out b/sql/core/src/test/resources/sql-tests/results/execute-immediate.sql.out
index 06adf4435046a..dd1207b4f2be5 100644
--- a/sql/core/src/test/resources/sql-tests/results/execute-immediate.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/execute-immediate.sql.out
@@ -1211,3 +1211,47 @@ EXECUTE IMMEDIATE 'SELECT typeof(:p) as type, :p as val' USING MAP(1, 'one', 2,
struct>
-- !query output
map {1:"one",2:"two"}
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT :param'
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "UNBOUND_SQL_PARAMETER",
+ "sqlState" : "42P02",
+ "messageParameters" : {
+ "name" : "param"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 13,
+ "fragment" : ":param"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT ?'
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "UNBOUND_SQL_PARAMETER",
+ "sqlState" : "42P02",
+ "messageParameters" : {
+ "name" : "_7"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 8,
+ "fragment" : "?"
+ } ]
+}
diff --git a/sql/core/src/test/resources/sql-tests/results/hll.sql.out b/sql/core/src/test/resources/sql-tests/results/hll.sql.out
index ecdfcbcc791a3..908221f0e7c40 100644
--- a/sql/core/src/test/resources/sql-tests/results/hll.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/hll.sql.out
@@ -256,6 +256,68 @@ org.apache.spark.SparkRuntimeException
}
+-- !query
+SELECT hll_sketch_agg(col, CAST(NULL AS INT)) AS k_is_null
+FROM VALUES (15), (16), (17) tab(col)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkRuntimeException
+{
+ "errorClass" : "HLL_INVALID_LG_K",
+ "sqlState" : "22546",
+ "messageParameters" : {
+ "function" : "`hll_sketch_agg`",
+ "max" : "21",
+ "min" : "4",
+ "value" : "0"
+ }
+}
+
+
+-- !query
+SELECT hll_sketch_agg(col, CAST(col AS INT)) AS k_non_constant
+FROM VALUES (15), (16), (17) tab(col)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkRuntimeException
+{
+ "errorClass" : "HLL_K_MUST_BE_CONSTANT",
+ "sqlState" : "42K0E",
+ "messageParameters" : {
+ "function" : "`hll_sketch_agg`"
+ }
+}
+
+
+-- !query
+SELECT hll_sketch_agg(col, '15')
+FROM VALUES (50), (60), (60) tab(col)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ "sqlState" : "42K09",
+ "messageParameters" : {
+ "inputSql" : "\"15\"",
+ "inputType" : "\"STRING\"",
+ "paramIndex" : "second",
+ "requiredType" : "\"INT\"",
+ "sqlExpr" : "\"hll_sketch_agg(col, 15)\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 32,
+ "fragment" : "hll_sketch_agg(col, '15')"
+ } ]
+}
+
+
-- !query
SELECT hll_union(
hll_sketch_agg(col1, 12),
diff --git a/sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out b/sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out
new file mode 100644
index 0000000000000..6a99be0570100
--- /dev/null
+++ b/sql/core/src/test/resources/sql-tests/results/identifier-clause-legacy.sql.out
@@ -0,0 +1,2923 @@
+-- Automatically generated by SQLQueryTestSuite
+-- !query
+SET hivevar:colname = 'c'
+-- !query schema
+struct
+-- !query output
+hivevar:colname 'c'
+
+
+-- !query
+SELECT IDENTIFIER(${colname} || '_1') FROM VALUES(1) AS T(c_1)
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+SELECT IDENTIFIER('c1') FROM VALUES(1) AS T(c1)
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+SELECT IDENTIFIER('t.c1') FROM VALUES(1) AS T(c1)
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+SELECT IDENTIFIER('`t`.c1') FROM VALUES(1) AS T(c1)
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+SELECT IDENTIFIER('`c 1`') FROM VALUES(1) AS T(`c 1`)
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+SELECT IDENTIFIER('``') FROM VALUES(1) AS T(``)
+-- !query schema
+struct<:int>
+-- !query output
+1
+
+
+-- !query
+SELECT IDENTIFIER('c' || '1') FROM VALUES(1) AS T(c1)
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+CREATE SCHEMA IF NOT EXISTS s
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE TABLE s.tab(c1 INT) USING CSV
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+USE SCHEMA s
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+INSERT INTO IDENTIFIER('ta' || 'b') VALUES(1)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DELETE FROM IDENTIFIER('ta' || 'b') WHERE 1=0
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "UNSUPPORTED_FEATURE.TABLE_OPERATION",
+ "sqlState" : "0A000",
+ "messageParameters" : {
+ "operation" : "DELETE",
+ "tableName" : "`spark_catalog`.`s`.`tab`"
+ }
+}
+
+
+-- !query
+UPDATE IDENTIFIER('ta' || 'b') SET c1 = 2
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkUnsupportedOperationException
+{
+ "errorClass" : "UNSUPPORTED_FEATURE.TABLE_OPERATION",
+ "sqlState" : "0A000",
+ "messageParameters" : {
+ "operation" : "UPDATE TABLE",
+ "tableName" : "`spark_catalog`.`s`.`tab`"
+ }
+}
+
+
+-- !query
+MERGE INTO IDENTIFIER('ta' || 'b') AS t USING IDENTIFIER('ta' || 'b') AS s ON s.c1 = t.c1
+ WHEN MATCHED THEN UPDATE SET c1 = 3
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.SparkUnsupportedOperationException
+{
+ "errorClass" : "UNSUPPORTED_FEATURE.TABLE_OPERATION",
+ "sqlState" : "0A000",
+ "messageParameters" : {
+ "operation" : "MERGE INTO TABLE",
+ "tableName" : "`spark_catalog`.`s`.`tab`"
+ }
+}
+
+
+-- !query
+SELECT * FROM IDENTIFIER('tab')
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+SELECT * FROM IDENTIFIER('s.tab')
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+SELECT * FROM IDENTIFIER('`s`.`tab`')
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+SELECT * FROM IDENTIFIER('t' || 'a' || 'b')
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+USE SCHEMA default
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP TABLE s.tab
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP SCHEMA s
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT IDENTIFIER('COAL' || 'ESCE')(NULL, 1)
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+SELECT IDENTIFIER('abs')(c1) FROM VALUES(-1) AS T(c1)
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+SELECT * FROM IDENTIFIER('ra' || 'nge')(0, 1)
+-- !query schema
+struct
+-- !query output
+0
+
+
+-- !query
+CREATE TABLE IDENTIFIER('tab')(c1 INT) USING CSV
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP TABLE IF EXISTS IDENTIFIER('ta' || 'b')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE SCHEMA identifier_clauses
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+USE identifier_clauses
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE TABLE IDENTIFIER('ta' || 'b')(c1 INT) USING CSV
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP TABLE IF EXISTS IDENTIFIER('identifier_clauses.' || 'tab')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE TABLE IDENTIFIER('identifier_clauses.' || 'tab')(c1 INT) USING CSV
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+REPLACE TABLE IDENTIFIER('identifier_clauses.' || 'tab')(c1 INT) USING CSV
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "UNSUPPORTED_FEATURE.TABLE_OPERATION",
+ "sqlState" : "0A000",
+ "messageParameters" : {
+ "operation" : "REPLACE TABLE",
+ "tableName" : "`spark_catalog`.`identifier_clauses`.`tab`"
+ }
+}
+
+
+-- !query
+CACHE TABLE IDENTIFIER('ta' || 'b')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+UNCACHE TABLE IDENTIFIER('ta' || 'b')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP TABLE IF EXISTS IDENTIFIER('ta' || 'b')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+USE default
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP SCHEMA identifier_clauses
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE TABLE tab(c1 INT) USING CSV
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+INSERT INTO tab VALUES (1)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT c1 FROM tab
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+DESCRIBE IDENTIFIER('ta' || 'b')
+-- !query schema
+struct
+-- !query output
+c1 int
+
+
+-- !query
+ANALYZE TABLE IDENTIFIER('ta' || 'b') COMPUTE STATISTICS
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+ALTER TABLE IDENTIFIER('ta' || 'b') ADD COLUMN c2 INT
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SHOW TBLPROPERTIES IDENTIFIER('ta' || 'b')
+-- !query schema
+struct
+-- !query output
+
+
+
+-- !query
+SHOW COLUMNS FROM IDENTIFIER('ta' || 'b')
+-- !query schema
+struct
+-- !query output
+c1
+c2
+
+
+-- !query
+COMMENT ON TABLE IDENTIFIER('ta' || 'b') IS 'hello'
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+REFRESH TABLE IDENTIFIER('ta' || 'b')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+REPAIR TABLE IDENTIFIER('ta' || 'b')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_PARTITIONED_TABLE",
+ "sqlState" : "42809",
+ "messageParameters" : {
+ "operation" : "MSCK REPAIR TABLE",
+ "tableIdentWithDB" : "`spark_catalog`.`default`.`tab`"
+ }
+}
+
+
+-- !query
+TRUNCATE TABLE IDENTIFIER('ta' || 'b')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP TABLE IF EXISTS tab
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE OR REPLACE VIEW IDENTIFIER('v')(c1) AS VALUES(1)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT * FROM v
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+ALTER VIEW IDENTIFIER('v') AS VALUES(2)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP VIEW IDENTIFIER('v')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE TEMPORARY VIEW IDENTIFIER('v')(c1) AS VALUES(1)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP VIEW IDENTIFIER('v')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE SCHEMA IDENTIFIER('id' || 'ent')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+ALTER SCHEMA IDENTIFIER('id' || 'ent') SET PROPERTIES (somekey = 'somevalue')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+ALTER SCHEMA IDENTIFIER('id' || 'ent') SET LOCATION 'someloc'
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+COMMENT ON SCHEMA IDENTIFIER('id' || 'ent') IS 'some comment'
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DESCRIBE SCHEMA IDENTIFIER('id' || 'ent')
+-- !query schema
+struct
+-- !query output
+Catalog Name spark_catalog
+Comment some comment
+Location [not included in comparison]/{warehouse_dir}/someloc
+Namespace Name ident
+Owner [not included in comparison]
+
+
+-- !query
+SHOW TABLES IN IDENTIFIER('id' || 'ent')
+-- !query schema
+struct
+-- !query output
+
+
+
+-- !query
+SHOW TABLE EXTENDED IN IDENTIFIER('id' || 'ent') LIKE 'hello'
+-- !query schema
+struct
+-- !query output
+
+
+
+-- !query
+USE IDENTIFIER('id' || 'ent')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SHOW CURRENT SCHEMA
+-- !query schema
+struct
+-- !query output
+spark_catalog ident
+
+
+-- !query
+USE SCHEMA IDENTIFIER('id' || 'ent')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+USE SCHEMA default
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP SCHEMA IDENTIFIER('id' || 'ent')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE SCHEMA ident
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE FUNCTION IDENTIFIER('ident.' || 'myDoubleAvg') AS 'test.org.apache.spark.sql.MyDoubleAvg'
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DESCRIBE FUNCTION IDENTIFIER('ident.' || 'myDoubleAvg')
+-- !query schema
+struct
+-- !query output
+Class: test.org.apache.spark.sql.MyDoubleAvg
+Function: spark_catalog.ident.mydoubleavg
+Usage: N/A.
+
+
+-- !query
+REFRESH FUNCTION IDENTIFIER('ident.' || 'myDoubleAvg')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP FUNCTION IDENTIFIER('ident.' || 'myDoubleAvg')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP SCHEMA ident
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE TEMPORARY FUNCTION IDENTIFIER('my' || 'DoubleAvg') AS 'test.org.apache.spark.sql.MyDoubleAvg'
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP TEMPORARY FUNCTION IDENTIFIER('my' || 'DoubleAvg')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DECLARE var = 'sometable'
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE TABLE IDENTIFIER(var)(c1 INT) USING CSV
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SET VAR var = 'c1'
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT IDENTIFIER(var) FROM VALUES(1) AS T(c1)
+-- !query schema
+struct
+-- !query output
+1
+
+
+-- !query
+SET VAR var = 'some'
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP TABLE IDENTIFIER(var || 'table')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT IDENTIFIER('c 1') FROM VALUES(1) AS T(`c 1`)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'1'",
+ "hint" : ": extra input '1'"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 24,
+ "fragment" : "IDENTIFIER('c 1')"
+ } ]
+}
+
+
+-- !query
+SELECT IDENTIFIER('') FROM VALUES(1) AS T(``)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_EMPTY_STATEMENT",
+ "sqlState" : "42617",
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 21,
+ "fragment" : "IDENTIFIER('')"
+ } ]
+}
+
+
+-- !query
+VALUES(IDENTIFIER(CAST(NULL AS STRING)))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NULL",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "CAST(NULL AS STRING)",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 19,
+ "stopIndex" : 38,
+ "fragment" : "CAST(NULL AS STRING)"
+ } ]
+}
+
+
+-- !query
+VALUES(IDENTIFIER(1))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.WRONG_TYPE",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "dataType" : "int",
+ "expr" : "1",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 19,
+ "stopIndex" : 19,
+ "fragment" : "1"
+ } ]
+}
+
+
+-- !query
+VALUES(IDENTIFIER(SUBSTR('HELLO', 1, RAND() + 1)))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NOT_CONSTANT",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "substr('HELLO', 1, CAST((rand() + CAST(1 AS DOUBLE)) AS INT))",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 19,
+ "stopIndex" : 48,
+ "fragment" : "SUBSTR('HELLO', 1, RAND() + 1)"
+ } ]
+}
+
+
+-- !query
+SELECT `IDENTIFIER`('abs')(c1) FROM VALUES(-1) AS T(c1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "UNRESOLVED_ROUTINE",
+ "sqlState" : "42883",
+ "messageParameters" : {
+ "routineName" : "`IDENTIFIER`",
+ "searchPath" : "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`default`]"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 26,
+ "fragment" : "`IDENTIFIER`('abs')"
+ } ]
+}
+
+
+-- !query
+CREATE TABLE t(col1 INT)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT * FROM IDENTIFIER((SELECT 't'))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NOT_CONSTANT",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "scalarsubquery()",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 26,
+ "stopIndex" : 37,
+ "fragment" : "(SELECT 't')"
+ } ]
+}
+
+
+-- !query
+SELECT * FROM (SELECT IDENTIFIER((SELECT 'col1')) FROM IDENTIFIER((SELECT 't')))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NOT_CONSTANT",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "scalarsubquery()",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 67,
+ "stopIndex" : 78,
+ "fragment" : "(SELECT 't')"
+ } ]
+}
+
+
+-- !query
+SELECT IDENTIFIER((SELECT 'col1')) FROM VALUES(1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NOT_CONSTANT",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "scalarsubquery()",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 19,
+ "stopIndex" : 33,
+ "fragment" : "(SELECT 'col1')"
+ } ]
+}
+
+
+-- !query
+SELECT col1, IDENTIFIER((SELECT col1)) FROM VALUES(1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NOT_CONSTANT",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "scalarsubquery(col1)",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 25,
+ "stopIndex" : 37,
+ "fragment" : "(SELECT col1)"
+ } ]
+}
+
+
+-- !query
+SELECT IDENTIFIER((SELECT 'col1', 'col2')) FROM VALUES(1,2)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "UNSUPPORTED_TYPED_LITERAL",
+ "sqlState" : "0A000",
+ "messageParameters" : {
+ "supportedTypes" : "\"DATE\", \"TIMESTAMP_NTZ\", \"TIMESTAMP_LTZ\", \"TIMESTAMP\", \"INTERVAL\", \"X\", \"TIME\"",
+ "unsupportedType" : "\"SELECT\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 20,
+ "stopIndex" : 32,
+ "fragment" : "SELECT 'col1'"
+ } ]
+}
+
+
+-- !query
+DROP TABLE t
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE TABLE IDENTIFIER(1)(c1 INT) USING csv
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.WRONG_TYPE",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "dataType" : "int",
+ "expr" : "1",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 25,
+ "stopIndex" : 25,
+ "fragment" : "1"
+ } ]
+}
+
+
+-- !query
+CREATE TABLE IDENTIFIER('a.b.c')(c1 INT) USING csv
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
+ "sqlState" : "42K05",
+ "messageParameters" : {
+ "namespace" : "`a`.`b`",
+ "sessionCatalog" : "spark_catalog"
+ }
+}
+
+
+-- !query
+CREATE VIEW IDENTIFIER('a.b.c')(c1) AS VALUES(1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
+ "sqlState" : "42K05",
+ "messageParameters" : {
+ "namespace" : "`a`.`b`",
+ "sessionCatalog" : "spark_catalog"
+ }
+}
+
+
+-- !query
+DROP TABLE IDENTIFIER('a.b.c')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
+ "sqlState" : "42K05",
+ "messageParameters" : {
+ "namespace" : "`a`.`b`",
+ "sessionCatalog" : "spark_catalog"
+ }
+}
+
+
+-- !query
+DROP VIEW IDENTIFIER('a.b.c')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
+ "sqlState" : "42K05",
+ "messageParameters" : {
+ "namespace" : "`a`.`b`",
+ "sessionCatalog" : "spark_catalog"
+ }
+}
+
+
+-- !query
+COMMENT ON TABLE IDENTIFIER('a.b.c.d') IS 'hello'
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "REQUIRES_SINGLE_PART_NAMESPACE",
+ "sqlState" : "42K05",
+ "messageParameters" : {
+ "namespace" : "`a`.`b`.`c`",
+ "sessionCatalog" : "spark_catalog"
+ }
+}
+
+
+-- !query
+VALUES(IDENTIFIER(1)())
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.WRONG_TYPE",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "dataType" : "int",
+ "expr" : "1",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 19,
+ "stopIndex" : 19,
+ "fragment" : "1"
+ } ]
+}
+
+
+-- !query
+VALUES(IDENTIFIER('a.b.c.d')())
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "identifier" : "`a`.`b`.`c`.`d`",
+ "limit" : "2"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 30,
+ "fragment" : "IDENTIFIER('a.b.c.d')()"
+ } ]
+}
+
+
+-- !query
+CREATE TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') AS 'test.org.apache.spark.sql.MyDoubleAvg'
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "INVALID_SQL_SYNTAX.CREATE_TEMP_FUNC_WITH_DATABASE",
+ "sqlState" : "42000",
+ "messageParameters" : {
+ "database" : "`default`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 108,
+ "fragment" : "CREATE TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg') AS 'test.org.apache.spark.sql.MyDoubleAvg'"
+ } ]
+}
+
+
+-- !query
+DROP TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "INVALID_SQL_SYNTAX.MULTI_PART_NAME",
+ "sqlState" : "42000",
+ "messageParameters" : {
+ "name" : "`default`.`myDoubleAvg`",
+ "statement" : "DROP TEMPORARY FUNCTION"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 63,
+ "fragment" : "DROP TEMPORARY FUNCTION IDENTIFIER('default.my' || 'DoubleAvg')"
+ } ]
+}
+
+
+-- !query
+CREATE TEMPORARY VIEW IDENTIFIER('default.v')(c1) AS VALUES(1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "TEMP_VIEW_NAME_TOO_MANY_NAME_PARTS",
+ "sqlState" : "428EK",
+ "messageParameters" : {
+ "actualName" : "`default`.`v`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 62,
+ "fragment" : "CREATE TEMPORARY VIEW IDENTIFIER('default.v')(c1) AS VALUES(1)"
+ } ]
+}
+
+
+-- !query
+create temporary view identifier('v1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+cache table identifier('t1') as (select my_col from (values (1), (2), (1) as (my_col)) group by 1)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+create table identifier('t2') using csv as (select my_col from (values (1), (2), (1) as (my_col)) group by 1)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+insert into identifier('t2') select my_col from (values (3) as (my_col)) group by 1
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+drop view v1
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+drop table t1
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+drop table t2
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DECLARE agg = 'max'
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DECLARE col = 'c1'
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DECLARE tab = 'T'
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)),
+ T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd'))
+SELECT IDENTIFIER(agg)(IDENTIFIER(col)) FROM IDENTIFIER(tab)
+-- !query schema
+struct
+-- !query output
+c
+
+
+-- !query
+WITH S(c1, c2) AS (VALUES(1, 2), (2, 3)),
+ T(c1, c2) AS (VALUES ('a', 'b'), ('c', 'd'))
+SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('T')
+-- !query schema
+struct
+-- !query output
+c
+
+
+-- !query
+WITH ABC(c1, c2) AS (VALUES(1, 2), (2, 3))
+SELECT IDENTIFIER('max')(IDENTIFIER('c1')) FROM IDENTIFIER('A' || 'BC')
+-- !query schema
+struct
+-- !query output
+2
+
+
+-- !query
+SELECT row_number() OVER IDENTIFIER('x.win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "''x.win''",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT T1.c1 FROM VALUES(1) AS T1(c1) JOIN VALUES(1) AS T2(c1) USING (IDENTIFIER('c1'))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT IDENTIFIER('t').c1 FROM VALUES(1) AS T(c1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION",
+ "sqlState" : "42703",
+ "messageParameters" : {
+ "objectName" : "`t`",
+ "proposal" : "`c1`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 22,
+ "fragment" : "IDENTIFIER('t')"
+ } ]
+}
+
+
+-- !query
+SELECT map('a', 1).IDENTIFIER('a') FROM VALUES(1) AS T(c1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "''a''",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT named_struct('a', 1).IDENTIFIER('a') FROM VALUES(1) AS T(c1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "''a''",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT * FROM s.IDENTIFIER('tab')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT * FROM IDENTIFIER('s').IDENTIFIER('tab')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'IDENTIFIER'",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT * FROM IDENTIFIER('s').tab
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'IDENTIFIER'",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT row_number() OVER IDENTIFIER('win') FROM VALUES(1) AS T(c1) WINDOW win AS (ORDER BY c1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "''win''",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT row_number() OVER win FROM VALUES(1) AS T(c1) WINDOW IDENTIFIER('win') AS (ORDER BY c1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ": missing 'AS'"
+ }
+}
+
+
+-- !query
+SELECT 1 AS IDENTIFIER('col1')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT my_table.* FROM VALUES (1, 2) AS IDENTIFIER('my_table')(IDENTIFIER('c1'), IDENTIFIER('c2'))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "''my_table''",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+WITH identifier('v')(identifier('c1')) AS (VALUES(1)) (SELECT c1 FROM v)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "''v''",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+CREATE OR REPLACE VIEW v(IDENTIFIER('c1')) AS VALUES(1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT c1 FROM v
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "TABLE_OR_VIEW_NOT_FOUND",
+ "sqlState" : "42P01",
+ "messageParameters" : {
+ "relationName" : "`v`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 16,
+ "stopIndex" : 16,
+ "fragment" : "v"
+ } ]
+}
+
+
+-- !query
+DROP VIEW IF EXISTS v
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE TABLE tab(IDENTIFIER('c1') INT) USING CSV
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+INSERT INTO tab(IDENTIFIER('c1')) VALUES(1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ": missing ')'"
+ }
+}
+
+
+-- !query
+SELECT c1 FROM tab
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "TABLE_OR_VIEW_NOT_FOUND",
+ "sqlState" : "42P01",
+ "messageParameters" : {
+ "relationName" : "`tab`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 16,
+ "stopIndex" : 18,
+ "fragment" : "tab"
+ } ]
+}
+
+
+-- !query
+ALTER TABLE IDENTIFIER('tab') RENAME COLUMN IDENTIFIER('c1') TO IDENTIFIER('col1')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT col1 FROM tab
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "TABLE_OR_VIEW_NOT_FOUND",
+ "sqlState" : "42P01",
+ "messageParameters" : {
+ "relationName" : "`tab`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 18,
+ "stopIndex" : 20,
+ "fragment" : "tab"
+ } ]
+}
+
+
+-- !query
+ALTER TABLE IDENTIFIER('tab') ADD COLUMN IDENTIFIER('c2') INT
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT c2 FROM tab
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "TABLE_OR_VIEW_NOT_FOUND",
+ "sqlState" : "42P01",
+ "messageParameters" : {
+ "relationName" : "`tab`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 16,
+ "stopIndex" : 18,
+ "fragment" : "tab"
+ } ]
+}
+
+
+-- !query
+ALTER TABLE IDENTIFIER('tab') DROP COLUMN IDENTIFIER('c2')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+ALTER TABLE IDENTIFIER('tab') RENAME TO IDENTIFIER('tab_renamed')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT * FROM tab_renamed
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "TABLE_OR_VIEW_NOT_FOUND",
+ "sqlState" : "42P01",
+ "messageParameters" : {
+ "relationName" : "`tab_renamed`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 15,
+ "stopIndex" : 25,
+ "fragment" : "tab_renamed"
+ } ]
+}
+
+
+-- !query
+DROP TABLE IF EXISTS tab_renamed
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP TABLE IF EXISTS tab
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE TABLE test_col_with_dot(IDENTIFIER('`col.with.dot`') INT) USING CSV
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+DROP TABLE IF EXISTS test_col_with_dot
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT * FROM VALUES (1, 2) AS IDENTIFIER('schema.table')(c1, c2)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "''schema.table''",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT 1 AS IDENTIFIER('col1.col2')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+CREATE SCHEMA identifier_clause_test_schema
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+USE identifier_clause_test_schema
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE TABLE test_show(c1 INT, c2 STRING) USING CSV
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SHOW VIEWS IN IDENTIFIER('identifier_clause_test_schema')
+-- !query schema
+struct
+-- !query output
+
+
+
+-- !query
+SHOW PARTITIONS IDENTIFIER('test_show')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "INVALID_PARTITION_OPERATION.PARTITION_SCHEMA_IS_EMPTY",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "name" : "`spark_catalog`.`identifier_clause_test_schema`.`test_show`"
+ }
+}
+
+
+-- !query
+SHOW CREATE TABLE IDENTIFIER('test_show')
+-- !query schema
+struct
+-- !query output
+CREATE TABLE spark_catalog.identifier_clause_test_schema.test_show (
+ c1 INT,
+ c2 STRING)
+USING CSV
+
+
+-- !query
+DROP TABLE test_show
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE TABLE test_desc(c1 INT) USING CSV
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DESCRIBE TABLE IDENTIFIER('test_desc')
+-- !query schema
+struct
+-- !query output
+c1 int
+
+
+-- !query
+DESCRIBE FORMATTED IDENTIFIER('test_desc')
+-- !query schema
+struct
+-- !query output
+c1 int
+
+# Detailed Table Information
+Catalog spark_catalog
+Database identifier_clause_test_schema
+Table test_desc
+Created Time [not included in comparison]
+Last Access [not included in comparison]
+Created By [not included in comparison]
+Type MANAGED
+Provider CSV
+Location [not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/test_desc
+
+
+-- !query
+DESCRIBE EXTENDED IDENTIFIER('test_desc')
+-- !query schema
+struct
+-- !query output
+c1 int
+
+# Detailed Table Information
+Catalog spark_catalog
+Database identifier_clause_test_schema
+Table test_desc
+Created Time [not included in comparison]
+Last Access [not included in comparison]
+Created By [not included in comparison]
+Type MANAGED
+Provider CSV
+Location [not included in comparison]/{warehouse_dir}/identifier_clause_test_schema.db/test_desc
+
+
+-- !query
+DESC IDENTIFIER('test_desc')
+-- !query schema
+struct
+-- !query output
+c1 int
+
+
+-- !query
+DROP TABLE test_desc
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE TABLE test_comment(c1 INT, c2 STRING) USING CSV
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+COMMENT ON TABLE IDENTIFIER('test_comment') IS 'table comment'
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+ALTER TABLE test_comment ALTER COLUMN IDENTIFIER('c1') COMMENT 'column comment'
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+DROP TABLE test_comment
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE TABLE identifier_clause_test_schema.test_table(c1 INT) USING CSV
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+ANALYZE TABLE IDENTIFIER('identifier_clause_test_schema.test_table') COMPUTE STATISTICS
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+REFRESH TABLE IDENTIFIER('identifier_clause_test_schema.test_table')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DESCRIBE IDENTIFIER('identifier_clause_test_schema.test_table')
+-- !query schema
+struct
+-- !query output
+c1 int
+
+
+-- !query
+SHOW COLUMNS FROM IDENTIFIER('identifier_clause_test_schema.test_table')
+-- !query schema
+struct
+-- !query output
+c1
+
+
+-- !query
+DROP TABLE IDENTIFIER('identifier_clause_test_schema.test_table')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DECLARE IDENTIFIER('my_var') = 'value'
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SET VAR IDENTIFIER('my_var') = 'new_value'
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ": missing EQ"
+ }
+}
+
+
+-- !query
+SELECT IDENTIFIER('my_var')
+-- !query schema
+struct
+-- !query output
+value
+
+
+-- !query
+DROP TEMPORARY VARIABLE IDENTIFIER('my_var')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE TEMPORARY FUNCTION test_udf(IDENTIFIER('param1') INT, IDENTIFIER('param2') STRING)
+RETURNS INT
+RETURN IDENTIFIER('param1') + length(IDENTIFIER('param2'))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT test_udf(5, 'hello')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "UNRESOLVED_ROUTINE",
+ "sqlState" : "42883",
+ "messageParameters" : {
+ "routineName" : "`test_udf`",
+ "searchPath" : "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`identifier_clause_test_schema`]"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 27,
+ "fragment" : "test_udf(5, 'hello')"
+ } ]
+}
+
+
+-- !query
+DROP TEMPORARY FUNCTION test_udf
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.analysis.NoSuchTempFunctionException
+{
+ "errorClass" : "ROUTINE_NOT_FOUND",
+ "sqlState" : "42883",
+ "messageParameters" : {
+ "routineName" : "`test_udf`"
+ }
+}
+
+
+-- !query
+CREATE TEMPORARY FUNCTION test_table_udf(IDENTIFIER('input_val') INT)
+RETURNS TABLE(IDENTIFIER('col1') INT, IDENTIFIER('col2') STRING)
+RETURN SELECT IDENTIFIER('input_val'), 'result'
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT * FROM test_table_udf(42)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "UNRESOLVABLE_TABLE_VALUED_FUNCTION",
+ "sqlState" : "42883",
+ "messageParameters" : {
+ "name" : "`test_table_udf`"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 15,
+ "stopIndex" : 32,
+ "fragment" : "test_table_udf(42)"
+ } ]
+}
+
+
+-- !query
+DROP TEMPORARY FUNCTION test_table_udf
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.analysis.NoSuchTempFunctionException
+{
+ "errorClass" : "ROUTINE_NOT_FOUND",
+ "sqlState" : "42883",
+ "messageParameters" : {
+ "routineName" : "`test_table_udf`"
+ }
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:tab \'b\').c1 FROM VALUES(1) AS tab(c1)' USING 'ta' AS tab
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "INVALID_EXTRACT_BASE_FIELD_TYPE",
+ "sqlState" : "42000",
+ "messageParameters" : {
+ "base" : "\"variablereference(system.session.tab='T')\"",
+ "other" : "\"STRING\""
+ }
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:col1 ''.c2'') FROM VALUES(named_struct(''c2'', 42)) AS T(c1)'
+ USING 'c1' AS col1
+-- !query schema
+struct
+-- !query output
+42
+
+
+-- !query
+CREATE TABLE integration_test(c1 INT, c2 STRING) USING CSV
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+INSERT INTO integration_test VALUES (1, 'a'), (2, 'b')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT * FROM IDENTIFIER(:schema ''.'' :table) ORDER BY ALL'
+ USING 'identifier_clause_test_schema' AS schema, 'integration_test' AS table
+-- !query schema
+struct
+-- !query output
+1 a
+2 b
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:prefix ''1''), IDENTIFIER(:prefix ''2'') FROM integration_test ORDER BY ALL'
+ USING 'c' AS prefix
+-- !query schema
+struct
+-- !query output
+1 a
+2 b
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT * FROM integration_test WHERE IDENTIFIER(:col) = :val'
+ USING 'c1' AS col, 1 AS val
+-- !query schema
+struct
+-- !query output
+1 a
+
+
+-- !query
+CREATE TABLE integration_test2(c1 INT, c3 STRING) USING CSV
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+INSERT INTO integration_test2 VALUES (1, 'x'), (2, 'y')
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT t1.*, t2.* FROM IDENTIFIER(:t1) t1 JOIN IDENTIFIER(:t2) t2 USING (IDENTIFIER(:col)) ORDER BY ALL'
+ USING 'integration_test' AS t1, 'integration_test2' AS t2, 'c1' AS col
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 103,
+ "fragment" : "SELECT t1.*, t2.* FROM IDENTIFIER(:t1) t1 JOIN IDENTIFIER(:t2) t2 USING (IDENTIFIER(:col)) ORDER BY ALL"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE
+ 'SELECT IDENTIFIER(:col1), IDENTIFIER(:col2), row_number() OVER (PARTITION BY IDENTIFIER(:part) ORDER BY IDENTIFIER(:ord)) as rn FROM integration_test'
+ USING 'c1' AS col1, 'c2' AS col2, 'c2' AS part, 'c1' AS ord
+-- !query schema
+struct
+-- !query output
+1 a 1
+2 b 1
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:prefix ''2''), IDENTIFIER(:agg)(IDENTIFIER(:col)) FROM integration_test GROUP BY IDENTIFIER(:prefix ''2'') ORDER BY ALL'
+ USING 'c' AS prefix, 'count' AS agg, 'c1' AS col
+-- !query schema
+struct
+-- !query output
+a 1
+b 1
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT * FROM integration_test ORDER BY IDENTIFIER(:col1) DESC, IDENTIFIER(:col2)'
+ USING 'c1' AS col1, 'c2' AS col2
+-- !query schema
+struct
+-- !query output
+2 b
+1 a
+
+
+-- !query
+EXECUTE IMMEDIATE 'INSERT INTO integration_test(IDENTIFIER(:col1), IDENTIFIER(:col2)) VALUES (:val1, :val2)'
+ USING 'c1' AS col1, 'c2' AS col2, 3 AS val1, 'c' AS val2
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ": missing ')'"
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 88,
+ "fragment" : "INSERT INTO integration_test(IDENTIFIER(:col1), IDENTIFIER(:col2)) VALUES (:val1, :val2)"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(concat(:schema, ''.'', :table, ''.c1'')) FROM VALUES(named_struct(''c1'', 100)) AS IDENTIFIER(:alias)(IDENTIFIER(:schema ''.'' :table))'
+ USING 'identifier_clause_test_schema' AS schema, 'my_table' AS table, 't' AS alias
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "':'",
+ "hint" : ": extra input ':'"
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 145,
+ "fragment" : "SELECT IDENTIFIER(concat(:schema, '.', :table, '.c1')) FROM VALUES(named_struct('c1', 100)) AS IDENTIFIER(:alias)(IDENTIFIER(:schema '.' :table))"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'WITH IDENTIFIER(:cte_name)(c1) AS (VALUES(1)) SELECT c1 FROM IDENTIFIER(:cte_name)'
+ USING 'my_cte' AS cte_name
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "':'",
+ "hint" : ""
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 82,
+ "fragment" : "WITH IDENTIFIER(:cte_name)(c1) AS (VALUES(1)) SELECT c1 FROM IDENTIFIER(:cte_name)"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'CREATE OR REPLACE TEMPORARY VIEW IDENTIFIER(:view_name)(IDENTIFIER(:col_name)) AS VALUES(1)'
+ USING 'test_view' AS view_name, 'test_col' AS col_name
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 91,
+ "fragment" : "CREATE OR REPLACE TEMPORARY VIEW IDENTIFIER(:view_name)(IDENTIFIER(:col_name)) AS VALUES(1)"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:col) FROM IDENTIFIER(:view)'
+ USING 'test_col' AS col, 'test_view' AS view
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "TABLE_OR_VIEW_NOT_FOUND",
+ "sqlState" : "42P01",
+ "messageParameters" : {
+ "relationName" : "`test_view`"
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 30,
+ "stopIndex" : 46,
+ "fragment" : "IDENTIFIER(:view)"
+ } ]
+}
+
+
+-- !query
+DROP VIEW test_view
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+{
+ "errorClass" : "TABLE_OR_VIEW_NOT_FOUND",
+ "sqlState" : "42P01",
+ "messageParameters" : {
+ "relationName" : "`spark_catalog`.`identifier_clause_test_schema`.`test_view`"
+ }
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'ALTER TABLE IDENTIFIER(:tab) ADD COLUMN IDENTIFIER(:new_col) INT'
+ USING 'integration_test' AS tab, 'c4' AS new_col
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 64,
+ "fragment" : "ALTER TABLE IDENTIFIER(:tab) ADD COLUMN IDENTIFIER(:new_col) INT"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'ALTER TABLE IDENTIFIER(:tab) RENAME COLUMN IDENTIFIER(:old_col) TO IDENTIFIER(:new_col)'
+ USING 'integration_test' AS tab, 'c4' AS old_col, 'c5' AS new_col
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 87,
+ "fragment" : "ALTER TABLE IDENTIFIER(:tab) RENAME COLUMN IDENTIFIER(:old_col) TO IDENTIFIER(:new_col)"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT map(:key, :val).IDENTIFIER(:key) AS result'
+ USING 'mykey' AS key, 42 AS val
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "':'",
+ "hint" : ""
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 49,
+ "fragment" : "SELECT map(:key, :val).IDENTIFIER(:key) AS result"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT IDENTIFIER(:alias ''.c1'') FROM integration_test AS IDENTIFIER(:alias) ORDER BY ALL'
+ USING 't' AS alias
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "':'",
+ "hint" : ": extra input ':'"
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 88,
+ "fragment" : "SELECT IDENTIFIER(:alias '.c1') FROM integration_test AS IDENTIFIER(:alias) ORDER BY ALL"
+ } ]
+}
+
+
+-- !query
+EXECUTE IMMEDIATE
+ 'SELECT IDENTIFIER(:col1), IDENTIFIER(:p ''2'') FROM IDENTIFIER(:schema ''.'' :tab) WHERE IDENTIFIER(:col1) > 0 ORDER BY IDENTIFIER(:p ''1'')'
+ USING 'c1' AS col1, 'c' AS p, 'identifier_clause_test_schema' AS schema, 'integration_test' AS tab
+-- !query schema
+struct
+-- !query output
+1 a
+2 b
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT * FROM IDENTIFIER(:schema ''.'' :table) WHERE IDENTIFIER(concat(:tab_alias, ''.c1'')) > 0 ORDER BY ALL'
+ USING 'identifier_clause_test_schema' AS schema, 'integration_test' AS table, 'integration_test' AS tab_alias
+-- !query schema
+struct
+-- !query output
+1 a
+2 b
+
+
+-- !query
+EXECUTE IMMEDIATE 'SELECT 1 AS IDENTIFIER(:schema ''.'' :col)'
+ USING 'identifier_clause_test_schema' AS schema, 'col1' AS col
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ },
+ "queryContext" : [ {
+ "objectType" : "EXECUTE IMMEDIATE",
+ "objectName" : "",
+ "startIndex" : 1,
+ "stopIndex" : 40,
+ "fragment" : "SELECT 1 AS IDENTIFIER(:schema '.' :col)"
+ } ]
+}
+
+
+-- !query
+DROP TABLE integration_test
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DROP TABLE integration_test2
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE TABLE lateral_test(arr ARRAY) USING PARQUET
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+INSERT INTO lateral_test VALUES (array(1, 2, 3))
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT * FROM lateral_test LATERAL VIEW explode(arr) IDENTIFIER('tbl') AS IDENTIFIER('col') ORDER BY ALL
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT * FROM lateral_test LATERAL VIEW OUTER explode(arr) IDENTIFIER('my_table') AS IDENTIFIER('my_col') ORDER BY ALL
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+DROP TABLE lateral_test
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE TABLE unpivot_test(id INT, a INT, b INT, c INT) USING CSV
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+INSERT INTO unpivot_test VALUES (1, 10, 20, 30)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT * FROM unpivot_test UNPIVOT (val FOR col IN (a AS IDENTIFIER('col_a'), b AS IDENTIFIER('col_b'))) ORDER BY ALL
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT * FROM unpivot_test UNPIVOT ((v1, v2) FOR col IN ((a, b) AS IDENTIFIER('cols_ab'), (b, c) AS IDENTIFIER('cols_bc'))) ORDER BY ALL
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+DROP TABLE unpivot_test
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+CREATE TABLE describe_col_test(c1 INT, c2 STRING, c3 DOUBLE) USING CSV
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+DESCRIBE describe_col_test IDENTIFIER('c1')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+DESCRIBE describe_col_test IDENTIFIER('c2')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+DROP TABLE describe_col_test
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT :IDENTIFIER('param1') FROM VALUES(1) AS T(c1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "''param1''",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+CREATE TABLE hint_test(c1 INT, c2 INT) USING CSV
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+INSERT INTO hint_test VALUES (1, 2), (3, 4)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT /*+ IDENTIFIER('BROADCAST')(hint_test) */ * FROM hint_test
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT /*+ IDENTIFIER('MERGE')(hint_test) */ * FROM hint_test
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+DROP TABLE hint_test
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SHOW IDENTIFIER('USER') FUNCTIONS
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT EXTRACT(IDENTIFIER('YEAR') FROM DATE'2024-01-15')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "PARSE_SYNTAX_ERROR",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "error" : "'('",
+ "hint" : ""
+ }
+}
+
+
+-- !query
+SELECT TIMESTAMPADD(IDENTIFIER('YEAR'), 1, DATE'2024-01-15')
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "UNRESOLVED_ROUTINE",
+ "sqlState" : "42883",
+ "messageParameters" : {
+ "routineName" : "`TIMESTAMPADD`",
+ "searchPath" : "[`system`.`builtin`, `system`.`session`, `spark_catalog`.`identifier_clause_test_schema`]"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 8,
+ "stopIndex" : 60,
+ "fragment" : "TIMESTAMPADD(IDENTIFIER('YEAR'), 1, DATE'2024-01-15')"
+ } ]
+}
+
+
+-- !query
+DROP SCHEMA identifier_clause_test_schema
+-- !query schema
+struct<>
+-- !query output
+
diff --git a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
index 596745b4ba5d8..0c0473791201f 100644
--- a/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/identifier-clause.sql.out
@@ -118,9 +118,11 @@ struct<>
-- !query output
org.apache.spark.SparkUnsupportedOperationException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_2096",
+ "errorClass" : "UNSUPPORTED_FEATURE.TABLE_OPERATION",
+ "sqlState" : "0A000",
"messageParameters" : {
- "ddl" : "UPDATE TABLE"
+ "operation" : "UPDATE TABLE",
+ "tableName" : "`spark_catalog`.`s`.`tab`"
}
}
@@ -133,9 +135,11 @@ struct<>
-- !query output
org.apache.spark.SparkUnsupportedOperationException
{
- "errorClass" : "_LEGACY_ERROR_TEMP_2096",
+ "errorClass" : "UNSUPPORTED_FEATURE.TABLE_OPERATION",
+ "sqlState" : "0A000",
"messageParameters" : {
- "ddl" : "MERGE INTO TABLE"
+ "operation" : "MERGE INTO TABLE",
+ "tableName" : "`spark_catalog`.`s`.`tab`"
}
}
@@ -842,6 +846,137 @@ org.apache.spark.sql.AnalysisException
}
+-- !query
+CREATE TABLE t(col1 INT)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT * FROM IDENTIFIER((SELECT 't'))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NOT_CONSTANT",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "scalarsubquery()",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 26,
+ "stopIndex" : 37,
+ "fragment" : "(SELECT 't')"
+ } ]
+}
+
+
+-- !query
+SELECT * FROM (SELECT IDENTIFIER((SELECT 'col1')) FROM IDENTIFIER((SELECT 't')))
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.ExtendedAnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NOT_CONSTANT",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "scalarsubquery()",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 67,
+ "stopIndex" : 78,
+ "fragment" : "(SELECT 't')"
+ } ]
+}
+
+
+-- !query
+SELECT IDENTIFIER((SELECT 'col1')) FROM VALUES(1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NOT_CONSTANT",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "scalarsubquery()",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 19,
+ "stopIndex" : 33,
+ "fragment" : "(SELECT 'col1')"
+ } ]
+}
+
+
+-- !query
+SELECT col1, IDENTIFIER((SELECT col1)) FROM VALUES(1)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+{
+ "errorClass" : "NOT_A_CONSTANT_STRING.NOT_CONSTANT",
+ "sqlState" : "42601",
+ "messageParameters" : {
+ "expr" : "scalarsubquery(col1)",
+ "name" : "IDENTIFIER"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 25,
+ "stopIndex" : 37,
+ "fragment" : "(SELECT col1)"
+ } ]
+}
+
+
+-- !query
+SELECT IDENTIFIER((SELECT 'col1', 'col2')) FROM VALUES(1,2)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.catalyst.parser.ParseException
+{
+ "errorClass" : "UNSUPPORTED_TYPED_LITERAL",
+ "sqlState" : "0A000",
+ "messageParameters" : {
+ "supportedTypes" : "\"DATE\", \"TIMESTAMP_NTZ\", \"TIMESTAMP_LTZ\", \"TIMESTAMP\", \"INTERVAL\", \"X\", \"TIME\"",
+ "unsupportedType" : "\"SELECT\""
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 20,
+ "stopIndex" : 32,
+ "fragment" : "SELECT 'col1'"
+ } ]
+}
+
+
+-- !query
+DROP TABLE t
+-- !query schema
+struct<>
+-- !query output
+
+
+
-- !query
CREATE TABLE IDENTIFIER(1)(c1 INT) USING csv
-- !query schema
@@ -980,7 +1115,8 @@ org.apache.spark.sql.AnalysisException
"errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
"sqlState" : "42601",
"messageParameters" : {
- "identifier" : "`a`.`b`.`c`.`d`"
+ "identifier" : "`a`.`b`.`c`.`d`",
+ "limit" : "2"
},
"queryContext" : [ {
"objectType" : "",
@@ -1175,29 +1311,28 @@ struct<>
-- !query output
org.apache.spark.sql.catalyst.parser.ParseException
{
- "errorClass" : "PARSE_SYNTAX_ERROR",
+ "errorClass" : "IDENTIFIER_TOO_MANY_NAME_PARTS",
"sqlState" : "42601",
"messageParameters" : {
- "error" : "''x.win''",
- "hint" : ""
- }
+ "identifier" : "`x`.`win`",
+ "limit" : "1"
+ },
+ "queryContext" : [ {
+ "objectType" : "",
+ "objectName" : "",
+ "startIndex" : 26,
+ "stopIndex" : 44,
+ "fragment" : "IDENTIFIER('x.win')"
+ } ]
}
-- !query
SELECT T1.c1 FROM VALUES(1) AS T1(c1) JOIN VALUES(1) AS T2(c1) USING (IDENTIFIER('c1'))
-- !query schema
-struct<>
+struct
-- !query output
-org.apache.spark.sql.catalyst.parser.ParseException
-{
- "errorClass" : "PARSE_SYNTAX_ERROR",
- "sqlState" : "42601",
- "messageParameters" : {
- "error" : "'('",
- "hint" : ""
- }
-}
+1
-- !query
@@ -1226,33 +1361,17 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
-- !query
SELECT map('a', 1).IDENTIFIER('a') FROM VALUES(1) AS T(c1)
-- !query schema
-struct<>
+struct