比較 CDist 與 scipy

以下範例著重於一個特定的運算子 CDist,並比較其在 *onnxruntime* 和 *scipy* 之間的執行時間。

具有 CDist 的 ONNX 圖形

cdist 函數計算成對距離。

from pprint import pprint
from timeit import Timer
import numpy as np
from scipy.spatial.distance import cdist
from tqdm import tqdm
from pandas import DataFrame
import onnx
import onnxruntime as rt
from onnxruntime import InferenceSession
import skl2onnx
from skl2onnx.algebra.custom_ops import OnnxCDist
from skl2onnx.common.data_types import FloatTensorType

X = np.ones((2, 4), dtype=np.float32)
Y = np.ones((3, 4), dtype=np.float32)
Y *= 2
print(cdist(X, Y, metric="euclidean"))
[[2. 2. 2.]
 [2. 2. 2.]]

ONNX

op = OnnxCDist("X", "Y", op_version=12, output_names=["Z"], metric="euclidean")
onx = op.to_onnx({"X": X, "Y": Y}, outputs=[("Z", FloatTensorType())])
print(onx)
ir_version: 8
opset_import {
  domain: "com.microsoft"
  version: 1
}
producer_name: "skl2onnx"
producer_version: "1.15.0"
domain: "ai.onnx"
model_version: 0
graph {
  node {
    input: "X"
    input: "Y"
    output: "Z"
    name: "CD_CDist"
    op_type: "CDist"
    domain: "com.microsoft"
    attribute {
      name: "metric"
      type: STRING
      s: "euclidean"
    }
  }
  name: "OnnxCDist"
  input {
    name: "X"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
          }
          dim {
            dim_value: 4
          }
        }
      }
    }
  }
  input {
    name: "Y"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
          }
          dim {
            dim_value: 4
          }
        }
      }
    }
  }
  output {
    name: "Z"
    type {
      tensor_type {
        elem_type: 1
      }
    }
  }
}

CDist 和 onnxruntime

我們使用 onnxruntime 計算 CDist 運算子的輸出。

sess = InferenceSession(onx.SerializeToString(), providers=["CPUExecutionProvider"])
res = sess.run(None, {"X": X, "Y": Y})
print(res)
[array([[1.9999999, 1.9999999, 1.9999999],
       [1.9999999, 2.       , 2.       ]], dtype=float32)]

基準測試

讓我們比較 onnxruntime 和 scipy。

def measure_time(name, stmt, context, repeat=100, number=20):
    tim = Timer(stmt, globals=context)
    res = np.array(tim.repeat(repeat=repeat, number=number))
    res /= number
    mean = np.mean(res)
    dev = np.mean(res**2)
    dev = (dev - mean**2) ** 0.5
    return dict(
        average=mean,
        deviation=dev,
        min_exec=np.min(res),
        max_exec=np.max(res),
        repeat=repeat,
        number=number,
        nrows=context["X"].shape[0],
        ncols=context["Y"].shape[1],
        name=name,
    )

scipy

time_scipy = measure_time(
    "scipy", "cdist(X, Y)", context={"cdist": cdist, "X": X, "Y": Y}
)
pprint(time_scipy)
{'average': 9.84484999992219e-06,
 'deviation': 4.450414501003498e-06,
 'max_exec': 4.325500000277316e-05,
 'min_exec': 3.990000004705508e-06,
 'name': 'scipy',
 'ncols': 4,
 'nrows': 2,
 'number': 20,
 'repeat': 100}

onnxruntime

time_ort = measure_time(
    "ort", "sess.run(None, {'X': X, 'Y': Y})", context={"sess": sess, "X": X, "Y": Y}
)
pprint(time_ort)
{'average': 1.90378500000179e-05,
 'deviation': 9.14980003126398e-06,
 'max_exec': 5.8449999994536486e-05,
 'min_exec': 1.1034999999992579e-05,
 'name': 'ort',
 'ncols': 4,
 'nrows': 2,
 'number': 20,
 'repeat': 100}

更長的基準測試

metrics = []
for dim in tqdm([10, 100, 1000, 10000]):
    # We cannot change the number of column otherwise
    # we need to create a new graph.
    X = np.random.randn(dim, 4).astype(np.float32)
    Y = np.random.randn(10, 4).astype(np.float32)

    time_scipy = measure_time(
        "scipy", "cdist(X, Y)", context={"cdist": cdist, "X": X, "Y": Y}
    )
    time_ort = measure_time(
        "ort",
        "sess.run(None, {'X': X, 'Y': Y})",
        context={"sess": sess, "X": X, "Y": Y},
    )
    metric = dict(N=dim, scipy=time_scipy["average"], ort=time_ort["average"])
    metrics.append(metric)

df = DataFrame(metrics)
df["scipy/ort"] = df["scipy"] / df["ort"]
print(df)

df.plot(x="N", y=["scipy/ort"])
plot benchmark cdist
  0%|          | 0/4 [00:00<?, ?it/s]
 50%|█████     | 2/4 [00:00<00:00, 19.40it/s]
100%|██████████| 4/4 [00:02<00:00,  1.41it/s]
100%|██████████| 4/4 [00:02<00:00,  1.64it/s]
       N     scipy       ort  scipy/ort
0     10  0.000010  0.000014   0.695364
1    100  0.000012  0.000015   0.824828
2   1000  0.000095  0.000041   2.312891
3  10000  0.000741  0.000287   2.578285

此範例使用的版本

print("numpy:", np.__version__)
print("onnx: ", onnx.__version__)
print("onnxruntime: ", rt.__version__)
print("skl2onnx: ", skl2onnx.__version__)
numpy: 1.23.5
onnx:  1.15.0
onnxruntime:  1.16.0+cu118
skl2onnx:  1.15.0

腳本總執行時間:(0 分鐘 2.761 秒)

Sphinx-Gallery 產生的範例展示