Spark MLlib快速入门(1)逻辑回归、Kmeans、决策树、Pipeline、交叉验证

Spark MLlib快速入门(1)逻辑回归、Kmeans、决策树案例

除了scikit-learn外,在spark中也提供了机器学习库,即Spark MLlib。

在Spark MLlib机器学习库提供两套算法实现的API:基于RDD API和基于DataFrame API。今天,主要介绍下DataFrame API的使用,不涉及算法的原理。

主要提供的算法如下:

  • 分类

    • 逻辑回归、贝叶斯支持向量机
  • 聚类

    • K-均值
  • 推荐

    • 交替最小二乘法
  • 回归

    • 线性回归
    • 决策树、随机森林

1 Spark MLlib中逻辑回归在鸢尾花数据集上的应用

鸢尾花数据集,总共150条数据,分为三种类别的鸢尾花。

鸢尾花数据集属于分类算法,构建分类模型,此处使用逻辑回归分类算法构建分类模型,进行预测。

全部基于DataFrame API算法库和特征工程函数使用。

使用的spark版本为2.3。

1.1 读取数据

package com.yyds.tags.ml.classification

import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.feature.{Normalizer, StringIndexer, StringIndexerModel, VectorAssembler}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.types.{DoubleType, StringType, StructType}
import org.apache.spark.storage.StorageLevel

object IrisClassification {

  def main(args: Array[String]): Unit = {

    // 构建SparkSession实例对象
    val spark: SparkSession = SparkSession.builder()
      .appName(this.getClass.getSimpleName.stripSuffix("$"))
      .master("local[4]")
      .config("spark.sql.shuffle.partitions",4)
      .getOrCreate()

    import spark.implicits._

    // TODO step1 -> 读取数据
    val isrsSchema: StructType = new StructType()
      .add("sepal_length",DoubleType,nullable = true)
      .add("sepal_width",DoubleType,nullable = true)
      .add("petal_length",DoubleType,nullable = true)
      .add("petal_width",DoubleType,nullable = true)
      .add("category",StringType, nullable = true)

    val rawIrisDF: DataFrame =  spark.read
      .option("sep",",")
      // 当首行不是列名称时候,需要自动设置schema
      .option("header","false")
      .option("inferSchema","false")
      .schema(isrsSchema)
      .csv("datas/iris/iris.data")

    rawIrisDF.printSchema()
    rawIrisDF.show(10,truncate = false)

  }

}
root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- category: string (nullable = true)
 
 
+------------+-----------+------------+-----------+-----------+
|sepal_length|sepal_width|petal_length|petal_width|category   |
+------------+-----------+------------+-----------+-----------+
|5.1         |3.5        |1.4         |0.2        |Iris-setosa|
|4.9         |3.0        |1.4         |0.2        |Iris-setosa|
|4.7         |3.2        |1.3         |0.2        |Iris-setosa|
|4.6         |3.1        |1.5         |0.2        |Iris-setosa|
|5.0         |3.6        |1.4         |0.2        |Iris-setosa|
|5.4         |3.9        |1.7         |0.4        |Iris-setosa|
|4.6         |3.4        |1.4         |0.3        |Iris-setosa|
|5.0         |3.4        |1.5         |0.2        |Iris-setosa|
|4.4         |2.9        |1.4         |0.2        |Iris-setosa|
|4.9         |3.1        |1.5         |0.1        |Iris-setosa|
+------------+-----------+------------+-----------+-----------+

1.2 特征工程

    // TODO step2 -> 特征工程
    /*
      1、类别转换数值类型
         类别特征索引化 -> label
      2、组合特征值
         features: Vector
    */
    // 1、类别特征转换 StringIndexer
    val indexerModel: StringIndexerModel = new StringIndexer()
      .setInputCol("category")
      .setOutputCol("label")
      .fit(rawIrisDF)

    val df1: DataFrame = indexerModel.transform(rawIrisDF)

    // 2、组合特征值 VectorAssembler
    val assembler: VectorAssembler = new VectorAssembler()
      // 设置特征列名称
      .setInputCols(rawIrisDF.columns.dropRight(1))
      .setOutputCol("raw_features")

    val rawFeaturesDF: DataFrame = assembler.transform(df1)

    
    // 3、特征值正则化,使用L2正则
    val normalizer: Normalizer = new Normalizer()
      .setInputCol("raw_features")
      .setOutputCol("features")
      .setP(2.0)

    val featuresDF: DataFrame = normalizer.transform(rawFeaturesDF)
    
    // 将数据集缓存,LR算法属于迭代算法,使用多次
    featuresDF.persist(StorageLevel.MEMORY_AND_DISK).count()

    featuresDF.printSchema()
    featuresDF.show(10, truncate = false)
root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- category: string (nullable = true)
 |-- label: double (nullable = true)
 |-- raw_features: vector (nullable = true)
 |-- features: vector (nullable = true)

在这里插入图片描述

1.3 训练模型

    // TODO step3 -> 模型训练
    val lr: LogisticRegression = new LogisticRegression()
      // 设置列名称
      .setLabelCol("label")
      .setFeaturesCol("features")
      .setPredictionCol("prediction")
      // 设置迭代次数
      .setMaxIter(10)
      .setRegParam(0.3) // 正则化参数
      .setElasticNetParam(0.8) // 弹性网络参数:L1正则和L2正则联合使用



    val lrModel: LogisticRegressionModel = lr.fit(featuresDF)

