Pytorch学习|深入理解PyTorch中的gather函数

gather函数
今天在用PyTorch复现softmax的时候,参考的书籍为《Dive into DL Pytorch》。在书里面,关于gather函数原文是这样叙述的:

上一节中,我们介绍了softmax回归使用的交叉熵损失函数。为了得到标签的预测概率,我们可以使用gather函数。在下面的例子中,变量y_hat是2个样本在3个类别的预测概率,变量y是这2个样本的标签类别。通过使用gather函数,我们得到了2个样本的标签的预测概率。在代码中,标签类别的离散值是从0开始逐一递增的。
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]]) y = torch.LongTensor([0, 2]) y_hat.gather(1, y.view(-1, 1))

输出为:
tensor([[0.1000], [0.5000]])

看到这里的时候,不太能理解两点:
  1. torch.LongTensor是什么样的数据格式;
  2. gather函数到底在中间有什么样的作用。
1. torch.LongTensor
torch.LongTensor(2, 3) # 构建一个2 * 3 Long类型的张量

所以torch.LongTensor指示的数据类型为张量,但里面的元素为Long类型
2. gather函数的作用 很显然在这段代码里面:gather函数的第一个参数’1‘,指定的是dim,即维度,也就是对哪个维度进行操作,此处为对the first dim(或者说第1轴)进行操作。第二个参数y.view(-1, 1), 其展开应该为:
torch.LongTensor([[0], [2]])

【Pytorch学习|深入理解PyTorch中的gather函数】这里 y.view(-1, 1) 里面的元素应该是作为index也即索引的意思。所以这句代码的理解就是对 y_hat 在第一轴上根据y.view(-1, 1)提供的索引进行取值,最后得到的输出就是:
tensor([[0.1000], [0.5000]])

    推荐阅读