はじめに
前回はSparkで作ったモデルをhttp4s
で無理やりWebAPIにした。
が、実行(予測)にSparkクラスタ(WebAPサーバと同じマシンで1台構成だけど)が必要である点が、どうにも気持ち悪い。もっとシンプルに、軽量に、ポータブルに、予測だけしたい。
で、調べてみるとMleap
というOSSプロジェクトが、Spark MLlibなどの機械学習モデルをSparkなしで動かすという、まさにドンピシャなOSSだったので、これを使ってみる。
まとめ
- Mleapを用いれば、予測時にSparkに依存せずにモデルを実行できる。
- MleapのランタイムはほとんどSparkと同じようなデータフレームの機能をもっているので、処理を大きく変更する必要もない
準備
必要なライブラリは以下。(ちょっと手抜きで、学習と予測を同一プロジェクトでやっているけれど)
build.sbt
name := "MleapSample"
version := "0.1"
scalaVersion := "2.11.12"
// 機械学習するときにはSparkが必要(学習時のみ必要)
libraryDependencies += "org.apache.spark" %% "spark-core" % "2.4.3"
libraryDependencies += "org.apache.spark" %% "spark-sql" % "2.4.3"
libraryDependencies += "org.apache.spark" %% "spark-mllib" % "2.4.3"
// Mleap(予測時、学習時ともに利用)
libraryDependencies += "ml.combust.mleap" %% "mleap-spark" % "0.14.0"
// WebAPI用にAPサーバとJSONパーサ(予測時のみ)
libraryDependencies += "org.http4s" %% "http4s-dsl" % "0.20.10"
libraryDependencies += "org.http4s" %% "http4s-blaze-server" % "0.20.10"
libraryDependencies += "org.http4s" %% "http4s-circe" % "0.20.10"
libraryDependencies += "io.circe" %% "circe-generic" % "0.11.1"
libraryDependencies += "io.circe" %% "circe-literal" % "0.11.1"
機械学習モデルのシリアライズ化
機械学習モデルは、前回と同じだが、最後のシリアライズ化の部分だけが異なりMleapを利用する。やっていることはMleapのマニュアル通り。
Training.scala
// (Sparkの機械学習に必要な依存は省略)
import ml.combust.mleap.spark.SparkSupport._
import ml.combust.bundle.BundleFile
import org.apache.spark.ml.bundle.SparkBundleContext
import resource.managed
object Training{
def main(args:Array[String]):Unit = {
val spark = SparkSession.builder().master("local[*]").getOrCreate()
train("/tmp/model.zip")(spark)
}
def train(path:String)(implicit spark:SparkSession):Unit = {
//(機械学習モデルの構築は中略。要はPipeLineを構築するだけ)
//学習
val model = validator.fit(train)
//検証
val predict = model.transform(test)
// ↓をやらないと、モデルを保存できてもロードで失敗した。
val sbc = SparkBundleContext().withDataset(predict)
//マニュアル通りに保存
for (bundle <- managed(BundleFile("jar:file:" + path))) {
model.writeBundle.save(bundle)(sbc)
}
}
}
予測処理
モデルがZIP形式で保存できたら、あとはこれを Sparkに依存せず 予測処理を行う。ほとんどSparkと同じ書き方でできる。
Predictor.scala
import Training.IrisData
import ml.combust.bundle.BundleFile
import ml.combust.mleap.runtime.MleapSupport._
import ml.combust.mleap.runtime.frame.{DefaultLeapFrame, Row}
import ml.combust.mleap.core.types._
import resource.managed
import scala.util.Try
object Predictor {
//モデルを読み込み
val bundle = (for (bundleFile <- managed(BundleFile("jar:file:/tmp/model.zip"))) yield {
bundleFile.loadMleapBundle().get
}).tried.get
// このStructTypeは`org.apache.spark.sql.types.StructType`とは別で、mleapのStructType
val schema = StructType(
StructField("f0", ScalarType.Double),
StructField("f1", ScalarType.Double),
StructField("f2", ScalarType.Double),
StructField("f3", ScalarType.Double),
StructField("target", ScalarType.Int)).get
def predict(iris:IrisData):Try[Int] = {
//Mleapのデータフレームを作成して
val data = DefaultLeapFrame(schema, Seq(
Row(iris.f0,iris.f1,iris.f2,iris.f3,iris.target)
))
//モデル
val pipeline = bundle.root
//予測処理
for {
transformed <- pipeline.transform(data) //予測処理
prediction <- transformed.select("prediction") //予測結果列だけ取り出し
} yield {
prediction.dataset(0).getDouble(0).toInt //先頭の1行1列目を取り出してInt変換
}
}
}
WebAPI側
こちらは、特に変わらず。(前回記事よりちょこっと修正しているけれど)
WebApp.scala
import Training.IrisData
import org.http4s.HttpRoutes
import org.http4s.implicits._
import org.http4s.dsl.io._
import scala.util.{Failure,Success}
object IrisPredictorService {
case class ReqJson(f0:Double,f1:Double,f2:Double,f3:Double){
def toIrisData():IrisData = IrisData(f0,f1,f2,f3,0)
}
case class ResJson(target:Int)
val service = HttpRoutes.of[IO]{
case req @ POST -> Root / "iris" / "predict" =>
import org.http4s.circe.{jsonOf,jsonEncoderOf}
import io.circe.generic.auto._
implicit val decoder = jsonOf[IO,ReqJson]
implicit val encoder = jsonEncoderOf[IO,ResJson]
for {
inputJson <- req.as[ReqJson]
test = inputJson.toIrisData
res <- Predictor.predict(test) match {
case Success(value) => Ok(ResJson(value))
case Failure(_) => InternalServerError()
}
} yield (res)
}.orNotFound
}
サーバ起動処理
前回から変わらず。相変わらず猫がモナモナしているところは、よく解ってないけど。。。
import cats.effect.{ContextShift, IO,Timer}
import scala.concurrent.ExecutionContext
import org.http4s.server.blaze.BlazeServerBuilder
object WebApp {
def main(args:Array[String]):Unit = {
implicit val cs:ContextShift[IO] = IO.contextShift(ExecutionContext.global)
implicit val timer:Timer[IO] = IO.timer(ExecutionContext.global)
val serverBuilder = BlazeServerBuilder[IO]
.bindLocal(9999)
.withHttpApp(IrisPredictorService.service)
val fiber = serverBuilder
.resource
.use(_ => IO.never)
.start
.unsafeRunSync()
scala.io.StdIn.readLine()
fiber.cancel.unsafeRunSync()
}
}