【庖丁解牛】从零实现RetinaNet(六)(RetinaNet的训练与测试)


文章目录

  • RetinaNet的训练
  • 在COCO数据集上测试RetinaNet
  • 在VOC数据集上测试RetinaNet
  • 完整训练与测试代码
  • 模型复现情况评估

所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。
在从零实现RetinaNet(一)到(五)中我已经完整复现了RetinaNet。这个复现的思路主要是把目标检测器分成三个独立的部分:前向网络、loss计算、decode解码。仔细查看可以发现在loss部分和decode部分实际上存在一定的重复代码,但我没有将重复的代码摘出来进行代码复用,这主要是为了实现三个部分的高内聚和低耦合,这样我们可以应用目前最新的目标检测方法对三个独立部分进行改变,然后像搭积木一样搭建改进后的目标检测器。下面我们就可以开始训练和测试RetinaNet了。
RetinaNet的训练 在RetinaNet论文(https://arxiv.org/pdf/1708.02002.pdf)中,标准的训练方法是这样的:使用momentum=0.9,weight_decay=0.0001的SGD优化器,batch_size=16且使用跨卡同步BN。一共迭代90000次,初始学习率为0.01,在60000和80000次分别将学习率除以10。
我的训练过程与上面稍有不同,但区别不大。通过16乘以90000再除以118287(COCO2017_train中图片的数量)可以计算得到大约是12.17个epoch。因此我们的训练就训练到12个epoch为止。为了简单起见,我使用Adam优化器,这样可以自动衰减学习率。根据以往经验,Adam优化器在初期收敛速度要比sgd要快,但最终训练的结果要略差于sgd(收敛的局部最优点没有sgd的好),但差距很小,对于RetinaNet,一般mAP差距最多不会超过0.5个百分点。
在Detectron和Detectron2框架中,上面所述RetinaNet论文的标准训练方法又被叫做1x_training。类似地,将迭代次数和学习率衰减的迭代次数index乘以2和乘以3,就叫做2x_training和3x_training。
在COCO数据集上测试RetinaNet 在COCO上测试RetinaNet的性能我们可以直接使用pycocotools.cocoeval中的COCOeval类提供的API。我们只需要将RetinaNet类的前向计算结果(包含anchor)送入RetinaDecoder类进行解码,然后将解码后的bbox按照scale放大成在原始图像上的尺寸即可(因为解码后的bbox尺寸大小是相对于resize后的图像大小的)。然后我们过滤掉每个图像探测到的目标中的无效目标(class_index为-1的),按照一定的格式写入一个josn文件中,再调用COCOeval进行计算就可以了。
COCOeval类提供12个性能指标:
self.maxDets = [1, 10, 100] # 即decoder中提到的max_detection_num stats[0] = _summarize(1) stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2]) stats[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2]) stats[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2]) stats[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2]) stats[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2]) stats[6] = _summarize(0, maxDets=self.params.maxDets[0]) stats[7] = _summarize(0, maxDets=self.params.maxDets[1]) stats[8] = _summarize(0, maxDets=self.params.maxDets[2]) stats[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2]) stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2]) stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2])

各结果的含义如下:
# 无特殊说明的情况下,一般目标检测论文中所说的模型在COCO上的性能指的就是stats[0],至于是coco2017_val集还是coco2017_test集要看论文中描述,不过两者一般只相差0.2~0.5个百分点 stats[0] : IoU=0.5:0.95,area=all,maxDets=100,mAP stats[1] : IoU=0.5,area=all,maxDets=100,mAP stats[2] : IoU=0.75,area=all,maxDets=100,mAP stats[3] : IoU=0.5:0.95,area=small,maxDets=100,mAP stats[4] : IoU=0.5:0.95,area=medium,maxDets=100,mAP stats[5] : IoU=0.5:0.95,area=large,maxDets=100,mAP stats[6] : IoU=0.5:0.95,area=all,maxDets=1,mAR stats[7] : IoU=0.5:0.95,area=all,maxDets=10,mAR stats[8] : IoU=0.5:0.95,area=all,maxDets=100,mAR stats[9] : IoU=0.5:0.95,area=small,maxDets=100,mAR stats[10]:IoU=0.5:0.95,area=medium,maxDets=100,mAR stats[11]:IoU=0.5:0.95,area=large,maxDets=100,mAR

在COCO数据集上测试的代码实现如下:
def validate(val_dataset, model, decoder): model = model.module # switch to evaluate mode model.eval() with torch.no_grad(): all_eval_result = evaluate_coco(val_dataset, model, decoder)return all_eval_resultdef evaluate_coco(val_dataset, model, decoder): results, image_ids = [], [] for index in range(len(val_dataset)): data = https://www.it610.com/article/val_dataset[index] scale = data['scale'] cls_heads, reg_heads, batch_anchors = model(data['img'].cuda().permute( 2, 0, 1).float().unsqueeze(dim=0)) scores, classes, boxes = decoder(cls_heads, reg_heads, batch_anchors) scores, classes, boxes = scores.cpu(), classes.cpu(), boxes.cpu() boxes /= scale# make sure decode batch_size=1 # scores shape:[1,max_detection_num] # classes shape:[1,max_detection_num] # bboxes shape[1,max_detection_num,4] assert scores.shape[0] == 1scores = scores.squeeze(0) classes = classes.squeeze(0) boxes = boxes.squeeze(0)# for coco_eval,we need [x_min,y_min,w,h] format pred boxes boxes[:, 2:] -= boxes[:, :2]for object_score, object_class, object_box in zip( scores, classes, boxes): object_score = float(object_score) object_class = int(object_class) object_box = object_box.tolist() if object_class == -1: breakimage_result = { 'image_id': val_dataset.image_ids[index], 'category_id': val_dataset.find_category_id_from_coco_label(object_class), 'score': object_score, 'bbox': object_box, } results.append(image_result)image_ids.append(val_dataset.image_ids[index])print('{}/{}'.format(index, len(val_dataset)), end='\r')if not len(results): print("No target detected in test set images") returnjson.dump(results, open('{}_bbox_results.json'.format(val_dataset.set_name), 'w'), indent=4)# load results in COCO evaluation tool coco_true = val_dataset.coco coco_pred = coco_true.loadRes('{}_bbox_results.json'.format( val_dataset.set_name))coco_eval = COCOeval(coco_true, coco_pred, 'bbox') coco_eval.params.imgIds = image_ids coco_eval.evaluate() coco_eval.accumulate() coco_eval.summarize() all_eval_result = coco_eval.statsreturn all_eval_result

在COCO数据集上训练和测试时,我们遵循RetinaNet论文中的数据集设置,使用coco_2017_train数据集训练模型,使用coco_2017_val数据集测试模型。使用IoU=0.5:0.95下,最多保留100个detect目标,保留所有大小的目标下的mAP(即pycocotools.cocoeval的COCOeval类中_summarizeDets函数中的stats[0]值)作为模型的性能表现。
在VOC数据集上测试RetinaNet 在VOC数据集上训练和测试时,我们参照detectron2中使用faster rcnn在VOC数据集上训练测试的做法(https://github.com/facebookresearch/detectron2/blob/master/MODEL_ZOO.md),使用VOC2007trainval+VOC2012trainval数据集训练模型,使用VOC2007test数据集测试模型。测试时使用VOC2007的11 point metric方式计算mAP。
测试代码使用经典的VOC测试代码,只是把输入和输出做了一下适配:
def compute_voc_ap(recall, precision, use_07_metric=True): if use_07_metric: # use voc 2007 11 point metric ap = 0. for t in np.arange(0., 1.1, 0.1): if np.sum(recall >= t) == 0: p = 0 else: # get max precisionfor recall >= t p = np.max(precision[recall >= t]) # average 11 recall point precision ap = ap + p / 11. else: # use voc>=2010 metric,average all different recall precision as ap # recall add first value 0. and last value 1. mrecall = np.concatenate(([0.], recall, [1.])) # precision add first value 0. and last value 0. mprecision = np.concatenate(([0.], precision, [0.]))# compute the precision envelope for i in range(mprecision.size - 1, 0, -1): mprecision[i - 1] = np.maximum(mprecision[i - 1], mprecision[i])# to calculate area under PR curve, look for points where X axis (recall) changes value i = np.where(mrecall[1:] != mrecall[:-1])[0]# sum (\Delta recall) * prec ap = np.sum((mrecall[i + 1] - mrecall[i]) * mprecision[i + 1])return apdef compute_ious(a, b): """ :param a: [N,(x1,y1,x2,y2)] :param b: [M,(x1,y1,x2,y2)] :return:IoU [N,M] """a = np.expand_dims(a, axis=1)# [N,1,4] b = np.expand_dims(b, axis=0)# [1,M,4]overlap = np.maximum(0.0, np.minimum(a[..., 2:], b[..., 2:]) - np.maximum(a[..., :2], b[..., :2]))# [N,M,(w,h)]overlap = np.prod(overlap, axis=-1)# [N,M]area_a = np.prod(a[..., 2:] - a[..., :2], axis=-1) area_b = np.prod(b[..., 2:] - b[..., :2], axis=-1)iou = overlap / (area_a + area_b - overlap)return ioudef validate(val_dataset, model, decoder): model = model.module # switch to evaluate mode model.eval() with torch.no_grad(): all_ap, mAP = evaluate_voc(val_dataset, model, decoder, num_classes=20, iou_thread=0.5)return all_ap, mAPdef evaluate_voc(val_dataset, model, decoder, num_classes=20, iou_thread=0.5): preds, gts = [], [] for index in tqdm(range(len(val_dataset))): data = https://www.it610.com/article/val_dataset[index] img, gt_annot, scale = data['img'], data['annot'], data['scale']gt_bboxes, gt_classes = gt_annot[:, 0:4], gt_annot[:, 4] gt_bboxes /= scalegts.append([gt_bboxes, gt_classes])cls_heads, reg_heads, batch_anchors = model(img.cuda().permute( 2, 0, 1).float().unsqueeze(dim=0)) preds_scores, preds_classes, preds_boxes = decoder( cls_heads, reg_heads, batch_anchors) preds_scores, preds_classes, preds_boxes = preds_scores.cpu( ), preds_classes.cpu(), preds_boxes.cpu() preds_boxes /= scale# make sure decode batch_size=1 # preds_scores shape:[1,max_detection_num] # preds_classes shape:[1,max_detection_num] # preds_bboxes shape[1,max_detection_num,4] assert preds_scores.shape[0] == 1preds_scores = preds_scores.squeeze(0) preds_classes = preds_classes.squeeze(0) preds_boxes = preds_boxes.squeeze(0)preds_scores = preds_scores[preds_classes > -1] preds_boxes = preds_boxes[preds_classes > -1] preds_classes = preds_classes[preds_classes > -1]preds.append([preds_boxes, preds_classes, preds_scores])print("all val sample decode done.")all_ap = {} for class_index in tqdm(range(num_classes)): per_class_gt_boxes = [ image[0][image[1] == class_index] for image in gts ] per_class_pred_boxes = [ image[0][image[1] == class_index] for image in preds ] per_class_pred_scores = [ image[2][image[1] == class_index] for image in preds ]fp = np.zeros((0, )) tp = np.zeros((0, )) scores = np.zeros((0, )) total_gts = 0# loop for each sample for per_image_gt_boxes, per_image_pred_boxes, per_image_pred_scores in zip( per_class_gt_boxes, per_class_pred_boxes, per_class_pred_scores): total_gts = total_gts + len(per_image_gt_boxes) # one gt can only be assigned to one predicted bbox assigned_gt = [] # loop for each predicted bbox for index in range(len(per_image_pred_boxes)): scores = np.append(scores, per_image_pred_scores[index]) if per_image_gt_boxes.shape[0] == 0: # if no gts found for the predicted bbox, assign the bbox to fp fp = np.append(fp, 1) tp = np.append(tp, 0) continue pred_box = np.expand_dims(per_image_pred_boxes[index], axis=0) iou = compute_ious(per_image_gt_boxes, pred_box) gt_for_box = np.argmax(iou, axis=0) max_overlap = iou[gt_for_box, 0] if max_overlap >= iou_thread and gt_for_box not in assigned_gt: fp = np.append(fp, 0) tp = np.append(tp, 1) assigned_gt.append(gt_for_box) else: fp = np.append(fp, 1) tp = np.append(tp, 0) # sort by score indices = np.argsort(-scores) fp = fp[indices] tp = tp[indices] # compute cumulative false positives and true positives fp = np.cumsum(fp) tp = np.cumsum(tp) # compute recall and precision recall = tp / total_gts precision = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) ap = compute_voc_ap(recall, precision) all_ap[class_index] = apmAP = 0. for _, class_mAP in all_ap.items(): mAP += float(class_mAP) mAP /= num_classesreturn all_ap, mAP

请注意compute_voc_ap函数中use_07_metric=True表示使用VOC2007的11 point metric方式计算mAP,use_07_metric=False表示使用VOC2010之后新的mAP计算方式。
完整训练与测试代码 我们在训练中一共训练12个epoch,每5个epoch测试一次模型性能表现,训练完成时也测试一次模型性能表现。
完整训练与测试代码实现如下(这里是在COCO数据集上的训练与测试代码,要在VOC数据集上训练与测试只要稍作修改即可)。
config.py文件:
import os import sysBASE_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(BASE_DIR)from public.path import COCO2017_path from public.detection.dataset.cocodataset import CocoDetection, Resize, RandomFlip, RandomCrop, RandomTranslateimport torchvision.transforms as transforms import torchvision.datasets as datasetsclass Config(object): log = './log'# Path to save log checkpoint_path = './checkpoints'# Path to store checkpoint model resume = './checkpoints/latest.pth'# load checkpoint model evaluate = None# evaluate model path train_dataset_path = os.path.join(COCO2017_path, 'images/train2017') val_dataset_path = os.path.join(COCO2017_path, 'images/val2017') dataset_annotations_path = os.path.join(COCO2017_path, 'annotations')network = "resnet50_retinanet" pretrained = False num_classes = 80 seed = 0 input_image_size = 600train_dataset = CocoDetection(image_root_dir=train_dataset_path, annotation_root_dir=dataset_annotations_path, set="train2017", transform=transforms.Compose([ RandomFlip(flip_prob=0.5), RandomCrop(crop_prob=0.5), RandomTranslate(translate_prob=0.5), Resize(resize=input_image_size), ])) val_dataset = CocoDetection(image_root_dir=val_dataset_path, annotation_root_dir=dataset_annotations_path, set="val2017", transform=transforms.Compose([ Resize(resize=input_image_size), ]))epochs = 12 batch_size = 15 lr = 1e-4 num_workers = 4 print_interval = 100 apex = True

train.py文件:
import sys import os import argparse import random import shutil import time import warnings import jsonBASE_DIR = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(BASE_DIR) warnings.filterwarnings('ignore')import numpy as np from thop import profile from thop import clever_format from apex import amp import torch import torch.nn as nn import torch.backends.cudnn as cudnn from torch.utils.data import DataLoader from torchvision import transforms from config import Config from public.detection.dataset.cocodataset import COCODataPrefetcher, collater from public.detection.models.loss import RetinaLoss from public.detection.models.decode import RetinaDecoder from public.detection.models.retinanet import resnet50_retinanet from public.imagenet.utils import get_logger from pycocotools.cocoeval import COCOevaldef parse_args(): parser = argparse.ArgumentParser( description='PyTorch COCO Detection Training') parser.add_argument('--network', type=str, default=Config.network, help='name of network') parser.add_argument('--lr', type=float, default=Config.lr, help='learning rate') parser.add_argument('--epochs', type=int, default=Config.epochs, help='num of training epochs') parser.add_argument('--batch_size', type=int, default=Config.batch_size, help='batch size') parser.add_argument('--pretrained', type=bool, default=Config.pretrained, help='load pretrained model params or not') parser.add_argument('--num_classes', type=int, default=Config.num_classes, help='model classification num') parser.add_argument('--input_image_size', type=int, default=Config.input_image_size, help='input image size') parser.add_argument('--num_workers', type=int, default=Config.num_workers, help='number of worker to load data') parser.add_argument('--resume', type=str, default=Config.resume, help='put the path to resuming file if needed') parser.add_argument('--checkpoints', type=str, default=Config.checkpoint_path, help='path for saving trained models') parser.add_argument('--log', type=str, default=Config.log, help='path to save log') parser.add_argument('--evaluate', type=str, default=Config.evaluate, help='path for evaluate model') parser.add_argument('--seed', type=int, default=Config.seed, help='seed') parser.add_argument('--print_interval', type=bool, default=Config.print_interval, help='print interval') parser.add_argument('--apex', type=bool, default=Config.apex, help='use apex or not')return parser.parse_args()def validate(val_dataset, model, decoder): model = model.module # switch to evaluate mode model.eval() with torch.no_grad(): all_eval_result = evaluate_coco(val_dataset, model, decoder)return all_eval_resultdef evaluate_coco(val_dataset, model, decoder): results, image_ids = [], [] for index in range(len(val_dataset)): data = https://www.it610.com/article/val_dataset[index] scale = data['scale'] cls_heads, reg_heads, batch_anchors = model(data['img'].cuda().permute( 2, 0, 1).float().unsqueeze(dim=0)) scores, classes, boxes = decoder(cls_heads, reg_heads, batch_anchors) scores, classes, boxes = scores.cpu(), classes.cpu(), boxes.cpu() boxes /= scale# make sure decode batch_size=1 # scores shape:[1,max_detection_num] # classes shape:[1,max_detection_num] # bboxes shape[1,max_detection_num,4] assert scores.shape[0] == 1scores = scores.squeeze(0) classes = classes.squeeze(0) boxes = boxes.squeeze(0)# for coco_eval,we need [x_min,y_min,w,h] format pred boxes boxes[:, 2:] -= boxes[:, :2]for object_score, object_class, object_box in zip( scores, classes, boxes): object_score = float(object_score) object_class = int(object_class) object_box = object_box.tolist() if object_class == -1: breakimage_result = { 'image_id': val_dataset.image_ids[index], 'category_id': val_dataset.find_category_id_from_coco_label(object_class), 'score': object_score, 'bbox': object_box, } results.append(image_result)image_ids.append(val_dataset.image_ids[index])print('{}/{}'.format(index, len(val_dataset)), end='\r')if not len(results): print("No target detected in test set images") returnjson.dump(results, open('{}_bbox_results.json'.format(val_dataset.set_name), 'w'), indent=4)# load results in COCO evaluation tool coco_true = val_dataset.coco coco_pred = coco_true.loadRes('{}_bbox_results.json'.format( val_dataset.set_name))coco_eval = COCOeval(coco_true, coco_pred, 'bbox') coco_eval.params.imgIds = image_ids coco_eval.evaluate() coco_eval.accumulate() coco_eval.summarize() all_eval_result = coco_eval.statsreturn all_eval_resultdef train(train_loader, model, criterion, optimizer, scheduler, epoch, logger, args): cls_losses, reg_losses, losses = [], [], []# switch to train mode model.train()iters = len(train_loader.dataset) // args.batch_size prefetcher = COCODataPrefetcher(train_loader) images, annotations = prefetcher.next() iter_index = 1while images is not None: images, annotations = images.cuda().float(), annotations.cuda() cls_heads, reg_heads, batch_anchors = model(images) cls_loss, reg_loss = criterion(cls_heads, reg_heads, batch_anchors, annotations) loss = cls_loss + reg_loss if cls_loss == 0.0 or reg_loss == 0.0: optimizer.zero_grad() continueif args.apex: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) optimizer.step() optimizer.zero_grad()cls_losses.append(cls_loss.item()) reg_losses.append(reg_loss.item()) losses.append(loss.item())images, annotations = prefetcher.next()if iter_index % args.print_interval == 0: logger.info( f"train: epoch {epoch:0>3d}, iter [{iter_index:0>5d}, {iters:0>5d}], cls_loss: {cls_loss.item():.2f}, reg_loss: {reg_loss.item():.2f}, loss_total: {loss.item():.2f}" )iter_index += 1scheduler.step(np.mean(losses))return np.mean(cls_losses), np.mean(reg_losses), np.mean(losses)def main(logger, args): if not torch.cuda.is_available(): raise Exception("need gpu to train network!")torch.cuda.empty_cache()if args.seed is not None: random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) cudnn.deterministic = Truegpus = torch.cuda.device_count() logger.info(f'use {gpus} gpus') logger.info(f"args: {args}")cudnn.benchmark = True cudnn.enabled = True start_time = time.time()# dataset and dataloader logger.info('start loading data') train_loader = DataLoader(Config.train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collater) logger.info('finish loading data')model = resnet50_retinanet(**{ "pretrained": args.pretrained, "num_classes": args.num_classes, })for name, param in model.named_parameters(): logger.info(f"{name},{param.requires_grad}")flops_input = torch.randn(1, 3, args.input_image_size, args.input_image_size) flops, params = profile(model, inputs=(flops_input, )) flops, params = clever_format([flops, params], "%.3f") logger.info(f"model: '{args.network}', flops: {flops}, params: {params}")criterion = RetinaLoss(image_w=args.input_image_size, image_h=args.input_image_size).cuda() decoder = RetinaDecoder(image_w=args.input_image_size, image_h=args.input_image_size).cuda()model = model.cuda() optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)if args.apex: amp.register_float_function(torch, 'sigmoid') amp.register_float_function(torch, 'softmax') model, optimizer = amp.initialize(model, optimizer, opt_level='O1')model = nn.DataParallel(model)if args.evaluate: if not os.path.isfile(args.evaluate): raise Exception( f"{args.resume} is not a file, please check it again") logger.info('start only evaluating') logger.info(f"start resuming model from {args.evaluate}") checkpoint = torch.load(args.evaluate, map_location=torch.device('cpu')) model.load_state_dict(checkpoint['model_state_dict']) all_eval_result = validate(Config.val_dataset, model, decoder) if all_eval_result is not None: logger.info( f"val: epoch: {checkpoint['epoch']:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}" )returnbest_map = 0.0 start_epoch = 1 # resume training if os.path.exists(args.resume): logger.info(f"start resuming model from {args.resume}") checkpoint = torch.load(args.resume, map_location=torch.device('cpu')) start_epoch += checkpoint['epoch'] best_map = checkpoint['best_map'] model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) scheduler.load_state_dict(checkpoint['scheduler_state_dict']) logger.info( f"finish resuming model from {args.resume}, epoch {checkpoint['epoch']}, best_map: {checkpoint['best_map']}, " f"loss: {checkpoint['loss']:3f}, cls_loss: {checkpoint['cls_loss']:2f}, reg_loss: {checkpoint['reg_loss']:2f}" )if not os.path.exists(args.checkpoints): os.makedirs(args.checkpoints)logger.info('start training') for epoch in range(start_epoch, args.epochs + 1): cls_losses, reg_losses, losses = train(train_loader, model, criterion, optimizer, scheduler, epoch, logger, args) logger.info( f"train: epoch {epoch:0>3d}, cls_loss: {cls_losses:.2f}, reg_loss: {reg_losses:.2f}, loss: {losses:.2f}" )if epoch % 5 == 0 or epoch == args.epochs: all_eval_result = validate(Config.val_dataset, model, decoder) logger.info(f"eval done.") if all_eval_result is not None: logger.info( f"val: epoch: {epoch:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}" ) if all_eval_result[0] > best_map: torch.save(model.module.state_dict(), os.path.join(args.checkpoints, "best.pth")) best_map = all_eval_result[0] torch.save( { 'epoch': epoch, 'best_map': best_map, 'cls_loss': cls_losses, 'reg_loss': reg_losses, 'loss': losses, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), }, os.path.join(args.checkpoints, 'latest.pth'))logger.info(f"finish training, best_map: {best_map:.3f}") training_time = (time.time() - start_time) / 3600 logger.info( f"finish training, total training time: {training_time:.2f} hours")if __name__ == '__main__': args = parse_args() logger = get_logger(__name__, args.log) main(logger, args)

上面实现的是在nn.DataParallel模式下的训练,config.py文件和train.py文件中的各项超参数对应于下面模型评估中ResNet50-RetinaNet-apex-aug项的超参数设置。分布式训练方法我会在下一篇文章中实现。要想进行训练,只需python train.py即可。
模型复现情况评估 根据六篇文章在各个方面对RetinaNet的复现方法,目前要与论文中RetinaNet模型点数对上还有三个问题:
  1. Detectron和Detectron2中使用的ResNet50在ImageNet上的预训练模型参数是他们自己训练的,该预训练模型参数可能要比我的ResNet50预训练模型参数要好(我的ResNet50预训练模型表现为top1-error=23.488%)。根据以往经验,预训练模型的表现越好,finetune后的结果就会更好(两者不是线性关系,但是是正相关关系)。
  2. 上面的训练使用的是nn.parallel模式,这个模式下不能使用跨卡同步BN,而Detectron和Detectron2中均使用分布式训练+跨卡同步BN。在没有同步BN的情况下,BN只能根据单张卡上batchsize的数据更新均值和标准差,如果单卡上的batchsize小于论文中的16的话,会造成BN训练的没有使用跨卡同步BN时准,模型表现会因此有所下降。
  3. 由于我没有看完Detectron和Detectron2中的全部代码,这两个框架中可能还有我未发现的有助于训练的改进手段。
对于问题1,目前无法去验证。对于问题2,我们将在下一章使用分布式训练+跨卡同步BN训练解决。对于问题3,我现在也没有精力去看完Detectron和Detectron2的全部代码,欢迎小伙伴们提出指正。
模型在COCO数据集上的性能表现如下(输入分辨率为600,约等于RetinaNet论文中的分辨率450):
Network batch gpu-num apex epoch5-mAP-loss epoch10-mAP-loss epoch12-mAP-loss one-epoch-training-times
ResNet50-RetinaNet 16 2 no 0.251,0.60 0.266,0.49 0.272,0.46 2h38min
ResNet50-RetinaNet 15 1 yes 0.251,0.59 0.272,0.48 0.276,0.45 2h31min
ResNet50-RetinaNet-aug 15 1 yes 0.254,0.62 0.274,0.53 0.283,0.51 2h31min
上面所列结果均为nn.DataParallel模式下的训练结果,均没有使用跨卡同步BN(nn.DataParallel模式下不能使用跨卡同步BN),所有实验训练时使用RandomFlip+Resize数据增强,测试时直接Resize。带-aug表示训练时还额外使用了RandomCrop和RandomTranslate数据增强。GPU全部使用RTX 2080ti。0.251,0.60表示mAP为0.251,此时的总loss为0.60。2h38min表示2小时38分。
【【庖丁解牛】从零实现RetinaNet(六)(RetinaNet的训练与测试)】根据结果,在同样数据增强情况下我的代码训练出来的RetinaNet(0.276)要比论文中低3.5个点(论文中分辨率450时点数推算应该在0.311左右),这应该是由于使用了Adam优化器代替SGD优化器,以及上面提出的问题1、2、3带来的点数差距。在下一篇文章中,我会分布式训练方法来训练RetinaNet,可以解决问题2。

    推荐阅读