源码定义
/**
* Aggregate the elements of each partition, and then the results for all the partitions, using
* given combine functions and a neutral "zero value". This function can return a different result
* type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U
* and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are
* allowed to modify and return their first argument instead of creating a new U to avoid memory
* allocation.
*
* @param zeroValue the initial value for the accumulated result of each partition for the
*`seqOp` operator, and also the initial value for the combine results from
*different partitions for the `combOp` operator - this will typically be the
*neutral element (e.g. `Nil` for list concatenation or `0` for summation)
* @param seqOp an operator used to accumulate results within a partition
* @param combOp an associative operator used to combine results from different partitions
*/
def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = withScope {
// Clone the zero value since we will also be serializing it as part of tasks
var jobResult = Utils.clone(zeroValue, sc.env.serializer.newInstance())
val cleanSeqOp = sc.clean(seqOp)
val cleanCombOp = sc.clean(combOp)
val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
sc.runJob(this, aggregatePartition, mergeResult)
jobResult
}
当我看这个源码的时候,那是一脸懵逼,不是太懂,看了好多优秀的博客才弄懂,这里附上优秀博客连接。https://www.cnblogs.com/Gxiaobai/p/11437739.html
我这里先对 aggregate 的各个参数做个说明:
- (zeroValue: U):这个是默认值,就是你这个默认值给的什么类型的,最后这个算子就给你返回什么类型的。
- (seqOp: (U, T) => U:这个函数适用于每个分区内处理。U 就是默认值,T 是参数的值。
- combOp: (U, U) => U):这个函数适用于你怎么处理每个分区。
举个栗子 求 RDD 内元素的个数以及元素 value 的和
object aggregate {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("aggregate").setMaster("local")
val sc = new SparkContext(conf)
sc.setLogLevel("Error")val rdd1 = sc.parallelize(List(("a", 2), ("a", 5), ("a", 4), ("b", 5), ("c", 3), ("b", 3), ("c", 6), ("a", 8)), 4)
val result = rdd1.aggregate((0, 0)) (
(u, c) => (u._1 + 1, u._2 + c._2),
(r1, r2) => (r1._1 + r2._1, r1._2 + r2._2)
)println(result)
}
}
输出:
文章图片
解释:
- rdd1 是一个有 4 个分区的 rdd。分区如下:
文章图片
- (u, c) => (u._1 + 1, u._2 + c._2):这个函数是在每个分区内部运行的。part0(0 + 1 + 1, 0 + 2 + 5), part1(0 + 1 + 1, 0 + 4 + 5), part2(0 + 1 + 1, 0 + 3 + 3), part3(0 + 1 + 1, 0 + 6 + 8)
- (r1, r2) => (r1._1 + r2._1, r1._2 + r2._2):这个函数是合并所有分区。part0._1 + part1._1, part0._2 + part1._2… 以此类推,就可以得到(8, 36)这个结果。
推荐阅读
- spark|spark UDAF根据某列去重求合 distinct sum
- 大数据|spark UDAF 自定义聚合函数 UserDefinedAggregateFunction 带条件的去重操作
- Spark 写入 MySQL 乱码问题
- SCD|一种基于SparkSQL的Hive数据仓库拉链表缓慢变化维(SCD2+SCD1)的示例实现
- spark算子--action篇
- 大数据|Spark调优解决方案(一)之提交任务时合理分配资源
- spark|spark2.1 新特性