1.4 模型预测

    // TODO step4 -> 使用模型预测
    val predictionDF: DataFrame = lrModel.transform(featuresDF)


    predictionDF
       // 获取真实标签类别和预测标签类别
      .select("label", "prediction")
      .show(10)

在这里插入图片描述

1.5 模型评估

 // TODO step5 -> 模型评估:准确度 = 预测正确的样本数 / 所有的样本数
    import  org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")

    # accuracy = 0.9466666666666667
    println(s"accuracy = ${evaluator.evaluate(predictionDF)}")

1.6 模型的保存与加载

   // TODO step6 ->  模型调优,此处省略


    // TODO step7 ->  模型保存与加载
    val modelPath = s"datas/models/lrModel-${System.currentTimeMillis()}"
    // 保存模型
    lrModel.save(modelPath)
    // 加载模型
    val loadLrModel = LogisticRegressionModel.load(modelPath)
    // 模型预测
    loadLrModel.transform(
      Seq(
        Vectors.dense(Array(5.1,3.5,1.4,0.2))
      )
        .map(x => Tuple1.apply(x))
        .toDF("features")
    ).show(10, truncate = false)

    // 应用结束,关闭资源
    spark.stop()

在这里插入图片描述

2 Spark MLlib中KMeans在鸢尾花数据集上的应用

2.1 读取数据集

iris_kmeans.txt数据如下

