ONNX 形狀推斷

ONNX 在 ONNX 圖表上提供形狀推斷的選擇性實作。此實作涵蓋每個核心運算子,並提供可擴充性的介面。因此,您可以選擇在圖表上叫用現有的形狀推斷功能,或是定義形狀推斷實作來搭配您的自訂運算子 (或兩者都做!)。形狀推斷函式會儲存為 OpSchema 物件的成員。

在 ONNX 1.10 版本中,符號產生和傳播以及形狀資料傳播已新增至 ONNX 圖表層級的形狀推斷。詳細的提案位於 此處

背景

如需檢閱靜態張量形狀,請參閱 IR.md 的章節。特別是,靜態張量形狀 (由 TensorShapeProto 表示) 與執行階段張量形狀不同。當靜態 (也就是在編譯時期) 未知確切執行階段張量形狀時,通常會使用此功能。

  • 使用具有未定義 shape 欄位的 Tensor 來表示未知階數的張量。

  • 使用具有已定義 shapeTensor 來表示已知階數的張量。

  • TensorShapeProto 的每個 Dimension 中,可以有已知的整數值 (由 dim_value 欄位表示),或可以有未知的值 (由符號識別碼表示) (dim_param 欄位),或者可能未定義任何欄位 (在這種情況下,它表示匿名的未知值)。

叫用形狀推斷

可以透過 C++ 或 Python 叫用形狀推斷。此處說明 Python API,並附帶範例:此處

C++ API 由單一函式組成

shape_inference::InferShapes(
    ModelProto& m,
    const ISchemaRegistry* schema_registry);

第一個引數是要執行形狀推斷的 ModelProto,並以形狀資訊就地註釋。第二個引數為選擇性。

限制

不保證形狀推斷完整。特別是,某些動態行為會阻擋形狀推斷的流程,例如對動態提供的形狀執行重塑。此外,並非所有運算子都必須具有形狀推斷實作。

形狀推斷僅適用於常數和簡單變數。它不支援包含變數的算術運算式。例如,可以推斷 (5, 2)(7, 2) 形狀的張量上的 Concat 會產生 (12, 2) 形狀的結果,但 (5, 2)(N, 2) 形狀的張量上的 Concat 只會產生 (M, 2),而不是包含 N+5 的表示法。請注意,會傳播不同的未知符號值,因此此處的 M 代表未知數量,其與 M 的其他出現位置相同。

這些限制是目前實作的特性,而非基本限制 - 如果您需要更進階的功能,請務必告知我們!

為運算子實作形狀推斷

您可以使用以下方法將形狀推斷函式新增至運算子的 Schema:

OpSchema& Opschema::TypeAndShapeInferenceFunction(InferenceFunction inferenceFunction);

InferenceFunction 定義於 shape_inference.h 中,以及核心介面結構 InferenceContext 和各種協助程式方法。InferenceContext 是提供給推斷函式的核心結構。它允許存取運算子輸入的相關資訊,並且也允許寫出推斷的資訊。

若要查看許多範例,請在程式碼庫中搜尋 TypeAndShapeInferenceFunction 的出現次數。其中一個相對複雜的是 Concat 的實作,位於 onnx/defs/tensor/defs.cc 中。

實作運算子的形狀推斷方法時,請注意以下幾點,以避免常見錯誤

  • 在存取任何輸入的 shape 之前,程式碼必須檢查 shape 是否可用。如果不可用,應將其視為 rank 未知的動態張量,並適當地處理。通常,shape 推斷邏輯會以呼叫 hasInputShapehasNInputShapes 來保護。

  • 在存取任何維度的 dim_valuedim_param 之前,程式碼必須檢查這些欄位是否有值。特別是,程式碼必須處理維度可能沒有靜態已知值的可能性。

shape_inference.h 中有幾個實用函式可用於處理各種常見情況。

  • 對於必須具有固定 rank 的輸入,請使用 checkInputRank。(請參閱 RoiAlign 的推斷作為範例。)

  • 當多個輸入維度預期相同,以及當輸入維度傳播到特定輸出維度時,可以使用 unifyInputDimunifyDimupdateOutputShape。(請參閱 RoiAlign 的推斷作為範例。)

  • 當使用算術從輸入維度計算輸出維度時,可以在符號維度上使用重載的運算子 */。(請參閱 SpaceToDepth 的推斷作為範例。)

這些實用工具可以安全地處理遺失的 shape 和維度。

範例:考慮一個簡單的矩陣乘法運算,它預期輸入的 shape 為 [M,K][K,N],並返回 shape 為 [M,N] 的輸出。這可以用以下方式編碼

   // Check that input 0 has rank 2 (if its rank is known).
   checkInputRank(ctx, 0, 2);
   // Check that input 1 has rank 2 (if its rank is known).
   checkInputRank(ctx, 1, 2);
   Dim M, K, N;
   // Check various dimensions, handling missing dimensions/shapes safely.
   unifyInputDim(ctx, 0, 0, M);
   unifyInputDim(ctx, 0, 1, K);
   unifyInputDim(ctx, 1, 0, K);
   unifyInputDim(ctx, 1, 1, N);
   updateOutputShape(ctx, 0, {M. N});