目录
- 使用自带的avg函数
- 弱类型自定义UDAF函数(AVG)
- 强类型自定义UDAF函数(AVG)
弱类型:3.x过期 2.x有
强类型:3.x 2.x没有
使用自带的avg函数
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}
object UserDefinedUDAF {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().appName("test").master("local[4]").getOrCreate()
import spark.implicits._
val list = List(
("zhangsan",20,"北京"),
("sd",30,"深圳"),
("asd",40,"北京"),
("asd",50,"深圳"),
("asdad",60,"深圳"),
("gfds",70,"北京"),
("dfg",60,"深圳"),
("erw",80,"上海"),
("asd",18,"广州"),
("sdassws",20,"广州"),
)
val rdd: RDD[(String, Int, String)] = spark.sparkContext.parallelize(list, 2)
val df: DataFrame = rdd.toDF("name", "age", "region")
df.createOrReplaceTempView("person")
spark.sql(
"""
|select
|region,
|avg(age)
|from person group by region
|""".stripMargin).show()
}
}
结果
弱类型自定义UDAF函数(AVG)
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, StructType}
/**
* 自定义弱类型UDAF函数
* 1.创建class继承
*/
class WeakAvgUDAF extends UserDefinedAggregateFunction{
/**
* 指定UDAF函数的参数类型【自定义avg函数,针对的参数是age,类型是Int类型】
* @return
*/
override def inputSchema: StructType = {
new StructType()
.add("input", IntegerType)
}
/**
* 指定中间变量的类型【求一组区域的平均值,需要统计总年龄和人的个数】(因为最后要年龄除以人数才是平均年龄)
* @return
*/
override def bufferSchema: StructType = {
new StructType()
.add("sum", IntegerType)
.add("count", IntegerType)
}
/**
* 指定UDAF最终计算结果类型
* @return
*/
override def dataType: DataType = DoubleType
/**
* 一致性的执行
* @return
*/
override def deterministic: Boolean = true
/**
* 指定中间变量的初始化[sum=0,count=0]
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//sum = 0
buffer(0) = 0
//count = 0
buffer(1) = 0
}
/**
* 类似combiner操作,针对每个组单个age值进行计算
* @param buffer 中间变量的封装[sum,count]
* @param input 组中一个值(age)
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getAs[Int](0) + input.getAs[Int](0)
buffer(1) = buffer.getAs[Int](1) + 1
}
/**
*
* @param buffer1 中间变量的封装[sum,count]
* @param buffer2 combine的结果[sum,count]
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//sum = sum + combiner_sum
buffer1(0) = buffer1.getAs[Int](0) + buffer2.getAs[Int](0)
//count = count + combiner_count
buffer1(1) = buffer1.getAs[Int](1) + buffer2.getAs[Int](1)
}
/**
* 计算最终结果
* @param buffer [中间变量封装 sum,count]
* @return
*/
override def evaluate(buffer: Row): Any = {
buffer.getAs[Int](0).toDouble / buffer.getAs[Int](1)
}
}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}
object UserDefinedUDAF {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().appName("test").master("local[4]").getOrCreate()
import spark.implicits._
val list = List(
("zhangsan",20,"北京"),
("sd",30,"深圳"),
("asd",40,"北京"),
("asd",50,"深圳"),
("asdad",60,"深圳"),
("gfds",70,"北京"),
("dfg",60,"深圳"),
("erw",80,"上海"),
("asd",18,"广州"),
("sdassws",20,"广州"),
)
val rdd: RDD[(String, Int, String)] = spark.sparkContext.parallelize(list, 2)
val df: DataFrame = rdd.toDF("name", "age", "region")
df.createOrReplaceTempView("person")
spark.udf.register("myavg",new WeakAvgUDAF)
spark.sql(
"""
|select
|region,
|myavg(age)
|from person group by region
|""".stripMargin).show()
}
}
强类型自定义UDAF函数(AVG)
import org.apache.spark.sql.{Encoder, Encoders}
import org.apache.spark.sql.expressions.Aggregator
/**
* 自定义强类型UDAF函数
* 1.定义class继承Aggregator[IN,BUFF,OUT]
* IN:代表UDAF函数参数类型
* BUFF:代表计算过程中中间变量类型
* OUT:最终计算结果类型
* 2.重写重抽象方法
* 强类型自定义UDAF函数的使用
* 1.创建自定义UDAF对象 : val obj = new xxx
* 2,导入转换方法 import org.apache.spark.sql.function._
* 3.转换:val function = udaf(obj)
* 4.注册 spark.udf.register(函数名,function)
*/
case class AvgBuff(sum:Int,count:Int)
class StrongAvgUDAF extends Aggregator[Int,AvgBuff,Double]{
/**
* 初始化中间变量值
* @return
*/
override def zero: AvgBuff = AvgBuff(0,0)
/**
* combiner计算
* @param buff 中间结果
* @param age udaf参数
* @return 返回累加之后的中间结果
*/
override def reduce(buff: AvgBuff, age: Int): AvgBuff = AvgBuff(buff.sum+age,buff.count+1);
/**
* reducer聚合
* @param b1 中间结果
* @param b2 combiner聚合结果
* @return 返回累加之后的中间
*/
override def merge(b1: AvgBuff, b2: AvgBuff): AvgBuff = AvgBuff(b1.sum + b2.sum,b1.count+b2.count)
/**
* 计算最终结果
* @param reduction
* @return
*/
override def finish(buff: AvgBuff): Double = buff.sum.toDouble / buff.count
/**
* 指定中间结果序列化
* @return
*/
override def bufferEncoder: Encoder[AvgBuff] = Encoders.product[AvgBuff]
/**
* 指定最终序列化类型
* @return
*/
override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}
object UserDefinedUDAF {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().appName("test").master("local[4]").getOrCreate()
import spark.implicits._
val list = List(
("zhangsan",20,"北京"),
("sd",30,"深圳"),
("asd",40,"北京"),
("asd",50,"深圳"),
("asdad",60,"深圳"),
("gfds",70,"北京"),
("dfg",60,"深圳"),
("erw",80,"上海"),
("asd",18,"广州"),
("sdassws",20,"广州"),
)
val rdd: RDD[(String, Int, String)] = spark.sparkContext.parallelize(list, 2)
val df: DataFrame = rdd.toDF("name", "age", "region")
df.createOrReplaceTempView("person")
//TODO 弱类型的注册
spark.udf.register("myavg",new WeakAvgUDAF)
//TODO 强类型的注册
import org.apache.spark.sql.functions._
spark.udf.register("myavg2",udaf(new StrongAvgUDAF))
spark.sql(
"""
|select
|region,
|myavg2(age)
|from person group by region
|""".stripMargin).show()
}
}