LoginSignup
3
1

More than 3 years have passed since last update.

続・Spark機械学習モデルのWebAPI化(Mleapを使って、Sparkに依存せずモデルを実行する)

Posted at

はじめに

前回はSparkで作ったモデルをhttp4sで無理やりWebAPIにした。

が、実行(予測)にSparkクラスタ(WebAPサーバと同じマシンで1台構成だけど)が必要である点が、どうにも気持ち悪い。もっとシンプルに、軽量に、ポータブルに、予測だけしたい。

で、調べてみるとMleapというOSSプロジェクトが、Spark MLlibなどの機械学習モデルをSparkなしで動かすという、まさにドンピシャなOSSだったので、これを使ってみる。

まとめ

  • Mleapを用いれば、予測時にSparkに依存せずにモデルを実行できる。
  • MleapのランタイムはほとんどSparkと同じようなデータフレームの機能をもっているので、処理を大きく変更する必要もない

準備

必要なライブラリは以下。(ちょっと手抜きで、学習と予測を同一プロジェクトでやっているけれど)

build.sbt
name := "MleapSample"
version := "0.1"
scalaVersion := "2.11.12"

// 機械学習するときにはSparkが必要(学習時のみ必要)
libraryDependencies += "org.apache.spark" %% "spark-core" % "2.4.3"
libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.4.3"
libraryDependencies += "org.apache.spark" %% "spark-mllib" % "2.4.3"

// Mleap(予測時、学習時ともに利用)
libraryDependencies += "ml.combust.mleap" %% "mleap-spark" % "0.14.0"

// WebAPI用にAPサーバとJSONパーサ(予測時のみ)
libraryDependencies += "org.http4s" %% "http4s-dsl" % "0.20.10"
libraryDependencies += "org.http4s" %% "http4s-blaze-server" % "0.20.10"
libraryDependencies += "org.http4s" %% "http4s-circe" % "0.20.10"
libraryDependencies += "io.circe" %% "circe-generic" % "0.11.1"
libraryDependencies += "io.circe" %% "circe-literal" % "0.11.1"

機械学習モデルのシリアライズ化

機械学習モデルは、前回と同じだが、最後のシリアライズ化の部分だけが異なりMleapを利用する。やっていることはMleapのマニュアル通り。

Training.scala
// (Sparkの機械学習に必要な依存は省略)
import ml.combust.mleap.spark.SparkSupport._
import ml.combust.bundle.BundleFile
import org.apache.spark.ml.bundle.SparkBundleContext
import resource.managed

object Training{
  def main(args:Array[String]):Unit = {
    val spark = SparkSession.builder().master("local[*]").getOrCreate()
    train("/tmp/model.zip")(spark)
  }

  def train(path:String)(implicit spark:SparkSession):Unit = {
    //(機械学習モデルの構築は中略。要はPipeLineを構築するだけ)

    //学習
    val model = validator.fit(train)
    //検証
    val predict = model.transform(test)
    // ↓をやらないと、モデルを保存できてもロードで失敗した。
    val sbc = SparkBundleContext().withDataset(predict)
    //マニュアル通りに保存
    for (bundle <- managed(BundleFile("jar:file:" + path))) {
      model.writeBundle.save(bundle)(sbc)
    }
  }
}

予測処理

モデルがZIP形式で保存できたら、あとはこれを Sparkに依存せず 予測処理を行う。ほとんどSparkと同じ書き方でできる。

Predictor.scala
import Training.IrisData
import ml.combust.bundle.BundleFile
import ml.combust.mleap.runtime.MleapSupport._
import ml.combust.mleap.runtime.frame.{DefaultLeapFrame, Row}
import ml.combust.mleap.core.types._
import resource.managed
import scala.util.Try

object Predictor {
  //モデルを読み込み
  val bundle = (for (bundleFile <- managed(BundleFile("jar:file:/tmp/model.zip"))) yield {
    bundleFile.loadMleapBundle().get
  }).tried.get

  // このStructTypeは`org.apache.spark.sql.types.StructType`とは別で、mleapのStructType
  val schema = StructType(
    StructField("f0", ScalarType.Double),
    StructField("f1", ScalarType.Double),
    StructField("f2", ScalarType.Double),
    StructField("f3", ScalarType.Double),
    StructField("target", ScalarType.Int)).get

  def predict(iris:IrisData):Try[Int] = {
    //Mleapのデータフレームを作成して
    val data = DefaultLeapFrame(schema, Seq(
      Row(iris.f0,iris.f1,iris.f2,iris.f3,iris.target)
    ))

    //モデル
    val pipeline = bundle.root

    //予測処理
    for {
      transformed <- pipeline.transform(data) //予測処理
      prediction <- transformed.select("prediction") //予測結果列だけ取り出し
    } yield {
      prediction.dataset(0).getDouble(0).toInt //先頭の1行1列目を取り出してInt変換
    }
  }
}

WebAPI側

こちらは、特に変わらず。(前回記事よりちょこっと修正しているけれど)

WebApp.scala
import Training.IrisData
import org.http4s.HttpRoutes
import org.http4s.implicits._
import org.http4s.dsl.io._
import scala.util.{Failure,Success}

object IrisPredictorService {
  case class ReqJson(f0:Double,f1:Double,f2:Double,f3:Double){
    def toIrisData():IrisData = IrisData(f0,f1,f2,f3,0)
  }
  case class ResJson(target:Int)

  val service = HttpRoutes.of[IO]{
    case req @ POST -> Root / "iris" / "predict" =>
      import org.http4s.circe.{jsonOf,jsonEncoderOf}
      import io.circe.generic.auto._
      implicit val decoder = jsonOf[IO,ReqJson]
      implicit val encoder = jsonEncoderOf[IO,ResJson]

      for {
        inputJson <- req.as[ReqJson]
        test = inputJson.toIrisData
        res <- Predictor.predict(test) match {
          case Success(value) => Ok(ResJson(value))
          case Failure(_) => InternalServerError()
        }
      } yield (res)
  }.orNotFound
}

サーバ起動処理

前回から変わらず。相変わらず猫がモナモナしているところは、よく解ってないけど。。。

import cats.effect.{ContextShift, IO,Timer}
import scala.concurrent.ExecutionContext
import org.http4s.server.blaze.BlazeServerBuilder

object WebApp {
  def main(args:Array[String]):Unit = {
    implicit val cs:ContextShift[IO] = IO.contextShift(ExecutionContext.global)
    implicit val timer:Timer[IO] = IO.timer(ExecutionContext.global)

    val serverBuilder = BlazeServerBuilder[IO]
      .bindLocal(9999)
      .withHttpApp(IrisPredictorService.service)

    val fiber = serverBuilder
      .resource
      .use(_ => IO.never)
      .start
      .unsafeRunSync()

    scala.io.StdIn.readLine()
    fiber.cancel.unsafeRunSync()
  }
}
3
1
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1