GatherElements

GatherElements - 13

版本

  • 名稱: GatherElements (GitHub)

  • 網域: main

  • since_version: 13

  • function: False

  • 支援層級: SupportType.COMMON

  • 形狀推斷: True

此版本的運算子自版本 13 開始可用。

摘要

GatherElements 接受兩個輸入 dataindices,兩者具有相同的階數 r >= 1,以及一個可選的屬性 axis,用於識別 data 的軸 (預設為最外層的軸,即軸 0)。它是一個索引運算,透過在 indices 張量的元素所決定的索引位置,對輸入資料張量進行索引,進而產生輸出。其輸出形狀與 indices 的形狀相同,並且針對 indices 中的每個元素包含一個值 (從 data 中收集而來)。

例如,在 3 維的情況下 (r = 3),產生的輸出由以下方程式決定

out[i][j][k] = input[index[i][j][k]][j][k] if axis = 0,
out[i][j][k] = input[i][index[i][j][k]][k] if axis = 1,
out[i][j][k] = input[i][j][index[i][j][k]] if axis = 2,

此運算子也是 ScatterElements 的反運算。它類似於 Torch 的 gather 運算。

範例 1

data = [
    [1, 2],
    [3, 4],
]
indices = [
    [0, 0],
    [1, 0],
]
axis = 1
output = [
    [1, 1],
    [4, 3],
]

範例 2

data = [
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9],
]
indices = [
    [1, 2, 0],
    [2, 0, 0],
]
axis = 0
output = [
    [4, 8, 3],
    [7, 2, 3],
]

屬性

  • axis - INT (預設為 '0')

    要收集的軸。負值表示從後面計算維度。接受的範圍是 [-r, r-1],其中 r = rank(data)。

輸入

  • data (異質) - T

    階數 r >= 1 的張量。

  • indices (異質) - Tind

    int32/int64 索引的張量,與輸入具有相同的階數 r。所有索引值都應在沿著大小為 s 的軸的範圍 [-s, s-1] 內。如果任何索引值超出範圍,則會發生錯誤。

輸出

  • output (異質) - T

    與索引具有相同形狀的張量。

類型限制

  • T in ( tensor(bfloat16), tensor(bool), tensor(complex128), tensor(complex64), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8) )

    將輸入和輸出類型限制為任何張量類型。

  • Tind in ( tensor(int32), tensor(int64) )

    將索引限制為整數類型

GatherElements - 11

版本

  • 名稱: GatherElements (GitHub)

  • 網域: main

  • since_version: 11

  • function: False

  • 支援層級: SupportType.COMMON

  • 形狀推斷: True

此版本的運算子自版本 11 開始可用。

摘要

GatherElements 接受兩個輸入 dataindices,兩者具有相同的階數 r >= 1,以及一個可選的屬性 axis,用於識別 data 的軸 (預設為最外層的軸,即軸 0)。它是一個索引運算,透過在 indices 張量的元素所決定的索引位置,對輸入資料張量進行索引,進而產生輸出。其輸出形狀與 indices 的形狀相同,並且針對 indices 中的每個元素包含一個值 (從 data 中收集而來)。

例如,在 3 維的情況下 (r = 3),產生的輸出由以下方程式決定

  out[i][j][k] = input[index[i][j][k]][j][k] if axis = 0,
  out[i][j][k] = input[i][index[i][j][k]][k] if axis = 1,
  out[i][j][k] = input[i][j][index[i][j][k]] if axis = 2,

此運算子也是 ScatterElements 的反運算。它類似於 Torch 的 gather 運算。

範例 1

  data = [
      [1, 2],
      [3, 4],
  ]
  indices = [
      [0, 0],
      [1, 0],
  ]
  axis = 1
  output = [
      [
        [1, 1],
        [4, 3],
      ],
  ]

範例 2

  data = [
      [1, 2, 3],
      [4, 5, 6],
      [7, 8, 9],
  ]
  indices = [
      [1, 2, 0],
      [2, 0, 0],
  ]
  axis = 0
  output = [
      [
        [4, 8, 3],
        [7, 2, 3],
      ],
  ]

屬性

  • axis - INT (預設為 '0')

    要收集的軸。負值表示從後面計算維度。接受的範圍是 [-r, r-1],其中 r = rank(data)。

輸入

  • data (異質) - T

    階數 r >= 1 的張量。

  • indices (異質) - Tind

    int32/int64 索引的張量,與輸入具有相同的階數 r。所有索引值都應在沿著大小為 s 的軸的範圍 [-s, s-1] 內。如果任何索引值超出範圍,則會發生錯誤。

輸出

  • output (異質) - T

    與索引具有相同形狀的張量。

類型限制

  • T in ( tensor(bool), tensor(complex128), tensor(complex64), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8) )

    將輸入和輸出類型限制為任何張量類型。

  • Tind in ( tensor(int32), tensor(int64) )

    將索引限制為整數類型