前言

本文隶属于专栏《1000个问题搞定大数据技术体系》,该专栏为笔者原创,引用请注明来源,不足和错误之处请在评论区帮忙指出,谢谢!

本专栏目录结构和参考文献请见1000个问题搞定大数据技术体系


关联

Spark RDD 论文详解(三)Spark 编程接口

正文

模型的创建与使用

第 1 步,数据准备。

在 MLlib 中, LinearRegressionWithSGD 需要一个 LabeledPoint 类型的 RDD 作为训练集。
训练集中 label 字段的值可以是任意实数。

第 2 步,训练模型。

LinearRegressionWithSGD 伴生对象中提供了用于训练模型的 train 方法。

train 方法通过设置训练参数进行模型训练,其主要参数如下:

  1. input : 训练集, LabeledPoint 类型的 RDD
  2. numIterations : 迭代次数,默认为 100
  3. stepSize : 每次迭代步长,默认值为 1 。
  4. miniBatchFraction : 每次送代参与计算的样本比例,默认值为 1 . 0 ,表示全部样本参与计算。
  5. initialWeights :初始化权重

第 3 步,使用模型。

LinearRegressionWithSGD 中 train 方法的返回值为 LinearRegressionModel 类型,其中定义了对特征的目标值进行预测的 predict 方法。

案例

使用下面链接中提供的数据集:

lpsa.data

启动 Intellij IDEA, 编辑如下的 Spark MLlib 应用程序:

package com.shockang.study.spark.mllib

import com.shockang.study.spark.internal.Logging
import org.apache.log4j.{Level, Logger}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.{LabeledPoint, LinearRegressionWithSGD}
import org.apache.spark.{SparkConf, SparkContext}

/**
 * JDK: 8
 * Scala: 2.11.8
 * spark-mllib_2.11: 2.2.0
 *
 * @author Shockang
 */
object LinearRegressionExample extends Logging {
  val DATA_PATH = "/Users/shockang/code/spark-examples/data/simple/mllib/lpsa.data"

  def main(args: Array[String]): Unit = {
    // 关闭 Spark 内部的日志打印,只关注结果日志
    Logger.getLogger("org").setLevel(Level.OFF)
    val conf = new SparkConf().setAppName("LinearRegressionExample").setMaster("local[*]")
    val sc = new SparkContext(conf)
    // 读取并创建训练数据
    val data = sc.textFile(DATA_PATH)
    val parsedData = data.map(line => {
      val parts = line.split(",")
      LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(" ").map(_.toDouble)))
    })
    // 训练模型
    val numIterations = 100
    val model = LinearRegressionWithSGD.train(parsedData, numIterations)
    // 预测并统计回归错误的样本比例
    val valuesAndPreds = parsedData.map(point => {
      val prediction = model.predict(point.features)
      (point.label, prediction)
    })
    val MSE = valuesAndPreds.map {
      case (v, p) => math.pow(v - p, 2)
    }.reduce(_ + _) / valuesAndPreds.count()
    logInfo(s"Training Mean Squared Error = $MSE")
  }
}
上一篇 下一篇