群組正規化¶
群組正規化 - 21¶
版本¶
網域:
main
自從版本:
21
函數:
True
支援層級:
SupportType.COMMON
形狀推斷:
False
此運算子的版本已可用自從版本 21。
摘要¶
群組正規化函數。如論文 https://arxiv.org/abs/1803.08494 中所述執行群組正規化
此運算子會根據下列公式轉換輸入
y = scale * (x - mean) / sqrt(variance + epsilon) + bias,
其中平均數和變異數是針對每個通道群組的每個執行個體計算,且應該為每個通道指定 scale
和 bias
。群組數 num_groups
應可被通道數整除,以便每個群組有相等數量的通道。
整體計算有兩個階段:第一階段會正規化元素,使每個群組中的每個執行個體具有零平均數和單位變異數,第二階段會縮放並移動第一階段的結果。第一階段中使用的浮點數精確度由 stash_type
屬性決定。例如,如果 stash_type
為 1,運算子會將所有輸入變數轉換為 32 位元浮點數,執行計算,並最終將正規化結果轉換回 X
的原始類型。第二階段不依賴 stash_type
。
當群組數與通道數相同時,此運算子等同於執行個體正規化。當只有一個群組時,此運算子等同於層正規化。
屬性¶
epsilon - FLOAT (預設值為
'1e-05'
)用於避免除以零的 epsilon 值。
num_groups - INT (必要)
通道的群組數。它應該是通道數
C
的因數。stash_type - INT (預設值為
'1'
)計算第一階段中使用的浮點數精確度。
輸入¶
X (異質) - T
輸入資料張量。影像案例的維度為
(N x C x H x W)
,其中N
是批次大小,C
是通道數,而H
和W
是資料的高度和寬度。統計資料是針對C
、H
和W
的每個通道群組計算。對於非影像案例,維度的形式為(N x C x D1 x D2 ... Dn)
。scale (異質) - T
形狀為
(C)
的縮放張量。bias (異質) - T
形狀為
(C)
的偏差張量。
輸出¶
Y (異質) - T
與
X
相同形狀的輸出張量。
類型限制¶
T 在 (
tensor(bfloat16)
,tensor(double)
,tensor(float)
,tensor(float16)
)將輸入和輸出類型限制為浮點數張量。
群組正規化 - 18¶
版本¶
網域:
main
自從版本:
18
函數:
True
支援層級:
SupportType.COMMON
形狀推斷:
False
此運算子的版本已自從版本 18 起被棄用。
摘要¶
群組正規化函數。如論文 https://arxiv.org/abs/1803.08494 中所述執行群組正規化
此運算子會根據下列公式轉換輸入
y = scale * (x - mean) / sqrt(variance + epsilon) + bias,
其中平均數和變異數是針對每個通道群組的每個執行個體計算,且應該為每個通道群組指定 scale
和 bias
。群組數 num_groups
應可被通道數整除,以便每個群組有相等數量的通道。
當群組數與通道數相同時,此運算子等同於執行個體正規化。當只有一個群組時,此運算子等同於層正規化。
屬性¶
epsilon - FLOAT (預設值為
'1e-05'
)用於避免除以零的 epsilon 值。
num_groups - INT (必要)
通道的群組數。它應該是通道數
C
的因數。
輸入¶
X (異質) - T
輸入資料張量。影像案例的維度為
(N x C x H x W)
,其中N
是批次大小,C
是通道數,而H
和W
是資料的高度和寬度。統計資料是針對C
、H
和W
的每個通道群組計算。對於非影像案例,維度的形式為(N x C x D1 x D2 ... Dn)
。scale (異質) - T
形狀為
(num_groups)
的縮放張量。bias (異質) - T
形狀為
(num_groups)
的偏差張量。
輸出¶
Y (異質) - T
與
X
相同形狀的輸出張量。
類型限制¶
T 在 (
tensor(bfloat16)
,tensor(double)
,tensor(float)
,tensor(float16)
)將輸入和輸出類型限制為浮點數張量。