pytorch分类模型绘制混淆矩阵以及可视化详解

目录

  • Step 1. 获取混淆矩阵
  • Step 2. 混淆矩阵可视化
  • 其它分类指标的获取
  • 总结

Step 1. 获取混淆矩阵
#首先定义一个 分类数*分类数 的空混淆矩阵 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds) # 使用torch.no_grad()可以显著降低测试用例的GPU占用with torch.no_grad():for step, (imgs, targets) in enumerate(test_loader):# imgs:torch.Size([50, 3, 200, 200])torch.FloatTensor# targets:torch.Size([50, 1]),torch.LongTensor多了一维,所以我们要把其去掉targets = targets.squeeze()# [50,1] ----->[50]# 将变量转为gputargets = targets.cuda()imgs = imgs.cuda()# print(step,imgs.shape,imgs.type(),targets.shape,targets.type())out = model(imgs)#记录混淆矩阵参数conf_matrix = confusion_matrix(out, targets, conf_matrix)conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:
def confusion_matrix(preds, labels, conf_matrix):preds = torch.argmax(preds, 1)for p, t in zip(preds, labels):conf_matrix[p, t] += 1return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:
conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到npcorrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数 print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num)) print(conf_matrix) # 获取每种Emotion的识别准确率 print("每种情感总个数:",per_kinds) print("每种情感预测正确的个数:",corrects) print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:
pytorch分类模型绘制混淆矩阵以及可视化详解
文章图片


Step 2. 混淆矩阵可视化 对上边求得的混淆矩阵可视化
# 绘制混淆矩阵Emotion=8#这个数值是具体的分类数,大家可以自行修改labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签# 显示数据plt.imshow(conf_matrix, cmap=plt.cm.Blues)# 在图中标注数量/概率信息thresh = conf_matrix.max() / 2 #数值颜色阈值,如果数值超过这个,就颜色加深。for x in range(Emotion_kinds):for y in range(Emotion_kinds):# 注意这里的matrix[y, x]不是matrix[x, y]info = int(conf_matrix[y, x])plt.text(x, y, info,verticalalignment='center',horizontalalignment='center',color="white" if info > thresh else "black")plt.tight_layout()#保证图不重叠plt.yticks(range(Emotion_kinds), labels)plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°plt.show()plt.close()

好了,以下就是最终的可视化的混淆矩阵啦:
pytorch分类模型绘制混淆矩阵以及可视化详解
文章图片


其它分类指标的获取 例如 F1分数、TP、TN、FP、FN、精确率、召回率 等指标, 待补充哈(因为暂时还没用到)~
pytorch分类模型绘制混淆矩阵以及可视化详解
文章图片


总结 【pytorch分类模型绘制混淆矩阵以及可视化详解】到此这篇关于pytorch分类模型绘制混淆矩阵以及可视化详的文章就介绍到这了,更多相关pytorch绘制混淆矩阵内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!

    推荐阅读