ICode9

精准搜索请尝试: 精确搜索
首页 > 数据库> 文章详细

sparksql系列(六) SparkSql中UDF、UDAF、UDTF

2019-11-24 20:58:01  阅读:279  来源: 互联网

标签:val sql UDTF UDAF UDF org apache import spark


RDD没有可以这种可以注册的方法。

在使用sparksql过程中发现UDF还是有点用的所以,还是单独写一篇博客记录一下。

UDF=》一个输入一个输出。相当于map

UDAF=》多个输入一个输出。相当于reduce

UDTF=》一个输入多个输出。相当于flatMap。(需要hive环境,暂时未测试)

UDF

        其实就是在sql语句中注册函数,不要想得太难了。给大家写一个case when的语句

        import java.util.Arrays

        import org.apache.spark.SparkConf
        import org.apache.spark.api.java.JavaSparkContext
        import org.apache.spark.sql.{ DataFrame, Row, SparkSession, functions }
        import org.apache.spark.sql.functions.{ col, desc, length, row_number, trim, when }
        import org.apache.spark.sql.functions.{ countDistinct, sum, count, avg }
        import org.apache.spark.sql.functions.concat
        import org.apache.spark.sql.types.{ LongType, StringType, StructField, StructType }
        import org.apache.spark.sql.expressions.Window
        import org.apache.spark.storage.StorageLevel
        import org.apache.spark.sql.SaveMode
        import java.util.ArrayList

        object WordCount {

                def main(args: Array[String]): Unit = {
                        val sparkSession = SparkSession.builder().master("local").getOrCreate()
                        val javasc = new JavaSparkContext(sparkSession.sparkContext)

                        val nameRDD1 = javasc.parallelize(Arrays.asList("{'id':'7'}", "{'id':'8'}",
                                "{'id':'9'}","{'id':'10'}"));
                        val nameRDD1df = sparkSession.read.json(nameRDD1)

                        nameRDD1df.createTempView("idList")
        
                        sparkSession.udf.register("idParse",(str:String)=>{//注册一个函数,实现case when的函数
                                str match{
                                        case "7" => "id7"
                                        case "8" => "id8"
                                        case "9" => "id9"
                                        case _=>"others"
                                }
                        })
                        val data = sparkSession.sql("select idParse(id) from idList").show(100)
                }
        }

UDAF

        import java.util.Arrays

        import org.apache.spark.SparkConf
        import org.apache.spark.api.java.JavaSparkContext
        import org.apache.spark.sql.{ DataFrame, Row, SparkSession, functions }
        import org.apache.spark.sql.functions.{ col, desc, length, row_number, trim, when }
        import org.apache.spark.sql.functions.{ countDistinct, sum, count, avg }
        import org.apache.spark.sql.functions.concat
        import org.apache.spark.sql.types.{ LongType, StringType, StructField, StructType }
        import org.apache.spark.sql.expressions.Window
        import org.apache.spark.storage.StorageLevel
        import org.apache.spark.sql.SaveMode
        import java.util.ArrayList
        import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
        import org.apache.spark.sql.expressions.MutableAggregationBuffer
        import org.apache.spark.sql.types.IntegerType
        import org.apache.spark.sql.types.DataType

        class MyMax extends UserDefinedAggregateFunction{
                //定义输入数据的类型,两种写法都可以
                //override def inputSchema: StructType = StructType(Array(StructField("input", IntegerType, true)))
                override def inputSchema: StructType = StructType(StructField("input", IntegerType) :: Nil)
                //定义聚合过程中所处理的数据类型
                // override def bufferSchema: StructType = StructType(Array(StructField("cache", IntegerType, true)))
                override def bufferSchema: StructType = StructType(StructField("max", IntegerType) :: Nil)
                //定义输入数据的类型
                override def dataType: DataType = IntegerType
                //规定一致性
                override def deterministic: Boolean = true
                //在聚合之前,每组数据的初始化操作
                override def initialize(buffer: MutableAggregationBuffer): Unit = {buffer(0) =0}
                //每组数据中,当新的值进来的时候,如何进行聚合值的计算
                override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
                        if(input.getInt(0)> buffer.getInt(0))
                                buffer(0)=input.getInt(0)
                }
                //合并各个分组的结果
                override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
                        if(buffer2.getInt(0)> buffer1.getInt(0)){
                                buffer1(0)=buffer2.getInt(0)
                        }
                }
                //返回最终结果
                override def evaluate(buffer: Row): Any = {buffer.getInt(0)}
        }


        class MyAvg extends UserDefinedAggregateFunction{
                //输入数据的类型
                override def inputSchema: StructType = StructType(StructField("input", IntegerType) :: Nil)
                //中间结果数据的类型
                override def bufferSchema: StructType = StructType(
                        StructField("sum", IntegerType) :: StructField("count", IntegerType) :: Nil)
                //定义输入数据的类型
                override def dataType: DataType = IntegerType
                //规定一致性
                override def deterministic: Boolean = true
                //初始化操作
                override def initialize(buffer: MutableAggregationBuffer): Unit = {buffer(0) =0;buffer(1) =0;}

                //map端reduce,所有数据必须过这一段代码
                override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
                        buffer.update(0, buffer.getInt(0)+input.getInt(0))
                        buffer.update(1, buffer.getInt(1)+1)
                }
                //reduce数据,update里面Row,没有第二个字段,这时候就有了第二个字段
                override def merge(buffer: MutableAggregationBuffer, input: Row): Unit = {
                        buffer.update(0, buffer.getInt(0)+input.getInt(0))
                        buffer.update(1, buffer.getInt(1)+input.getInt(1))
                }
                //返回最终结果
                override def evaluate(finalVaue: Row): Int = {finalVaue.getInt(0)/finalVaue.getInt(1)}
                }

                object WordCount {

                        def main(args: Array[String]): Unit = {
                                val sparkSession = SparkSession.builder().master("local").getOrCreate()
                                val javasc = new JavaSparkContext(sparkSession.sparkContext)

                                val nameRDD1 = javasc.parallelize(Arrays.asList("{'id':'7'}"));
                                val nameRDD1df = sparkSession.read.json(nameRDD1)
                                val nameRDD2 = javasc.parallelize(Arrays.asList( "{'id':'8'}"));
                                val nameRDD2df = sparkSession.read.json(nameRDD2)
                                val nameRDD3 = javasc.parallelize(Arrays.asList("{'id':'9'}"));
                                val nameRDD3df = sparkSession.read.json(nameRDD3)
                                val nameRDD4 = javasc.parallelize(Arrays.asList("{'id':'10'}"));
                                val nameRDD4df = sparkSession.read.json(nameRDD4)

                                nameRDD1df.union(nameRDD2df).union(nameRDD3df).union(nameRDD4df).registerTempTable("idList")

                                // sparkSession.udf.register("myMax",new MyMax)
                                sparkSession.udf.register("myAvg",new MyAvg)

                                val data = sparkSession.sql("select myAvg(id) from idList").show(100)


                }
        }

