spark源码阅读之shuffle模块②

Java SenLin 4年前 (2020-02-14) 498次浏览 已收录 0个评论

在spark源码阅读之shuffle模块①中,介绍了spark版本shuffle的演化史,提到了主要的两个shuffle策略:HashBasedShuffle和SortedBasedShuffle,分别分析了它们的原理以及shuffle write过程,而中间的过程,也就是shuffleMapTask运算结果的处理过程在spark源码阅读之executor模块③文章中也已经分析过,本章继续分析下游的shuffle read过程,本篇文章源码基于spark 1.6.3

shuffle read

shuffle read的起点应该是下游的Reducer来读取中间落地文件,而除了需要从外部存储取数据和已经cache或者checkpoint的RDD之外,一般的Task都是通过ShuffledRDD的shuffle read开始reduce之旅的。

首先可以看一下ShuffledRDD的compute()方法

override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
  val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
  SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
    .read()
    .asInstanceOf[Iterator[(K, C)]]
}

调用ShuffleManager的getReader方法回去一个reader,之前说过,ShuffleManager这里有两个实现类,HashShuffleManager和SortShuffleManager了,分别对应两种不同的策略,但在shuffle read的过程中,他们的getReader方法都创建了同一个BlockStoreShuffleReader对象,也就是他们的shuffle read过程相同,接着应该点入BlockStoreShuffleReader的read()方法:

// shuffle read的核心实现,读取map out结果并做聚合
  override def read(): Iterator[Product2[K, C]] = {
    val blockFetcherItr = new ShuffleBlockFetcherIterator(
      context,
      blockManager.shuffleClient,
      blockManager,
      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
      // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)

    // Wrap the streams for compression based on configuration
    val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
      blockManager.wrapForCompression(blockId, inputStream)   //将输入根据参数进行压缩
    }

    val ser = Serializer.getSerializer(dep.serializer)
    val serializerInstance = ser.newInstance()    //获取序列化工具

    // Create a key/value iterator for each stream
    val recordIter = wrappedStreams.flatMap { wrappedStream =>
      // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
      // NextIterator. The NextIterator makes sure that close() is called on the
      // underlying InputStream when all records have been read.
      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator    //将输入反序列化为KeyValueIterator
    }

    // Update the context task metrics for each record read.
    // 更新Task context的元数据信息
    val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
      recordIter.map(record => {
        readMetrics.incRecordsRead(1)
        record
      }),
      context.taskMetrics().updateShuffleReadMetrics())

    // An interruptible iterator must be used here in order to support task cancellation
    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)  //可取消的iter

    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { //需要聚合
      if (dep.mapSideCombine) { //读取map端已聚合过的数据
        // We are reading values that are already combined
        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
      } else {    //仅需要reduce端的聚合
        // We don't know the value type, but also don't care -- the dependency *should*
        // have made sure its compatible w/ this aggregator, which will convert the value
        // type to the combined type C
        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
      }
    } else {  //不需要聚合
      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
    }

    // Sort the output if there is a sort ordering defined.
    dep.keyOrdering match {   //判断是否需要排序
      case Some(keyOrd: Ordering[K]) => //如果需要排序
        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
        // the ExternalSorter won't spill to disk.
        // 使用ExternalSorter进行排序,如果spark.shuffle.spill没有开启,那么数据是不会写入硬盘的
        val sorter =
          new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
        sorter.insertAll(aggregatedIter)
        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
        context.internalMetricsToAccumulators(
          InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
        CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
      case None =>
        aggregatedIter
    }
  }

这段代码中已经做了注释,切分一下有三块功能:

  1. 用序列化工具读取文件成为一个key/value iterator并更新Task context的元数据信息
  2. 根据传入的Dependency中是否有聚合动作来对数据进行聚合处理
  3. 根据Dependency中是否存在key的排序器来对数据进行排序处理

其中,aggregator和keyOrdering对应着shuffle write过程中的相应参数,实现比较简单,这里不做具体分析,我们主要关注下游是如何获取数据的,这样可以与上一篇文章一起形成关于shuffle整个过程的闭环。

block fetch

在第一部分中,首先创建了一个ShuffleBlockFetcherIterator对象,这个对象会创建一个(BlockID, InputStream)形式的Iterator来拉取中间文件的multiple blocks,这个对象在实例化的过程中首先会调用initialize()方法,以下是其源码:

