4. 用户自定义函数
通过spark.udf功能用户可以自定义函数。
4.1用户自定义UDF函数
Shell scala> val df = spark.read.json("examples/src/main/resources/people.json") df: org.apache.spark.sql.DataFrame = [age: bigint, name: string] scala> df.show() +----+-------+ | age| name| +----+-------+ |null|Michael| | 30| Andy| | 19| Justin| +----+-------+ scala> spark.udf.register("addName", (x:String)=> "Name:"+x) res5: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,StringType,Some(List(StringType))) scala> df.createOrReplaceTempView("people") scala> spark.sql("Select addName(name), age from people").show() +-----------------+----+ |UDF:addName(name)| age| +-----------------+----+ | Name:Michael|null| | Name:Andy| 30| | Name:Justin| 19| +-----------------+----+ |
UDF案例2
需求,有如下数据
Plain Text id,name,age,height,weight,yanzhi,score 1,a,18,172,120,98,68.8 2,b,28,175,120,97,68.8 3,c,30,180,130,94,88.8 4,d,18,168,110,98,68.8 5,e,26,165,120,98,68.8 6,f,27,182,135,95,89.8 7,g,19,171,122,99,68.8 |
需要计算每一个人和其他人之间的余弦相似度(特征向量之间的余弦相似度)
代码实现:
Scala package cn.doitedu.sparksql.udf import cn.doitedu.sparksql.dataframe.SparkUtil import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.UserDefinedFunction import scala.collection.mutable /** * UDF 案例2 : 用一个自定义函数实现两个向量之间的余弦相似度计算 */ case class Human(id: Int, name: String, features: Array[Double]) object CosinSimilarity { def main(args: Array[String]): Unit = { val spark = SparkUtil.getSpark() import spark.implicits._ import spark.sql // 加载用户特征数据 val df = spark.read.option("inferSchema", true).option("header", true).csv("data/features.csv") df.show()
// id,name,age,height,weight,yanzhi,score // 将用户特征数据组成一个向量(数组) // 方式1: df.rdd.map(row => { val id = row.getAs[Int]("id") val name = row.getAs[String]("name") val age = row.getAs[Double]("age") val height = row.getAs[Double]("height") val weight = row.getAs[Double]("weight") val yanzhi = row.getAs[Double]("yanzhi") val score = row.getAs[Double]("score") (id, name, Array(age, height, weight, yanzhi, score)) }).toDF("id", "name", "features") // 方式2: df.rdd.map({ case Row(id: Int, name: String, age: Double, height: Double, weight: Double, yanzhi: Double, score: Double) => (id, name, Array(age, height, weight, yanzhi, score)) }) .toDF("id", "name", "features") // 方式3: 直接利用sql中的函数array来生成一个数组 df.selectExpr("id", "name", "array(age,height,weight,yanzhi,score) as features") import org.apache.spark.sql.functions._ df.select('id, 'name, array('age, 'height, 'weight, 'yanzhi, 'score) as "features") // 方式4:返回case class val features = df.rdd.map({ case Row(id: Int, name: String, age: Double, height: Double, weight: Double, yanzhi: Double, score: Double) => Human(id, name, Array(age, height, weight, yanzhi, score)) }) .toDF() // 将表自己和自己join,得到每个人和其他所有人的连接行 val joined = features.join(features.toDF("bid","bname","bfeatures"),'id < 'bid) joined.show(100,false) // 定义一个计算余弦相似度的函数 // val cosinSim = (f1:Array[Double],f2:Array[Double])=>{ /* 余弦相似度 */ } // 开根号的api: Math.pow(4.0,0.5) val cosinSim = (f1:mutable.WrappedArray[Double], f2:mutable.WrappedArray[Double])=>{ val fenmu1 = Math.pow(f1.map(Math.pow(_,2)).sum,0.5) val fenmu2 = Math.pow(f2.map(Math.pow(_,2)).sum,0.5) val fenzi = f1.zip(f2).map(tp=>tp._1*tp._2).sum fenzi/(fenmu1*fenmu2) } // 注册到sql引擎: spark.udf.register("cosin_sim",consinSim) spark.udf.register("cos_sim",cosinSim) joined.createTempView("temp") // 然后在这个表上计算两人之间的余弦相似度 sql("select id,bid,cos_sim(features,bfeatures) as cos_similary from temp").show() // 可以自定义函数简单包装一下,就成为一个能生成column结果的dsl风格函数了 val cossim2: UserDefinedFunction = udf(cosinSim) joined.select('id,'bid,cossim2('features,'bfeatures) as "cos_sim").show() spark.close() } } |
4.2用户自定义聚合函数UDAF
弱类型的DataFrame和强类型的Dataset都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。
除此之外,用户可以设定自己的自定义UDAF聚合函数。
UDAF的编程模板:
/**
* @date: 2019/10/12
* @site: www.doitedu.cn
* @author: hunter.d 涛哥
* @qq: 657270652
* @description:
* 用户自定义UDAF入门示例:求薪资的平均值
*/
object MyAvgUDAF extends UserDefinedAggregateFunction{
// 函数输入的字段schema(字段名-字段类型)
override def inputSchema: StructType = ???
// 聚合过程中,用于存储局部聚合结果的schema
// 比如求平均薪资,中间缓存(局部数据薪资总和,局部数据人数总和)
override def bufferSchema: StructType = ???
// 函数的最终返回结果数据类型
override def dataType: DataType = ???
// 你这个函数是否是稳定一致的?(对一组相同的输入,永远返回相同的结果),只要是确定的,就写true
override def deterministic: Boolean = true
// 对局部聚合缓存的初始化方法
override def initialize(buffer: MutableAggregationBuffer): Unit = ???
// 聚合逻辑所在方法,框架会不断地传入一个新的输入row,来更新你的聚合缓存数据
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = ???
// 全局聚合:将多个局部缓存中的数据,聚合成一个缓存
// 比如:薪资和薪资累加,人数和人数累加
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = ???
// 最终输出
// 比如:从全局缓存中取薪资总和/人数总和
override def evaluate(buffer: Row): Any = ???
}
核心要义:
聚合是分步骤进行: 先局部聚合,再全局聚合
局部聚合(update)的结果是保存在一个局部buffer中的
全局聚合(merge)就是将多个局部buffer再聚合成一个buffer
最后通过evaluate将全局聚合的buffer中的数据做一个运算得出你要的结果
如下图所示:
4.2.1弱类型用户自定义聚合函数UDAF
(1)需求说明
示例数据:
+---+----------------+------+---------+------+----------+
| id| name | sales|discount |state | saleDate|
+---+----------------+------+---------+------+----------+
| 1| Widget Co|1000.0| 0.0| AZ|2014-01-01|
| 2| Acme Widgets |2000.0| 500.0| CA|2014-02-01|
| 3| Widgetry|1000.0| 200.0| CA|2015-01-11|
| 4| Widgets R Us |2000.0| 0.0| CA|2015-02-19|
| 5|Ye Olde Widgete |3000.0| 0.0| MA|2015-02-28|
+---+---------------+------+--------+-----+-------------+
需求:计算x年份的同比上一年份的总销售增长率;比如2015 vs 2014的同比增长
显然,没有任何一个内置聚合函数可以完成上述需求;
可以多写一些sql逻辑来实现,但如果能自定义一个聚合函数,当然更方便高效!
Select yearOnyear(saleDate,sales) from t
(2)自定义UDAF实现销售额同比计算
通过继承UserDefinedAggregateFunction来实现用户自定义聚合函数。
自定义UDAF的代码骨架如下:
class UdfMy extends UserDefinedAggregateFunction{
override def inputSchema: StructType = ???
override def bufferSchema: StructType = ???
override def dataType: DataType = ???
override def deterministic: Boolean = ???
override def initialize(buffer: MutableAggregationBuffer): Unit = ???
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = ???
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = ???
override def evaluate(buffer: Row): Any = ???
}
完整实现代码如下:
/**
* 工具类
* @param startDate
* @param endDate
*/
case class DateRange(startDate: Timestamp, endDate: Timestamp) {
def contain(targetDate: Date): Boolean = {
targetDate.before(endDate) && targetDate.after(startDate)
}
}
/**
* @date: 2019/10/10
* @site: www.doitedu.cn
* @author: hunter.d 涛哥
* @qq: 657270652
* @description: 自定义UDAF实现年份销售额同比增长计算
*/
class YearOnYearBasis(current: DateRange) extends UserDefinedAggregateFunction{
// 聚合函数输入参数的数据类型
override def inputSchema: StructType = {
StructType(StructField("metric", DoubleType) :: StructField("timeCategory", DateType) :: Nil)
}
// 聚合缓冲区中值得数据类型
override def bufferSchema: StructType = {
StructType(StructField("sumOfCurrent", DoubleType) :: StructField("sumOfPrevious", DoubleType) :: Nil)
}
// 返回值的数据类型
override def dataType: DataType = DoubleType
// 对于相同的输入是否一直返回相同的输出。
override def deterministic: Boolean = true
// 初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0, 0.0)
buffer.update(1, 0.0)
}
// 相同Execute间的数据合并。
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (current.contain(input.getAs[Date](1))) {
buffer(0) = buffer.getAs[Double](0) + input.getAs[Double](0)
}
val previous = DateRange(subtractOneYear(current.startDate), subtractOneYear(current.endDate))
if (previous.contain(input.getAs[Date](1))) {
buffer(1) = buffer.getAs[Double](0) + input.getAs[Double](0)
}
}
// 不同Execute间的数据合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getAs[Double](0) + buffer2.getAs[Double](0)
buffer1(1) = buffer1.getAs[Double](1) + buffer2.getAs[Double](1)
}
// 计算最终结果
override def evaluate(buffer: Row): Any = {
if (buffer.getDouble(1) == 0.0)
0.0
else
(buffer.getDouble(0) - buffer.getDouble(1)) / buffer.getDouble(1) * 100
}
def subtractOneYear(d:Timestamp):Timestamp={
Timestamp.valueOf(d.toLocalDateTime.minusYears(1))
}
}
(3)补充示例:自定义UDAF实现平均薪资计算
下面展示一个求平均工资的自定义聚合函数。
Scala package cn.doitedu.sparksql.udf import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types.{DataType, DataTypes, StructField, StructType} /** * @description: * 用户自定义UDAF入门示例:求薪资的平均值 */ object MyAvgUDAF extends UserDefinedAggregateFunction { // 函数输入的字段schema(字段名-字段类型) override def inputSchema: StructType = StructType(Seq(StructField("salary", DataTypes.DoubleType))) // 聚合过程中,用于存储局部聚合结果的schema // 比如求平均薪资,中间缓存(局部数据薪资总和,局部数据人数总和) override def bufferSchema: StructType = StructType(Seq( StructField("sum", DataTypes.DoubleType), StructField("cnts", DataTypes.LongType) )) // 函数的最终返回结果数据类型 override def dataType: DataType = DataTypes.DoubleType // 你这个函数是否是稳定一致的?(对一组相同的输入,永远返回相同的结果),只要是确定的,就写true override def deterministic: Boolean = true // 对局部聚合缓存的初始化方法 override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer.update(0, 0.0) buffer.update(1, 0L) } // 聚合逻辑所在方法,框架会不断地传入一个新的输入row,来更新你的聚合缓存数据 override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { // 从输入中获取那个人的薪资,加到buffer的第一个字段上 buffer.update(0, buffer.getDouble(0) + input.getDouble(0)) // 给buffer的第2个字段加1 buffer.update(1, buffer.getLong(1) + 1) } // 全局聚合:将多个局部缓存中的数据,聚合成一个缓存 // 比如:薪资和薪资累加,人数和人数累加 override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { // 把两个buffer的字段1(薪资和)累加到一起,并更新回buffer1 buffer1.update(0, buffer1.getDouble(0) + buffer2.getDouble(0)) // 更新人数 buffer1.update(1, buffer1.getLong(1) + buffer2.getLong(1)) } // 最终输出 // 比如:从全局缓存中取薪资总和/人数总和 override def evaluate(buffer: Row): Any = { if (buffer.getLong(1) != 0) buffer.getDouble(0) / buffer.getLong(1) else 0.0 } } |
4.2.2强类型用户自定义聚合函数
通过继承Aggregator来实现强类型自定义聚合函数,同样是求平均工资
Scala import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.Encoder import org.apache.spark.sql.Encoders import org.apache.spark.sql.SparkSession
// 既然是强类型,可能有case类 case class Employee(name: String, salary: Long) case class Average(var sum: Long, var count: Long) object MyAverage extends Aggregator[Employee, Average, Double] { // 定义一个数据结构,保存工资总数和工资总个数,初始都为0 def zero: Average = Average(0L, 0L) // Combine two values to produce a new value. For performance, the function may modify `buffer` // and return it instead of constructing a new object def reduce(buffer: Average, employee: Employee): Average = { buffer.sum += employee.salary buffer.count += 1 buffer } // 聚合不同execute的结果 def merge(b1: Average, b2: Average): Average = { b1.sum += b2.sum b1.count += b2.count b1 } // 计算输出 def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count // 设定之间值类型的编码器,要转换成case类 // Encoders.product是进行scala元组和case类转换的编码器 def bufferEncoder: Encoder[Average] = Encoders.product // 设定最终输出值的编码器 def outputEncoder: Encoder[Double] = Encoders.scalaDouble } import spark.implicits._ val ds = spark.read.json("examples/src/main/resources/employees.json").as[Employee] ds.show() // +-------+------+ // | name|salary| // +-------+------+ // |Michael| 3000| // | Andy| 4500| // | Justin| 3500| // | Berta| 4000| // +-------+------+ // Convert the function to a `TypedColumn` and give it a name val averageSalary = MyAverage.toColumn.name("average_salary") val result = ds.select(averageSalary) result.show() // +--------------+ // |average_salary| // +--------------+ // | 3750.0| // +--------------+ } |
5. Spark SQL 的运行原理
正常的 SQL 执行先会经过 SQL Parser 解析 SQL,然后经过 Catalyst 优化器处理,最后到 Spark 执行。而 Catalyst 的过程又分为很多个过程,其中包括:
- Analysis:主要利用 Catalog 信息将 Unresolved Logical Plan 解析成 Analyzed logical plan;
- Logical Optimizations:利用一些 Rule (规则)将 Analyzed logical plan 解析成 Optimized Logical Plan;
- Physical Planning:前面的 logical plan 不能被 Spark 执行,而这个过程是把 logical plan 转换成多个 physical plans,然后利用代价模型(cost model)选择最佳的 physical plan;
- Code Generation:这个过程会把 SQL逻辑生成Java字节码。
所以整个 SQL 的执行过程可以使用下图表示:
其中蓝色部分就是 Catalyst 优化器处理的部分,也是本章主要讲解的内容。
5.1 元数据管理SessionCatalog
SessionCatalog 主要用于各种函数资源信息和元数据信息(数据库、数据表、数据视图、数据分区与函数等)的统一管理。
创建临时表或者视图,其实是往SessionCatalog注册;
Analyzer在进行逻辑计划元数据绑定时,也是从catalog中获取元数据;
5.2 SQL解析成逻辑执行计划
当调用SparkSession的sql或者SQLContext的sql方法,就会使用SparkSqlParser进行SQL解析。
Spark 2.0.0开始引入了第三方语法解析器工具 ANTLR,对 SQL 进行词法分析并构建语法树。
(Antlr 是一款强大的语法生成器工具,可用于读取、处理、执行和翻译结构化的文本或二进制文件,是当前 Java 语言中使用最为广泛的语法生成器工具,我们常见的大数据 SQL 解析都用到了这个工具,包括 Hive、Cassandra、Phoenix、Pig 以及 presto 等)目前最新版本的 Spark 使用的是 ANTLR4)
它分为2个步骤来生成Unresolved LogicalPlan:
- 词法分析(SqlBaseLexer):Lexical Analysis,负责将token分组成符号类
- 语法分析(SqlBaseParser):构建一棵分析树(parse tree)或者抽象语法树AST(abstract syntax tree)
Scala /** * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or * TableIdentifier. */ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging { import ParserUtils._ def this() = this(new SQLConf()) protected def typedVisit[T](ctx: ParseTree): T = { ... } } |
具体来说,Spark 基于presto的语法文件定义了Spark SQL语法文件SqlBase.g4
(路径 spark-2.4.3\sql\catalyst\src\main\antlr4\org\apache\spark\sql\catalyst\parser\SqlBase.g4)
这个文件定义了 Spark SQL 支持的 SQL 语法。
如果我们需要自定义新的语法,需要在这个文件定义好相关语法。然后使用 ANTLR4 对 SqlBase.g4 文件自动解析生成几个 Java 类,其中就包含重要的词法分析器 SqlBaseLexer.java 和语法分析器SqlBaseParser.java。运行上面的 SQL 会使用 SqlBaseLexer 来解析关键词以及各种标识符等;然后使用 SqlBaseParser 来构建语法树。
下面以一条简单的 SQL 为例进行分析
SQL SELECT sum(v) FROM ( SELECT t1.id, 1 + 2 + t1.value AS v FROM t1 JOIN t2 WHERE t1.id = t2.id AND t1.cid = 1 AND t1.did = t1.cid + 1 AND t2.id > 5) o |
整个过程就类似于下图。
生成语法树之后,使用 AstBuilder 将语法树转换成 LogicalPlan,这个 LogicalPlan 也被称为 Unresolved LogicalPlan。解析后的逻辑计划如下:
Plain Text == Parsed Logical Plan == 'Project [unresolvedalias('sum('v), None)] +- 'SubqueryAlias `doitedu_stu` +- 'Project ['t1.id, ((1 + 2) + 't1.value) AS v#16] +- 'Filter ((('t1.id = 't2.id) && ('t1.cid = 1)) && (('t1.did = ('t1.cid + 1)) && ('t2.id > 5))) +- 'Join Inner :- 'UnresolvedRelation `t1` +- 'UnresolvedRelation `t2` |
图片表示如下:
Unresolved LogicalPlan 是从下往上看的,t1 和 t2 两张表被生成了 UnresolvedRelation,过滤的条件、选择的列以及聚合字段都知道了。
Unresolved LogicalPlan 仅仅是一种数据结构,不包含任何数据信息,比如不知道数据源、数据类型,不同的列来自于哪张表等。
5.3 Analyzer绑定逻辑计划
Analyzer 阶段会使用事先定义好的 Rule 以及 SessionCatalog 等信息对 Unresolved LogicalPlan 进行元数据绑定。
Scala /** * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and * [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]]. */ class Analyzer( catalog: SessionCatalog, conf: SQLConf, maxIterations: Int) extends RuleExecutor[LogicalPlan] with CheckAnalysis {
class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser(conf) { val astBuilder = new SparkSqlAstBuilder(conf)
override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => astBuilder.visitSingleStatement(parser.singleStatement()) match { case plan: LogicalPlan => plan case _ => val position = Origin(None, None) throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position) } }
Rule 是定义在 Analyzer 里面的,具体如下: lazy val batches: Seq[Batch] = Seq( Batch("Hints", fixedPoint, new ResolveHints.ResolveBroadcastHints(conf), ResolveHints.ResolveCoalesceHints, ResolveHints.RemoveAllHints), Batch("Simple Sanity Check", Once, LookupFunctions), Batch("Substitution", fixedPoint, CTESubstitution, WindowsSubstitution, EliminateUnions, new SubstituteUnresolvedOrdinals(conf)), Batch("Resolution", fixedPoint, ResolveTableValuedFunctions :: //解析表的函数 ResolveRelations :: //解析表或视图 ResolveReferences :: //解析列 ResolveCreateNamedStruct :: ResolveDeserializer :: //解析反序列化操作类 ResolveNewInstance :: ResolveUpCast :: //解析类型转换 ResolveGroupingAnalytics :: ResolvePivot :: ResolveOrdinalInOrderByAndGroupBy :: ResolveAggAliasInGroupBy :: ResolveMissingReferences :: ExtractGenerator :: ResolveGenerate :: ResolveFunctions :: //解析函数 ResolveAliases :: //解析表别名 ResolveSubquery :: //解析子查询 ResolveSubqueryColumnAliases :: ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: ResolveOutputRelation :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: TimeWindowing :: ResolveInlineTables(conf) :: ResolveHigherOrderFunctions(catalog) :: ResolveLambdaVariables(conf) :: ResolveTimeZone(conf) :: ResolveRandomSeed :: TypeCoercion.typeCoercionRules(conf) ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), Batch("View", Once, AliasViewChild(conf)), Batch("Nondeterministic", Once, PullOutNondeterministic), Batch("UDF", Once, HandleNullInputsForUDF), Batch("FixNullability", Once, FixNullability), Batch("Subquery", Once, UpdateOuterReferences), Batch("Cleanup", fixedPoint, CleanupAliases) ) |
从上面代码可以看出,多个性质类似的 Rule 组成一个 Batch;而多个 Batch 构成一个 batches。这些 batches 会由 RuleExecutor 执行,先按一个一个 Batch 顺序执行,然后对 Batch 里面的每个 Rule 顺序执行。每个 Batch 会执行一次(Once)或多次(FixedPoint,由spark.sql.optimizer.maxIterations 参数决定),执行过程如下:
5.4 Optimizer优化逻辑计划
优化器也是会定义一套Rules,利用这些Rule对逻辑计划和Exepression进行迭代处理,从而使得树的节点进行合并和优化
在前文的绑定逻辑计划阶段对 Unresolved LogicalPlan 进行相关 transform 操作得到了 Analyzed Logical Plan,这个 Analyzed Logical Plan 是可以直接转换成 Physical Plan 然后在spark中执行。但是如果直接这么弄的话,得到的 Physical Plan 很可能不是最优的,因为在实际应用中,很多低效的写法会带来执行效率的问题,需要进一步对Analyzed Logical Plan 进行处理,得到更优的逻辑算子树。于是,针对SQL 逻辑算子树的优化器 Optimizer 应运而生。
这个阶段的优化器主要是基于规则的(Rule-based Optimizer,简称 RBO),而绝大部分的规则都是启发式规则,也就是基于直观或经验而得出的规则,比如列裁剪(过滤掉查询不需要使用到的列)、谓词下推(将过滤尽可能地下沉到数据源端)、常量累加(比如 1 + 2 这种事先计算好) 以及常量替换(比如 SELECT * FROM table WHERE i = 5 AND j = i + 3 可以转换成 SELECT * FROM table WHERE i = 5 AND j = 8)等等。
与绑定逻辑计划阶段类似,这个阶段所有的规则也是实现 Rule 抽象类,多个规则组成一个 Batch,多个 Batch 组成一个 batches,同样也是在 RuleExecutor 中进行执行。
核心源码骨架如下列截图所示:
那么针对前文的 SQL 语句,这个过程都会执行哪些优化呢?下文举例说明。
5.4.1谓词下推
谓词下推在 Spark SQL 是由 PushDownPredicate 实现的,这个过程主要将过滤条件尽可能地下推到底层,最好是数据源。上面介绍的 SQL,使用谓词下推优化得到的逻辑计划如下:
从上图可以看出,谓词下推将 Filter 算子直接下推到 Join 之前了(注意,上图是从下往上看的)。也就是在扫描 t1 表的时候会先使用 ((((isnotnull(cid#2) && isnotnull(did#3)) && (cid#2 = 1)) && (did#3 = 2)) && (id#0 > 50000)) && isnotnull(id#0) 过滤条件过滤出满足条件的数据;同时在扫描 t2 表的时候会先使用 isnotnull(id#8) && (id#8 > 50000) 过滤条件过滤出满足条件的数据。经过这样的操作,可以大大减少 Join 算子处理的数据量,从而加快计算速度。
5.4.2列裁剪
列裁剪在 Spark SQL 是由 ColumnPruning 实现的。因为我们查询的表可能有很多个字段,但是每次查询我们很大可能不需要扫描出所有的字段,这个时候利用列裁剪可以把那些查询不需要的字段过滤掉,使得扫描的数据量减少。所以针对我们上面介绍的 SQL,使用列裁剪优化得到的逻辑计划如下:
从上图可以看出,经过列裁剪后,t1 表只需要查询 id 和 value 两个字段;t2 表只需要查询 id 字段。这样减少了数据的传输,而且如果底层的文件格式为列存(比如 Parquet),可以大大提高数据的扫描速度的。
5.4.3常量替换
常量替换在 Spark SQL 是由 ConstantPropagation 实现的。也就是将变量替换成常量,比如 SELECT * FROM table WHERE i = 5 AND j = i + 3 可以转换成 SELECT * FROM table WHERE i = 5 AND j = 8。这个看起来好像没什么的,但是如果扫描的行数非常多可以减少很多的计算时间的开销的。经过这个优化,得到的逻辑计划如下:
我们的查询中有 t1.cid = 1 AND t1.did = t1.cid + 1 查询语句,从里面可以看出 t1.cid 其实已经是确定的值了,所以我们完全可以使用它计算出 t1.did。
5.4.4常量累加
常量累加在 Spark SQL 是由 ConstantFolding 实现的。这个和常量替换类似,也是在这个阶段把一些常量表达式事先计算好。这个看起来改动的不大,但是在数据量非常大的时候可以减少大量的计算,减少 CPU 等资源的使用。经过这个优化,得到的逻辑计划如下:
经过上面四个步骤的优化之后,得到的优化之后的逻辑计划为:
Plain Text == Optimized Logical Plan == Aggregate [sum(cast(v#16 as bigint)) AS sum(v)#22L] +- Project [(3 + value#1) AS v#16] +- Join Inner, (id#0 = id#8) :- Project [id#0, value#1] : +- Filter (((((isnotnull(cid#2) && isnotnull(did#3)) && (cid#2 = 1)) && (did#3 = 2)) && (id#0 > 5)) && isnotnull(id#0)) : +- Relation[id#0,value#1,cid#2,did#3] csv +- Project [id#8] +- Filter (isnotnull(id#8) && (id#8 > 5)) +- Relation[id#8,value#9,cid#10,did#11] csv |
对应的图如下:
到这里,优化逻辑计划阶段就算完成了。另外,Spark 内置提供了多达70个优化 Rule,详情请参见
https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala#L59
5.5使用SparkPlanner生成物理计划
SparkSpanner使用Planning Strategies,对优化后的逻辑计划进行转换,生成可以执行的物理计划SparkPlan.
Scala /** * 将逻辑计划转成物理计划的抽象类. * 各实现类通过各种GenericStrategy来生成各种可行的待选物理计划. * 如一个策略无法对逻辑计划树的所有操作转换,则会调用[GenericStrategy#planLater planLater]], 来获得 一个“占位符”对象暂时填充;之后由[[collectPlaceholders collected]]收集并使用其他策略进行转换
* TODO: 目前为止,永远只生成一个物理计划 * 后续迭代中会对“多计划”予以实现 */ abstract class QueryPlanner[PhysicalPlan <: TreeNode[PhysicalPlan]] { /** A list of execution strategies that can be used by the planner */ def strategies: Seq[GenericStrategy[PhysicalPlan]] def plan(plan: LogicalPlan): Iterator[PhysicalPlan] = { // 显然,此处还有大量工作需要做,可依然... // 收集所有可选的物理计划. val candidates = strategies.iterator.flatMap(_(plan))
abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SparkPlanner => /** * Plans special cases of limit operators. */ object SpecialLimits extends Strategy { class SparkPlanner( val sparkContext: SparkContext, val conf: SQLConf, val experimentalMethods: ExperimentalMethods) extends SparkStrategies { |
逻辑计划翻译成物理计划时,使用的是策略(Strategy);
前面介绍的逻辑计划绑定和优化经过 Transformations 动作之后,树的类型并没有改变,
Logical Plan转化成物理计划后,树的类型改变了,由 Logical Plan 转换成 Physical Plan 了。
一个逻辑计划(Logical Plan)经过一系列的策略处理之后,得到多个物理计划(Physical Plans),物理计划在 Spark 是由 SparkPlan 实现的。
多个物理计划经过代价模型(Cost Model)得到选择后的物理计划(Selected Physical Plan),整个过程如下所示:
Cost Model 对应的就是基于代价的优化(Cost-based Optimizations,CBO,主要由华为的大佬们实现的,详见 SPARK-16026 ),核心思想是计算每个物理计划的代价,然后得到最优的物理计划。目前,这一部分并没有实现,直接返回多个物理计划列表的第一个作为最优的物理计划,如下:
Scala lazy val sparkPlan: SparkPlan = { SparkSession.setActiveSession(sparkSession) // TODO: We use next(), i.e. take the first plan returned by the planner, here for now, // but we will implement to choose the best plan. planner.plan(ReturnAnswer(optimizedPlan)).next() } |
而 SPARK-16026 引入的 CBO 优化主要是在前面介绍的优化逻辑计划阶段 - Optimizer 阶段进行的,对应的 Rule 为 CostBasedJoinReorder,并且默认是关闭的,需要通过 spark.sql.cbo.enabled 或 spark.sql.cbo.joinReorder.enabled 参数开启。
所以到了这个节点,最后得到的物理计划如下:
Plain Text == Physical Plan == *(3) HashAggregate(keys=[], functions=[sum(cast(v#16 as bigint))], output=[sum(v)#22L]) +- Exchange SinglePartition +- *(2) HashAggregate(keys=[], functions=[partial_sum(cast(v#16 as bigint))], output=[sum#24L]) +- *(2) Project [(3 + value#1) AS v#16] +- *(2) BroadcastHashJoin [id#0], [id#8], Inner, BuildRight :- *(2) Project [id#0, value#1] : +- *(2) Filter (((((isnotnull(cid#2) && isnotnull(did#3)) && (cid#2 = 1)) && (did#3 = 2)) && (id#0 > 5)) && isnotnull(id#0)) : +- *(2) FileScan csv [id#0,value#1,cid#2,did#3] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/iteblog/t1.csv], PartitionFilters: [], PushedFilters: [IsNotNull(cid), IsNotNull(did), EqualTo(cid,1), EqualTo(did,2), GreaterThan(id,5), IsNotNull(id)], ReadSchema: struct<id:int,value:int,cid:int,did:int> +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint))) +- *(1) Project [id#8] +- *(1) Filter (isnotnull(id#8) && (id#8 > 5)) +- *(1) FileScan csv [id#8] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/iteblog/t2.csv], PartitionFilters: [], PushedFilters: [IsNotNull(id), GreaterThan(id,5)], ReadSchema: struct<id:int> |
从上面的结果可以看出,物理计划阶段已经知道数据源是从 csv 文件里面读取了,也知道文件的路径,数据类型等。而且在读取文件的时候,直接将过滤条件(PushedFilters)加进去了。
同时,这个 Join 变成了 BroadcastHashJoin,也就是将 t2 表的数据 Broadcast 到 t1 表所在的节点。图表示如下:
到这里, Physical Plan 就完全生成了。
5.6从物理执行计划获取inputRdd执行
从物理计划上,获取inputRdd
从物理计划上,生成全阶段代码,并编译反射出迭代器newBiIterator的Clazz
[真名:BufferedRowIterator]
然后将inputRDD做一个transformation得到最终要执行的rdd
Scala inputRdd.mapPartitionsWithIndex((index,iter)=>{ new newBiIterator(){ hasNext(){ iter.hasNext } next(){ processNext(iter.next()) } } })
然后,对最后返回的rdd,执行你所需要的行动算子 rdd.collect().foreach(println) |