gather函数
今天在用PyTorch复现softmax的时候,参考的书籍为《Dive into DL Pytorch》。在书里面,关于gather函数原文是这样叙述的:
上一节中,我们介绍了softmax回归使用的交叉熵损失函数。为了得到标签的预测概率,我们可以使用gather函数。在下面的例子中,变量y_hat是2个样本在3个类别的预测概率,变量y是这2个样本的标签类别。通过使用gather函数,我们得到了2个样本的标签的预测概率。在代码中,标签类别的离散值是从0开始逐一递增的。
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y = torch.LongTensor([0, 2])
y_hat.gather(1, y.view(-1, 1))
输出为:
tensor([[0.1000],
[0.5000]])
看到这里的时候,不太能理解两点:
- torch.LongTensor是什么样的数据格式;
- gather函数到底在中间有什么样的作用。
torch.LongTensor(2, 3) # 构建一个2 * 3 Long类型的张量
所以torch.LongTensor指示的数据类型为张量,但里面的元素为Long类型
2. gather函数的作用 很显然在这段代码里面:gather函数的第一个参数’1‘,指定的是dim,即维度,也就是对哪个维度进行操作,此处为对the first dim(或者说第1轴)进行操作。第二个参数y.view(-1, 1), 其展开应该为:
torch.LongTensor([[0],
[2]])
【Pytorch学习|深入理解PyTorch中的gather函数】这里 y.view(-1, 1) 里面的元素应该是作为index也即索引的意思。所以这句代码的理解就是对 y_hat 在第一轴上根据y.view(-1, 1)提供的索引进行取值,最后得到的输出就是:
tensor([[0.1000],
[0.5000]])
推荐阅读
- Pytorch学习笔记|【Pytorch学习笔记】4.细讲Pytorch的gather函数是什么——从Softmax回归中交叉熵损失函数定义的角度讲述
- 深度学习|pytorch学习三、softmax回归
- 用python做一个文本翻译器,自动将中文翻译成英文,超方便的
- 蓝桥杯|2021年第十二届蓝桥杯省赛第二场Python组(真题+解析+代码)(城邦)
- 蓝桥杯|2021年第十二届蓝桥杯省赛第二场Python组(真题+解析+代码)(格点)
- python 包之 httpx 请求操作教程
- Python中的七个小技巧
- 深度学习|掌握神经网络的法宝(一)
- python|python数据分析基础007 -利用pandas带你玩转excel表格(中上篇)