1 1:5.1 2:3.5 3:1.4 4:0.2
1 1:4.9 2:3.0 3:1.4 4:0.2
1 1:4.7 2:3.2 3:1.3 4:0.2
1 1:4.6 2:3.1 3:1.5 4:0.2
1 1:5.0 2:3.6 3:1.4 4:0.2
1 1:5.4 2:3.9 3:1.7 4:0.4
1 1:4.6 2:3.4 3:1.4 4:0.3
1 1:5.0 2:3.4 3:1.5 4:0.2
1 1:4.4 2:2.9 3:1.4 4:0.2
1 1:4.9 2:3.1 3:1.5 4:0.1
1 1:5.4 2:3.7 3:1.5 4:0.2
1 1:4.8 2:3.4 3:1.6 4:0.2
1 1:4.8 2:3.0 3:1.4 4:0.1
1 1:4.3 2:3.0 3:1.1 4:0.1
1 1:5.8 2:4.0 3:1.2 4:0.2
1 1:5.7 2:4.4 3:1.5 4:0.4
1 1:5.4 2:3.9 3:1.3 4:0.4
1 1:5.1 2:3.5 3:1.4 4:0.3
1 1:5.7 2:3.8 3:1.7 4:0.3
1 1:5.1 2:3.8 3:1.5 4:0.3
1 1:5.4 2:3.4 3:1.7 4:0.2
1 1:5.1 2:3.7 3:1.5 4:0.4
1 1:4.6 2:3.6 3:1.0 4:0.2
1 1:5.1 2:3.3 3:1.7 4:0.5
1 1:4.8 2:3.4 3:1.9 4:0.2
1 1:5.0 2:3.0 3:1.6 4:0.2
1 1:5.0 2:3.4 3:1.6 4:0.4
1 1:5.2 2:3.5 3:1.5 4:0.2
1 1:5.2 2:3.4 3:1.4 4:0.2
1 1:4.7 2:3.2 3:1.6 4:0.2
1 1:4.8 2:3.1 3:1.6 4:0.2
1 1:5.4 2:3.4 3:1.5 4:0.4
1 1:5.2 2:4.1 3:1.5 4:0.1
1 1:5.5 2:4.2 3:1.4 4:0.2
1 1:4.9 2:3.1 3:1.5 4:0.1
1 1:5.0 2:3.2 3:1.2 4:0.2
1 1:5.5 2:3.5 3:1.3 4:0.2
1 1:4.9 2:3.1 3:1.5 4:0.1
1 1:4.4 2:3.0 3:1.3 4:0.2
1 1:5.1 2:3.4 3:1.5 4:0.2
1 1:5.0 2:3.5 3:1.3 4:0.3
1 1:4.5 2:2.3 3:1.3 4:0.3
1 1:4.4 2:3.2 3:1.3 4:0.2
1 1:5.0 2:3.5 3:1.6 4:0.6
1 1:5.1 2:3.8 3:1.9 4:0.4
1 1:4.8 2:3.0 3:1.4 4:0.3
1 1:5.1 2:3.8 3:1.6 4:0.2
1 1:4.6 2:3.2 3:1.4 4:0.2
1 1:5.3 2:3.7 3:1.5 4:0.2
1 1:5.0 2:3.3 3:1.4 4:0.2
2 1:7.0 2:3.2 3:4.7 4:1.4
2 1:6.4 2:3.2 3:4.5 4:1.5
2 1:6.9 2:3.1 3:4.9 4:1.5
2 1:5.5 2:2.3 3:4.0 4:1.3
2 1:6.5 2:2.8 3:4.6 4:1.5
2 1:5.7 2:2.8 3:4.5 4:1.3
2 1:6.3 2:3.3 3:4.7 4:1.6
2 1:4.9 2:2.4 3:3.3 4:1.0
2 1:6.6 2:2.9 3:4.6 4:1.3
2 1:5.2 2:2.7 3:3.9 4:1.4
2 1:5.0 2:2.0 3:3.5 4:1.0
2 1:5.9 2:3.0 3:4.2 4:1.5
2 1:6.0 2:2.2 3:4.0 4:1.0
2 1:6.1 2:2.9 3:4.7 4:1.4
2 1:5.6 2:2.9 3:3.6 4:1.3
2 1:6.7 2:3.1 3:4.4 4:1.4
2 1:5.6 2:3.0 3:4.5 4:1.5
2 1:5.8 2:2.7 3:4.1 4:1.0
2 1:6.2 2:2.2 3:4.5 4:1.5
2 1:5.6 2:2.5 3:3.9 4:1.1
2 1:5.9 2:3.2 3:4.8 4:1.8
2 1:6.1 2:2.8 3:4.0 4:1.3
2 1:6.3 2:2.5 3:4.9 4:1.5
2 1:6.1 2:2.8 3:4.7 4:1.2
2 1:6.4 2:2.9 3:4.3 4:1.3
2 1:6.6 2:3.0 3:4.4 4:1.4
2 1:6.8 2:2.8 3:4.8 4:1.4
2 1:6.7 2:3.0 3:5.0 4:1.7
2 1:6.0 2:2.9 3:4.5 4:1.5
2 1:5.7 2:2.6 3:3.5 4:1.0
2 1:5.5 2:2.4 3:3.8 4:1.1
2 1:5.5 2:2.4 3:3.7 4:1.0
2 1:5.8 2:2.7 3:3.9 4:1.2
2 1:6.0 2:2.7 3:5.1 4:1.6
2 1:5.4 2:3.0 3:4.5 4:1.5
2 1:6.0 2:3.4 3:4.5 4:1.6
2 1:6.7 2:3.1 3:4.7 4:1.5
2 1:6.3 2:2.3 3:4.4 4:1.3
2 1:5.6 2:3.0 3:4.1 4:1.3
2 1:5.5 2:2.5 3:4.0 4:1.3
2 1:5.5 2:2.6 3:4.4 4:1.2
2 1:6.1 2:3.0 3:4.6 4:1.4
2 1:5.8 2:2.6 3:4.0 4:1.2
2 1:5.0 2:2.3 3:3.3 4:1.0
2 1:5.6 2:2.7 3:4.2 4:1.3
2 1:5.7 2:3.0 3:4.2 4:1.2
2 1:5.7 2:2.9 3:4.2 4:1.3
2 1:6.2 2:2.9 3:4.3 4:1.3
2 1:5.1 2:2.5 3:3.0 4:1.1
2 1:5.7 2:2.8 3:4.1 4:1.3
3 1:6.3 2:3.3 3:6.0 4:2.5
3 1:5.8 2:2.7 3:5.1 4:1.9
3 1:7.1 2:3.0 3:5.9 4:2.1
3 1:6.3 2:2.9 3:5.6 4:1.8
3 1:6.5 2:3.0 3:5.8 4:2.2
3 1:7.6 2:3.0 3:6.6 4:2.1
3 1:4.9 2:2.5 3:4.5 4:1.7
3 1:7.3 2:2.9 3:6.3 4:1.8
3 1:6.7 2:2.5 3:5.8 4:1.8
3 1:7.2 2:3.6 3:6.1 4:2.5
3 1:6.5 2:3.2 3:5.1 4:2.0
3 1:6.4 2:2.7 3:5.3 4:1.9
3 1:6.8 2:3.0 3:5.5 4:2.1
3 1:5.7 2:2.5 3:5.0 4:2.0
3 1:5.8 2:2.8 3:5.1 4:2.4
3 1:6.4 2:3.2 3:5.3 4:2.3
3 1:6.5 2:3.0 3:5.5 4:1.8
3 1:7.7 2:3.8 3:6.7 4:2.2
3 1:7.7 2:2.6 3:6.9 4:2.3
3 1:6.0 2:2.2 3:5.0 4:1.5
3 1:6.9 2:3.2 3:5.7 4:2.3
3 1:5.6 2:2.8 3:4.9 4:2.0
3 1:7.7 2:2.8 3:6.7 4:2.0
3 1:6.3 2:2.7 3:4.9 4:1.8
3 1:6.7 2:3.3 3:5.7 4:2.1
3 1:7.2 2:3.2 3:6.0 4:1.8
3 1:6.2 2:2.8 3:4.8 4:1.8
3 1:6.1 2:3.0 3:4.9 4:1.8
3 1:6.4 2:2.8 3:5.6 4:2.1
3 1:7.2 2:3.0 3:5.8 4:1.6
3 1:7.4 2:2.8 3:6.1 4:1.9
3 1:7.9 2:3.8 3:6.4 4:2.0
3 1:6.4 2:2.8 3:5.6 4:2.2
3 1:6.3 2:2.8 3:5.1 4:1.5
3 1:6.1 2:2.6 3:5.6 4:1.4
3 1:7.7 2:3.0 3:6.1 4:2.3
3 1:6.3 2:3.4 3:5.6 4:2.4
3 1:6.4 2:3.1 3:5.5 4:1.8
3 1:6.0 2:3.0 3:4.8 4:1.8
3 1:6.9 2:3.1 3:5.4 4:2.1
3 1:6.7 2:3.1 3:5.6 4:2.4
3 1:6.9 2:3.1 3:5.1 4:2.3
3 1:5.8 2:2.7 3:5.1 4:1.9
3 1:6.8 2:3.2 3:5.9 4:2.3
3 1:6.7 2:3.3 3:5.7 4:2.5
3 1:6.7 2:3.0 3:5.2 4:2.3
3 1:6.3 2:2.5 3:5.0 4:1.9
3 1:6.5 2:3.0 3:5.2 4:2.0
3 1:6.2 2:3.4 3:5.4 4:2.3
3 1:5.9 2:3.0 3:5.1 4:1.8
package com.yyds.tags.ml.clustering

