Spark之combineByKey学习理解

combineByKey()是最为常用的基于键进行聚合的函数。大多数基于键聚合的函数都是用它实现的。和aggregate()一样,combineByKey()可以让用户返回与输入数据的类型不同的返回值。


要理解combineByKey(),要先理解它在处理数据时是如何处理每个元素的。由于combineByKey()会遍历分区中的所有元素,因此每个元素的键要么还没有遇到过,要么就和之前的某个元素的键相同。


如果这是一个新的元素,combineByKey()会使用一个叫做createCombiner()的函数来创建那个键对应的累加器的初始值。需要注意的是,这个过程会在每个分区中第一次出现各个键时发生,而不是在整个RDD中第一次出现一个键时发生。


如果这是一个在处理当前分区之前已经遇到的键,它会使用mergeValue()方法将该键的累加器对应的当前值与这个新的值进行合并。
【Spark之combineByKey学习理解】

由于每个分区都是独立处理的,因此对于同一个键可以有多个累加器。如果有两个或者更多的分区都有对应同一个键的累加器,就需要使用用户提供的mergeCombiners()方法将各个分区的结果合并。


combineByKey() 有多个参数分别对应聚合操作的各个阶段,因而非常适合用来解释聚合操作各个阶段的功能划分。为了更好地演示combineByKey() 是如何工作的,下面来看看如何
计算各键对应的平均值。
Spark之combineByKey学习理解
文章图片
Spark之combineByKey学习理解
文章图片
下面来看看如何计算各键对应的平均值


package spark; import java.util.Map; import java.util.Map.Entry; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; import scala.Serializable; import scala.Tuple2; public class AvgCount implements Serializable { /** * */ private static final long serialVersionUID = 4702514714336992425L; private int total; private int num; static SparkConf conf = new SparkConf().setMaster("local").setAppName("wordCount"); // 创建一个java版本的Spark Context static JavaSparkContext sc = new JavaSparkContext(conf); public AvgCount(int total, int num) { this.total = total; this.num = num; } public float avg() { return total / (float) num; } public void testCombineByKey() { Function createCombiner = new Function() { private static final long serialVersionUID = 1L; @Override public AvgCount call(Integer x) throws Exception { return new AvgCount(x, 1); } }; Function2 mergeValue = https://www.it610.com/article/new Function2() { private static final long serialVersionUID = 1L; @Override public AvgCount call(AvgCount a, Integer x) throws Exception { a.total += x; a.num += 1; return a; } }; Function2 mergeCombiners = new Function2() { private static final long serialVersionUID = 1L; @Override public AvgCount call(AvgCount a1, AvgCount a2) throws Exception { a1.total += a2.total; a1.num += a2.num; return a1; } }; JavaRDD input = sc.textFile("F:\\spark\\spark-2.2.1-bin-hadoop2.7\\README.md"); // 转换为键值对并计数 JavaPairRDD counts = input.mapToPair(new PairFunction() { private static final long serialVersionUID = 1L; @SuppressWarnings("unchecked") @Override public Tuple2 call(String t) throws Exception { return new Tuple2(t, 1); } }).reduceByKey(new Function2() { private static final long serialVersionUID = 1L; public Integer call(Integer v1, Integer v2) throws Exception { return v1 + v2; } }); JavaPairRDD avgCounts = counts.combineByKey(createCombiner, mergeValue, mergeCombiners); Map countMap = avgCounts.collectAsMap(); for (Entry entry : countMap.entrySet()) { System.out.println(entry.getKey() + ":" + entry.getValue().avg()); } }}



    推荐阅读