@@ -36,7 +36,7 @@ import org.apache.spark.rdd.{DeterministicLevel, RDD}
3636import org .apache .spark .scheduler .SchedulingMode .SchedulingMode
3737import org .apache .spark .shuffle .{FetchFailedException , MetadataFetchFailedException }
3838import org .apache .spark .storage .{BlockId , BlockManagerId , BlockManagerMaster }
39- import org .apache .spark .util .{ AccumulatorContext , AccumulatorV2 , CallSite , LongAccumulator , Utils }
39+ import org .apache .spark .util ._
4040
4141class DAGSchedulerEventProcessLoopTester (dagScheduler : DAGScheduler )
4242 extends DAGSchedulerEventProcessLoop (dagScheduler) {
@@ -2195,6 +2195,102 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
21952195 assertDataStructuresEmpty()
21962196 }
21972197
2198+ test(" stage level active shuffle tracking" ) {
2199+ // We will have 3 stages depending on each other.
2200+ // The second stage is composed of 2 RDDs to check we're tracking shuffle up the chain.
2201+ val shuffleMapRdd1 = new MyRDD (sc, 2 , Nil )
2202+ val shuffleDep1 = new ShuffleDependency (shuffleMapRdd1, new HashPartitioner (1 ))
2203+ val shuffleId1 = shuffleDep1.shuffleId
2204+ val shuffleMapRdd2 = new MyRDD (sc, 2 , List (shuffleDep1), tracker = mapOutputTracker)
2205+ val shuffleDep2 = new ShuffleDependency (shuffleMapRdd2, new HashPartitioner (1 ))
2206+ val shuffleId2 = shuffleDep2.shuffleId
2207+ val intermediateRdd = new MyRDD (sc, 1 , List (shuffleDep2), tracker = mapOutputTracker)
2208+ val intermediateDep = new OneToOneDependency (intermediateRdd)
2209+ val reduceRdd = new MyRDD (sc, 1 , List (intermediateDep), tracker = mapOutputTracker)
2210+
2211+ // Submit the job.
2212+ // Both shuffles should become active.
2213+ submit(reduceRdd, Array (0 ))
2214+ assert(mapOutputTracker.shuffleStatuses(shuffleId1).isActive === true )
2215+ assert(mapOutputTracker.shuffleStatuses(shuffleId2).isActive === true )
2216+
2217+ // Complete the first stage.
2218+ // Both shuffles remain active.
2219+ completeShuffleMapStageSuccessfully(0 , 0 , 2 )
2220+ assert(mapOutputTracker.shuffleStatuses(shuffleId1).isActive === true )
2221+ assert(mapOutputTracker.shuffleStatuses(shuffleId2).isActive === true )
2222+
2223+ // Complete the second stage.
2224+ // Shuffle 1 is no longer needed and should become inactive.
2225+ completeShuffleMapStageSuccessfully(1 , 0 , 1 )
2226+ assert(mapOutputTracker.shuffleStatuses(shuffleId1).isActive === false )
2227+ assert(mapOutputTracker.shuffleStatuses(shuffleId2).isActive === true )
2228+
2229+ // Complete the results stage.
2230+ // Both shuffles are no longer needed and should become inactive.
2231+ completeNextResultStageWithSuccess(2 , 0 )
2232+ assert(mapOutputTracker.shuffleStatuses(shuffleId1).isActive === false )
2233+ assert(mapOutputTracker.shuffleStatuses(shuffleId2).isActive === false )
2234+
2235+ // Double check results.
2236+ assert(results === Map (0 -> 42 ))
2237+ results.clear()
2238+ assertDataStructuresEmpty()
2239+ }
2240+
2241+ test(" stage level active shuffle tracking with multiple dependents" ) {
2242+ // We will have a diamond shape dependency.
2243+ val shuffleMapRdd = new MyRDD (sc, 2 , Nil )
2244+ val shuffleDep = new ShuffleDependency (shuffleMapRdd, new HashPartitioner (1 ))
2245+ val shuffleId = shuffleDep.shuffleId
2246+ val intermediateRdd1 = new MyRDD (sc, 1 , List (shuffleDep), tracker = mapOutputTracker)
2247+ val intermediateRdd2 = new MyRDD (sc, 1 , List (shuffleDep), tracker = mapOutputTracker)
2248+ val intermediateDep1 = new ShuffleDependency (intermediateRdd1, new HashPartitioner (1 ))
2249+ val intermediateDep2 = new ShuffleDependency (intermediateRdd2, new HashPartitioner (1 ))
2250+ val reduceRdd =
2251+ new MyRDD (sc, 1 , List (intermediateDep1, intermediateDep2), tracker = mapOutputTracker)
2252+
2253+ // Submit the job.
2254+ // Shuffle becomes active.
2255+ submit(reduceRdd, Array (0 ))
2256+ assert(mapOutputTracker.shuffleStatuses(shuffleId).isActive === true )
2257+
2258+ // Complete the shuffle stage.
2259+ // Shuffle remains active.
2260+ completeShuffleMapStageSuccessfully(0 , 0 , 2 )
2261+ assert(mapOutputTracker.shuffleStatuses(shuffleId).isActive === true )
2262+
2263+ // Complete first intermediate stage.
2264+ // Shuffle is still active.
2265+ val stageAttempt = taskSets(1 )
2266+ checkStageId(1 , 0 , stageAttempt)
2267+ complete(stageAttempt, stageAttempt.tasks.zipWithIndex.map {
2268+ case (task, idx) =>
2269+ (Success , makeMapStatus(" host" + ('A' + idx).toChar, 1 ))
2270+ }.toSeq)
2271+ assert(mapOutputTracker.shuffleStatuses(shuffleId).isActive === true )
2272+
2273+ // Complete second intermediate stage.
2274+ // Shuffle is no longer active.
2275+ val stageAttempt2 = taskSets(2 )
2276+ checkStageId(2 , 0 , stageAttempt2)
2277+ complete(stageAttempt2, stageAttempt2.tasks.zipWithIndex.map {
2278+ case (task, idx) =>
2279+ (Success , makeMapStatus(" host" + ('A' + idx).toChar, 1 ))
2280+ }.toSeq)
2281+ assert(mapOutputTracker.shuffleStatuses(shuffleId).isActive === false )
2282+
2283+ // Complete the results stage.
2284+ // Shuffle is still inactive.
2285+ completeNextResultStageWithSuccess(3 , 0 )
2286+ assert(mapOutputTracker.shuffleStatuses(shuffleId).isActive === false )
2287+
2288+ // Double check results.
2289+ assert(results === Map (0 -> 42 ))
2290+ results.clear()
2291+ assertDataStructuresEmpty()
2292+ }
2293+
21982294 test(" map stage submission with fetch failure" ) {
21992295 val shuffleMapRdd = new MyRDD (sc, 2 , Nil )
22002296 val shuffleDep = new ShuffleDependency (shuffleMapRdd, new HashPartitioner (2 ))
0 commit comments