ONNX 形狀推斷¶
ONNX 在 ONNX 圖表上提供形狀推斷的選擇性實作。此實作涵蓋每個核心運算子,並提供可擴充性的介面。因此,您可以選擇在圖表上叫用現有的形狀推斷功能,或是定義形狀推斷實作來搭配您的自訂運算子 (或兩者都做!)。形狀推斷函式會儲存為 OpSchema 物件的成員。
在 ONNX 1.10 版本中,符號產生和傳播以及形狀資料傳播已新增至 ONNX 圖表層級的形狀推斷。詳細的提案位於 此處
背景¶
如需檢閱靜態張量形狀,請參閱 此 IR.md 的章節。特別是,靜態張量形狀 (由 TensorShapeProto
表示) 與執行階段張量形狀不同。當靜態 (也就是在編譯時期) 未知確切執行階段張量形狀時,通常會使用此功能。
使用具有未定義
shape
欄位的Tensor
來表示未知階數的張量。使用具有已定義
shape
的Tensor
來表示已知階數的張量。在
TensorShapeProto
的每個Dimension
中,可以有已知的整數值 (由dim_value
欄位表示),或可以有未知的值 (由符號識別碼表示) (dim_param
欄位),或者可能未定義任何欄位 (在這種情況下,它表示匿名的未知值)。
叫用形狀推斷¶
可以透過 C++ 或 Python 叫用形狀推斷。此處說明 Python API,並附帶範例:此處。
C++ API 由單一函式組成
shape_inference::InferShapes(
ModelProto& m,
const ISchemaRegistry* schema_registry);
第一個引數是要執行形狀推斷的 ModelProto
,並以形狀資訊就地註釋。第二個引數為選擇性。
限制¶
不保證形狀推斷完整。特別是,某些動態行為會阻擋形狀推斷的流程,例如對動態提供的形狀執行重塑。此外,並非所有運算子都必須具有形狀推斷實作。
形狀推斷僅適用於常數和簡單變數。它不支援包含變數的算術運算式。例如,可以推斷 (5, 2)
和 (7, 2)
形狀的張量上的 Concat
會產生 (12, 2)
形狀的結果,但 (5, 2)
和 (N, 2)
形狀的張量上的 Concat
只會產生 (M, 2)
,而不是包含 N+5
的表示法。請注意,會傳播不同的未知符號值,因此此處的 M
代表未知數量,其與 M
的其他出現位置相同。
這些限制是目前實作的特性,而非基本限制 - 如果您需要更進階的功能,請務必告知我們!
為運算子實作形狀推斷¶
您可以使用以下方法將形狀推斷函式新增至運算子的 Schema:
OpSchema& Opschema::TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction);
InferenceFunction
定義於 shape_inference.h 中,以及核心介面結構 InferenceContext
和各種協助程式方法。InferenceContext
是提供給推斷函式的核心結構。它允許存取運算子輸入的相關資訊,並且也允許寫出推斷的資訊。
若要查看許多範例,請在程式碼庫中搜尋 TypeAndShapeInferenceFunction
的出現次數。其中一個相對複雜的是 Concat
的實作,位於 onnx/defs/tensor/defs.cc 中。
實作運算子的形狀推斷方法時,請注意以下幾點,以避免常見錯誤
在存取任何輸入的
shape
之前,程式碼必須檢查 shape 是否可用。如果不可用,應將其視為 rank 未知的動態張量,並適當地處理。通常,shape 推斷邏輯會以呼叫hasInputShape
或hasNInputShapes
來保護。在存取任何維度的
dim_value
或dim_param
之前,程式碼必須檢查這些欄位是否有值。特別是,程式碼必須處理維度可能沒有靜態已知值的可能性。
shape_inference.h 中有幾個實用函式可用於處理各種常見情況。
對於必須具有固定 rank 的輸入,請使用
checkInputRank
。(請參閱RoiAlign
的推斷作為範例。)當多個輸入維度預期相同,以及當輸入維度傳播到特定輸出維度時,可以使用
unifyInputDim
和unifyDim
和updateOutputShape
。(請參閱RoiAlign
的推斷作為範例。)當使用算術從輸入維度計算輸出維度時,可以在符號維度上使用重載的運算子
*
和/
。(請參閱SpaceToDepth
的推斷作為範例。)
這些實用工具可以安全地處理遺失的 shape 和維度。
範例:考慮一個簡單的矩陣乘法運算,它預期輸入的 shape 為 [M,K]
和 [K,N]
,並返回 shape 為 [M,N]
的輸出。這可以用以下方式編碼
// Check that input 0 has rank 2 (if its rank is known).
checkInputRank(ctx, 0, 2);
// Check that input 1 has rank 2 (if its rank is known).
checkInputRank(ctx, 1, 2);
Dim M, K, N;
// Check various dimensions, handling missing dimensions/shapes safely.
unifyInputDim(ctx, 0, 0, M);
unifyInputDim(ctx, 0, 1, K);
unifyInputDim(ctx, 1, 0, K);
unifyInputDim(ctx, 1, 1, N);
updateOutputShape(ctx, 0, {M. N});