private[this] def initialize(): Unit = {
  // Add a task completion callback (called in both success case and failure case) to cleanup.
  context.addTaskCompletionListener(_ => cleanup())
  // Split local and remote blocks.
  // 如果数据从其他节点上获取,那么需要通过网络
  val remoteRequests: ArrayBuffer[FetchRequest] = splitLocalRemoteBlocks()
  // Add the remote requests into our queue in a random order
  fetchRequests ++= Utils.randomize(remoteRequests)
  // Send out initial requests for blocks, up to our maxBytesInFlight
  // sendFetchRequests发送请求,每次请求最大值为maxBytesInFlight(默认48MB),5个线程到5个节点
  fetchUpToMaxBytes()
  val numFetches = remoteRequests.size - fetchRequests.size
  logInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))
  // Get Local Blocks
  // 如果数据在本地,直接获取即可
  fetchLocalBlocks()
  logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}

代码中拉取数据有两种,一种是remoteBlocks另一种localBlocks,如果数据不在本地节点上,那么就要通过网络去获取数据,通过网络拉取就会占用网络带宽,所以系统提供了两种策略,具体实现在splitLocalRemoteBlocks方法中:

private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
    // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them
    // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
    // nodes, rather than blocking on reading output from one node.
    // 每次最多启动5个线程到最多5个节点上读取数据
    // 每次请求的数据大小不会超过maxBytesInFlight的1/5
    val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
    logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)

    // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
    // at most maxBytesInFlight in order to limit the amount of data in flight.
    val remoteRequests = new ArrayBuffer[FetchRequest]

    // Tracks total number of blocks (including zero sized blocks)
    var totalBlocks = 0
    for ((address, blockInfos) <- blocksByAddress) {
      totalBlocks += blockInfos.size
      if (address.executorId == blockManager.blockManagerId.executorId) {
        // Filter out zero-sized blocks
        // 需要过滤大小为0的本地block
        localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
        numBlocksToFetch += localBlocks.size
      } else {    // 需要远程获取的block
        val iterator = blockInfos.iterator
        var curRequestSize = 0L
        var curBlocks = new ArrayBuffer[(BlockId, Long)]
        while (iterator.hasNext) {
          val (blockId, size) = iterator.next()
          // Skip empty blocks
          if (size > 0) {
            curBlocks += ((blockId, size))
            remoteBlocks += blockId
            numBlocksToFetch += 1
            curRequestSize += size
          } else if (size < 0) {
            throw new BlockException(blockId, "Negative block size " + size)
          }
          if (curRequestSize >= targetRequestSize) {
            // Add this FetchRequest
            remoteRequests += new FetchRequest(address, curBlocks)
            curBlocks = new ArrayBuffer[(BlockId, Long)]
            logDebug(s"Creating fetch request of $curRequestSize at $address")
            curRequestSize = 0
          }
        }
        // Add in the final request
        if (curBlocks.nonEmpty) {
          remoteRequests += new FetchRequest(address, curBlocks)
        }
      }
    }
    logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")
    remoteRequests
  }

从代码逻辑中可以得出通过网络了拉取数据blocks的策略:

  1. 每次最多启动5个线程到最多5个节点上读取数据
  2. 每次请求数据的大小不会超过spark.reducer.maxMbInFlight(默认48MB)的五分之一

这么做的目的一个是减少占用带宽,另一个是使用并行化请求数据减少请求时间。

请求已经切分好了,接下来通过调用fetchUpToMaxBytes()方法来发送请求:

private def fetchUpToMaxBytes(): Unit = {
  // Send fetch requests up to maxBytesInFlight
  while (fetchRequests.nonEmpty &&
    (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
    sendRequest(fetchRequests.dequeue())
  }
}

当请求大小不超过maxBytesInFlight,发送请求sendRequest()

private[this] def sendRequest(req: FetchRequest) {
  logDebug("Sending request for %d blocks (%s) from %s".format(
    req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
  bytesInFlight += req.size
  // so we can look up the size of each blockID
  val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
  val blockIds = req.blocks.map(_._1.toString)
  val address = req.address
  // 通过网络fetchBlocks的实现类为:NettyBlockTransferService,本地的fetchBlocks实现类为:BlockTransferService
  shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
    new BlockFetchingListener {
      override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
        // Only add the buffer to results queue if the iterator is not zombie,
        // i.e. cleanup() has not been called yet.
        if (!isZombie) {
          // Increment the ref count because we need to pass this to a different thread.
          // This needs to be released after use.
          buf.retain()
          results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf))
          shuffleMetrics.incRemoteBytesRead(buf.size)
          shuffleMetrics.incRemoteBlocksFetched(1)
        }
        logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
      }
      override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
        logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
        results.put(new FailureFetchResult(BlockId(blockId), address, e))
      }
    }
  )
}

