維度符號

維度符號是一種實驗性的嘗試,旨在給予張量軸語義描述,從而對它們進行類型化,並後續根據它們執行驗證步驟。

動機

這種機制的動機可以透過一個簡單的範例來說明。在下面的線性神經網路規範中,我們假設一個NCHW模型輸入

input_in_NCHW -> Transpose(input, perm=[0, 2, 1, 3]) -> AveragePool(input, ...)

在這個神經網路中,使用者錯誤地構建了一個將NCHW輸入轉置為奇怪的NHCW格式的神經網路,並通過了假設NCHW輸入格式的空間池化。雖然這明顯是一個錯誤,但現有的基礎架構不會向使用者報告錯誤。對於那些嚴重依賴類型檢查作為程式正確性保證不可或缺一部分的程式設計師來說,這應該是令人不安的。本提案旨在解決當前神經網路規範範式中固有的這種缺乏適當類型檢查的真空。

本提案包含三個關鍵組成部分:符號定義、符號傳播和符號驗證,每個部分都將詳細討論。

符號定義

首先,我們為張量類型定義一組類型。這些類型是根據以下原則定義的

  1. 足夠細緻,可以消除潛在的陷阱。例如,動機部分說明的上述範例要求我們區分通道維度和空間特徵維度,以確保AveragePool運算的執行正確性。

  2. 足夠粗略,可以減輕使用者的精神負擔。例如,在上述範例中,區分寬度維度和高度維度的需求明顯較少,因為池化和卷積等運算通常不會區分各種空間維度。因此,我們將所有空間維度總結為特徵維度。

  3. 作為第2點的重要推論,與模型無關。例如,循環神經網路(RNN)中特徵維度的語義和卷積神經網路(CNN)中空間維度的語義幾乎無法區分,因此我們允許使用者和開發人員將其描述為特徵維度。

具體來說,在我們的第一個提案中,我們定義了以下一組標準符號

  1. DATA_BATCH描述了訓練資料的批次維度。這對應於更常用的張量格式符號 NCHW 中的 N 維度。

  2. DATA_CHANNEL描述了訓練資料的通道維度。這對應於 C 維度。

  3. DATA_TIME描述了時間維度。

  4. DATA_FEATURE描述了特徵維度。這對應於 HW 維度或 RNN 中的特徵維度。

  5. FILTER_IN_CHANNEL描述了濾波器輸入通道維度。這是大小與輸入影像特徵圖的通道維度相同的維度。

  6. FILTER_OUT_CHANNEL描述了濾波器輸出通道維度。這是大小與輸出影像特徵圖的通道維度相同的維度。

  7. FILTER_SPATIAL描述了濾波器空間維度。

符號傳播

當運算根據其輸入張量排列、銷毀或建立維度時,就會發生符號傳播。在這種情況下,我們將實作自訂的、特定於運算的函式,以根據輸入張量維度符號推斷輸出張量維度符號。發生符號傳播的一個範例運算是轉置運算,其中輸出維度符號推斷的虛擬碼可以公式化為輸入維度符號的函式

for i, j in enumerate(perm):
    out_dim_denotaion[i] = in_dim_denotation[j]

符號驗證

當運算預期其輸入以特定格式到達時,就會發生符號驗證。發生符號驗證的一個範例運算是AveragePool運算,其中輸入(如果使用維度符號註釋)在2D情況下應該具有符號 [DATA_BATCHDATA_CHANNELDATA_FEATUREDATA_FEATURE]。如果預期的維度符號和實際的維度符號之間存在不匹配,則應報告錯誤。

類型符號

請參閱 類型符號文件,以獲取有關如何描述影像和其他類型的更多詳細資訊。