pytorch中的gather函数_Pytorch中的torch.gather函数的理解

Pytorch中的torch.gather函数的理解
Pytorch中的torch.gather函数
pytorch比tensorflow更加编程友好,准备用pytorch试着做一些实验。
先看一下简单的用法示例代码,然后结合官方示例来解读:
b = torch.Tensor([[1,2,3],[4,5,6]])
print b
index_0 = torch.LongTensor([[1],[2]])
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print (torch.gather(b, dim=1, index=index_0))
print (torch.gather(b, dim=1, index=index_1))
print (torch.gather(b, dim=0, index=index_2))
输出结果:
1 2 3
4 5 6
[torch.FloatTensor of size 2x3]
tensor([[2.],
[6.]])
1 2
6 4
[torch.FloatTensor of size 2x2]
1 5 6
1 2 3
[torch.FloatTensor of size 2x3]
结合上面的例子来看官方解读及示例,官方解读是给了三个公式:
torch.gather(input, dim, index, out=None) → Tensor
'''
Gathers values along an axis specified by dim.
For a 3-D tensor the output is specified by:
out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
out[i][j][k] = input[i][j][index[i][j][k]] # dim=2
Parameters:
input (Tensor) – The source tensor
dim (int) – The axis along which to index
index (LongTensor) – The indices of elements to gather
out (Tensor, optional) – Destination tensor
Example:
'''
>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
1 1
4 3
[torch.FloatTensor of size 2x2]
首先,可以看出output的形状和index的一致,且位置一 一对应。
dim=0时,
out[i][j][k] = input[index[i][j][k]][j][k]
out的取值为input[index[i][j][k]] [j] [k],为input值,output行(dim=0)的取值是index张量的元素值,列(dim=1)和index张量里面的列对应,(dim=2)维度也和index的一致。
例1:
b = torch.Tensor([[1,2,3],[4,5,6]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print (torch.gather(b, dim=0, index=index_2))
'''
out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
输出结果:
1 5 6
1 2 3
[torch.FloatTensor of size 2x3]
'''
一二三四五六
dim=0, index_2为两行三列的张量,output也为两行三列的张量。index中的【0,1,1】在第一行,【0,0,0】在第二行。
取值时,out第一个值取input第一行(由由【0,1,1】的0指定)第一列(这个是与index对应,0在第一列)的元素1;
第二个数取input第二行(由【0,1,1】的第一个1指定)第二列(这个是与index对应,第一个1在第二列)的元素5;
第三个数取input第二行(由【0,1,1】的第二个1指定)第三列(这个是与index对应,第二个1在第三列)的元素6。
第四个数取input第一行(由【0,0,0】的第一个0指定)第一列(这个是与index对应,第一个0在第一列)的元素1。
后面的同理。
dim=1时,
out[i][j][k] = input[i] [index[i][j][k]] [k] # dim=1
out的取值为input[i] [index[i][j][k]] [k],为input值,行(dim=0)和index张量里面的行对应,列(dim=1)是index张量的元素值,(dim=2)维度也和index的(dim=2)维度 对应。
例2:
b = torch.Tensor([[1,2,3],[4,5,6]])
index_1 = torch.LongTensor([[1,2],[2,0]])
print (torch.gather(b, dim=1, index=index_1))
'''
out[i][j][k] = input[i] [index[i][j][k]] [k] # dim=1
输出结果:
2 3
6 4
[torch.FloatTensor of size 2x2]
'''
一二三四
dim=1, index_1为两行两列的张量,output也为两行两列的张量。index中的【1,2】在第一行,【2,0】在第二行。
第一个数取input第一行(这个是与index对应,同在第一行)第二列(由【1,2】中的1指定 )的元素2,
第二个数取input第一行(这个是与index对应,同在第一行)第三列(由【1,2】中的2指定)的元素3。
第三个数取input第二行(这个是与index对应,同在第二行)第三列(由【2,0】中的2指定)的元素6。
第四个数取input第二行(这个是与index对应,同在第二行)第一列(由【2,0】中的0指定)的元素4。
还可以看出index的形状和input的形状是一致的,都是二维的,里面的index数值不能超过input的界限,比如行的不能超过1,列的不能超过2。
【pytorch中的gather函数_Pytorch中的torch.gather函数的理解】理解了这几个式子也就记住了这个方法的用法。

    推荐阅读