import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.sql.{DataFrame, SparkSession}

/**
 * 使用KMeans算法对鸢尾花数据进行聚类操作
 */
object IrisClusterTest {
  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder()
      .appName(this.getClass.getSimpleName.stripSuffix("$"))
      .master("local[2]")
      .config("spark.sql.shuffle.partitions", "2")
      .getOrCreate()

    import org.apache.spark.sql.functions._
    import spark.implicits._


    // 1. 读取鸢尾花数据集
    val irisDF: DataFrame = spark.read
      .format("libsvm")
      .load("datas/iris/iris_kmeans.txt")
    irisDF.printSchema()
    irisDF.show(10, truncate = false)
  }

}

root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)



+-----+-------------------------------+
|label|features                       |
+-----+-------------------------------+
|1.0  |(4,[0,1,2,3],[5.1,3.5,1.4,0.2])|
|1.0  |(4,[0,1,2,3],[4.9,3.0,1.4,0.2])|
|1.0  |(4,[0,1,2,3],[4.7,3.2,1.3,0.2])|
|1.0  |(4,[0,1,2,3],[4.6,3.1,1.5,0.2])|
|1.0  |(4,[0,1,2,3],[5.0,3.6,1.4,0.2])|
|1.0  |(4,[0,1,2,3],[5.4,3.9,1.7,0.4])|
|1.0  |(4,[0,1,2,3],[4.6,3.4,1.4,0.3])|
|1.0  |(4,[0,1,2,3],[5.0,3.4,1.5,0.2])|
|1.0  |(4,[0,1,2,3],[4.4,2.9,1.4,0.2])|
|1.0  |(4,[0,1,2,3],[4.9,3.1,1.5,0.1])|
+-----+-------------------------------+
only showing top 10 rows

2.2 模型训练

// 2. 构建KMeans算法
    val kmeans: KMeans = new KMeans()
      // 设置输入特征列名称和输出列的名名称
      .setFeaturesCol("features")
      .setPredictionCol("prediction")
      // 设置K值为3
      .setK(3)
      // 设置最大的迭代次数
      .setMaxIter(20)


    // 3. 应用数据集训练模型, 获取转换器
    val kMeansModel: KMeansModel = kmeans.fit(irisDF)

    // 获取聚类的簇中心点
    kMeansModel.clusterCenters.foreach(println)
[5.88360655737705,2.7409836065573776,4.388524590163936,1.4344262295081969]
[5.005999999999999,3.4180000000000006,1.4640000000000002,0.2439999999999999]
[6.853846153846153,3.0769230769230766,5.715384615384615,2.053846153846153]

2.3 模型评估和预测

   // 4. 模型评估
    val wssse: Double = kMeansModel.computeCost(irisDF)
    println(s"WSSSE = ${wssse}")


    // 5. 使用模型预测
    val predictionDF: DataFrame = kMeansModel.transform(irisDF)

    predictionDF.show(10, truncate = false)

    // 应用结束,关闭资源
    spark.stop()
+-----+-------------------------------+----------+
|label|features                       |prediction|
+-----+-------------------------------+----------+
|1.0  |(4,[0,1,2,3],[5.1,3.5,1.4,0.2])|1         |
|1.0  |(4,[0,1,2,3],[4.9,3.0,1.4,0.2])|1         |
|1.0  |(4,[0,1,2,3],[4.7,3.2,1.3,0.2])|1         |
|1.0  |(4,[0,1,2,3],[4.6,3.1,1.5,0.2])|1         |
|1.0  |(4,[0,1,2,3],[5.0,3.6,1.4,0.2])|1         |
|1.0  |(4,[0,1,2,3],[5.4,3.9,1.7,0.4])|1         |
|1.0  |(4,[0,1,2,3],[4.6,3.4,1.4,0.3])|1         |
|1.0  |(4,[0,1,2,3],[5.0,3.4,1.5,0.2])|1         |
|1.0  |(4,[0,1,2,3],[4.4,2.9,1.4,0.2])|1         |
|1.0  |(4,[0,1,2,3],[4.9,3.1,1.5,0.1])|1         |
+-----+-------------------------------+----------+

3 Spark MLlib中决策树入门案例

决策树学习采用的是 自顶向下 的递归方法 ,其基本思想是以信息熵为度量构造一颗熵值下降最快的树,到叶子节点处,熵值为0。其具有可读性、分类速度快的优点,是一种有监督学习。

