訓練、轉換和預測模型

訓練和部署模型通常包含以下三個步驟

  • 使用 scikit-learn 訓練管線,

  • 使用 sklearn-onnx 將其轉換為 ONNX

  • 使用 onnxruntime 進行預測。

訓練模型

使用隨機森林和鳶尾花資料集的基本範例。

import skl2onnx
import onnx
import sklearn
from sklearn.linear_model import LogisticRegression
import numpy
import onnxruntime as rt
from skl2onnx.common.data_types import FloatTensorType
from skl2onnx import convert_sklearn
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y)
clr = RandomForestClassifier()
clr.fit(X_train, y_train)
print(clr)
RandomForestClassifier()

將模型轉換為 ONNX

initial_type = [("float_input", FloatTensorType([None, 4]))]
onx = convert_sklearn(clr, initial_types=initial_type, target_opset=12)

with open("rf_iris.onnx", "wb") as f:
    f.write(onx.SerializeToString())

使用 ONNX Runtime 計算預測

sess = rt.InferenceSession("rf_iris.onnx", providers=["CPUExecutionProvider"])
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]
print(pred_onx)
[0 0 0 0 1 2 1 0 0 0 1 1 1 0 1 0 1 1 2 0 1 2 2 1 2 2 2 1 1 0 0 0 2 2 0 2 2
 0]

具有邏輯迴歸的完整範例

clr = LogisticRegression()
clr.fit(X_train, y_train)
initial_type = [("float_input", FloatTensorType([None, X_train.shape[1]]))]
onx = convert_sklearn(clr, initial_types=initial_type, target_opset=12)
with open("logreg_iris.onnx", "wb") as f:
    f.write(onx.SerializeToString())

sess = rt.InferenceSession("logreg_iris.onnx", providers=["CPUExecutionProvider"])
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name
pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0]
print(pred_onx)
[0 0 0 0 1 2 1 0 0 0 1 1 1 0 1 0 1 1 2 0 1 2 2 1 2 2 2 1 1 0 0 0 2 2 0 2 2
 0]

此範例使用的版本

print("numpy:", numpy.__version__)
print("scikit-learn:", sklearn.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 1.23.5
scikit-learn: 1.4.dev0
onnx:  1.15.0
onnxruntime:  1.16.0+cu118
skl2onnx:  1.16.0

腳本的總執行時間:(0 分鐘 0.172 秒)

由 Sphinx-Gallery 產生的範例集