ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

SparkMllib分类问题的模板代码

2020-03-02 17:55:43  阅读:226  来源: 互联网

标签:feeature val petal 代码 SparkMllib sepal width length 模板


  • 需求:对数据进行分类问题的处理

  • 开发步骤:

    • 1-准备SparkSession的环境
    • 2-准备大数据的数据
    • 3-读取数据并进行解析
    • 4-数据的基本信息的查看
    • 5-特征工程
    • 6-准备算法
    • 7-模型训练
    • 8-模型预测
    • 9-模型校验
    • 10-模型保存
    • 11-新数据预测
  • 代码模板:
import org.apache.spark.SparkConf
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature._
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}

/**
  * DESC: 对分类问题的模板的代码
  * Complete data processing and modeling process steps:
  *- 1-准备SparkSession的环境
  *- 2-准备大数据的数据
  *- 3-读取数据并进行解析
  *- 4-数据的基本信息的查看
  *- 5-特征工程
  *- 6-准备算法
  *- 7-模型训练
  *- 8-模型预测
  *- 9-模型校验
  *- 10-模型保存
  *- 11-新数据预测
  *
  */
object ClassficationModelTest {

  var datapath = "D:\\BigData\\Workspace\\SparkMachineLearningTest\\SparkMllib_BigData32\\src\\main\\resources\\iris.csv"

  def main(args: Array[String]): Unit = {
    //    - 1-准备SparkSession的环境
    val conf: SparkConf = new SparkConf().setAppName("ClassficationModelTest").setMaster("local[*]")
    val spark: SparkSession = SparkSession.builder().config(conf).getOrCreate()
    spark.sparkContext.setLogLevel("WARN")
    //    - 2-准备大数据的数据
    val irisDF: DataFrame = spark.read.format("csv")
      .option("header", true)
      .option("inferschema", true)
      .option("sep", ",")
      .load(datapath)
    //    - 3-读取数据并进行解析
    irisDF.show(10, false)
    //    +------------+-----------+------------+-----------+-----------+
    //    |sepal_length|sepal_width|petal_length|petal_width|class      |
    //    +------------+-----------+------------+-----------+-----------+
    //    |5.1         |3.5        |1.4         |0.2        |Iris-setosa|
    //      |4.9         |3.0        |1.4         |0.2        |Iris-setosa|
    //      |4.7         |3.2        |1.3         |0.2        |Iris-setosa|
    //      |4.6         |3.1        |1.5         |0.2        |Iris-setosa|
    //    - 4-数据的基本信息的查看
    irisDF.printSchema()
    // 因为在写各种string类型数据的时候可能会有一些单词拼写错误,可以实现定义
    val sepal_length_feeature = "sepal_length"
    val sepal_width_feeature = "sepal_width"
    val petal_length_feeature = "petal_length"
    val petal_width_feeature = "petal_width"
    val class_label = "class"
    //    root
    //    |-- sepal_length: double (nullable = true)
    //    |-- sepal_width: double (nullable = true)
    //    |-- petal_length: double (nullable = true)
    //    |-- petal_width: double (nullable = true)
    //    |-- class: string (nullable = true)
    //    - 5-特征工程
    //5-1处理类别型的数据class
    val stringIndexer: StringIndexer = new StringIndexer()
      .setInputCol(class_label)
      .setOutputCol("classlabel")
    val stringIndexerModel: StringIndexerModel = stringIndexer.fit(irisDF)
    val indexDF: DataFrame = stringIndexerModel.transform(irisDF)
    //5-2处理分散的特征整合为特征向量
    val vectorAssembler: VectorAssembler = new VectorAssembler()
      .setInputCols(Array(sepal_length_feeature, sepal_width_feeature, petal_length_feeature, petal_width_feeature))
      .setOutputCol("features")
    val vecDF: DataFrame = vectorAssembler.transform(indexDF)
    //5-3VectorIndexer对类别值的索引化,加速构建决策树
    val vectorIndexer: VectorIndexer = new VectorIndexer()
      .setInputCol("features")
      .setOutputCol("vecindexFeatures")
      .setMaxCategories(20)
    val vectorIndexerModel: VectorIndexerModel = vectorIndexer.fit(vecDF)
    val vecindexerDF: DataFrame = vectorIndexerModel.transform(vecDF)
    vecindexerDF.show(10, false)
    //    - 6-准备算法
    val classifier: DecisionTreeClassifier = new DecisionTreeClassifier()
      .setLabelCol("classlabel")
      .setPredictionCol("prces")
      .setFeaturesCol("vecindexFeatures")
      .setMaxDepth(5)
      .setImpurity("gini")
    val Array(trainingSet, testSet): Array[Dataset[Row]] = vecindexerDF.randomSplit(Array(0.8, 0.2), seed = 1234L)
    //    - 7-模型训练
    val model: DecisionTreeClassificationModel = classifier.fit(trainingSet)
    //    - 8-模型预测
    val y_pred_train: DataFrame = model.transform(trainingSet)
    val y_pred_test: DataFrame = model.transform(testSet)
    y_pred_train.show(10, false)
    //    - 9-模型校验
    val evaluator: MulticlassClassificationEvaluator = new MulticlassClassificationEvaluator()
      //"(f1|weightedPrecision|weightedRecall|accuracy)"
      .setMetricName("accuracy")
      .setPredictionCol("prces")
      .setLabelCol("classlabel")
    val acc_test: Double = evaluator.evaluate(y_pred_test)
    val acc_train: Double = evaluator.evaluate(y_pred_train)
    println("acc in trainset score is:", acc_train)
    println("acc in testset score is:", acc_test)
    //    (acc in trainset score is:,0.9920634920634921)
    //    (acc in testset score is:,0.9583333333333334)
    //    //    - 10-模型保存
    //    val datapath="D:\\BigData\\Workspace\\SparkMachineLearningTest\\SparkMllib_BigData32\\src\\main\\resources\\model1"
    //    model.save(datapath)
    //    //    - 11-新数据预测
    //    DecisionTreeClassificationModel.load(datapath)

  }
}

标签:feeature,val,petal,代码,SparkMllib,sepal,width,length,模板
来源: https://www.cnblogs.com/haojia/p/12396975.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有