實作 ONNX 後端

什麼是 ONNX 後端

ONNX 後端是一個可以執行 ONNX 模型的函式庫。由於已經存在許多深度學習框架,您可能不需要從頭開始建立所有內容。相反地,您可能會建立一個轉換器,將 ONNX 模型轉換為對應的框架特定表示法,然後將執行委派給該框架。例如,onnx-caffe2 (作為 caffe2 的一部分)onnx-coremlonnx-tensorflow 都實作為轉換器。

統一的後端介面

ONNX 在 onnx/backend/base.py 定義了一個統一的 (Python) 後端介面。

此介面中有三個核心概念:DeviceBackendBackendRep

  • Device 是對各種硬體的輕量級抽象,例如 CPU、GPU 等。

  • Backend 是一個實體,它將接收具有輸入的 ONNX 模型,執行計算,然後傳回輸出。

    對於一次性的執行,使用者可以使用 run_noderun_model 來快速獲取結果。

    對於重複執行,使用者應該使用 prepare,其中 Backend 會執行所有準備工作以重複執行模型 (例如,載入初始化程式),並傳回 BackendRep 控制代碼。

  • BackendRepBackend 在準備好重複執行模型後傳回的控制代碼。然後,使用者會將輸入傳遞給 BackendReprun 函數以檢索相應的結果。

請注意,即使 ONNX 統一後端介面是在 Python 中定義的,您的後端也不需要在 Python 中實作。例如,您的後端可以在 C++ 中建立,並且可以使用 pybind11cython 等工具來滿足介面。

ONNX 後端測試

ONNX 提供標準的後端測試套件,以協助後端實作驗證。強烈建議每個 ONNX 後端都執行此測試。

將 ONNX 後端測試套件整合到您的 CI 中很簡單。以下是一些範例,示範後端如何執行整合

如果您已安裝 pytest,您可以在執行 ONNX 後端測試後取得涵蓋率報告,以了解您的後端表現如何

---------- onnx coverage: ----------
Operators (passed/loaded/total): 21/21/70
------------------------------------
╒════════════════════╤════════════════════╕
│ Operator           │ Attributes         │
│                    │ (name: #values)    │
╞════════════════════╪════════════════════╡
│ Slice              │ axes: 2            │
│                    │ ends: 3            │
│                    │ starts: 3          │
├────────────────────┼────────────────────┤
│ Constant           │ value: 1           │
├────────────────────┼────────────────────┤
│ Concat             │ axis: 0            │
├────────────────────┼────────────────────┤
│ Conv               │ group: 6           │
│                    │ kernel_shape: 5    │
│                    │ pads: 4            │
│                    │ strides: 3         │
│                    │ auto_pad: 0        │
│                    │ dilations: 0       │
├────────────────────┼────────────────────┤
│ Reshape            │ shape: 9           │
├────────────────────┼────────────────────┤
│ BatchNormalization │ consumed_inputs: 1 │
│                    │ epsilon: 2         │
│                    │ is_test: 1         │
│                    │ momentum: 0        │
│                    │ spatial: 0         │
├────────────────────┼────────────────────┤
│ Dropout            │ is_test: 1         │
│                    │ ratio: 2           │
├────────────────────┼────────────────────┤
│ MaxPool            │ kernel_shape: 2    │
│                    │ pads: 3            │
│                    │ strides: 2         │
│                    │ auto_pad: 0        │
│                    │ dilations: 0       │
├────────────────────┼────────────────────┤
│ Transpose          │ perm: 1            │
├────────────────────┼────────────────────┤
│ MatMul             │ No attributes      │
├────────────────────┼────────────────────┤
│ Relu               │ No attributes      │
├────────────────────┼────────────────────┤
│ LRN                │ alpha: 2           │
│                    │ beta: 1            │
│                    │ bias: 2            │
│                    │ size: 1            │
├────────────────────┼────────────────────┤
│ Add                │ axis: 1            │
│                    │ broadcast: 1       │
├────────────────────┼────────────────────┤
│ Abs                │ No attributes      │
├────────────────────┼────────────────────┤
│ Pad                │ mode: 3            │
│                    │ paddings: 2        │
│                    │ value: 1           │
├────────────────────┼────────────────────┤
│ Softmax            │ axis: 0            │
├────────────────────┼────────────────────┤
│ GlobalAveragePool  │ No attributes      │
├────────────────────┼────────────────────┤
│ Mul                │ axis: 1            │
│                    │ broadcast: 1       │
├────────────────────┼────────────────────┤
│ Sum                │ No attributes      │
├────────────────────┼────────────────────┤
│ Gemm               │ broadcast: 1       │
│                    │ transB: 1          │
│                    │ alpha: 0           │
│                    │ beta: 0            │
│                    │ transA: 0          │
├────────────────────┼────────────────────┤
│ AveragePool        │ kernel_shape: 3    │
│                    │ pads: 3            │
│                    │ strides: 2         │
│                    │ auto_pad: 0        │
╘════════════════════╧════════════════════╛

Operators (passed/loaded/total): 21/21/70 中的數字表示您的後端在所有測試案例中涵蓋的 21 個運算符已通過,ONNX 後端測試的所有測試案例中涵蓋了 21 個運算符,而 ONNX 總共有 70 個運算符。