
使用者想要的不僅僅是將轉換後的模型用於 ONNX,原因有很多。可能需要中間結果,即圖形中每個節點的輸出。可能需要變更 ONNX 以移除某些節點。遷移學習通常是移除深度神經網路的最後幾層。另一個原因是偵錯。執行階段經常因形狀不符而無法計算預測。然後,取得每個中間結果的形狀會很有用。此範例探討了兩種方法。


第一種方法很棘手:它會重載 *transform*、*predict* 和 *predict_proba* 方法,以保留輸入和輸出的副本。然後,它會遍歷管線的每個步驟。如果管線有 *n* 個步驟,它會轉換具有步驟 1 的管線,然後轉換具有步驟 1、2 的管線,然後轉換具有步驟 1、2、3 的管線,依此類推。

import numpy
from onnx.reference import ReferenceEvaluator
from onnxruntime import InferenceSession
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
from sklearn.datasets import load_iris
from skl2onnx import to_onnx
from skl2onnx.helpers import collect_intermediate_steps
from skl2onnx.common.data_types import FloatTensorType


data = load_iris()
X = data.data

pipe = Pipeline(steps=[("std", StandardScaler()), ("km", KMeans(3, n_init=3))])
Pipeline(steps=[('std', StandardScaler()),
                ('km', KMeans(n_clusters=3, n_init=3))])
在 Jupyter 環境中,請重新執行此儲存格以顯示 HTML 表示法,或信任筆記本。
在 GitHub 上,HTML 表示法無法呈現,請嘗試使用 nbviewer.org 載入此頁面。

此函式會遍歷每個步驟,重載 *transform* 方法,並為每個步驟傳回 ONNX 圖形。

