ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

Spark记录(四):Dataset.count()方法源码剖析

2022-05-30 00:34:53  阅读:170  来源: 互联网

标签:count RelationalGroupedDataset val df Dataset queryExecution 源码 方法


因最近工作中涉及较多的Spark相关功能,所以趁周末闲来无事,研读一下Dataset的count方法。Spark版本3.2.0

1、方法入口:

  def count(): Long = withAction("count", groupBy().count().queryExecution) { plan =>
    plan.executeCollect().head.getLong(0)
  }

可以看到,count方法调用的是withAction方法,入参有三个:字符串count、调用方法获取到的QueryExecution、一个函数。注:此处就是对Scala函数式编程的应用,将函数作为参数来传递

2、第二个参数QueryExecution的获取流程

 2.1、首先看groupBy()方法:

1   def groupBy(cols: Column*): RelationalGroupedDataset = {
2     RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType)
3   }

groupBy方法是用于分组聚合的,一般用法是groupBy之后加上agg聚合函数,对分组之后的每组数据进行聚合,入参为Column类型的可变长度参数。

但上面count方法中调用时未传任何入参,产生的效果就是****

groupBy方法只有一行代码,生成并返回了一个RelationalGroupedDataset的对象,而且此处是用伴生对象的简略写法创建出来的,该行代码其实质是调用了RelationalGroupedDataset的伴生对象中的apply方法,三个入参。

注:RelationalGroupedDataset 类是用于处理聚合操作的,内部封装了对agg方法的处理,以及一些统计函数sum、max等的实现。

2.1.1、逐一看下RelationalGroupedDataset的三个入参:

首先是toDF()方法,方法体如下,可见就是重新创建了一个Dataset[Row]对象,即DataFrame

  def toDF(): DataFrame = new Dataset[Row](queryExecution, RowEncoder(schema))

然后是cols.map(_.expr),即遍历执行每个Column的expr表达式,因为此处未传入cols,故可忽略。

最后传入的是 RelationalGroupedDataset.GroupByType,起了标识的作用。因为RelationalGroupedDataset类的方法除了groupBy调用之外,还有Cube、Rollup、Pivot等都会调用,为与其他几种区别开,故传入了GroupByType。

2.1.2、初探 RelationalGroupedDataset 类

apply方法:

1   def apply(
2       df: DataFrame,
3       groupingExprs: Seq[Expression],
4       groupType: GroupType): RelationalGroupedDataset = {
5     new RelationalGroupedDataset(df, groupingExprs, groupType: GroupType)
6   }

类的定义:

class RelationalGroupedDataset protected[sql](
    private[sql] val df: DataFrame,
    private[sql] val groupingExprs: Seq[Expression],
    groupType: RelationalGroupedDataset.GroupType) {
......
}

可见没有多余的逻辑,只是单纯的创建了一个对象。至于这个对象如何使用的,还需继续追溯它里面的count方法,即Dataset.count()中调用的groupBy().count()。

2.2、groupBy().count(),即 RelationalGroupedDataset.count():

  def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)).toAggregateExpression(), "count")()))

2.2.1、其中Alias(Count(Literal(1)).toAggregateExpression(), "count")的作用,就是生成 count(1) as count 这样的一个统计函数的表达式。

2.2.2、然后toDF方法,如下所示:

 1 private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
 2     val aggregates = if (df.sparkSession.sessionState.conf.dataFrameRetainGroupColumns) { // 是否保留分组的主键列,默认true
 3       groupingExprs match { // 若保留,则将分组的主键列拼到聚合表达式的前面
 4         // call `toList` because `Stream` can't serialize in scala 2.13
 5         case s: Stream[Expression] => s.toList ++ aggExprs
 6         case other => other ++ aggExprs
 7       }
 8     } else {
 9       aggExprs
10     }
11 
12     val aliasedAgg = aggregates.map(alias) // 处理设置别名的表达式
13 
14     groupType match {
15       case RelationalGroupedDataset.GroupByType =>
16         Dataset.ofRows(df.sparkSession, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) // ***
17       case RelationalGroupedDataset.RollupType =>
18         Dataset.ofRows(
19           df.sparkSession, Aggregate(Seq(Rollup(groupingExprs.map(Seq(_)))),
20             aliasedAgg, df.logicalPlan))
21       case RelationalGroupedDataset.CubeType =>
22         Dataset.ofRows(
23           df.sparkSession, Aggregate(Seq(Cube(groupingExprs.map(Seq(_)))),
24             aliasedAgg, df.logicalPlan))
25       case RelationalGroupedDataset.PivotType(pivotCol, values) =>
26         val aliasedGrps = groupingExprs.map(alias)
27         Dataset.ofRows(
28           df.sparkSession, Pivot(Some(aliasedGrps), pivotCol, values, aggExprs, df.logicalPlan))
29     }
30   }

