(1)自定义UDF
object SparkSqlTest {
def main(args: Array[String]): Unit = {
//屏蔽多余的日志
Logger.getLogger("org.apache.hadoop").setLevel(Level.WARN)
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.project-spark").setLevel(Level.WARN)
//构建编程入口
val conf: SparkConf = new SparkConf()
conf.setAppName("SparkSqlTest")
.setMaster("local[2]")
val spark: SparkSession = SparkSession.builder().config(conf)
.getOrCreate()
//创建sqlcontext对象
val sqlContext: SQLContext = spark.sqlContext
/**
* 注册定义的UDF:
* 这里的泛型[Int,String]
* 第一个是返回值类型,后面可以是一个或者多个,是方法参数类型
*/
sqlContext.udf.register[Int,String]("strLen",strLen)
val sql=
"""
|select strLen("zhangsan")
""".stripMargin
spark.sql(sql).show()
}
//自定义UDF方法
def strLen(str:String):Integer={
str.length
}
}
(2) 自定义UDAF
这里举的例子是实现一个count:
自定义UDAF类:
class MyCountUDAF extends UserDefinedAggregateFunction{
//该UDAF输入的数据类型
override def inputSchema: StructType = {
StructType(List(
StructField("age",DataTypes.IntegerType)
))
}
//在该UDAF中聚合的数据类型
override def bufferSchema: StructType = {
StructType(List(
StructField("age",DataTypes.IntegerType)
))
}
//该UDAF输出的数据类型
override def dataType: DataType = DataTypes.IntegerType
//确定性判断,通常特定输入和输出的类型一致
override def deterministic: Boolean = true
//buffer:计算过程中临时的存储了聚合结果的Buffer
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0,0)
}
/**
* 分区内的数据聚合合并
* @param buffer:就是我们在initialize方法中声明初始化的临时缓冲区
* @param input:聚合操作新传入的值
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val oldValue=buffer.getInt(0)
buffer.update(0,oldValue+1)
}
/**
* 分区间的聚合
* @param buffer1:分区一聚合的临时结果
* @param buffer2;分区二聚合的临时结果
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val p1=buffer1.getInt(0)
val p2=buffer2.getInt(0)
buffer1.update(0,p1+p2)
}
//该聚合函数最终输出的值
override def evaluate(buffer: Row): Any = {
buffer.get(0)
}
}
调用:
object SparkSqlTest {
def main(args: Array[String]): Unit = {
//屏蔽多余的日志
Logger.getLogger("org.apache.hadoop").setLevel(Level.WARN)
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.project-spark").setLevel(Level.WARN)
//构建编程入口
val conf: SparkConf = new SparkConf()
conf.setAppName("SparkSqlTest")
.setMaster("local[2]")
.set("spark.serializer","org.apache.spark.serializer.KryoSerializer")
.registerKryoClasses(Array(classOf[Student]))
val spark: SparkSession = SparkSession.builder().config(conf)
.getOrCreate()
//创建sqlcontext对象
val sqlContext: SQLContext = spark.sqlContext
//注册UDAF
sqlContext.udf.register("myCount",new MyCountUDAF())
val stuList = List(
new Student("委xx", 18),
new Student("吴xx", 18),
new Student("戚xx", 18),
new Student("王xx", 19),
new Student("薛xx", 19)
)
import spark.implicits._
val stuDS: Dataset[Student] = sqlContext.createDataset(stuList)
stuDS.createTempView("student")
val sql=
"""
|select myCount(1) counts
|from student
|group by age
|order by counts
""".stripMargin
spark.sql(sql).show()
}
}
case class Student(name:String,age:Int)
原创文章,作者:奋斗,如若转载,请注明出处:https://blog.ytso.com/tech/opensource/190562.html