最早提及决策树思想的是Quinlan在1986年提出的ID3算法和1993年提出的C4.5算法,以及Breiman等人在1984年提出的CART算法。

决策树算法是机器学习算法中非常重要的算法之一,既可以分类又可以回归,其中还可以构建出集成学习算法。

由于决策树分类模型 DecisionTreeClassificationModel 属于概率分类模型ProbabilisticClassificationModel ,所以构建模型时要求数据集中标签label必须从0开始

在这里插入图片描述

上述数据集中特征:退款和婚姻状态,都是类别类型特征,需要将其转换为数值特征,数值从0开始计算。

针对 特征:退款 来说,将其转换为【0,1】两个值,不能是【1,2】数值。

3.1 读取数据

package com.yyds.tags.ml.classification

import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel, VectorIndexer, VectorIndexerModel}
import org.apache.spark.sql.{DataFrame, SparkSession}

object DecisionTreeTest {

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder()
      .appName(this.getClass.getSimpleName.stripSuffix("$"))
      .master("local[4]")
      .getOrCreate()


    import org.apache.spark.sql.functions._
    import spark.implicits._

    // 1. 加载数据
    val dataframe: DataFrame = spark.read
      .format("libsvm")
      .load("datas/iris/sample_libsvm_data.txt")


    dataframe.printSchema()
    dataframe.show(10, truncate = false)

    spark.stop()
  }

}

在这里插入图片描述

3.2 特征工程

    // 2. 特征工程:特征提取、特征转换及特征选择

    // a. 将标签值label,转换为索引,从0开始,到 K-1
    val labelIndexer: StringIndexerModel = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("index_label")
      .fit(dataframe)
    val df1: DataFrame = labelIndexer.transform(dataframe)

    // b. 对类别特征数据进行特殊处理, 当每列的值的个数小于设置K,那么此列数据被当做类别特征,自动进行索引转换
    val featureIndexer: VectorIndexerModel = new VectorIndexer()
      .setInputCol("features")
      .setOutputCol("index_features")
      .setMaxCategories(4)
      .fit(df1)


    val df2: DataFrame = featureIndexer.transform(df1)

    df2.printSchema()
    df2.show(10, truncate = false)
root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)
 |-- index_label: double (nullable = true)
 |-- index_features: vector (nullable = true)

3.3 训练模型

    // 3. 划分数据集:训练数据和测试数据
    val Array(trainingDF, testingDF) = df2.randomSplit(Array(0.8, 0.2))


    // 4. 使用决策树算法构建分类模型
    val dtc: DecisionTreeClassifier = new DecisionTreeClassifier()
      .setLabelCol("index_label")
      .setFeaturesCol("index_features")
      // 设置决策树算法相关超参数
      .setMaxDepth(5)
      .setMaxBins(32)       // 此值必须大于等于类别特征类别个数
      .setImpurity("gini")  // 也可以是香农熵:entropy


    val dtcModel: DecisionTreeClassificationModel = dtc.fit(trainingDF)

    println(dtcModel.toDebugString)
DecisionTreeClassificationModel (uid=dtc_338073100075) of depth 1 with 3 nodes
  If (feature 406 <= 72.0)
   Predict: 1.0
  Else (feature 406 > 72.0)
   Predict: 0.0

3.4 模型评估

    // 5. 模型评估,计算准确度
    val predictionDF: DataFrame = dtcModel.transform(testingDF)
    predictionDF.printSchema()
    predictionDF
      .select($"label", $"index_label", $"probability", $"prediction")
      .show(10, truncate = false)


    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("index_label")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")


    val accuracy: Double = evaluator.evaluate(predictionDF)
    println(s"Accuracy = $accuracy")
Accuracy = 0.8823529411764706

4、ML Pipeline

管道 Pipeline 概念:将多个Transformer转换器Estimators模型学习器按照 依赖顺序 组工作流WorkFlow形式,方面数据集的特征转换和模型训练及预测。

将上面的决策树分类代码,改为使用 Pipeline 构建模型与预测。

在这里插入图片描述

package com.yyds.tags.ml.classification


import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, StringIndexerModel, VectorIndexer, VectorIndexerModel}
import org.apache.spark.sql.{DataFrame, SparkSession}


object PipelineTest {

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder()
      .appName(this.getClass.getSimpleName.stripSuffix("$"))
      .master("local[4]")
      .getOrCreate()

    import org.apache.spark.sql.functions._
    import spark.implicits._


    // 1. 加载数据
    val dataframe: DataFrame = spark.read
      .format("libsvm")
      .load("datas/iris/sample_libsvm_data.txt")
    
    //dataframe.printSchema()
    //dataframe.show(10, truncate = false)


    // 划分数据集:训练集和测试集
    val Array(trainingDF, testingDF) = dataframe.randomSplit(Array(0.8, 0.2))

    // 2. 构建管道Pipeline

    // a. 将标签值label,转换为索引,从0开始,到 K-1
    val labelIndexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("index_label")
      .fit(dataframe)

    // b. 对类别特征数据进行特殊处理, 当每列的值的个数小于设置K,那么此列数据被当做类别特征,自动进行索引转换
    val featureIndexer = new VectorIndexer()
      .setInputCol("features")
      .setOutputCol("index_features")
      .setMaxCategories(4)
      .fit(dataframe)

