onnx-mlir

Logo

MLIR 編譯器基礎架構中 ONNX 模型的表示和參考降低

在 GitHub 上檢視專案 onnx/onnx-mlir

操作指南

使用 Python 進行推論
使用 C/C++ 進行推論
使用 Java 進行推論

參考資料

ONNX 方言
OMTensor C99 執行階段 API
OMTensorList C99 執行階段 API
OMTensor Java 執行階段 API
OMTensorList Java 執行階段 API
產生 ONNX 方言
關於文件

開發

新增運算
測試指南
錯誤處理
命令列選項
儀器化
常數傳播
新增加速器

工具

工具

RunONNXModel.py
DocCheck

這個專案由 onnx 維護

託管於 GitHub Pages — 主題由 orderedlist 提供

ONNX 運算的常數傳播

本文檔描述了 --constprop-onnx 傳遞,用於對 ONNX 方言中的運算進行常數傳播。

原始碼.

範例

給定以下程式碼

func @foo() -> tensor<1xf32> {
  %0 = "onnx.Constant"() {value = dense<[1.0]> : tensor<1xf32>} : () -> tensor<1xf32>
  %1 = "onnx.Constant"() {value = dense<[2.0]> : tensor<1xf32>} : () -> tensor<1xf32>
  %2 = "onnx.Add"(%0, %1) : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32>
  %3 = "onnx.Constant"() {value = dense<[3.0]> : tensor<1xf32>} : () -> tensor<1xf32>
  %4 = "onnx.Add"(%2, %3) : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32>
  "std.return"(%4) : (tensor<1xf32>) -> ()
}

如果我們呼叫 onnx-mlir-op --constprop-onnx,我們會得到

func @foo() -> tensor<1xf32> {
  %0 = "onnx.Constant"() {value = dense<[6.0]> : tensor<1xf32>} : () -> tensor<1xf32>
  "std.return"(%0) : (tensor<1xf32>) -> ()
}

備註

ONNXConstantOp 使用 MLIR DenseElementsAttr 來儲存常數值。 重要的是要注意,一旦建立 DenseElementsAttr,它就會一直存在並消耗記憶體直到編譯結束。 在範例中,三個 ONNXConstantOp 中的所有三個 DenseElementsAttr 都會存在直到編譯結束。 特別是,透過折疊兩個 ONNXAddOp 而產生的兩個 ONNXConstantOp 中的兩個中間 DenseElementsAttr 也會存在。 對於真實世界的模型,中間 DenseElementsAttr 的數量將會迅速增加,這會在編譯期間導致大量的記憶體佔用。

為了避免在 --constprop-onnx 期間為中間 ONNXConstantOp 建立太多 DenseElementsAttr,我們設計了一種機制,可以動態地為中間 ONNXConstantOp 分配和釋放緩衝區,並且僅在常數傳播和其他 ONNX 方言傳遞之後,在降低到 Krnl(或任何其他目標方言)之前才建立 DenseElementsAttr。

這是透過自訂屬性 DisposableElementsAttr 完成的,它在非複雜純量元素類型(布林值、整數和浮點類型)的常見情況下充當 DenseElementsAttr 的替代品。 DisposableElementsAttr 實現與 DenseElementsAttr 相同的 ElementsAttr 介面,並且在大多數情況下,它們在功能上是相同的,並且周圍的程式碼不需要區分。 它只需要使用 OnnxElementsAttrBuilder 類別和 ElementsAttrHelper 函式來建構和存取 ElementsAttr 實例,以獲得記憶體佔用和效能優勢。

DisposableElementsAttr 緩衝區的釋放發生在編譯器傳遞之間的 DisposableGarbageCollector 中,它由 PassManager 在「模組」傳遞之間(保證「停止世界」,沒有其他傳遞並行執行)作為「儀器化」執行。

DisposableElementsAttr 提供其他記憶體和速度優勢,這些優勢在類別原始碼檔案中的註解中概述,並在 2022 年 11 月的簡報中說明,該簡報連結在 會議 wiki 頁面

