大数据|spark UDAF 自定义聚合函数 UserDefinedAggregateFunction 带条件的去重操作

需求:按餐品分组,并求出无优惠金额的订单数。

package cd.custom.jde.job.udfimport org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._/** * create by roy 2020-02-12 * 去重订单,并判断是否是折扣 */ class CountDistinctAndIf extends UserDefinedAggregateFunction {override def inputSchema: StructType = { new StructType().add("orderid", StringType, nullable = true) .add("price", DoubleType, nullable = true) }override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { //println("update==>>>", buffer,input,input.getDouble(1) <= 0) //=1,说是折扣的 if (input.getDouble(1) <= 0) { //取出新加入的行,并加入缓存区 buffer(0) = (buffer.getSeq[String](0).toSet + input.getString(0)).toSeq } }override def bufferSchema: StructType = { new StructType().add("items", ArrayType(StringType, true), nullable = true) //.add("price", DoubleType, nullable = true) }//合并数据 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { //println("merge==>", buffer2) //if (buffer2 != null && buffer2.size >= 2 && buffer2.get(1) != null && buffer2.get(0) != null && buffer2.getDouble(1) > 0) { buffer1(0) = (buffer1.getSeq[String](0).toSet ++ buffer2.getSeq[String](0).toSet).toSeq }override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer(0) = Seq[String]() }override def deterministic: Boolean = trueoverride def evaluate(buffer: Row): Any = { buffer.getSeq[String](0).length }override def dataType: DataType = IntegerType }

实例应用:
package spark.udfimport cd.custom.jde.job.udf.CountDistinctAndIf import org.apache.log4j.{Level, Logger} import org.apache.spark.sql.types.{DoubleType, StringType, StructType} import org.apache.spark.sql.{Row, SparkSession}object MyOrderTest {Logger.getRootLogger.setLevel(Level.WARN)def main(args: Array[String]): Unit = {val data = https://www.it610.com/article/Seq( Row("a", "a100", 0.0, "300"), Row("a", "a100", 7.0, "300"), Row("a", "a101", 6.0, "300"), Row("a", "a101", 5.0, "301"), Row("a", "a100", 0.0, "300") ) val schme = new StructType() .add("storeid", StringType) .add("orderid", StringType) .add("yhPrice", DoubleType) .add("pid", StringType) val spark = SparkSession.builder().master("local[*]").getOrCreate() val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schme) df.show() df.createOrReplaceTempView("tab_tmp")val cCountDistinct2 = new CountDistinctAndIf spark.sqlContext.udf.register("cCountDistinct2", cCountDistinct2) spark.sql( """ |select pid,count(1) pid_num, |sum(if(yhPrice<=0,1,0)) as zk_all_order_num, |cCountDistinct2(orderid,yhPrice) as zk_order_num |from tab_tmp group by pid """.stripMargin).show()} /* +---+-------+----------------+------------+ |pid|pid_num|zk_all_order_num|zk_order_num| +---+-------+----------------+------------+ |300|4|2|1| |301|1|0|0| +---+-------+----------------+------------+*/ }

【大数据|spark UDAF 自定义聚合函数 UserDefinedAggregateFunction 带条件的去重操作】

    推荐阅读