Skip to content

Commit 8e75fc9

Browse files
Juliusz Sompolskidongjoon-hyun
authored andcommitted
[SPARK-54637][CONNECT][TESTS] Add SQL API test helpers to SparkConnectServerTest
### What changes were proposed in this pull request? Add testing helpers to SparkConnectServerTest to enable using connect Spark SQL APIs in tests using that helper. ### Why are the changes needed? In Spark 3.5, a testing trait SparkConnectServerTest was introduced that helped test Spark Connect Service with a SparkConnectClient in the same JVM proccess, which tested real Spark Connect code paths (SparkConnectClient communicating with the server over actual connection to the localhost server). Before that, using RemoteSparkSession, server was started in a separate process. It helped * testability: can trigger stuff from the client, then have verification code checking stuff server side. Can also do some more internal server side setup to test specific things. * debugging, as both client and server can be easily connected to by a debugger. At that time, it was impossible to test Spark Connect client SQL APIs (SparkSession, Dataset) this way, because they were in the same namespace as server, and hence couldn't be classloaded together. Since Spark 4.0, there is a new API layer that makes it possible for connect and classic implementation of the interfaces to coexist. With that, testing can be extended to use actual SparkSession and other APIs, instead of having to construct tests using more raw APIs. ### Does this PR introduce _any_ user-facing change? No. It's testing only. ### How was this patch tested? Added SparkConnectServerTestSuite showcasing the new APIs. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Claude Code opus 4.5 Closes #53384 from juliuszsompolski/spark-connect-server-client-test. Authored-by: Juliusz Sompolski <Juliusz Sompolski> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent 0824934 commit 8e75fc9

File tree

2 files changed

+278
-1
lines changed

2 files changed

+278
-1
lines changed

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkConnectServerTest.scala

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,17 @@ import org.scalatest.time.Span
2626
import org.scalatest.time.SpanSugar._
2727

2828
import org.apache.spark.connect.proto
29+
import org.apache.spark.sql.SparkSession
2930
import org.apache.spark.sql.catalyst.ScalaReflection
31+
import org.apache.spark.sql.classic
32+
import org.apache.spark.sql.connect
3033
import org.apache.spark.sql.connect.client.{CloseableIterator, CustomSparkConnectBlockingStub, ExecutePlanResponseReattachableIterator, RetryPolicy, SparkConnectClient, SparkConnectStubState}
3134
import org.apache.spark.sql.connect.client.arrow.ArrowSerializer
3235
import org.apache.spark.sql.connect.common.config.ConnectCommon
3336
import org.apache.spark.sql.connect.config.Connect
3437
import org.apache.spark.sql.connect.dsl.MockRemoteSession
3538
import org.apache.spark.sql.connect.dsl.plans._
36-
import org.apache.spark.sql.connect.service.{ExecuteHolder, SparkConnectService}
39+
import org.apache.spark.sql.connect.service.{ExecuteHolder, SessionKey, SparkConnectService}
3740
import org.apache.spark.sql.test.SharedSparkSession
3841