steps = collect_intermediate_steps(
    pipe, "pipeline", [("X", FloatTensorType([None, X.shape[1]]))], target_opset=17

我們呼叫方法 transform 來填入快取,重載的方法 *transform* 會保留快取。

array([[3.12119834, 0.21295824, 3.98940603],
       [2.6755083 , 0.99604549, 4.01793312],
       [2.97416665, 0.65198444, 4.19343668],
       [2.88014429, 0.9034561 , 4.19784749],
       [3.30022609, 0.40215457, 4.11157152],
       [3.50554424, 1.21154793, 3.89893116],
       [3.14856384, 0.50244932, 4.21638048],
       [2.99184826, 0.09132468, 3.97313411],
       [2.92515933, 1.42174651, 4.40757189],
       [2.79398956, 0.78993078, 4.05764261],
       [3.32125333, 0.78999385, 3.92088109],
       [3.0493632 , 0.27618123, 4.07853631],
       [2.80635045, 1.03497888, 4.16440431],
       [3.21220972, 1.33482453, 4.63069748],
       [3.88834965, 1.63865558, 4.14619343],
       [4.4998303 , 2.39898792, 4.49547518],
       [3.60978017, 1.20748818, 4.02966144],
       [3.05594182, 0.21618828, 3.91388548],
       [3.34493953, 1.20986655, 3.72562039],
       [3.50065397, 0.86706182, 4.10101938],
       [2.80825681, 0.50401564, 3.66383713],
       [3.27800809, 0.66826437, 3.94496718],
       [3.58990876, 0.68658071, 4.51061335],
       [2.55934697, 0.47945627, 3.57996434],
       [2.96493153, 0.36345425, 3.98817445],
       [2.55682739, 0.99023912, 3.88431906],
       [2.8279719 , 0.22683089, 3.79088782],
       [3.05970831, 0.2947186 , 3.89539875],
       [2.95425291, 0.25361098, 3.88085622],
       [2.87745051, 0.65019824, 4.09851673],
       [2.73238773, 0.80138328, 4.01796142],
       [2.73361981, 0.52309257, 3.57350896],
       [4.11853014, 1.57658655, 4.5037664 ],
       [4.22845606, 1.87652483, 4.4465301 ],
       [2.71452112, 0.76858489, 3.97906378],
       [2.86508665, 0.54896332, 4.01986385],
       [3.0573692 , 0.63079314, 3.80064093],
       [3.40284985, 0.45982568, 4.25136846],
       [3.00742655, 1.2336976 , 4.42052558],
       [2.95472117, 0.14580827, 3.90865188],
       [3.12324651, 0.20261743, 4.01192633],
       [2.90164193, 2.67055552, 4.64398605],
       [3.15411688, 0.90927099, 4.42154566],
       [2.8613548 , 0.50081008, 3.70483773],
       [3.34606471, 0.92159916, 3.9078554 ],
       [2.65231058, 1.01946042, 4.01421067],
       [3.53206587, 0.86953764, 4.14238152],
       [2.99813103, 0.72275914, 4.23577398],
       [3.34116935, 0.72324305, 3.97409784],
       [2.90222887, 0.30295342, 3.97223984],
       [1.9003878 , 3.43619989, 0.95288059],
       [1.41851492, 2.97232682, 0.99352148],
       [1.68457079, 3.51850037, 0.72661726],
       [0.96940962, 3.33264308, 2.69898424],
       [0.9112523 , 3.35747592, 1.11074501],
       [0.35721918, 2.77550662, 1.8143491 ],
       [1.59351202, 3.01808184, 1.00650285],
       [1.50213315, 2.77360088, 3.31296552],
       [1.11632078, 3.21148368, 1.14114175],
       [0.77921299, 2.66294828, 2.42994048],
       [1.97194958, 3.62389817, 3.73666782],
       [0.77530513, 2.70011145, 1.45918639],
       [1.25941769, 3.53658932, 2.74268279],
       [0.66155141, 2.98813829, 1.28976474],
       [0.73833453, 2.32311723, 2.05251547],
       [1.46572707, 3.14311522, 0.98780965],
       [0.80185102, 2.68234835, 1.67700171],
       [0.568386  , 2.63954211, 2.12682734],
       [1.19987895, 3.97369206, 2.33743839],
       [0.67881532, 2.87494798, 2.46667974],
       [1.34222961, 3.03853641, 1.1880022 ],
       [0.53061062, 2.8022861 , 1.63233668],
       [0.79234309, 3.68305664, 1.65142259],
       [0.57371215, 2.96833851, 1.54593744],
       [0.90589785, 2.9760862 , 1.2933375 ],
       [1.22490527, 3.13002382, 1.03085926],
       [1.26783271, 3.56679427, 1.09304603],
       [1.42114042, 3.5903606 , 0.52050254],
       [0.58974672, 2.93839428, 1.34712856],
       [0.76432091, 2.58203512, 2.44164622],
       [0.89738242, 2.99796537, 2.69027665],
       [0.98549851, 2.92597852, 2.76965187],
       [0.3921368 , 2.68907313, 2.02829879],
       [0.54223583, 3.42215998, 1.4211892 ],
       [0.90567816, 2.62771445, 1.88799766],
       [1.70872911, 2.75915071, 1.39853465],
       [1.48190142, 3.30075052, 0.78009974],
       [1.06129323, 3.73017167, 2.2083069 ],
       [0.81863359, 2.37943811, 1.87666989],
       [0.599882  , 2.98789866, 2.41035271],
       [0.4914813 , 2.89079656, 2.26782134],
       [0.84409423, 2.86642713, 1.25085451],
       [0.38941349, 2.86642575, 2.11791607],
       [1.53271026, 2.96966239, 3.35089399],
       [0.30831638, 2.77003779, 2.05312152],
       [0.81726253, 2.38255534, 1.83091351],
       [0.56428027, 2.55559903, 1.80454586],
       [0.72672271, 2.8455521 , 1.39825227],
       [1.28805849, 2.56987887, 3.06324547],
       [0.38163798, 2.64007308, 1.89861511],
       [2.31271244, 4.24274589, 1.0584579 ],
       [0.76585766, 3.57067982, 1.5185265 ],
       [2.14762671, 4.44150237, 0.52472   ],
       [1.17645413, 3.69480186, 0.77236486],
       [1.73594932, 4.11613683, 0.53031563],
       [2.78128346, 5.03326801, 1.2022172 ],
       [1.22550604, 3.3503222 , 2.74462238],
       [2.2426558 , 4.577021  , 0.92275933],
       [1.50462864, 4.363498  , 1.40314162],
       [3.22975724, 4.79334275, 1.48323372],
       [1.71837714, 3.62749566, 0.4787491 ],
       [1.10409694, 3.89360823, 1.0325986 ],
       [1.80475907, 4.1132966 , 0.27818948],
       [0.94858807, 3.82688169, 1.91870424],
       [1.39433359, 3.91538879, 1.49910975],
       [1.90677079, 3.89835633, 0.68622715],
       [1.39713702, 3.70128288, 0.46463058],
       [3.85224062, 5.18341242, 2.10127163],
       [2.95786451, 5.58136629, 1.83092395],
       [1.17790381, 4.02615768, 2.37017622],
       [2.27442972, 4.31907679, 0.52540209],
       [0.91211061, 3.4288432 , 1.62249456],
       [2.77937737, 5.19031307, 1.47042293],
       [0.84735471, 3.64273089, 1.15814207],
       [2.15695444, 4.00723617, 0.520093  ],
       [2.33581345, 4.2637671 , 0.66660166],
       [0.79774043, 3.45930032, 1.08324891],
       [1.022307  , 3.27575645, 0.94925151],
       [1.3842265 , 4.05342943, 0.84098317],
       [2.03854964, 4.1585729 , 0.75748198],
       [2.28297732, 4.71100584, 1.07124861],
       [3.88774921, 5.12224641, 2.17345728],
       [1.47357101, 4.13401784, 0.87682321],
       [0.7964005 , 3.39830644, 1.11534598],
       [0.80521086, 3.63719075, 1.59782917],
       [2.8607372 , 5.08776655, 1.25982873],
       [2.3101089 , 4.00416552, 1.07214028],
       [1.46990247, 3.58815834, 0.51434392],
       [0.97017134, 3.19454679, 1.0762733 ],
       [1.97333575, 4.09907253, 0.23050145],
       [2.07939567, 4.28416057, 0.57373487],
       [2.06609741, 4.17402084, 0.51130902],
       [0.76585766, 3.57067982, 1.5185265 ],
       [2.24723796, 4.32128686, 0.54141867],
       [2.42521977, 4.3480018 , 0.85128501],
       [1.82594618, 4.1240495 , 0.52475835],
       [1.03093862, 3.97564407, 1.52100812],
       [1.44892686, 3.7539635 , 0.44371189],
       [2.17585453, 3.7969924 , 1.08437101],
       [1.00508668, 3.25638099, 1.13739231]])

我們計算每個步驟,並比較 ONNX 和 scikit-learn 的輸出。

for step in steps:
    onnx_step = step["onnx_step"]
    sess = InferenceSession(
        onnx_step.SerializeToString(), providers=["CPUExecutionProvider"]
    onnx_outputs = sess.run(None, {"X": X.astype(numpy.float32)})
    onnx_output = onnx_outputs[-1]
    skl_outputs = step["model"]._debug.outputs["transform"]

    # comparison
    diff = numpy.abs(skl_outputs.ravel() - onnx_output.ravel()).max()
    print("difference", diff)

# That was the first way: dynamically overwrite
# every method transform or predict in a scikit-learn
# pipeline to capture the input and output of every step,
# compare them to the output produced by truncated ONNX
# graphs built from the first one.
difference 4.799262827148709e-07
KMeans(n_clusters=3, n_init=3)
difference 1.095537650763756e-06

使用 Python 執行階段檢視每個節點

Python 執行階段可能對於輕鬆檢視 ONNX 圖形的每個節點很有用。此選項可用於檢查何時因 nan 值或維度不符而導致計算失敗。

onx = to_onnx(pipe, X[:1].astype(numpy.float32), target_opset=17)

oinf = ReferenceEvaluator(onx, verbose=1)
oinf.run(None, {"X": X[:2].astype(numpy.float32)})
[array([1, 1]), array([[3.1211984 , 0.21295893, 3.9894059 ],
       [2.675508  , 0.99604493, 4.017933  ]], dtype=float32)]


oinf = ReferenceEvaluator(onx, verbose=3)
oinf.run(None, {"X": X[:2].astype(numpy.float32)})

# This way is usually better if you need to investigate
# issues within the code of the runtime for an operator.
 +C Ad_Addcst: float32:(3,) in [0.9830552339553833, 5.035177230834961]
 +C Ge_Gemmcst: float32:(3, 4) in [-1.3049873113632202, 1.1359702348709106]
 +C Mu_Mulcst: float32:(1,) in [0.0, 0.0]
 +I X: float32:(2, 4) in [0.20000000298023224, 5.099999904632568]
Scaler(X) -> variable
 + variable: float32:(2, 4) in [-1.340226411819458, 1.0190045833587646]
ReduceSumSquare(variable) -> Re_reduced0
 + Re_reduced0: float32:(2, 1) in [4.850505828857422, 5.376197338104248]
Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
 + Mu_C0: float32:(2, 1) in [0.0, 0.0]
Gemm(variable, Ge_Gemmcst, Mu_C0) -> Ge_Y0
 + Ge_Y0: float32:(2, 3) in [-10.366023063659668, 7.967348575592041]
Add(Re_reduced0, Ge_Y0) -> Ad_C01
 + Ad_C01: float32:(2, 3) in [-4.98982572555542, 12.817853927612305]
Add(Ad_Addcst, Ad_C01) -> Ad_C0
 + Ad_C0: float32:(2, 3) in [0.045351505279541016, 16.143783569335938]
ArgMin(Ad_C0) -> label
 + label: int64:(2,) in [1, 1]
Sqrt(Ad_C0) -> scores
 + scores: float32:(2, 3) in [0.2129589319229126, 4.017932891845703]

[array([1, 1]), array([[3.1211984 , 0.21295893, 3.9894059 ],
       [2.675508  , 0.99604493, 4.017933  ]], dtype=float32)]

腳本總執行時間: (0 分鐘 0.291 秒)

由 Sphinx-Gallery 產生圖庫