為常數傳播撰寫規則

我們使用 MLIR 宣告式重寫規則 (DRR) 來撰寫常數傳播的模式。 用於定義模式的 DRR 定義如下所示

class Pattern<
   dag sourcePattern,
   list<dag> resultPatterns,
   list<dag> additionalConstraints = [],
   list<dag> supplementalPatterns = [],
   dag benefitsAdded = (addBenefit 0)
>;

有關 DRR 的更多資訊,請參閱 這裡

現在,我們將逐步介紹一個簡單的範例,該範例為 ONNXAddOp 新增常數傳播。

步驟 1:撰寫 DRR 模式

我們首先將一個模式新增到 ConstProp.td

// Constant Propagation for Add
def AddConstProp : Pat<
    // source patten: From add(lhs, rhs).
    (ONNXAddOp:$addOp (ONNXConstantOp:$lhs $_, $_, $_, $_, $_, $_, $_, $_),
                      (ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)),
    // result pattern: To c = lhs + rhs
    (CreateAddOfTwoConst $addOp, $lhs, $rhs),
    // Additional constraints: if both lhs and rhs are dense constants.
    [(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs)]>;

上面的模式將會以在編譯時相加輸入的新常數取代其輸入為常數的 ONNXAddOp。 為了檢查輸入是否為常數,僅使用 ONNXConstantOp 是不夠的,因為常數張量可能是稀疏的,而我們現在只支援密集常數張量。 我們需要使用 IsFromDenseONNXConstantOp 來額外檢查密集常數張量。

在結果模式中,為了產生 ONNXConstantOp,我們將在編譯時相加 lhsrhs,並發出 ONNXConstantOp。 為了最大程度地減少記憶體佔用,此 ONNXConstantOp 具有 DisposableElementsAttr 而不是傳統的 DenseElementsAttr。

函式 CreateAddOfTwoConst 將在編譯時進行相加並傳回 ONNXConstantOp。

def CreateAddOfTwoConst :
   NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">;

步驟 2:為輸入和結果準備陣列緩衝區

模式中的函式 CreateAddOfTwoConst 呼叫 ConstProp.cpp 中的 ConstPropElementwiseBinary,其內容如下。

template <typename ElementwiseBinaryOp>
Value ConstPropElementwiseBinary(PatternRewriter &rewriter,
    Value replacingValue, Value lhsValue, Value rhsValue) {
  ConstPropCounters::count("ElementwiseBinary", {lhsValue, rhsValue});
  Type replacingType = mlir::cast<ShapedType>(replacingValue.getType());

  // Get lhs and rhs ElementsAttr from the values' defining constant ops.
  ElementsAttr lhs = getConstValueElements(lhsValue);
  ElementsAttr rhs = getConstValueElements(rhsValue);

  Type operandsElemType = lhs.getElementType();
  assert(operandsElemType == rhs.getElementType() &&
         "all element-wise binary ops have matching operands element types");
  OnnxElementsAttrBuilder elementsBuilder(rewriter.getContext());
  ElementsAttr resultElements = elementsBuilder.combine(lhs, rhs, replacingType,
      combinerOfElementwiseBinaryOp<ElementwiseBinaryOp>(operandsElemType));

  // Construct and return a new ONNXConstantOp with the resultElements attribute.
  return createReplacingConstantOp(rewriter, replacingValue, resultElements)
      .getResult();
}

其中 OnnxElementsAttrBuilder.combine(...) 會在需要時廣播 lhs 和 rhs 元素,並建構新的 (Disposable) ElementsAttr,其元素是二元函式 combinerOfElementwiseBinaryOp<ElementwiseBinaryOp>(operandsElemType) 的元素應用結果,該函式將 ElementwiseBinaryOp ONNX 運算對應到 c++ 運算子。

TODO:描述如何為新的運算新增 OnnxElementsAttrBuilder 建構器方法

有關常數傳播的更多資訊,請參閱 ConstProp.tdConstProp.cpp