深度学习|pytorch计算分类验证精度acc1,acc5代码

def accuracy(output, label, topk=(1,)): maxk = max(topk) batch_size = output.size(0) # 在输出结果中取前maxk个最大概率作为预测结果,并获取其下标,当topk=(1, 5)时取5就可以了。 _, pred = torch.topk(output, k=maxk, dim=1, largest=True, sorted=True)# 将得到的k个预测结果的矩阵进行转置,方便后续和label作比较 pred = pred.T # 将label先拓展成为和pred相同的形状,和pred进行对比,输出结果 correct = torch.eq(pred, label.contiguous().view(1,-1).expand_as(pred)) # 例: # 若label为:[1,2,3,4], topk = (1, 5)时 # 则label.contiguous().view(1,-1).expand_as(pred)为: # [[1, 2, 3, 4], #[1, 2, 3, 4], #[1, 2, 3, 4], #[1, 2, 3, 4], #[1, 2, 3, 4]] res = []for k in topk: # 取前k个预测正确的结果进行求和 correct_k = correct[:k].contiguous().view(-1).float().sum(dim=0, keepdim=True) # 计算平均精度, 将结果加入res中 res.append(correct_k*100/batch_size)return res

当topk=(1, 5)时同时返回acc1, 和acc5
【深度学习|pytorch计算分类验证精度acc1,acc5代码】在验证时的实例代码:
class AverageMeter(): def __init__(self): self.reset()def reset(self): self.val = 0 self.sum = 0 self.avg = 0 self.count = 0def update(self, val, n): self.sum += float(val)*n self.count += n self.avg = self.sum / self.countdef validation(val_dataloader, num_batch_val, criterion, model, device, total_epochs, logger, debug_step = 100):model.eval() acc1_val = AverageMeter() acc5_val = AverageMeter() loss_val = AverageMeter()start_time = time.time() with torch.no_grad(): for batch_id, data in enumerate(val_dataloader): image = data[0] label = data[1] image = Variable(image.to(device), requires_grad=False) label = Variable(label.to(device), requires_grad=False) image = image.flatten(1) # logger.info(f"the image size :{image.size()}") # logger.info(f"the label size :{label.size()}") output = model(image) loss = criterion(output, label) acc1, acc5 = accuracy(output=output, label=label, topk=(1, 5))loss_val.update(loss.data, image.size(0)) acc1_val.update(acc1[0], image.size(0)) acc5_val.update(acc5, image.size(0))if batch_id % debug_step == 0 or batch_id == num_batch_val: logger.info(f"Val Step:[{batch_id:03d}/{num_batch_val:03d}], "+ f"Avg Loss:{loss_val.avg:.4f}, "+ f"Avg Acc1:{acc1_val.avg:.4f}, "+ f"Avg Acc5:{acc5_val.avg:.4f}") end_time = time.time()return acc1_val.avg, acc5_val.avg, loss_val.avg, end_time-start_time

    推荐阅读