はじめに
最近、CNTKを使用してみています。CNTKはONNXに対応しており、ONNX形式のモデルを入出力できます。ONNXは機械学習のモデルを異なるフレームワーク間で共有するツールです。Chainerで作成したモデルをCNTKで利用する、など相互に活用できるようになります。
今回はirisの分類をCNTKで行い、ONNX形式でモデルを出力してみようと思います。
環境
環境 | バージョン |
---|---|
OS | Ubuntu 16.04 |
Python | 3.6.8 |
CUDA | 10.01 |
CNTK | 2.7 |
CNTKでiris分類の概要
以下のような簡単サンプルプログラムを使い、CNTKとONNXの使い方をご紹介します。
本記事ではソースコードの一部を抜粋しています。ソースコードの全量は以下に公開しています。
https://github.com/t2hk/cntk_onnx_iris
- iris分類を行うモデルを作成する
- 作成したモデルをONNXで出力する
- ONNXモデルを読み込み、irisの推論を行う
GPUを使う
CNTKでGPUとCPUのどちらを優先的に使用するかは"try_set_default_device"で設定できます。
from cntk.device import try_set_default_device, gpu
try_set_default_device(gpu(0))
ネットワーク構築
入出力となる学習データ、分類ラベルの変数定義を行い、ネットワークを構築します。
irisの分類なので、入力となる特徴データは4つ、出力となる分類は3つとなります。
# 各種パラメータ定義
n_input = 4
n_hidden = 10
n_output = 3
# 学習データと分類ラベルの入力変数を定義する。
features = C.input_variable((n_input))
label = C.input_variable((n_output))
# ネットワークを構築する。
model = Sequential([
Dense(n_hidden, activation=C.relu),
Dense(n_hidden, activation=C.relu),
Dense(n_hidden, activation=C.relu),
Dense(n_output)])(features)
ce = C.cross_entropy_with_softmax(model, label)
pe = C.classification_error(model, label)
学習
学習データと正解ラベルのデータを用意し、訓練器を作成します。
なお、irisのデータセットの読み込みは省略します。
x_train_batch = [ミニバッチ用の学習データ]
t_train_batch = [ミニバッチ用の正解ラベル]
minibatch = C.learning_parameter_schedule(0.125)
trainer = C.Trainer(model, (ce, pe), [sgd(model.parameters, lr=minibatch)])
trainer.train_minibatch({features : x_train_batch, label : t_train_batch})
sample_count = trainer.previous_minibatch_sample_count
aggregate_loss += trainer.previous_minibatch_loss_average * sample_count
モデルのONNX出力
作成したモデルをONNX形式で保存します。formatで形式を指定するだけです。
output_file_path = R"[出力するモデルのファイルパス]"
model.save(output_file_path, format=C.ModelFormat.ONNX)
ONNXモデルの読み込み
既存のONNXモデルを使用する際もformatでONNXを指定するだけです。
load_model = C.Function.load(output_file_path, device=C.device.gpu(0), format=C.ModelFormat.ONNX)
ロードしたONNXモデルで推論する
ロードしたモデルから分類器を作成し、evalで推論します。
classifier = C.softmax(load_model)
for i, test in enumerate(x_test):
infer = np.argmax(classifier.eval([test]))
print("[{}] in:{} correct:{} infer:{}".format(i,test, t_test[i], infer))
おしまい
CNTKを使って簡単にONNX形式のモデルを出力できました。次回はこのモデルをMarkLogic10のCNTKで使用してみたいと思います。
参考
以下のサイトを参考にさせて頂きました。ありがとうございます。