3942
/**
@@ -320,4 +323,71 @@ trait SparkConnectServerTest extends SharedSparkSession {
320323
val plan = buildPlan(query)
321324
runQuery(plan, queryTimeout, iterSleep)
322325
}
326+
327+
/**
328+
* Helper method to create a connect SparkSession that connects to the localhost server. Similar
329+
* to withClient, but provides a full SparkSession API instead of just a client.
330+
*
331+
* @param sessionId
332+
* Optional session ID (defaults to defaultSessionId)
333+
* @param userId
334+
* Optional user ID (defaults to defaultUserId)
335+
* @param f
336+
* Function to execute with the session
337+
*/
338+
protected def withSession(sessionId: String = defaultSessionId, userId: String = defaultUserId)(
339+
f: SparkSession => Unit): Unit = {
340+
withSession(f, sessionId, userId)
341+
}
342+
343+
/**
344+
* Helper method to create a connect SparkSession with default session and user IDs.
345+
*
346+
* @param f
347+
* Function to execute with the session
348+
*/
349+
protected def withSession(f: SparkSession => Unit): Unit = {
350+
withSession(f, defaultSessionId, defaultUserId)
351+
}
352+
353+
private def withSession(f: SparkSession => Unit, sessionId: String, userId: String): Unit = {
354+
val client = SparkConnectClient
355+
.builder()
356+
.port(serverPort)
357+
.sessionId(sessionId)
358+
.userId(userId)
359+
.build()
360+
361+
val session = connect.SparkSession
362+
.builder()
363+
.client(client)
364+
.create()
365+
try f(session)
366+
finally {
367+
session.close()
368+
}
369+
}
370+
371+
/**
372+
* Get the server-side SparkSession corresponding to a client SparkSession.
373+
*
374+
* This helper takes a sql.SparkSession (which is assumed to be a connect.SparkSession),
375+
* extracts the userId and sessionId from it, and looks up the corresponding server-side classic
376+
* SparkSession using SparkConnectSessionManager.
377+
*
378+
* @param clientSession
379+
* The client SparkSession (must be a connect.SparkSession)
380+
* @return
381+
* The server-side classic SparkSession
382+
*/
383+
protected def getServerSession(clientSession: SparkSession): classic.SparkSession = {
384+
val connectSession = clientSession.asInstanceOf[connect.SparkSession]
385+
val userId = connectSession.client.userId
386+
val sessionId = connectSession.sessionId
387+
val key = SessionKey(userId, sessionId)
388+
SparkConnectService.sessionManager
389+
.getIsolatedSessionIfPresent(key)
390+
.get
391+
.session
392+
}
323393
}
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.connect
18+
19+
import org.scalatest.time.SpanSugar._
20+
21+
/**
22+
* Test suite showcasing the APIs provided by SparkConnectServerTest trait.
23+
*
24+
* This suite demonstrates:
25+
* - Session and client helper methods (withSession, withClient, getServerSession)
26+
* - Low-level stub helpers (withRawBlockingStub, withCustomBlockingStub)
27+
* - Plan building helpers (buildPlan, buildExecutePlanRequest, etc.)
28+
* - Assertion helpers for execution state
29+
*/
30+
class SparkConnectServerTestSuite extends SparkConnectServerTest {
31+
32+
test("withSession: execute SQL and collect results") {
33+
withSession { session =>
34+
val df = session.sql("SELECT 1 as value")
35+
val result = df.collect()
36+
assert(result.length == 1)
37+
assert(result(0).getInt(0) == 1)
38+
}
39+
}
40+
41+
test("withSession: with custom session and user IDs") {
42+
val customSessionId = java.util.UUID.randomUUID().toString
43+
val customUserId = "test-user"
44+
withSession(sessionId = customSessionId, userId = customUserId) { session =>
45+
val df = session.sql("SELECT 'hello' as greeting")
46+
val result = df.collect()
47+
assert(result.length == 1)
48+
assert(result(0).getString(0) == "hello")
49+
}
50+
}
51+
52+
test("withSession: DataFrame operations") {
53+
withSession { session =>
54+
val df = session.range(10)
55+
assert(df.count() == 10)
56+
57+
val sum = df.selectExpr("sum(id)").collect()(0).getLong(0)
58+
assert(sum == 45) // 0 + 1 + ... + 9 = 45
59+
}
60+
}
61+
62+
test("withClient: execute plan and iterate results") {
63+
withClient { client =>
64+
val plan = buildPlan("SELECT 1 as x, 2 as y")
65+
val iter = client.execute(plan)
66+
var hasResults = false
67+
while (iter.hasNext) {
68+
iter.next()
69+
hasResults = true
70+
}
71+
assert(hasResults)
72+
}
73+
}
74+
75+
test("withClient: with custom session and user IDs") {
76+
val customSessionId = java.util.UUID.randomUUID().toString
77+
val customUserId = "custom-user"
78+
withClient(sessionId = customSessionId, userId = customUserId) { client =>
79+
val plan = buildPlan("SELECT 42")
80+
val iter = client.execute(plan)
81+
while (iter.hasNext) iter.next()
82+
}
83+
}
84+
85+
test("getServerSession: returns server-side classic session") {
86+
withSession { clientSession =>
87+
clientSession.sql("SELECT 1").collect()
88+
89+
val serverSession = getServerSession(clientSession)
90+
91+
assert(serverSession != null)
92+
assert(serverSession.sparkContext != null)
93+
}
94+
}
95+
96+
test("getServerSession: client and server share configuration") {
97+
withSession { clientSession =>
98+
clientSession.sql("SET spark.sql.shuffle.partitions=17").collect()
99+
100+
val serverSession = getServerSession(clientSession)
101+
assert(serverSession.conf.get("spark.sql.shuffle.partitions") == "17")
102+
}
103+
}
104+
105+
test("getServerSession: register and use temporary view from server") {
106+
withSession { clientSession =>
107+
clientSession.sql("SELECT 1 as a, 2 as b").collect()
108+
109+
val serverSession = getServerSession(clientSession)
110+
111+
// Create a temp view on the server side
112+
import serverSession.implicits._
113+
val serverDf = Seq((100, "server"), (200, "side")).toDF("num", "source")
114+
serverDf.createOrReplaceTempView("server_view")
115+
116+
// Access the view from the client
117+
val result = clientSession.sql("SELECT * FROM server_view ORDER BY num").collect()
118+
assert(result.length == 2)
119+
assert(result(0).getInt(0) == 100)
120+
assert(result(0).getString(1) == "server")
121+
assert(result(1).getInt(0) == 200)
122+
assert(result(1).getString(1) == "side")
123+
}
124+
}
125+
126+
test("withRawBlockingStub: execute plan via raw gRPC stub") {
127+
withRawBlockingStub { stub =>
128+
val request = buildExecutePlanRequest(buildPlan("SELECT 'raw' as mode"))
129+
val iter = stub.executePlan(request)
130+
assert(iter.hasNext)
131+
while (iter.hasNext) iter.next()
132+
}
133+
}
134+
135+
test("withCustomBlockingStub: execute plan via custom blocking stub") {
136+
withCustomBlockingStub() { stub =>
137+
val request = buildExecutePlanRequest(buildPlan("SELECT 'custom' as mode"))
138+
val iter = stub.executePlan(request)
139+
while (iter.hasNext) iter.next()
140+
}
141+
}
142+
143+
test("buildPlan: creates plan from SQL query") {
144+
val plan = buildPlan("SELECT 1, 2, 3")
145+
assert(plan.hasRoot)
146+
}
147+
148+
test("buildSqlCommandPlan: creates command plan") {
149+
val plan = buildSqlCommandPlan("SET spark.sql.adaptive.enabled=true")
150+
assert(plan.hasCommand)
151+
assert(plan.getCommand.hasSqlCommand)
152+
}
153+
154+
test("buildLocalRelation: creates plan from local data") {
155+
val data = Seq((1, "a"), (2, "b"), (3, "c"))
156+
val plan = buildLocalRelation(data)
157+
assert(plan.hasRoot)
158+
assert(plan.getRoot.hasLocalRelation)
159+
}
160+
161+
test("buildExecutePlanRequest: creates request with options") {
162+
val plan = buildPlan("SELECT 1")
163+
val request = buildExecutePlanRequest(plan)
164+
assert(request.hasPlan)
165+
assert(request.hasUserContext)
166+
assert(request.getSessionId == defaultSessionId)
167+
}
168+
169+
test("buildExecutePlanRequest: with custom session and operation IDs") {
170+
val plan = buildPlan("SELECT 1")
171+
val customSessionId = "my-session"
172+
val customOperationId = "my-operation"
173+
val request =
174+
buildExecutePlanRequest(plan, sessionId = customSessionId, operationId = customOperationId)
175+
assert(request.getSessionId == customSessionId)
176+
assert(request.getOperationId == customOperationId)
177+
}
178+
179+
test("runQuery: executes query string with timeout") {
180+
runQuery("SELECT * FROM range(100)", 30.seconds)
181+
}
182+
183+
test("runQuery: executes plan with timeout and iter sleep") {
184+
val plan = buildPlan("SELECT * FROM range(10)")
185+
runQuery(plan, 30.seconds, iterSleep = 10)
186+
}
187+
188+
test("assertNoActiveExecutions: verifies clean state") {
189+
assertNoActiveExecutions()
190+
}
191+
192+
test("assertNoActiveRpcs: verifies no active RPCs") {
193+
assertNoActiveRpcs()
194+
}
195+
196+
test("eventuallyGetExecutionHolder: retrieves active execution") {
197+
withRawBlockingStub { stub =>
198+
val request = buildExecutePlanRequest(buildPlan("SELECT * FROM range(1000000)"))
199+
val iter = stub.executePlan(request)
200+
iter.hasNext // trigger execution
201+
202+
val holder = eventuallyGetExecutionHolder
203+
assert(holder != null)
204+
assert(holder.operationId == request.getOperationId)
205+
}
206+
}
207+
}

0 commit comments

Comments
 (0)