UDTF 暂时没测试,家里没有hive环境

       import java.util.Arrays

       import org.apache.spark.SparkConf
       import org.apache.spark.api.java.JavaSparkContext
       import org.apache.spark.sql.{ DataFrame, Row, SparkSession, functions }
       import org.apache.spark.sql.functions.{ col, desc, length, row_number, trim, when }
       import org.apache.spark.sql.functions.{ countDistinct, sum, count, avg }
       import org.apache.spark.sql.functions.concat
       import org.apache.spark.sql.types.{ LongType, StringType, StructField, StructType }
       import org.apache.spark.sql.expressions.Window
       import org.apache.spark.storage.StorageLevel
       import org.apache.spark.sql.SaveMode
       import java.util.ArrayList
       import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
       import org.apache.spark.sql.expressions.MutableAggregationBuffer
       import org.apache.spark.sql.types.IntegerType
       import org.apache.spark.sql.types.DataType
       import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
       import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
       import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector
       import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory
       import org.apache.hadoop.hive.ql.exec.UDFArgumentException
       import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException
       import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory

       class MyFloatMap extends GenericUDTF{
              override def close(): Unit = {}
              //这个方法的作用:1.输入参数校验 2. 输出列定义,可以多于1列,相当于可以生成多行多列数据
              override def initialize(args:Array[ObjectInspector]): StructObjectInspector = {
                     if (args.length != 1) {
                            throw new UDFArgumentLengthException("UserDefinedUDTF takes only one argument")
                     }
                     if (args(0).getCategory() != ObjectInspector.Category.PRIMITIVE) {
                            throw new UDFArgumentException("UserDefinedUDTF takes string as a parameter")
                     }

                     val fieldNames = new java.util.ArrayList[String]
                     val fieldOIs = new java.util.ArrayList[ObjectInspector]

                     //这里定义的是输出列默认字段名称
                     fieldNames.add("col1")
                     //这里定义的是输出列字段类型
                     fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector)

                     ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs)
              }

              //这是处理数据的方法,入参数组里只有1行数据,即每次调用process方法只处理一行数据
              override def process(args: Array[AnyRef]): Unit = {
                     //将字符串切分成单个字符的数组
                     val strLst = args(0).toString.split("")
                     for(i <- strLst){
                            var tmp:Array[String] = new Array[String](1)
                            tmp(0) = i
                            //调用forward方法,必须传字符串数组,即使只有一个元素
                            forward(tmp)
                     }
              }
       }

       object WordCount {

              def main(args: Array[String]): Unit = {
                     val sparkSession = SparkSession.builder().master("local").getOrCreate()
                     val javasc = new JavaSparkContext(sparkSession.sparkContext)

                     val nameRDD1 = javasc.parallelize(Arrays.asList("{'id':'7'}"));
                     val nameRDD1df = sparkSession.read.json(nameRDD1)

                     nameRDD1df.createOrReplaceTempView("idList")

                     sparkSession.sql("create temporary function myFloatMap as 'MyFloatMap'")

                     val data = sparkSession.sql("select myFloatMap(id) from idList").show(100)

              }
       }

标签:val,sql,UDTF,UDAF,UDF,org,apache,import,spark
来源: https://www.cnblogs.com/wuxiaolong4/p/11924172.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有