原创

Spark 核心篇 - 任务调度 4 执行 Task

详解

:accept:

上一节分析到Driver端在Task执行前的所有工作,这部分工作主要是在TaskSechduler和SchedulerBackend中完成的,处理完成后将Task交由Work上的Executor执行
通过保存的executorEndpoint进行发送任务及数据。

在Executor端ExecutorBackend负责接受消息,调用的是receive方法:

  override def receive: PartialFunction[Any, Unit] = {
    ...
    //接受到的消息为LunchTask时
    case LaunchTask(data) =>
      if (executor == null) {
        exitExecutor(1, "Received LaunchTask command but executor was null")
      } else {
        // 对data进行decode,具体原因请参考上一篇文章,主要目的是为了减小传输数据大小。
        val taskDesc = TaskDescription.decode(data.value)
        logInfo("Got assigned task " + taskDesc.taskId)
        // 获取任务执行描述
        taskResources(taskDesc.taskId) = taskDesc.resources
        // 最后使用Executor执行任务
        executor.launchTask(this, taskDesc)
      }
    ...
  }

当backend接受到LaunchTask消息时,会调用executor.launchTask方法进行处理,在方法中会封装一个TaskRunner对象,他用于管理运行时的细节,之后再将他放入threadPool中进行执行。

  def launchTask(context: ExecutorBackend, taskDescription: TaskDescription): Unit = {
    // 封装taskrunner对象
    val tr = new TaskRunner(context, taskDescription)
    // 记录正在执行的Task信息
    runningTasks.put(taskDescription.taskId, tr)
    // 交给线程池执行
    threadPool.execute(tr)
  }

TaskRunner的run方法中,首先会进行反序列化,将传递过来的Task本身以及所依赖的jar都进行处理,然后对反序列化的任务调用Task的runTask方法,
不过Task是一个抽象类,具体的实现交由子类的ShuffleMapTaskResultTaskl来具体实现。

接下来看一下TaskRunner的run方法:

override def run(): Unit = {
    val threadMXBean = ManagementFactory.getThreadMXBean
    // 生成内存管理对象,用于任务执行时的内存管理
    val taskMemoryManager = new TaskMemoryManager(env.memoryManager, taskId)
    val deserializeStartTimeNs = System.nanoTime()
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
    threadMXBean.getCurrentThreadCpuTime
    } else 0L
    Thread.currentThread.setContextClassLoader(replClassLoader)
    val ser = env.closureSerializer.newInstance()
    logInfo(s"Running $taskName (TID $taskId)")
    //发送消息,通知任务开始执行
    execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
    ......
    try {
    // Must be set before updateDependencies() is called, in case fetching dependencies
    // requires access to properties contained within (e.g. for access control).
    Executor.taskDeserializationProps.set(taskDescription.properties)

    updateDependencies(taskDescription.addedFiles, taskDescription.addedJars)
    task = ser.deserialize[Task[Any]](
        taskDescription.serializedTask, Thread.currentThread.getContextClassLoader)
    task.localProperties = taskDescription.properties
    task.setTaskMemoryManager(taskMemoryManager)

    val killReason = reasonIfKilled
    // 若任务在反序列化之前被kill掉那么将会抛出异常,并退出
    if (killReason.isDefined) {
        throw new TaskKilledException(killReason.get)
    }
    if (!isLocal) {
        logDebug("Task " + taskId + "'s epoch is " + task.epoch)
        env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch)
    }

    metricsPoller.onTaskStart(taskId, task.stageId, task.stageAttemptId)
    ......
    // tryWithSafeFinally 方法被设计成执行一段代码块,
    // 若代码块发生异常那么将在finally中执行另一段代码块。并且异常不会被覆盖。
    val value = Utils.tryWithSafeFinally {
        // 执行task,并获取到返回结果
        // 在run方法中其实是调用的runTask方法,不过其由两个子类实现
        val res = task.run(
        taskAttemptId = taskId,
        attemptNumber = taskDescription.attemptNumber,
        metricsSystem = env.metricsSystem,
        resources = taskDescription.resources)
        threwException = false
        res
    } {
        // 产生异常的时候,将会被执行的代码块,
        val releasedLocks = env.blockManager.releaseAllLocksForTask(taskId)
        val freedMemory = taskMemoryManager.cleanUpAllAllocatedMemory()

        if (freedMemory > 0 && !threwException) {
        val errMsg = s"Managed memory leak detected; size = $freedMemory bytes, TID = $taskId"
        if (conf.get(UNSAFE_EXCEPTION_ON_MEMORY_LEAK)) {
            throw new SparkException(errMsg)
        } else {
            logWarning(errMsg)
        }
        }

        if (releasedLocks.nonEmpty && !threwException) {
        val errMsg =
            s"${releasedLocks.size} block locks were not released by TID = $taskId:\n" +
            releasedLocks.mkString("[", ", ", "]")
        if (conf.get(STORAGE_EXCEPTION_PIN_LEAK)) {
            throw new SparkException(errMsg)
        } else {
            logInfo(errMsg)
        }
        }
    }
    .....
    // 更新累加器
    val accumUpdates = task.collectAccumulatorUpdates()
    val metricPeaks = metricsPoller.getTaskMetricPeaks(taskId)
    // 创建一个TaskResult,包含了任务结果,累加器,和metrics
    val directResult = new DirectTaskResult(valueBytes, accumUpdates, metricPeaks)
    val serializedDirectResult = ser.serialize(directResult)
    val resultSize = serializedDirectResult.limit()

    // 将任务结果发送给 driver
    val serializedResult: ByteBuffer = {
        // 判断运行结果的大小, 若大于最大限制则抛弃掉它
        if (maxResultSize > 0 && resultSize > maxResultSize) {
        logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
            s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
            s"dropping it.")
        ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
        } else if (resultSize > maxDirectResultSize) {
        // 直接结果的最大大小。如果任务结果大于此值,则使用block manager将结果发送回
        val blockId = TaskResultBlockId(taskId)
        env.blockManager.putBytes(
            blockId,
            new ChunkedByteBuffer(serializedDirectResult.duplicate()),
            StorageLevel.MEMORY_AND_DISK_SER)
        logInfo(
            s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
        ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
        } else {
        logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
        serializedDirectResult
        }
    }
    }
    //异常处理
    .....
    finally {
    // 在runningTasks移出这个task
    runningTasks.remove(taskId)
    if (taskStarted) {
        // This means the task was successfully deserialized, its stageId and stageAttemptId
        // are known, and metricsPoller.onTaskStart was called.
        metricsPoller.onTaskCompletion(taskId, task.stageId, task.stageAttemptId)
    }
  }
}