    // c. 使用决策树算法构建分类模型
    val dtc: DecisionTreeClassifier = new DecisionTreeClassifier()
      .setLabelCol("index_label")
      .setFeaturesCol("index_features")
      // 设置决策树算法相关超参数
      .setMaxDepth(5)
      .setMaxBins(32) // 此值必须大于等于类别特征类别个数
      .setImpurity("gini")

    // d. 创建Pipeline,设置Stage(转换器和模型学习器)
    val pipeline: Pipeline = new Pipeline().setStages(
      Array(labelIndexer, featureIndexer, dtc)
    )


    // 3. 训练模型
    val pipelineModel: PipelineModel = pipeline.fit(trainingDF)

    // 获取决策树分类模型
    val dtcModel: DecisionTreeClassificationModel =
         pipelineModel.stages(2)
        .asInstanceOf[DecisionTreeClassificationModel]

    println(dtcModel.toDebugString)


    // 4. 模型评估
    val predictionDF: DataFrame = pipelineModel.transform(testingDF)

    predictionDF.printSchema()

    predictionDF
      .select($"label", $"index_label", $"probability", $"prediction")
      .show(20, truncate = false)


    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("index_label")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")

    val accuracy: Double = evaluator.evaluate(predictionDF)

    println(s"Accuracy = $accuracy")

    // 应用结束,关闭资源
    spark.stop()

  }

}

5、模型调优

使用决策树算法训练模型时,可以调整相关超参数,结合训练验证(Train-Validation Split)交叉验证(Cross-Validation),获取最佳模型。

5.1 训练验证

将数据集划分为两个部分 ,静态的划分,一个用于训练模型,一个用于验证模型

通过评估指标,获取最佳模型,超参数设置比较好。

在这里插入图片描述

// 无论使用何种验证方式通过调整算法超参数来进行模型调优,需要使用工具类ParamGridBuilder 将 超参数封装到Map集合中
import org.apache.spark.ml.tuning.ParamGridBuilder


val paramGrid: Array[ParamMap] = new ParamGridBuilder()
            .addGrid(lr.regParam, Array(0.1, 0.01))
            .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
            .build()

// 使用训练验证 TrainValidationSplit 方式获取最佳模型
val trainValidationSplit = new TrainValidationSplit()
        .setEstimator(lr)                      // 也可以是pipeline
        .setEvaluator(new RegressionEvaluator) // 评估器
        .setEstimatorParamMaps(paramGrid)      // 超参数
        // 80% of the data will be used for training and the remaining 20% for validation.
        .setTrainRatio(0.8)

训练验证的使用

package com.yyds.tags.ml.classification

import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{VectorAssembler, VectorIndexer}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit, TrainValidationSplitModel}
import org.apache.spark.sql.{DataFrame, SparkSession}

object HPO {

  /**
   * 调整算法超参数,找出最优模型
   * @param dataframe 数据集
   * @return
   */
  def trainBestModel(dataframe: DataFrame): PipelineModel = {
    // a. 特征向量化
    val assembler: VectorAssembler = new VectorAssembler()
      .setInputCols(Array("color", "product"))
      .setOutputCol("raw_features")

    // b. 类别特征进行索引
    val indexer: VectorIndexer = new VectorIndexer()
      .setInputCol("raw_features")
      .setOutputCol("features")
      .setMaxCategories(30)
    // .fit(dataframe)

    // c. 构建决策树分类器
    val dtc: DecisionTreeClassifier = new DecisionTreeClassifier()
      .setFeaturesCol("features")
      .setLabelCol("label")
      .setPredictionCol("prediction")

    // d. 构建Pipeline管道流实例对象
    val pipeline: Pipeline = new Pipeline().setStages(
      Array(assembler, indexer, dtc)
    )

    // e. 构建参数网格,设置超参数的值
    val paramGrid: Array[ParamMap] = new ParamGridBuilder()
      .addGrid(dtc.maxDepth, Array(5, 10))
      .addGrid(dtc.impurity, Array("gini", "entropy"))
      .addGrid(dtc.maxBins, Array(32, 64))
      .build()

    // f. 多分类评估器
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      // 指标名称,支持:f1、weightedPrecision、weightedRecall、accuracy
      .setMetricName("accuracy")

    // g. 训练验证
    val trainValidationSplit = new TrainValidationSplit()
      .setEstimator(pipeline)
      .setEvaluator(evaluator)
      .setEstimatorParamMaps(paramGrid)

      // 80% of the data will be used for training and the remaining 20% for validation.
      .setTrainRatio(0.8)


    // h. 训练模型
    val model: TrainValidationSplitModel =
      trainValidationSplit.fit(dataframe)
    // i. 获取最佳模型返回
    model.bestModel.asInstanceOf[PipelineModel]
  }


}

5.2 交叉验证(K折)

将数据集划分为两个部分 ,动态的划分为K个部分数据集,其中1份数据集为验证数据集,其他K-1分数据为训练数据集,调整参数训练模型。

在这里插入图片描述

