Spark Aggregate算子

源码定义

/** * Aggregate the elements of each partition, and then the results for all the partitions, using * given combine functions and a neutral "zero value". This function can return a different result * type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U * and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are * allowed to modify and return their first argument instead of creating a new U to avoid memory * allocation. * * @param zeroValue the initial value for the accumulated result of each partition for the *`seqOp` operator, and also the initial value for the combine results from *different partitions for the `combOp` operator - this will typically be the *neutral element (e.g. `Nil` for list concatenation or `0` for summation) * @param seqOp an operator used to accumulate results within a partition * @param combOp an associative operator used to combine results from different partitions */ def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = withScope { // Clone the zero value since we will also be serializing it as part of tasks var jobResult = Utils.clone(zeroValue, sc.env.serializer.newInstance()) val cleanSeqOp = sc.clean(seqOp) val cleanCombOp = sc.clean(combOp) val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult) sc.runJob(this, aggregatePartition, mergeResult) jobResult }

当我看这个源码的时候,那是一脸懵逼,不是太懂,看了好多优秀的博客才弄懂,这里附上优秀博客连接。https://www.cnblogs.com/Gxiaobai/p/11437739.html
我这里先对 aggregate 的各个参数做个说明:
  • (zeroValue: U):这个是默认值,就是你这个默认值给的什么类型的,最后这个算子就给你返回什么类型的。
  • (seqOp: (U, T) => U:这个函数适用于每个分区内处理。U 就是默认值,T 是参数的值。
  • combOp: (U, U) => U):这个函数适用于你怎么处理每个分区。
下面来举个例子就明白了。
举个栗子 求 RDD 内元素的个数以及元素 value 的和
object aggregate { def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("aggregate").setMaster("local") val sc = new SparkContext(conf) sc.setLogLevel("Error")val rdd1 = sc.parallelize(List(("a", 2), ("a", 5), ("a", 4), ("b", 5), ("c", 3), ("b", 3), ("c", 6), ("a", 8)), 4) val result = rdd1.aggregate((0, 0)) ( (u, c) => (u._1 + 1, u._2 + c._2), (r1, r2) => (r1._1 + r2._1, r1._2 + r2._2) )println(result) } }

输出:
Spark Aggregate算子
文章图片

解释:
  • rdd1 是一个有 4 个分区的 rdd。分区如下:
    Spark Aggregate算子
    文章图片

  • (u, c) => (u._1 + 1, u._2 + c._2):这个函数是在每个分区内部运行的。part0(0 + 1 + 1, 0 + 2 + 5), part1(0 + 1 + 1, 0 + 4 + 5), part2(0 + 1 + 1, 0 + 3 + 3), part3(0 + 1 + 1, 0 + 6 + 8)
  • (r1, r2) => (r1._1 + r2._1, r1._2 + r2._2):这个函数是合并所有分区。part0._1 + part1._1, part0._2 + part1._2… 以此类推,就可以得到(8, 36)这个结果。
【Spark Aggregate算子】~

    推荐阅读