对于ShuffleMapTask中的runTask实现,他的计算结果将会交给BlockManager,最终返回给DAGScheduler的是一个MapStatus对象。
该对象保存了交给BlockManager的相关信息,这样这些信息将会用于下一阶段执行的输入数据。

  override def runTask(context: TaskContext): MapStatus = {
    // Deserialize the RDD using the broadcast variable.
    val threadMXBean = ManagementFactory.getThreadMXBean
    val deserializeStartTimeNs = System.nanoTime()
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime
    } else 0L
    val ser = SparkEnv.get.closureSerializer.newInstance()
    // 反序列化
    val rddAndDep = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    // 计算反序列化所需的时间
    _executorDeserializeTimeNs = System.nanoTime() - deserializeStartTimeNs
    _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
    } else 0L

    val rdd = rddAndDep._1
    val dep = rddAndDep._2
    dep.shuffleWriterProcessor.write(rdd, dep, partitionId, context, partition)
  }

而在ResultTask中实现的runTask方法,他最终返回的其实是一个func的计算结果。

  override def runTask(context: TaskContext): U = {
    // Deserialize the RDD and the func using the broadcast variables.
    val threadMXBean = ManagementFactory.getThreadMXBean
    val deserializeStartTimeNs = System.nanoTime()
    val deserializeStartCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime
    } else 0L
    val ser = SparkEnv.get.closureSerializer.newInstance()
    val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
      ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
    _executorDeserializeTimeNs = System.nanoTime() - deserializeStartTimeNs
    _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
      threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
    } else 0L
    // 返回的是func的计算结果。
    func(context, rdd.iterator(partition, context))
  }

总结

简单来说在这一阶段工作并不算复杂,主要就是接受和执行Task,并且还需要对执行结果进行处理,同时还要维护metric,这样能够在WebUI上实时的看到任务的进度。
难点在于后边任务结果的处理工作,涉及到BlockManager相关。这些将保留到spark 储存原理分析

sev7e0
Write by sev7e0
end
本文目录