Pytorch|Pytorch中的torch.gather函数详解,从抽象到具体


目录

      • 从官方文档出发
      • 更具体的二维例子
      • 更宏观的理解
      • 参考

从官方文档出发
这个官方文档写得看起来非常不人道,但实际上它已经包含了我们理解所需要的信息,我们来仔细看一下。
首先来理解一下输入:
torch.gather(input, dim, index, *, sparse_grad=False, out=None)

此处的input是一个张量,假设是个三维张量吧。
dim是指定的维度。
index也是一个张量,它必须和input有一样的维度,所以这里也假设是三维。
?
假设我们设置dim=1
out[i][j][k] = input[i][index[i][j][k]][k]

乍一眼看上去很复杂,实际上很简单
可以看到out输出的索引,和index的索引是完全一致的
也就是说out的形状肯定和index一模一样,而且每一个位置都是互相一一对应的。设我们目前的位置是m,m相当于位置 [ i ] [ j ] [ k ] [i][j][k] [i][j][k]。那么也就是说
out{m} = input[i][index{m}][k]

【Pytorch|Pytorch中的torch.gather函数详解,从抽象到具体】这里的m用了一个大括号,因为它代表着整个 [ i ] [ j ] [ k ] [i][j][k] [i][j][k]这个索引位置,所以没有去用方括号来混淆。
而这每一个位置具体的数字,我们从input取出来。
怎么取出来?
答:dim是什么,我们就看input的哪个dim。
那我们来到那个dim之后,用什么数字去索引?
答:用index的{m}位置的数去作为一个索引值。
那input的其它dim我们用什么数字去索引?
答:其它dim使用的索引值和{m}看齐就行。
也就是说,除了我们指定的那个dim要用index{m}来进行索引,其它位置直接延用{m}的索引值就行。
也就是说,这里的input只是一个仓库,它只是用来让out按照index进行数字的取用的。
举一反三,dim=0时:
out[i][j][k] = input[index[i][j][k]][j][k] # 也就是说 out{m} = input[index{m}][j][k]

dim=2时:
out[i][j][k] = input[i][j][index[i][j][k]] #也就是说 out{m} = input[i][j][index{m}]

更具体的二维例子
假设我们的input是 [ [ 1 , 2 ] , [ 3 , 4 ] ] [[1, 2], [3, 4]] [[1,2],[3,4]]
torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))

这样操作的话,
对于 o u t [ 0 ] [ 0 ] out[0][0] out[0][0]这个位置,怎么去input里取数呢?
取 i n p u t [ 0 ] [ i n d e x { m } ] input[0][index\{m\}] input[0][index{m}], 也就是 i n p u t [ 0 ] [ 0 ] input[0][0] input[0][0],可以看到对应的是1这个数字
(此时 { m } = [ 0 ] [ 0 ] \{m\}=[0][0] {m}=[0][0])
?
对于 o u t p u t [ 0 ] [ 1 ] output[0][1] output[0][1]
取 i n p u t [ 0 ] [ i n d e x { m } ] input[0][index\{m\}] input[0][index{m}], 也就是 i n p u t [ 0 ] [ 0 ] input[0][0] input[0][0],可以看到对应的是1这个数字
(此时 { m } = [ 0 ] [ 1 ] \{m\}=[0][1] {m}=[0][1])
?
对于 o u t p u t [ 1 ] [ 0 ] output[1][0] output[1][0]
取 i n p u t [ 1 ] [ i n d e x { m } ] input[1][index\{m\}] input[1][index{m}], 也就是 i n p u t [ 1 ] [ 1 ] input[1][1] input[1][1],可以看到对应的是4这个数字
(此时 { m } = [ 1 ] [ 0 ] \{m\}=[1][0] {m}=[1][0])
?
对于 o u t p u t [ 1 ] [ 1 ] output[1][1] output[1][1]
取 i n p u t [ 1 ] [ i n d e x { m } ] input[1][index\{m\}] input[1][index{m}], 也就是 i n p u t [ 1 ] [ 0 ] input[1][0] input[1][0],可以看到对应的是3这个数字
(此时 { m } = [ 1 ] [ 1 ] \{m\}=[1][1] {m}=[1][1])
?
所以o u t = [ [ 1 , 1 ] , [ 4 , 3 ] ] out=[[1,1],[4,3]] out=[[1,1],[4,3]]
更宏观的理解
也就是说,对于out中的每一个{m}位置,我们都去input这个仓库里找到对应的那一个点!
找到那一个点之后,我们从这个点出发,发射一条只照亮"dim"这条线的光线,在这条光线上再根据index找到我们真正需要的那个点就行了!
参考
  1. Pytorch torch.gather官方文档

    推荐阅读