重点是第16行,进入ofRows方法中可以看到,其实就是又新建了一个Dataset[Row],并将加上count(1)表达式之后新生成的Aggregate执行计划传入。

至此,groupBy().count().queryExecution得到的就是一个count(1)的执行计划了。

3、第三个参数,也是一个函数式参数:

{ plan =>
    plan.executeCollect().head.getLong(0)
  }

该参数入参是一个plan,返回值long类型,推测是获取最终count值的,暂时放一放,后面调用到的时候再来研究。

4、看完三个参数,下面进入withAction方法:

1 private def withAction[U](name: String, qe: QueryExecution)(action: SparkPlan => U) = {
2     SQLExecution.withNewExecutionId(qe, Some(name)) {
3       qe.executedPlan.resetMetrics()
4       action(qe.executedPlan)
5     }
6   }

又是使用了科里化传参,第三个参数同样是一个函数,在里面调用了action这个函数参数。继续追踪withNewExecutionId方法:

 5、SQLExecution.withNewExecutionId

该方法代码较多,下面先看一下它的主体结构。里面省略的若干行代码,实际是作为一个函数参数传入了withActive方法。

def withNewExecutionId[T](
      queryExecution: QueryExecution,
      name: Option[String] = None)(body: => T): T = queryExecution.sparkSession.withActive {
... // 省略若干代码
}

而withActive方法如下,实际是将当前的SparkSession存入了本地线程变量中,方便后面的获取。然后执行了函数block,而返回值就是外层withNewExecutionId方法中函数体的返回值。

private[sql] def withActive[T](block: => T): T = {
    val old = SparkSession.activeThreadSession.get()
    SparkSession.setActiveSession(this)
    try block finally {
      SparkSession.setActiveSession(old)
    }
  }

 下面回到外层的函数体:

5.1、SQLExecution.withNewExecutionId函数体第一部分

1     val sparkSession = queryExecution.sparkSession
2     val sc = sparkSession.sparkContext
3     val oldExecutionId = sc.getLocalProperty(EXECUTION_ID_KEY)
4     val executionId = SQLExecution.nextExecutionId
5     sc.setLocalProperty(EXECUTION_ID_KEY, executionId.toString)
6     executionIdToQueryExecution.put(executionId, queryExecution)

先设置了一下executionId,该ID是一个线程安全的自增序列,每次加1,。设置给SparkContext之后,又将id与QueryExecution的映射关系存入Map中。

5.2、SQLExecution.withNewExecutionId函数体第二部分

第二部分主要是判断若sql长度过长,需要进行截断处理,无甚要点。

5.3、SQLExecution.withNewExecutionId函数体第三部分,代码如下:

 1       withSQLConfPropagated(sparkSession) {
 2         var ex: Option[Throwable] = None
 3         val startTime = System.nanoTime()
 4         try {
 5           sc.listenerBus.post(SparkListenerSQLExecutionStart(
 6             executionId = executionId,
 7             description = desc,
 8             details = callSite.longForm,
 9             physicalPlanDescription = queryExecution.explainString(planDescriptionMode),
10             sparkPlanInfo = SparkPlanInfo.fromSparkPlan(queryExecution.executedPlan),
11             time = System.currentTimeMillis()))
12           body
13         } catch {
14           case e: Throwable =>
15             ex = Some(e)
16             throw e
17         } finally {
18           val endTime = System.nanoTime()
19           val event = SparkListenerSQLExecutionEnd(executionId, System.currentTimeMillis())
20           event.executionName = name
21           event.duration = endTime - startTime
22           event.qe = queryExecution
23           event.executionFailure = ex
24           sc.listenerBus.post(event)
25         }
26       }

起头的 withSQLConfPropagated 方法,同样还是科里化的方式传参,方法里面将配置参数替换为新的配置参数,执行完之后再将老参数存回去。

再然后是try里面的一个post方法,finally里面一个post方法,用于发送SQLExecution执行开始和结束的通知消息。

最后是核心函数调用,body。即前面一直引而未看的方法。

下面再返回头来好好研究一下此处的body函数,函数体是:

{
      qe.executedPlan.resetMetrics()
      action(qe.executedPlan)
    }

qe变量即上面2.2中返回的groupBy().count().queryExecution

而action的函数体是:

{ plan =>
    plan.executeCollect().head.getLong(0)
  }

那么内部具体是怎么实现的呢?今天时间不早了,改日再搞它。

标签:count,RelationalGroupedDataset,val,df,Dataset,queryExecution,源码,方法
来源: https://www.cnblogs.com/zzq6032010/p/16323297.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有