/**
   * 采用K-Fold交叉验证方式,调整超参数获取最佳PipelineModel模型
   * @param dataframe 数据集
   * @return
   */
  def trainBestPipelineModel(dataframe: DataFrame): PipelineModel = {
    // a. 特征向量化
    val assembler: VectorAssembler = new VectorAssembler()
      .setInputCols(Array("color", "product"))
      .setOutputCol("raw_features")
    
    
    // b. 类别特征进行索引
    val indexer: VectorIndexer = new VectorIndexer()
      .setInputCol("raw_features")
      .setOutputCol("features")
      .setMaxCategories(30)
    // .fit(dataframe)
    
    
    // c. 构建决策树分类器
    val dtc: DecisionTreeClassifier = new DecisionTreeClassifier()
      .setFeaturesCol("features")
      .setLabelCol("label")
      .setPredictionCol("prediction")
    
    
    // d. 构建Pipeline管道流实例对象
    val pipeline: Pipeline = new Pipeline().setStages(
      Array(assembler, indexer, dtc)
    )
    
    
    // e. 构建参数网格,设置超参数的值
    val paramGrid: Array[ParamMap] = new ParamGridBuilder()
      .addGrid(dtc.maxDepth, Array(5, 10))
      .addGrid(dtc.impurity, Array("gini", "entropy"))
      .addGrid(dtc.maxBins, Array(32, 64))
      .build()
    
    
    // f. 多分类评估器
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      // 指标名称,支持:f1、weightedPrecision、weightedRecall、accuracy
      .setMetricName("accuracy")
    // g. 构建交叉验证实例对象
    val crossValidator: CrossValidator = new CrossValidator()
      .setEstimator(pipeline)
      .setEvaluator(evaluator)
      .setEstimatorParamMaps(paramGrid)
      .setNumFolds(3)

    // h. 训练模式
    val crossValidatorModel: CrossValidatorModel =  crossValidator.fit(dataframe)
    
    // i. 获取最佳模型
    val pipelineModel: PipelineModel = crossValidatorModel.bestModel.asInstanceOf[PipelineModel]
    
    
    // j. 返回模型
    pipelineModel

  }

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/38527.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

FPGA实验四:交通灯控制器设计

目录 一、实验目的 二、设计要求 三、实验代码 1.design source文件代码 2.仿真文件代码 3.代码原理分析 四、实验结果及分析 1、引脚锁定 2、仿真波形及分析 3、下载测试结果及分析 五、实验心得 1.解决实验中遇见的问题及解决 2.实验完成的心得 一、实验目的 &a…

Linux使用HTTP隧道代理代码示例模版

下面是一个在Linux上使用HTTP隧道代理的代码示例模板&#xff0c;可以根据自己的实际情况进行修改和配置&#xff1a; #!/bin/bash# 配置代理服务器信息 proxy_server"代理服务器IP或域名" proxy_port"代理服务器端口号" proxy_username"代理服务器用…

《动手学深度学习》——线性神经网络

参考资料&#xff1a; 《动手学深度学习》 3.1 线性回归 3.1.1 线性回归的基本元素 样本&#xff1a; n n n 表示样本数&#xff0c; x ( i ) [ x 1 ( i ) , x 2 ( i ) , ⋯ , x d ( i ) ] x^{(i)}[x^{(i)}_1,x^{(i)}_2,\cdots,x^{(i)}_d] x(i)[x1(i)​,x2(i)​,⋯,xd(i)​…

《实战AI低代码》:普元智能化低代码开发平台发布,结合专有模型大幅提升软件生产力

在7月6日举办的“低代码+AI”产品战略发布会上,普元智能化低代码开发平台正式发布。该平台融合了普元自主研发的专有模型,同时也接入了多款AI大模型的功能。它提供了一系列低代码产品,包括中间件、业务分析、应用开发、数据中台和业务流程自动化等,旨在简化企业的复杂软件生…

Nginx学习

文章目录 Nginx什么是NginxLinux安装与配置Nginx编译安装Nginxnignx使用nginx默认首页配置案例 localtion的匹配规则Nginx虚拟主机基于多IP的虚拟主机基于多端口的虚拟主机基于域名的虚拟机主机 反向代理案例①案例② 负载均衡案例①案例②分配策略 动静分离案例 配置Nginx网关…

文心一言 VS 讯飞星火 VS chatgpt (58)-- 算法导论6.4 2题

文心一言 VS 讯飞星火 VS chatgpt &#xff08;58&#xff09;-- 算法导论6.4 2题 二、试分析在使用下列循环不变量时&#xff0c;HEAPSORT 的正确性&#xff1a;在算法的第 2~5行 for 循环每次迭代开始时&#xff0c;子数组 A[1…i]是一个包含了数组A[1…n]中第i小元素的最大…

【Distributed】zookeeper+kafka的应用及部署

文章目录 一、zookeeper1. zookeeper的概述1.1 Zookeeper 定义1.2 Zookeeper 工作机制1.3 Zookeeper 特点1.4 Zookeeper 数据结构1.5 Zookeeper 应用场景1.6 Zookeeper 选举机制第一次启动选举机制非第一次启动选举机制选举Leader规则 2. 部署 Zookeeper 集群2.1 安装前准备2.2…

day52

