diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4f7338f74e298..e06b7d86e1db0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -2937,7 +2937,8 @@ private[spark] class DAGScheduler( } else { // This stage is only used by the job, so finish the stage if it is running. val stage = stageIdToStage(stageId) - if (runningStages.contains(stage)) { + // Stages with failedAttemptIds may have tasks that are running + if (runningStages.contains(stage) || stage.failedAttemptIds.nonEmpty) { try { // killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask taskScheduler.killAllTaskAttempts(stageId, shouldInterruptTaskThread(job), reason) if (legacyAbortStageAfterKillTasks) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 243d33fe55a79..3e507df706ba5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -185,6 +185,8 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti private var firstInit: Boolean = _ /** Set of TaskSets the DAGScheduler has requested executed. */ val taskSets = scala.collection.mutable.Buffer[TaskSet]() + /** Track running tasks, the key is the task's stageId , the value is the task's partitionId */ + var runningTaskInfos = new HashMap[Int, HashSet[Int]]() /** Stages for which the DAGScheduler has called TaskScheduler.killAllTaskAttempts(). */ val cancelledStages = new HashSet[Int]() @@ -206,12 +208,14 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti // normally done by TaskSetManager taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) taskSets += taskSet + runningTaskInfos.put(taskSet.stageId, new HashSet[Int]() ++ taskSet.tasks.map(_.partitionId)) } override def killTaskAttempt( taskId: Long, interruptThread: Boolean, reason: String): Boolean = false override def killAllTaskAttempts( stageId: Int, interruptThread: Boolean, reason: String): Unit = { cancelledStages += stageId + runningTaskInfos.remove(stageId) } override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = { taskSets.filter(_.stageId == stageId).lastOption.foreach { ts => @@ -393,6 +397,14 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti handleShuffleMergeFinalized(shuffleMapStage, shuffleMapStage.shuffleDep.shuffleMergeId) } } + + override private[scheduler] def handleTaskCompletion(event: CompletionEvent): Unit = { + super.handleTaskCompletion(event) + runningTaskInfos.get(event.task.stageId).foreach{ partitions => + partitions -= event.task.partitionId + if (partitions.isEmpty) runningTaskInfos.remove(event.task.stageId) + } + } } override def beforeEach(): Unit = { @@ -2252,6 +2264,46 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assert(scheduler.activeJobs.isEmpty) } + test("SPARK-50648: when job is cancelled during shuffle retry in parent stage, " + + "should kill all running tasks") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2)) + val reduceRdd = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + completeShuffleMapStageSuccessfully(0, 0, 2) + sc.listenerBus.waitUntilEmpty() + + val info = new TaskInfo( + 3, index = 1, attemptNumber = 1, + partitionId = taskSets(1).tasks(0).partitionId, 0L, "", "", TaskLocality.ANY, true) + // result task 0.0 fetch failed, but result task 1.0 is still running + runEvent(makeCompletionEvent(taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleDep.shuffleId, 0L, 0, 1, "ignored"), + null, + Seq.empty, + Array.empty, + info)) + sc.listenerBus.waitUntilEmpty() + + Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2) + // map stage is running by resubmitted, result stage is waiting + // map tasks and the origin result task 1.0 are running + assert(scheduler.runningStages.size == 1, "Map stage should be running") + val mapStage = scheduler.runningStages.head + assert(mapStage.id === 0) + assert(mapStage.latestInfo.failureReason.isEmpty) + assert(scheduler.waitingStages.size == 1, "Result stage should be waiting") + assert(runningTaskInfos.size == 2) + assert(runningTaskInfos(taskSets(1).stageId).size == 1, + "origin result task 1.0 should be running") + + scheduler.doCancelAllJobs() + // all tasks should be killed + assert(runningTaskInfos.isEmpty) + assert(scheduler.runningStages.isEmpty) + assert(scheduler.waitingStages.isEmpty) + } + test("misbehaved accumulator should not crash DAGScheduler and SparkContext") { val acc = new LongAccumulator { override def add(v: java.lang.Long): Unit = throw new DAGSchedulerSuiteDummyException