群組正規化

群組正規化 - 21

版本

  • 名稱群組正規化 (GitHub)

  • 網域main

  • 自從版本21

  • 函數True

  • 支援層級SupportType.COMMON

  • 形狀推斷False

此運算子的版本已可用自從版本 21

摘要

群組正規化函數。如論文 https://arxiv.org/abs/1803.08494 中所述執行群組正規化

此運算子會根據下列公式轉換輸入

y = scale * (x - mean) / sqrt(variance + epsilon) + bias,

其中平均數和變異數是針對每個通道群組的每個執行個體計算,且應該為每個通道指定 scalebias。群組數 num_groups 應可被通道數整除,以便每個群組有相等數量的通道。

整體計算有兩個階段:第一階段會正規化元素,使每個群組中的每個執行個體具有零平均數和單位變異數,第二階段會縮放並移動第一階段的結果。第一階段中使用的浮點數精確度由 stash_type 屬性決定。例如,如果 stash_type 為 1,運算子會將所有輸入變數轉換為 32 位元浮點數,執行計算,並最終將正規化結果轉換回 X 的原始類型。第二階段不依賴 stash_type

當群組數與通道數相同時,此運算子等同於執行個體正規化。當只有一個群組時,此運算子等同於層正規化。

屬性

  • epsilon - FLOAT (預設值為 '1e-05')

    用於避免除以零的 epsilon 值。

  • num_groups - INT (必要)

    通道的群組數。它應該是通道數 C 的因數。

  • stash_type - INT (預設值為 '1')

    計算第一階段中使用的浮點數精確度。

輸入

  • X (異質) - T

    輸入資料張量。影像案例的維度為 (N x C x H x W),其中 N 是批次大小,C 是通道數,而 HW 是資料的高度和寬度。統計資料是針對 CHW 的每個通道群組計算。對於非影像案例,維度的形式為 (N x C x D1 x D2 ... Dn)

  • scale (異質) - T

    形狀為 (C) 的縮放張量。

  • bias (異質) - T

    形狀為 (C) 的偏差張量。

輸出

  • Y (異質) - T

    X 相同形狀的輸出張量。

類型限制

  • T 在 ( tensor(bfloat16), tensor(double), tensor(float), tensor(float16) )

    將輸入和輸出類型限制為浮點數張量。

群組正規化 - 18

版本

  • 名稱群組正規化 (GitHub)

  • 網域main

  • 自從版本18

  • 函數True

  • 支援層級SupportType.COMMON

  • 形狀推斷False

此運算子的版本已自從版本 18 起被棄用。

摘要

群組正規化函數。如論文 https://arxiv.org/abs/1803.08494 中所述執行群組正規化

此運算子會根據下列公式轉換輸入

y = scale * (x - mean) / sqrt(variance + epsilon) + bias,

其中平均數和變異數是針對每個通道群組的每個執行個體計算,且應該為每個通道群組指定 scalebias。群組數 num_groups 應可被通道數整除,以便每個群組有相等數量的通道。

當群組數與通道數相同時,此運算子等同於執行個體正規化。當只有一個群組時,此運算子等同於層正規化。

屬性

  • epsilon - FLOAT (預設值為 '1e-05')

    用於避免除以零的 epsilon 值。

  • num_groups - INT (必要)

    通道的群組數。它應該是通道數 C 的因數。

輸入

  • X (異質) - T

    輸入資料張量。影像案例的維度為 (N x C x H x W),其中 N 是批次大小,C 是通道數,而 HW 是資料的高度和寬度。統計資料是針對 CHW 的每個通道群組計算。對於非影像案例,維度的形式為 (N x C x D1 x D2 ... Dn)

  • scale (異質) - T

    形狀為 (num_groups) 的縮放張量。

  • bias (異質) - T

    形狀為 (num_groups) 的偏差張量。

輸出

  • Y (異質) - T

    X 相同形狀的輸出張量。

類型限制

  • T 在 ( tensor(bfloat16), tensor(double), tensor(float), tensor(float16) )

    將輸入和輸出類型限制為浮點數張量。