Spark源码分析 – DAGScheduler
1. eventQueue, 所有需要DAGScheduler处理的事情都需要往eventQueue中发送event

2. eventLoop Thread, 会不断的从eventQueue中获取event并处理

3. 实现TaskSchedulerListener, 并注册到TaskScheduler中, 这样TaskScheduler可以随时调用TaskSchedulerListener中的接口报告状况变更 


/** * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of * stages for each job, keeps track of which RDDs and stage outputs are materialized, and finds a * minimal schedule to run the job. It then submits stages as TaskSets to an underlying * TaskScheduler implementation that runs them on the cluster. * * In addition to coming up with a DAG of stages, this class also determines the preferred * locations to run each task on, based on the current cache status, and passes these to the * low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being * lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are * not caused by shuffie file loss are handled by the TaskScheduler, which will retry each task * a small number of times before cancelling the whole stage. * * THREADING: This class runs all its logic in a single thread executing the run() method, to which * events are submitted using a synchonized queue (eventQueue). The public API methods, such as * runJob, taskEnded and executorLost, post events asynchronously to this queue. All other methods * should be private. */private[spark]class DAGScheduler(    taskSched: TaskScheduler, // 绑定的TaskScheduler    mapOutputTracker: MapOutputTracker,    blockManagerMaster: BlockManagerMaster,    env: SparkEnv)  extends TaskSchedulerListener with Logging {  def this(taskSched: TaskScheduler) {    this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)  }  // task需要将task执行的状况报告给DAGScheduler,所以需要把DAGScheduler作为listener加到TaskScheduler中 taskSched.setListener(this)
// 并且实现各种TaskSchedulerListener的接口, 以便于TaskScheduler在状态发生变化时调用  // Called by TaskScheduler to report task's starting.  override def taskStarted(task: Task[_], taskInfo: TaskInfo) {    eventQueue.put(BeginEvent(task, taskInfo))  }
private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent] // DAGScheduler的核心event queue  val nextJobId = new AtomicInteger(0)  val nextStageId = new AtomicInteger(0)  val stageIdToStage = new TimeStampedHashMap[Int, Stage]  val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]  private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]  private val listenerBus = new SparkListenerBus() //DAGScheduler本身也提供SparkListenerBus, 便于其他模块listen DAGScheduler  // Contains the locations that each RDD's partitions are cached on  private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]]
// Start a thread to run the DAGScheduler event loop  def start() {    new Thread("DAGScheduler") { // 创建event处理线程      setDaemon(true)      override def run() {        DAGScheduler.this.run()      }    }.start()  }
/**   * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure   * events and responds by launching tasks. This runs in a dedicated thread and receives events   * via the eventQueue.   */  private def run() {    SparkEnv.set(env)    while (true) {      val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)      if (event != null) {        logDebug("Got event of type " + event.getClass.getName)      }      this.synchronized { // needed in case other threads makes calls into methods of this class        if (event != null) {          if (processEvent(event)) {            return          }        }        val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability        // Periodically resubmit failed stages if some map output fetches have failed and we have        // waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,        // tasks on many other nodes are bound to get a fetch failure, and they won't all get it at        // the same time, so we want to make sure we've identified all the reduce tasks that depend        // on the failed node.        if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {          resubmitFailedStages()        } else {          submitWaitingStages()        }      }    }  }
/**   * Process one event retrieved from the event queue.   * Returns true if we should stop the event loop.   */  private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {    event match {      case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) =>        val jobId = nextJobId.getAndIncrement()  // 获取新的jobId, nextJobId是AtomicInteger        val finalStage = newStage(finalRDD, None, jobId, Some(callSite)) // 用finalRDD创建finalStage,前面是否有其他的stage或RDD需要根据deps推断        val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) // 用finalStage创建Job        clearCacheLocs()        if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {          // Compute very short actions like first() or take() with no parent stages locally.          runLocally(job) // 对于简单的Job, 直接locally执行        } else {          listenerBus.post(SparkListenerJobStart(job, properties))          idToActiveJob(jobId) = job          activeJobs += job          resultStageToJob(finalStage) = job          submitStage(finalStage)        }
// 对于各种event的处理, 这里只看JobSubmitted, 其他的先省略  }


1. dagScheduler.runJob

继续前面, 在SparkContext中调用runJob的结果就是调用dagScheduler.runJob 

而dagScheduler.runJob的工作, 就是把toSubmit event放到eventQueue中去, 并且wait这个Job结束, 很简单 

def runJob[T, U: ClassManifest](      finalRdd: RDD[T],      func: (TaskContext, Iterator[T]) => U,      partitions: Seq[Int],      callSite: String,      allowLocal: Boolean,      resultHandler: (Int, U) => Unit,      properties: Properties = null)  {    if (partitions.size == 0) {      return    }    val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob(        finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties)    eventQueue.put(toSubmit)    waiter.awaitResult() match {      case JobSucceeded => {}      case JobFailed(exception: Exception, _) =>        logInfo("Failed to run " + callSite)        throw exception    }  }


1.1 JobWaiter

JobWaiter比较简单, 首先实现JobListener的taskSucceeded和jobFailed函数, 当DAGScheduler收到tasksuccess或fail的event就会调用相应的函数 

在tasksuccess会判断当所有task都success时, 就表示jobFinished 
而awaitResult, 就是一直等待jobFinished被置位

private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit)  extends JobListener {  override def taskSucceeded(index: Int, result: Any) {    synchronized {      if (jobFinished) {        throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")      }      resultHandler(index, result.asInstanceOf[T]) // 使用resultHandler处理task result      finishedTasks += 1      if (finishedTasks == totalTasks) {        jobFinished = true        jobResult = JobSucceeded        this.notifyAll()      }    }  }  override def jobFailed(exception: Exception) {……}  def awaitResult(): JobResult = synchronized {    while (!jobFinished) {      this.wait()    }    return jobResult  }}


1.2 JobSubmitted

JobSubmitted只是DAGSchedulerEvent的一种, 典型的pattern matching的场景 


private[spark] sealed trait DAGSchedulerEventprivate[spark] case class JobSubmitted(    finalRDD: RDD[_],    func: (TaskContext, Iterator[_]) => _,    partitions: Array[Int],    allowLocal: Boolean,    callSite: String,    listener: JobListener,    properties: Properties = null)  extends DAGSchedulerEventprivate[spark] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEventprivate[spark] case class CompletionEvent(    task: Task[_],    reason: TaskEndReason,    result: Any,    accumUpdates: Map[Long, Any],    taskInfo: TaskInfo,    taskMetrics: TaskMetrics)  extends DAGSchedulerEventprivate[spark] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEventprivate[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEventprivate[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEventprivate[spark] case object StopDAGScheduler extends DAGSchedulerEvent


2 processEvent.JobSubmitted

JobSubmit, 首先创建final stage, 然后submit final stage


2.1 submitStage

在submitStage, 首先会产生Stage的DAG, 然后按照先后顺序去提交每个stage的tasks

/** Submits stage, but first recursively submits any missing parents. */  private def submitStage(stage: Stage) {    logDebug("submitStage(" + stage + ")")    if (!waiting(stage) && !running(stage) && !failed(stage)) {      val missing = getMissingParentStages(stage).sortBy(_.id) // 根据final stage发现是否有parent stage      logDebug("missing: " + missing)      if (missing == Nil) {        logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")        submitMissingTasks(stage) // 如果没有parent stage需要执行, 则直接submit当前stage        running += stage      } else {        for (parent <- missing) {          submitStage(parent) // 如果有parent stage,需要先submit parent, 因为stage之间需要顺序执行        }        waiting += stage // 当前stage放到waiting列表中      }    }  }


2.2 submitMissingTasks


可见无论是哪种stage, 都是对于每个stage中的每个partitions创建task 

并最终封装成TaskSet, 将该stage提交给taskscheduler

/** Called when stage's parents are available and we can now do its task. */  private def submitMissingTasks(stage: Stage) {   // Get our pending tasks and remember them in our pendingTasks entry    var tasks = ArrayBuffer[Task[_]]()    if (stage.isShuffleMap) { // 对于ShuffleMap Stage      for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {        val locs = getPreferredLocs(stage.rdd, p)        tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs)      }    } else { // 对于Result Stage       // This is a final stage; figure out its job's missing partitions      val job = resultStageToJob(stage)      for (id <- 0 until job.numPartitions if !job.finished(id)) {        val partition = job.partitions(id)        val locs = getPreferredLocs(stage.rdd, partition)        tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id)      }    }    taskSched.submitTasks(        new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))      if (!stage.submissionTime.isDefined) {        stage.submissionTime = Some(System.currentTimeMillis())      }    } else {      logDebug("Stage " + stage + " is actually done; %b %d %d".format(        stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))      running -= stage    }  }