通过ShuffleClient实例去拉取Blocks,这里的ShuffleClient有多种实现,其中通过网络获取Blocks的实现类为:NettyBlockTransferService,而本地获取Blocks的实现类为:BlockTransferService,fetchBlocks方法中根据传入的host地址端口和executorId,然后使用Netty协议去获取数据。

接下来,我们再来看一下本地的数据拉取方法:

private[this] def fetchLocalBlocks() {
  val iter = localBlocks.iterator
  while (iter.hasNext) {
    val blockId = iter.next()
    try {
      val buf = blockManager.getBlockData(blockId)
      shuffleMetrics.incLocalBlocksFetched(1)
      shuffleMetrics.incLocalBytesRead(buf.size)
      buf.retain()
      results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf))
    } catch {
      case e: Exception =>
        // If we see an exception, stop immediately.
        logError(s"Error occurred while fetching local blocks", e)
        results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
        return
    }
  }
}

可以看出,本地的Blocks直接通过blockManager的getBlockData方法去获取数据,而如果数据是通过shuffle过程获取的,getBlockData就有两种实现:Hash和Sort
Hash的实现类为:FileShuffleBlockResolver
Sort的实现类为:IndexShuffleBlockResolver

其中的不同就是Sort策略的getBlockData需要先通过IndexFile定位到数据对应的FileSegment,而Hash则可以直接通过blockId直接获取文件.
以下是IndexShuffleBlockResolver的getBlockData方法:

override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
  // The block is actually going to be a range of a single map output file for this map, so
  // find out the consolidated file, then the offset within that from our index
  val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)
  val in = new DataInputStream(new FileInputStream(indexFile))
  try {
    ByteStreams.skipFully(in, blockId.reduceId * 8)   //跳到本次block的数据区
    val offset = in.readLong()    // 数据文件中的开始位置
    val nextOffset = in.readLong()    // 数据文件中的结束位置
    new FileSegmentManagedBuffer(
      transportConf,
      getDataFile(blockId.shuffleId, blockId.mapId),
      offset,
      nextOffset - offset)
  } finally {
    in.close()
  }
}
性能调优

通过两篇对于shuffle的架构和源码实现的分析,可以得出shuffle是Spark Core中比较复杂的模块,也很影响性能,这里总结一下shuffle模块中对性能有影响的系统配置:

spark.shuffle.manager

这个参数用来选择shuffle的机制:Hash还是Sort,在spark 1.2版本后默认的机制已从Hash变成了Sort,而在2.0版本后,Hash机制已经退出历史舞台。那么选择Hash还是Sort主要是取决于内存、排序和文件操作等多方面因素,如果产生的中间文件不是很多,那么采用Hash模式来避免不必要的排序可能是更好地选择

spark.shuffle.sort.BypassMergeThreshold

这个配置的默认值是200,用于设置在Reducer的partitions数目少于这个值时,Sort Based Shuffle内部使用归并排序的方式处理数据,而是直接将每个Partition写入单独的文件。这种方式可以看作Sort Based Shuffle在Shuffle量比较小的时候对Hash Based Shuffle的一种折中,当然它也存在中间文件过多的问题,如果GC或者内存使用比较紧张的话,可以适当降低这个值

spark.shuffle.compress和spark.shuffle.spill.compress

这两个参数的默认配置都是true,前者是设置shuffle最终输出到文件系统的文件是否压缩,后者是在shuffle过程中处理数据写入外部存储的数据是否压缩。
spark.shuffle.compress
如果下游的Task读取上游结果的网络IO成为瓶颈,那么可以考虑启用压缩来减少网络IO,如果计算是CPU密集型的,那么将这个选项设置为false更为合适。
spark.shuffle.spill.compress
如果在处理中间结果spill到本地硬盘时,出现Disk IO,那么设置为true启用压缩可能会比较合适,如果本地硬盘是SSD的,那么设置为false会比较合适。

简单来说,需要在项目中衡量压缩、解压缩带来的时间消耗与磁盘、带宽IO之间的利弊,具体情况,具体对待。

spark.reducer.maxSizeInFlight

这个参数用于限制一个Reducer Task向其他的Executor请求shuffle数据是所占用的最大内存数,默认值为48MB,如果带宽限制较大,那么适当调小这个值,如果是万兆网卡,可以考虑增大这个值。


top8488大数据 , 版权所有丨如未注明 , 均为原创丨本网站采用BY-NC-SA协议进行授权
转载请注明原文链接:spark源码阅读之shuffle模块②
喜欢 (0)
[]
分享 (0)

您必须 登录 才能发表评论!