思维导图 比较指令结果的条件码 练习 汇编实现1-100的累加 .text .global _strat _start: mov r0,#0mov r1,#0 add_fun:add r0,r0,#1cmp r0,#100addls r1,r1,r0bls add_fun .end

机器学习技术(三)——机器学习实践案例总体流程

机器学习实践案例总体流程 文章目录 机器学习实践案例总体流程一、引言二、案例1、决策树对鸢尾花分类1.数据来源2.数据导入及描述3.数据划分与特征处理4.建模预测 2、各类回归波士顿房价预测1.案例数据2.导入所需的包和数据集3.载入数据集&#xff0c;查看数据属性&#xff0c…

JVM重点整理

一、虚拟机架构图 二、类加载过程 类加载器的作用&#xff1a;负责把class文件加载到内存中 类加载过程&#xff1a; 加载&#xff1a; 通过类的全限定名获取此类的二进制字节流文件的编码结构---->运行时的内存结构内存中生成一个class对象 链接&#xff1a; 验证&#x…

【网络】socket——预备知识 | 套接字 | UDP网络通信

&#x1f431;作者&#xff1a;一只大喵咪1201 &#x1f431;专栏&#xff1a;《网络》 &#x1f525;格言&#xff1a;你只管努力&#xff0c;剩下的交给时间&#xff01; 在前面本喵对网络的整体轮廓做了一个大概的介绍&#xff0c;比如分层&#xff0c;协议等等内容&#x…

【QT】元对象系统学习笔记(一)

QT元对象系统 01、元对象系统1.1、 元对象运行原则1.2、 Q_OBJECT宏1.3、 Qt Creator启动元对象系统1.4、 命令行启动元对象&#xff08;不常用&#xff09; 02、反射机制2.1、 Qt实现反射机制2.2、 反射机制获取类中成员函数的信息2.1.1、 QMetaMethon类2.1.2、QMetaObject类 …

【UE4 塔防游戏系列】07-子弹对敌人造成伤害

目录 效果 步骤 一、让子弹拥有不同伤害 二、敌人拥有不同血量 三、修改“BP_TowerBase”逻辑 四、发射的子弹对敌人造成伤害 效果 步骤 一、让子弹拥有不同伤害 为了让每一种子弹拥有不同的伤害值&#xff0c;打开“TotalBulletsCategory”&#xff08;所有子弹的父类…

架构训练营:3-3设计备选方案与架构细化

3架构中期 什么是备选架构&#xff1f; 备选架构定义了系统可行的架构模式和技术选型 备选方案筛选过程 头脑风暴 &#xff1a;对可选技术进行排列组合&#xff0c;得到可能的方案 红线筛选&#xff1a;根据系统明确的约束和限定&#xff0c;一票否决某些方案&#xff08;主要…

为 GitHub 设置 SSH 密钥

1. 起因 给自己的 github 改个名&#xff0c;顺便就给原来 Hexo 对应的仓库也改了个名。然后发现 ubhexo clean && hexo generate && hexo deploy 失败了&#xff0c;报错如下&#xff1a; INFO Deploying: git INFO Clearing .deploy_git folder... INFO …

Hive自定义函数

本文章主要分享单行函数UDF&#xff08;一进一出&#xff09; 现在前面大体总结&#xff0c;后边文章详细介绍 自定义函数分为临时函数与永久函数 需要创建Java项目&#xff0c;导入hive依赖 创建类继承 GenericUDF&#xff08;自定义函数的抽象类&#xff09;&#xff08;实现…

仓库管理软件有哪些功能?2023仓库管理软件该如何选?

对于现代企业或批发零售商&#xff0c;高效的仓库管理是确保供应链运作顺畅、库存控制精准的关键要素。在数字化时代&#xff0c;越来越多的企业和商户意识到采用仓库管理软件的重要性。 无论您是中小型企业还是中小商户&#xff0c;仓库管理都是不可忽视的一环。 一、选择仓库…

边缘计算在智慧校园应用,实现校园智能化管理

随着科技的发展和互联网技术进步&#xff0c;校园管理正逐步实现数字化、智能化转型。边缘计算作为一种新兴技术&#xff0c;通过在离数据源较近的地方进行数据处理&#xff0c;实现了实时性分析与响应&#xff0c;为校园带来了更智能、安全的管理方式。 学生学习状态监控 AI动…

AI Chat 设计模式:8. 门面(外观)模式

本文是该系列的第八篇&#xff0c;采用问答式的方式展开&#xff0c;问题由我提出&#xff0c;答案由 Chat AI 作出&#xff0c;灰色背景的文字则主要是我的一些思考和补充。 问题列表 Q.1 请介绍一下门面模式A.1Q.2 该模式由哪些角色组成呢A.2Q.3 举一个门面模式的例子A.3Q.4…

串口wifi6+蓝牙二合一系列模块选型参考和外围电路参考设计-WG236/WG237

针对物联网数据传输&#xff0c;智能控制等应用场景研发推出的高集成小尺寸串口WiFi串口蓝牙的二合一组合模块。WiFi符合802.11a/b/g/n无线标准&#xff0c;蓝牙支持低功耗蓝牙V4.2/V5.0 BLE/V2.1和EDR&#xff0c;WiFi部分的接口是UART&#xff0c;蓝牙部分是UART/PCM 接口。模…