pytorch中torch.topk()函数的快速理解
目录
- 函数作用:
- 举个栗子:
- 实例演示
- 总结
函数作用:
文章图片
文章图片
该函数的作用即按字面意思理解,topk:取数组的前k个元素进行排序。
通常该函数返回2个值,第一个值为排序的数组,第二个值为该数组中获取到的元素在原数组中的位置标号。
举个栗子:
import numpy as npimport torchimport torch.utils.data.dataset as Datasetfrom torch.utils.data import Dataset,DataLoader####################准备一个数组#########################tensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10],[3,4,5,1,1,1,1,1,1,1,1],[7,8,9,1,1,1,1,1,1,1,1],[1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32)####################打印这个原数组#########################print('tensor1:')print(tensor1)#################使用torch.topk()这个函数##################print('使用torch.topk()这个函数得到:')'''k=3代表从原数组中取得3个元素,dim=1表示从原数组中的第一维获取元素(在本例中是分别从[10,1,2,1,1,1,1,1,1,1,10]、[3,4,5,1,1,1,1,1,1,1,1]、[7,8,9,1,1,1,1,1,1,1,1]、[1,4,7,1,1,1,1,1,1,1,1]这四个数组中获取3个元素)其中largest=True表示从大到小取元素'''print(torch.topk(tensor1, k=3, dim=1, largest=True))#################打印这个函数第一个返回值####################print('函数第一个返回值topk[0]如下')print(torch.topk(tensor1, k=3, dim=1, largest=True)[0])#################打印这个函数第二个返回值####################print('函数第二个返回值topk[1]如下')print(torch.topk(tensor1, k=3, dim=1, largest=True)[1])'''#######################运行结果##########################tensor1:tensor([[10.,1.,2.,1.,1.,1.,1.,1.,1.,1., 10.],[ 3.,4.,5.,1.,1.,1.,1.,1.,1.,1.,1.],[ 7.,8.,9.,1.,1.,1.,1.,1.,1.,1.,1.],[ 1.,4.,7.,1.,1.,1.,1.,1.,1.,1.,1.]])使用torch.topk()这个函数得到:'得到的values是原数组dim=1的四组从大到小的三个元素值;得到的indices是获取到的元素值在原数组dim=1中的位置。'torch.return_types.topk(values=tensor([[10., 10.,2.],[ 5.,4.,3.],[ 9.,8.,7.],[ 7.,4.,1.]]),indices=tensor([[ 0, 10,2],[ 2,1,0],[ 2,1,0],[ 2,1,0]]))函数第一个返回值topk[0]如下tensor([[10., 10.,2.],[ 5.,4.,3.],[ 9.,8.,7.],[ 7.,4.,1.]])函数第二个返回值topk[1]如下tensor([[ 0, 10,2],[ 2,1,0],[ 2,1,0],[ 2,1,0]])'''
该函数功能经常用来获取张量或者数组中最大或者最小的元素以及索引位置,是一个经常用到的基本函数。
实例演示 任务一:
取top1(最大值):
pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145,0.4053],[ 0.7265,1.4164,1.3443,1.2035,1.8823],[-0.4451,0.1673,1.2590, -2.0757,1.7255],[ 0.2021,0.3041,0.1383,0.3849, -1.6311]])print(pred)values, indices = pred.topk(1, dim=0, largest=True, sorted=True)print(indices)print(values)# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。_, indices_max = pred.max(dim=0, keepdim=True)print(indices_max)print(indices_max == indices)输出:tensor([[-0.5816, -0.3873, -1.0215, -1.0145,0.4053],[ 0.7265,1.4164,1.3443,1.2035,1.8823],[-0.4451,0.1673,1.2590, -2.0757,1.7255],[ 0.2021,0.3041,0.1383,0.3849, -1.6311]])tensor([[1, 1, 1, 1, 1]])tensor([[0.7265, 1.4164, 1.3443, 1.2035, 1.8823]])tensor([[1, 1, 1, 1, 1]])tensor([[True, True, True, True, True]])
任务二:
按行取出topk,将小于topk的置为inf:
pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145,0.4053],[ 0.7265,1.4164,1.3443,1.2035,1.8823],[-0.4451,0.1673,1.2590, -2.0757,1.7255],[ 0.2021,0.3041,0.1383,0.3849, -1.6311]])print(pred)top_k = 2# 按行求出每一行的最大的前两个值filter_value=https://www.it610.com/article/-float('Inf')indices_to_remove = pred < torch.topk(pred, top_k)[0][..., -1, None]print(indices_to_remove)pred[indices_to_remove] = filter_value# 对于topk之外的其他元素的logits值设为负无穷print(pred) 输出:tensor([[-0.5816, -0.3873, -1.0215, -1.0145,0.4053],[ 0.7265,1.4164,1.3443,1.2035,1.8823],[-0.4451,0.1673,1.2590, -2.0757,1.7255],[ 0.2021,0.3041,0.1383,0.3849, -1.6311]])tensor([[4],[4],[4],[3]])tensor([[0.4053],[1.8823],[1.7255],[0.3849]])tensor([[ True, False,True,True, False],[ True, False,True,True, False],[ True,True, False,True, False],[ True, False,True, False,True]])tensor([[-inf, -0.3873,-inf,-inf,0.4053],[-inf,1.4164,-inf,-inf,1.8823],[-inf,-inf,1.2590,-inf,1.7255],[-inf,0.3041,-inf,0.3849,-inf]])
任务三:
import numpy as npimport torchimport torch.utils.data.dataset as Datasetfrom torch.utils.data import Dataset,DataLoadertensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10],[3,4,5,1,1,1,1,1,1,1,1],[7,8,9,1,1,1,1,1,1,1,1],[1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32)# tensor2=torch.tensor([[3,2,1],#[6,5,4],#[1,4,7],#[9,8,7]],dtype=torch.float32)#print('tensor1:')print(tensor1)print('直接输出topk,会得到两个东西,我们需要的是第二个indices')print(torch.topk(tensor1, k=3, dim=1, largest=True))print('topk[0]如下')print(torch.topk(tensor1, k=3, dim=1, largest=True)[0])print('topk[1]如下')print(torch.topk(tensor1, k=3, dim=1, largest=True)[1])'''tensor1:tensor([[10.,1.,2.,1.,1.,1.,1.,1.,1.,1., 10.],[ 3.,4.,5.,1.,1.,1.,1.,1.,1.,1.,1.],[ 7.,8.,9.,1.,1.,1.,1.,1.,1.,1.,1.],[ 1.,4.,7.,1.,1.,1.,1.,1.,1.,1.,1.]])直接输出topk,会得到两个东西,我们需要的是第二个indicestorch.return_types.topk(values=tensor([[10., 10.,2.],[ 5.,4.,3.],[ 9.,8.,7.],[ 7.,4.,1.]]),indices=tensor([[ 0, 10,2],[ 2,1,0],[ 2,1,0],[ 2,1,0]]))topk[0]如下tensor([[10., 10.,2.],[ 5.,4.,3.],[ 9.,8.,7.],[ 7.,4.,1.]])topk[1]如下tensor([[ 0, 10,2],[ 2,1,0],[ 2,1,0],[ 2,1,0]])'''
总结 【pytorch中torch.topk()函数的快速理解】到此这篇关于pytorch中torch.topk()函数快速理解的文章就介绍到这了,更多相关pytorch torch.topk()函数理解内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
推荐阅读
- sqlserver和java将resultSet中的记录转换为学生对象
- React实现复制代码块到codeMirror代码编辑器中
- ASP.NET|ASP.NET MVC从控制器中获取URL值
- java|产品硬件成本分析_硬件项目中的错误成本
- 逆境之中见真章,JASMINER在众多出海企业中脱颖而出
- 人工智能|李飞飞最新访谈(我每天都在对AI的担忧中醒来)
- 李飞飞之后谷歌再失华裔高管,谷歌AI中国中心总裁李佳离职创业
- 观点|观点 | 李飞飞最新访谈(我每天都在对AI的担忧中醒来)
- JavaScript中常见的高阶函数总结
- 解决react中label标签for报错问题