Skip to content

Commit

Permalink
[SPARK-50648][CORE] Cleanup zombie tasks in non-running stages when t…
Browse files Browse the repository at this point in the history
…he job is cancelled

### What changes were proposed in this pull request?

This is a problem that Spark always had. See the following section for the scenario when the problem occurs.

 When cancel a job, some tasks may be still running.
The reason is that when `DAGScheduler#handleTaskCompletion` encounters FetchFailed, `markStageAsFinished` will be called to remove the stage in `DAGScheduler#runningStages` (see https://github.com/apache/spark/blob/7cd5c4a1d1eb56fa92c10696bdbd8450d357b128/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala#L2059) and don't `killAllTaskAttempts`.
But `DAGScheduler#cancelRunningIndependentStages` only find `runningStages`, this will leave zombie shuffle tasks, occupying cluster resources.

### Why are the changes needed?

Assume a job is stage1-> stage2, when FetchFailed occurs during the stage 2, the stage1 and stage2  will resubmit (stage2 may still have some tasks running even if stage2 is resubmitted , this is as expected, because these tasks may eventually succeed and avoid retry)

But during the execution of the stage1-retry , if the SQL is canceled, the tasks in stage1 and stage1-retry can all be killed, but the tasks previously running in stage2 are still running and can't be killed. These tasks can greatly affect cluster stability and occupy resources.

### Does this PR introduce _any_ user-facing change?
No

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #49270 from yabola/zombie-task-when-shuffle-retry.

Authored-by: chenliang.lu <chenlianglu@tencent.com>
Signed-off-by: Yi Wu <yi.wu@databricks.com>
  • Loading branch information
yabola authored and Ngone51 committed Dec 31, 2024
1 parent e1fb18d commit 5334494
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -2937,7 +2937,8 @@ private[spark] class DAGScheduler(
} else {
// This stage is only used by the job, so finish the stage if it is running.
val stage = stageIdToStage(stageId)
if (runningStages.contains(stage)) {
// Stages with failedAttemptIds may have tasks that are running
if (runningStages.contains(stage) || stage.failedAttemptIds.nonEmpty) {
try { // killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask
taskScheduler.killAllTaskAttempts(stageId, shouldInterruptTaskThread(job), reason)
if (legacyAbortStageAfterKillTasks) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
private var firstInit: Boolean = _
/** Set of TaskSets the DAGScheduler has requested executed. */
val taskSets = scala.collection.mutable.Buffer[TaskSet]()
/** Track running tasks, the key is the task's stageId , the value is the task's partitionId */
var runningTaskInfos = new HashMap[Int, HashSet[Int]]()

/** Stages for which the DAGScheduler has called TaskScheduler.killAllTaskAttempts(). */
val cancelledStages = new HashSet[Int]()
Expand All @@ -206,12 +208,14 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
// normally done by TaskSetManager
taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
taskSets += taskSet
runningTaskInfos.put(taskSet.stageId, new HashSet[Int]() ++ taskSet.tasks.map(_.partitionId))
}
override def killTaskAttempt(
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
override def killAllTaskAttempts(
stageId: Int, interruptThread: Boolean, reason: String): Unit = {
cancelledStages += stageId
runningTaskInfos.remove(stageId)
}
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
taskSets.filter(_.stageId == stageId).lastOption.foreach { ts =>
Expand Down Expand Up @@ -393,6 +397,14 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
handleShuffleMergeFinalized(shuffleMapStage, shuffleMapStage.shuffleDep.shuffleMergeId)
}
}

override private[scheduler] def handleTaskCompletion(event: CompletionEvent): Unit = {
super.handleTaskCompletion(event)
runningTaskInfos.get(event.task.stageId).foreach{ partitions =>
partitions -= event.task.partitionId
if (partitions.isEmpty) runningTaskInfos.remove(event.task.stageId)
}
}
}

override def beforeEach(): Unit = {
Expand Down Expand Up @@ -2252,6 +2264,46 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
assert(scheduler.activeJobs.isEmpty)
}

test("SPARK-50648: when job is cancelled during shuffle retry in parent stage, " +
"should kill all running tasks") {
val shuffleMapRdd = new MyRDD(sc, 2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
val reduceRdd = new MyRDD(sc, 2, List(shuffleDep))
submit(reduceRdd, Array(0, 1))
completeShuffleMapStageSuccessfully(0, 0, 2)
sc.listenerBus.waitUntilEmpty()

val info = new TaskInfo(
3, index = 1, attemptNumber = 1,
partitionId = taskSets(1).tasks(0).partitionId, 0L, "", "", TaskLocality.ANY, true)
// result task 0.0 fetch failed, but result task 1.0 is still running
runEvent(makeCompletionEvent(taskSets(1).tasks(0),
FetchFailed(makeBlockManagerId("hostA"), shuffleDep.shuffleId, 0L, 0, 1, "ignored"),
null,
Seq.empty,
Array.empty,
info))
sc.listenerBus.waitUntilEmpty()

Thread.sleep(DAGScheduler.RESUBMIT_TIMEOUT * 2)
// map stage is running by resubmitted, result stage is waiting
// map tasks and the origin result task 1.0 are running
assert(scheduler.runningStages.size == 1, "Map stage should be running")
val mapStage = scheduler.runningStages.head
assert(mapStage.id === 0)
assert(mapStage.latestInfo.failureReason.isEmpty)
assert(scheduler.waitingStages.size == 1, "Result stage should be waiting")
assert(runningTaskInfos.size == 2)
assert(runningTaskInfos(taskSets(1).stageId).size == 1,
"origin result task 1.0 should be running")

scheduler.doCancelAllJobs()
// all tasks should be killed
assert(runningTaskInfos.isEmpty)
assert(scheduler.runningStages.isEmpty)
assert(scheduler.waitingStages.isEmpty)
}

test("misbehaved accumulator should not crash DAGScheduler and SparkContext") {
val acc = new LongAccumulator {
override def add(v: java.lang.Long): Unit = throw new DAGSchedulerSuiteDummyException
Expand Down

0 comments on commit 5334494

Please sign in to comment.