Skip to content

Commit 78d0bdd

Browse files
committed
Add base session created in SparkConnectService
1 parent 20af57c commit 78d0bdd

File tree

6 files changed

+80
-4
lines changed

6 files changed

+80
-4
lines changed

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ object SparkConnectService extends Logging {
436436
return
437437
}
438438

439+
sessionManager.initializeBaseSession(sc)
439440
startGRPCService()
440441
createListenerAndUI(sc)
441442

sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import scala.util.control.NonFatal
2727

2828
import com.google.common.cache.CacheBuilder
2929

30-
import org.apache.spark.{SparkEnv, SparkSQLException}
30+
import org.apache.spark.{SparkContext, SparkEnv, SparkSQLException}
3131
import org.apache.spark.internal.Logging
3232
import org.apache.spark.internal.LogKeys.{INTERVAL, SESSION_HOLD_INFO}
3333
import org.apache.spark.sql.classic.SparkSession
@@ -39,6 +39,9 @@ import org.apache.spark.util.ThreadUtils
3939
*/
4040
class SparkConnectSessionManager extends Logging {
4141

42+
// Base SparkSession created from the SparkContext, used to create new isolated sessions
43+
@volatile private var baseSession: Option[SparkSession] = None
44+
4245
private val sessionStore: ConcurrentMap[SessionKey, SessionHolder] =
4346
new ConcurrentHashMap[SessionKey, SessionHolder]()
4447

@@ -48,6 +51,16 @@ class SparkConnectSessionManager extends Logging {
4851
.maximumSize(SparkEnv.get.conf.get(CONNECT_SESSION_MANAGER_CLOSED_SESSIONS_TOMBSTONES_SIZE))
4952
.build[SessionKey, SessionHolderInfo]()
5053

54+
/**
55+
* Initialize the base SparkSession from the provided SparkContext.
56+
* This should be called once during SparkConnectService startup.
57+
*/
58+
def initializeBaseSession(sc: SparkContext): Unit = {
59+
if (baseSession.isEmpty) {
60+
baseSession = Some(SparkSession.builder().sparkContext(sc).getOrCreate().newSession())
61+
}
62+
}
63+
5164
/** Executor for the periodic maintenance */
5265
private val scheduledExecutor: AtomicReference[ScheduledExecutorService] =
5366
new AtomicReference[ScheduledExecutorService]()
@@ -333,12 +346,12 @@ class SparkConnectSessionManager extends Logging {
333346
}
334347

335348
private def newIsolatedSession(): SparkSession = {
336-
val active = SparkSession.active
337-
if (active.sparkContext.isStopped) {
349+
val session = baseSession.get
350+
if (session.sparkContext.isStopped) {
338351
assert(SparkSession.getDefaultSession.nonEmpty)
339352
SparkSession.getDefaultSession.get.newSession()
340353
} else {
341-
active.newSession()
354+
session.newSession()
342355
}
343356
}
344357

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ class SparkConnectServiceSuite
6565
with Logging
6666
with SparkConnectPlanTest {
6767

68+
override def beforeEach(): Unit = {
69+
super.beforeEach()
70+
SparkConnectService.sessionManager.invalidateAllSessions()
71+
SparkConnectService.sessionManager.initializeBaseSession(spark.sparkContext)
72+
}
73+
6874
private def sparkSessionHolder = SparkConnectTestUtils.createDummySessionHolder(spark)
6975
private def DEFAULT_UUID = UUID.fromString("89ea6117-1f45-4c03-ae27-f47c6aded093")
7076

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ class ArtifactStatusesHandlerSuite extends SharedSparkSession with ResourceHelpe
4242

4343
val sessionId = UUID.randomUUID().toString
4444

45+
override def beforeEach(): Unit = {
46+
super.beforeEach()
47+
SparkConnectService.sessionManager.invalidateAllSessions()
48+
SparkConnectService.sessionManager.initializeBaseSession(spark.sparkContext)
49+
}
50+
4551
def getStatuses(names: Seq[String], exist: Set[String]): ArtifactStatusesResponse = {
4652
val promise = Promise[ArtifactStatusesResponse]()
4753
val handler = new SparkConnectArtifactStatusesHandler(new DummyStreamObserver(promise)) {

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectCloneSessionSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class SparkConnectCloneSessionSuite extends SharedSparkSession with BeforeAndAft
2929
override def beforeEach(): Unit = {
3030
super.beforeEach()
3131
SparkConnectService.sessionManager.invalidateAllSessions()
32+
SparkConnectService.sessionManager.initializeBaseSession(spark.sparkContext)
3233
}
3334

3435
test("clone session with invalid target session ID format") {

sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.scalatest.BeforeAndAfterEach
2323
import org.scalatest.time.SpanSugar._
2424

2525
import org.apache.spark.SparkSQLException
26+
import org.apache.spark.sql.SparkSession
2627
import org.apache.spark.sql.pipelines.graph.{DataflowGraph, PipelineUpdateContextImpl}
2728
import org.apache.spark.sql.pipelines.logging.PipelineEvent
2829
import org.apache.spark.sql.test.SharedSparkSession
@@ -32,6 +33,7 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA
3233
override def beforeEach(): Unit = {
3334
super.beforeEach()
3435
SparkConnectService.sessionManager.invalidateAllSessions()
36+
SparkConnectService.sessionManager.initializeBaseSession(spark.sparkContext)
3537
}
3638

3739
test("sessionId needs to be an UUID") {
@@ -171,4 +173,51 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA
171173
sessionHolder.getPipelineExecution(graphId).isEmpty,
172174
"pipeline execution was not removed")
173175
}
176+
177+
test("baseSession allows creating sessions after default session is cleared") {
178+
// Create a new session manager to test initialization
179+
val sessionManager = new SparkConnectSessionManager()
180+
181+
// Initialize the base session with the test SparkContext
182+
sessionManager.initializeBaseSession(spark.sparkContext)
183+
184+
// Clear the default and active sessions to simulate the scenario where
185+
// SparkSession.active or SparkSession.getDefaultSession would fail
186+
SparkSession.clearDefaultSession()
187+
SparkSession.clearActiveSession()
188+
189+
// Create an isolated session - this should still work because we have baseSession
190+
val key = SessionKey("user", UUID.randomUUID().toString)
191+
val sessionHolder = sessionManager.getOrCreateIsolatedSession(key, None)
192+
193+
// Verify the session was created successfully
194+
assert(sessionHolder != null)
195+
assert(sessionHolder.session != null)
196+
197+
// Clean up
198+
sessionManager.closeSession(key)
199+
}
200+
201+
test("initializeBaseSession is idempotent") {
202+
// Create a new session manager to test initialization
203+
val sessionManager = new SparkConnectSessionManager()
204+
205+
// Initialize the base session multiple times
206+
sessionManager.initializeBaseSession(spark.sparkContext)
207+
val key1 = SessionKey("user1", UUID.randomUUID().toString)
208+
val sessionHolder1 = sessionManager.getOrCreateIsolatedSession(key1, None)
209+
val baseSessionUUID1 = sessionHolder1.session.sessionUUID
210+
211+
// Initialize again - should not change the base session
212+
sessionManager.initializeBaseSession(spark.sparkContext)
213+
val key2 = SessionKey("user2", UUID.randomUUID().toString)
214+
val sessionHolder2 = sessionManager.getOrCreateIsolatedSession(key2, None)
215+
216+
// Both sessions should be isolated from each other
217+
assert(sessionHolder1.session.sessionUUID != sessionHolder2.session.sessionUUID)
218+
219+
// Clean up
220+
sessionManager.closeSession(key1)
221+
sessionManager.closeSession(key2)
222+
}
174223
}

0 commit comments

Comments
 (0)