【庖丁解牛】从零实现RetinaNet(六)(RetinaNet的训练与测试)
文章目录
- RetinaNet的训练
- 在COCO数据集上测试RetinaNet
- 在VOC数据集上测试RetinaNet
- 完整训练与测试代码
- 模型复现情况评估
所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。
在从零实现RetinaNet(一)到(五)中我已经完整复现了RetinaNet。这个复现的思路主要是把目标检测器分成三个独立的部分:前向网络、loss计算、decode解码。仔细查看可以发现在loss部分和decode部分实际上存在一定的重复代码,但我没有将重复的代码摘出来进行代码复用,这主要是为了实现三个部分的高内聚和低耦合,这样我们可以应用目前最新的目标检测方法对三个独立部分进行改变,然后像搭积木一样搭建改进后的目标检测器。下面我们就可以开始训练和测试RetinaNet了。
RetinaNet的训练 在RetinaNet论文(https://arxiv.org/pdf/1708.02002.pdf)中,标准的训练方法是这样的:使用momentum=0.9,weight_decay=0.0001的SGD优化器,batch_size=16且使用跨卡同步BN。一共迭代90000次,初始学习率为0.01,在60000和80000次分别将学习率除以10。
我的训练过程与上面稍有不同,但区别不大。通过16乘以90000再除以118287(COCO2017_train中图片的数量)可以计算得到大约是12.17个epoch。因此我们的训练就训练到12个epoch为止。为了简单起见,我使用Adam优化器,这样可以自动衰减学习率。根据以往经验,Adam优化器在初期收敛速度要比sgd要快,但最终训练的结果要略差于sgd(收敛的局部最优点没有sgd的好),但差距很小,对于RetinaNet,一般mAP差距最多不会超过0.5个百分点。
在Detectron和Detectron2框架中,上面所述RetinaNet论文的标准训练方法又被叫做1x_training。类似地,将迭代次数和学习率衰减的迭代次数index乘以2和乘以3,就叫做2x_training和3x_training。
在COCO数据集上测试RetinaNet 在COCO上测试RetinaNet的性能我们可以直接使用pycocotools.cocoeval中的COCOeval类提供的API。我们只需要将RetinaNet类的前向计算结果(包含anchor)送入RetinaDecoder类进行解码,然后将解码后的bbox按照scale放大成在原始图像上的尺寸即可(因为解码后的bbox尺寸大小是相对于resize后的图像大小的)。然后我们过滤掉每个图像探测到的目标中的无效目标(class_index为-1的),按照一定的格式写入一个josn文件中,再调用COCOeval进行计算就可以了。
COCOeval类提供12个性能指标:
self.maxDets = [1, 10, 100] # 即decoder中提到的max_detection_num
stats[0] = _summarize(1)
stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2])
stats[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2])
stats[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2])
stats[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2])
stats[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2])
stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
stats[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2])
stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2])
stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2])
各结果的含义如下:
# 无特殊说明的情况下,一般目标检测论文中所说的模型在COCO上的性能指的就是stats[0],至于是coco2017_val集还是coco2017_test集要看论文中描述,不过两者一般只相差0.2~0.5个百分点
stats[0] : IoU=0.5:0.95,area=all,maxDets=100,mAP
stats[1] : IoU=0.5,area=all,maxDets=100,mAP
stats[2] : IoU=0.75,area=all,maxDets=100,mAP
stats[3] : IoU=0.5:0.95,area=small,maxDets=100,mAP
stats[4] : IoU=0.5:0.95,area=medium,maxDets=100,mAP
stats[5] : IoU=0.5:0.95,area=large,maxDets=100,mAP
stats[6] : IoU=0.5:0.95,area=all,maxDets=1,mAR
stats[7] : IoU=0.5:0.95,area=all,maxDets=10,mAR
stats[8] : IoU=0.5:0.95,area=all,maxDets=100,mAR
stats[9] : IoU=0.5:0.95,area=small,maxDets=100,mAR
stats[10]:IoU=0.5:0.95,area=medium,maxDets=100,mAR
stats[11]:IoU=0.5:0.95,area=large,maxDets=100,mAR
在COCO数据集上测试的代码实现如下:
def validate(val_dataset, model, decoder):
model = model.module
# switch to evaluate mode
model.eval()
with torch.no_grad():
all_eval_result = evaluate_coco(val_dataset, model, decoder)return all_eval_resultdef evaluate_coco(val_dataset, model, decoder):
results, image_ids = [], []
for index in range(len(val_dataset)):
data = https://www.it610.com/article/val_dataset[index]
scale = data['scale']
cls_heads, reg_heads, batch_anchors = model(data['img'].cuda().permute(
2, 0, 1).float().unsqueeze(dim=0))
scores, classes, boxes = decoder(cls_heads, reg_heads, batch_anchors)
scores, classes, boxes = scores.cpu(), classes.cpu(), boxes.cpu()
boxes /= scale# make sure decode batch_size=1
# scores shape:[1,max_detection_num]
# classes shape:[1,max_detection_num]
# bboxes shape[1,max_detection_num,4]
assert scores.shape[0] == 1scores = scores.squeeze(0)
classes = classes.squeeze(0)
boxes = boxes.squeeze(0)# for coco_eval,we need [x_min,y_min,w,h] format pred boxes
boxes[:, 2:] -= boxes[:, :2]for object_score, object_class, object_box in zip(
scores, classes, boxes):
object_score = float(object_score)
object_class = int(object_class)
object_box = object_box.tolist()
if object_class == -1:
breakimage_result = {
'image_id':
val_dataset.image_ids[index],
'category_id':
val_dataset.find_category_id_from_coco_label(object_class),
'score':
object_score,
'bbox':
object_box,
}
results.append(image_result)image_ids.append(val_dataset.image_ids[index])print('{}/{}'.format(index, len(val_dataset)), end='\r')if not len(results):
print("No target detected in test set images")
returnjson.dump(results,
open('{}_bbox_results.json'.format(val_dataset.set_name), 'w'),
indent=4)# load results in COCO evaluation tool
coco_true = val_dataset.coco
coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(
val_dataset.set_name))coco_eval = COCOeval(coco_true, coco_pred, 'bbox')
coco_eval.params.imgIds = image_ids
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
all_eval_result = coco_eval.statsreturn all_eval_result
在COCO数据集上训练和测试时,我们遵循RetinaNet论文中的数据集设置,使用coco_2017_train数据集训练模型,使用coco_2017_val数据集测试模型。使用IoU=0.5:0.95下,最多保留100个detect目标,保留所有大小的目标下的mAP(即pycocotools.cocoeval的COCOeval类中_summarizeDets函数中的stats[0]值)作为模型的性能表现。
在VOC数据集上测试RetinaNet 在VOC数据集上训练和测试时,我们参照detectron2中使用faster rcnn在VOC数据集上训练测试的做法(https://github.com/facebookresearch/detectron2/blob/master/MODEL_ZOO.md),使用VOC2007trainval+VOC2012trainval数据集训练模型,使用VOC2007test数据集测试模型。测试时使用VOC2007的11 point metric方式计算mAP。
测试代码使用经典的VOC测试代码,只是把输入和输出做了一下适配:
def compute_voc_ap(recall, precision, use_07_metric=True):
if use_07_metric:
# use voc 2007 11 point metric
ap = 0.
for t in np.arange(0., 1.1, 0.1):
if np.sum(recall >= t) == 0:
p = 0
else:
# get max precisionfor recall >= t
p = np.max(precision[recall >= t])
# average 11 recall point precision
ap = ap + p / 11.
else:
# use voc>=2010 metric,average all different recall precision as ap
# recall add first value 0. and last value 1.
mrecall = np.concatenate(([0.], recall, [1.]))
# precision add first value 0. and last value 0.
mprecision = np.concatenate(([0.], precision, [0.]))# compute the precision envelope
for i in range(mprecision.size - 1, 0, -1):
mprecision[i - 1] = np.maximum(mprecision[i - 1], mprecision[i])# to calculate area under PR curve, look for points where X axis (recall) changes value
i = np.where(mrecall[1:] != mrecall[:-1])[0]# sum (\Delta recall) * prec
ap = np.sum((mrecall[i + 1] - mrecall[i]) * mprecision[i + 1])return apdef compute_ious(a, b):
"""
:param a: [N,(x1,y1,x2,y2)]
:param b: [M,(x1,y1,x2,y2)]
:return:IoU [N,M]
"""a = np.expand_dims(a, axis=1)# [N,1,4]
b = np.expand_dims(b, axis=0)# [1,M,4]overlap = np.maximum(0.0,
np.minimum(a[..., 2:], b[..., 2:]) -
np.maximum(a[..., :2], b[..., :2]))# [N,M,(w,h)]overlap = np.prod(overlap, axis=-1)# [N,M]area_a = np.prod(a[..., 2:] - a[..., :2], axis=-1)
area_b = np.prod(b[..., 2:] - b[..., :2], axis=-1)iou = overlap / (area_a + area_b - overlap)return ioudef validate(val_dataset, model, decoder):
model = model.module
# switch to evaluate mode
model.eval()
with torch.no_grad():
all_ap, mAP = evaluate_voc(val_dataset,
model,
decoder,
num_classes=20,
iou_thread=0.5)return all_ap, mAPdef evaluate_voc(val_dataset, model, decoder, num_classes=20, iou_thread=0.5):
preds, gts = [], []
for index in tqdm(range(len(val_dataset))):
data = https://www.it610.com/article/val_dataset[index]
img, gt_annot, scale = data['img'], data['annot'], data['scale']gt_bboxes, gt_classes = gt_annot[:, 0:4], gt_annot[:, 4]
gt_bboxes /= scalegts.append([gt_bboxes, gt_classes])cls_heads, reg_heads, batch_anchors = model(img.cuda().permute(
2, 0, 1).float().unsqueeze(dim=0))
preds_scores, preds_classes, preds_boxes = decoder(
cls_heads, reg_heads, batch_anchors)
preds_scores, preds_classes, preds_boxes = preds_scores.cpu(
), preds_classes.cpu(), preds_boxes.cpu()
preds_boxes /= scale# make sure decode batch_size=1
# preds_scores shape:[1,max_detection_num]
# preds_classes shape:[1,max_detection_num]
# preds_bboxes shape[1,max_detection_num,4]
assert preds_scores.shape[0] == 1preds_scores = preds_scores.squeeze(0)
preds_classes = preds_classes.squeeze(0)
preds_boxes = preds_boxes.squeeze(0)preds_scores = preds_scores[preds_classes > -1]
preds_boxes = preds_boxes[preds_classes > -1]
preds_classes = preds_classes[preds_classes > -1]preds.append([preds_boxes, preds_classes, preds_scores])print("all val sample decode done.")all_ap = {}
for class_index in tqdm(range(num_classes)):
per_class_gt_boxes = [
image[0][image[1] == class_index] for image in gts
]
per_class_pred_boxes = [
image[0][image[1] == class_index] for image in preds
]
per_class_pred_scores = [
image[2][image[1] == class_index] for image in preds
]fp = np.zeros((0, ))
tp = np.zeros((0, ))
scores = np.zeros((0, ))
total_gts = 0# loop for each sample
for per_image_gt_boxes, per_image_pred_boxes, per_image_pred_scores in zip(
per_class_gt_boxes, per_class_pred_boxes,
per_class_pred_scores):
total_gts = total_gts + len(per_image_gt_boxes)
# one gt can only be assigned to one predicted bbox
assigned_gt = []
# loop for each predicted bbox
for index in range(len(per_image_pred_boxes)):
scores = np.append(scores, per_image_pred_scores[index])
if per_image_gt_boxes.shape[0] == 0:
# if no gts found for the predicted bbox, assign the bbox to fp
fp = np.append(fp, 1)
tp = np.append(tp, 0)
continue
pred_box = np.expand_dims(per_image_pred_boxes[index], axis=0)
iou = compute_ious(per_image_gt_boxes, pred_box)
gt_for_box = np.argmax(iou, axis=0)
max_overlap = iou[gt_for_box, 0]
if max_overlap >= iou_thread and gt_for_box not in assigned_gt:
fp = np.append(fp, 0)
tp = np.append(tp, 1)
assigned_gt.append(gt_for_box)
else:
fp = np.append(fp, 1)
tp = np.append(tp, 0)
# sort by score
indices = np.argsort(-scores)
fp = fp[indices]
tp = tp[indices]
# compute cumulative false positives and true positives
fp = np.cumsum(fp)
tp = np.cumsum(tp)
# compute recall and precision
recall = tp / total_gts
precision = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
ap = compute_voc_ap(recall, precision)
all_ap[class_index] = apmAP = 0.
for _, class_mAP in all_ap.items():
mAP += float(class_mAP)
mAP /= num_classesreturn all_ap, mAP
请注意compute_voc_ap函数中use_07_metric=True表示使用VOC2007的11 point metric方式计算mAP,use_07_metric=False表示使用VOC2010之后新的mAP计算方式。
完整训练与测试代码 我们在训练中一共训练12个epoch,每5个epoch测试一次模型性能表现,训练完成时也测试一次模型性能表现。
完整训练与测试代码实现如下(这里是在COCO数据集上的训练与测试代码,要在VOC数据集上训练与测试只要稍作修改即可)。
config.py文件:
import os
import sysBASE_DIR = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)from public.path import COCO2017_path
from public.detection.dataset.cocodataset import CocoDetection, Resize, RandomFlip, RandomCrop, RandomTranslateimport torchvision.transforms as transforms
import torchvision.datasets as datasetsclass Config(object):
log = './log'# Path to save log
checkpoint_path = './checkpoints'# Path to store checkpoint model
resume = './checkpoints/latest.pth'# load checkpoint model
evaluate = None# evaluate model path
train_dataset_path = os.path.join(COCO2017_path, 'images/train2017')
val_dataset_path = os.path.join(COCO2017_path, 'images/val2017')
dataset_annotations_path = os.path.join(COCO2017_path, 'annotations')network = "resnet50_retinanet"
pretrained = False
num_classes = 80
seed = 0
input_image_size = 600train_dataset = CocoDetection(image_root_dir=train_dataset_path,
annotation_root_dir=dataset_annotations_path,
set="train2017",
transform=transforms.Compose([
RandomFlip(flip_prob=0.5),
RandomCrop(crop_prob=0.5),
RandomTranslate(translate_prob=0.5),
Resize(resize=input_image_size),
]))
val_dataset = CocoDetection(image_root_dir=val_dataset_path,
annotation_root_dir=dataset_annotations_path,
set="val2017",
transform=transforms.Compose([
Resize(resize=input_image_size),
]))epochs = 12
batch_size = 15
lr = 1e-4
num_workers = 4
print_interval = 100
apex = True
train.py文件:
import sys
import os
import argparse
import random
import shutil
import time
import warnings
import jsonBASE_DIR = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(BASE_DIR)
warnings.filterwarnings('ignore')import numpy as np
from thop import profile
from thop import clever_format
from apex import amp
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torchvision import transforms
from config import Config
from public.detection.dataset.cocodataset import COCODataPrefetcher, collater
from public.detection.models.loss import RetinaLoss
from public.detection.models.decode import RetinaDecoder
from public.detection.models.retinanet import resnet50_retinanet
from public.imagenet.utils import get_logger
from pycocotools.cocoeval import COCOevaldef parse_args():
parser = argparse.ArgumentParser(
description='PyTorch COCO Detection Training')
parser.add_argument('--network',
type=str,
default=Config.network,
help='name of network')
parser.add_argument('--lr',
type=float,
default=Config.lr,
help='learning rate')
parser.add_argument('--epochs',
type=int,
default=Config.epochs,
help='num of training epochs')
parser.add_argument('--batch_size',
type=int,
default=Config.batch_size,
help='batch size')
parser.add_argument('--pretrained',
type=bool,
default=Config.pretrained,
help='load pretrained model params or not')
parser.add_argument('--num_classes',
type=int,
default=Config.num_classes,
help='model classification num')
parser.add_argument('--input_image_size',
type=int,
default=Config.input_image_size,
help='input image size')
parser.add_argument('--num_workers',
type=int,
default=Config.num_workers,
help='number of worker to load data')
parser.add_argument('--resume',
type=str,
default=Config.resume,
help='put the path to resuming file if needed')
parser.add_argument('--checkpoints',
type=str,
default=Config.checkpoint_path,
help='path for saving trained models')
parser.add_argument('--log',
type=str,
default=Config.log,
help='path to save log')
parser.add_argument('--evaluate',
type=str,
default=Config.evaluate,
help='path for evaluate model')
parser.add_argument('--seed', type=int, default=Config.seed, help='seed')
parser.add_argument('--print_interval',
type=bool,
default=Config.print_interval,
help='print interval')
parser.add_argument('--apex',
type=bool,
default=Config.apex,
help='use apex or not')return parser.parse_args()def validate(val_dataset, model, decoder):
model = model.module
# switch to evaluate mode
model.eval()
with torch.no_grad():
all_eval_result = evaluate_coco(val_dataset, model, decoder)return all_eval_resultdef evaluate_coco(val_dataset, model, decoder):
results, image_ids = [], []
for index in range(len(val_dataset)):
data = https://www.it610.com/article/val_dataset[index]
scale = data['scale']
cls_heads, reg_heads, batch_anchors = model(data['img'].cuda().permute(
2, 0, 1).float().unsqueeze(dim=0))
scores, classes, boxes = decoder(cls_heads, reg_heads, batch_anchors)
scores, classes, boxes = scores.cpu(), classes.cpu(), boxes.cpu()
boxes /= scale# make sure decode batch_size=1
# scores shape:[1,max_detection_num]
# classes shape:[1,max_detection_num]
# bboxes shape[1,max_detection_num,4]
assert scores.shape[0] == 1scores = scores.squeeze(0)
classes = classes.squeeze(0)
boxes = boxes.squeeze(0)# for coco_eval,we need [x_min,y_min,w,h] format pred boxes
boxes[:, 2:] -= boxes[:, :2]for object_score, object_class, object_box in zip(
scores, classes, boxes):
object_score = float(object_score)
object_class = int(object_class)
object_box = object_box.tolist()
if object_class == -1:
breakimage_result = {
'image_id':
val_dataset.image_ids[index],
'category_id':
val_dataset.find_category_id_from_coco_label(object_class),
'score':
object_score,
'bbox':
object_box,
}
results.append(image_result)image_ids.append(val_dataset.image_ids[index])print('{}/{}'.format(index, len(val_dataset)), end='\r')if not len(results):
print("No target detected in test set images")
returnjson.dump(results,
open('{}_bbox_results.json'.format(val_dataset.set_name), 'w'),
indent=4)# load results in COCO evaluation tool
coco_true = val_dataset.coco
coco_pred = coco_true.loadRes('{}_bbox_results.json'.format(
val_dataset.set_name))coco_eval = COCOeval(coco_true, coco_pred, 'bbox')
coco_eval.params.imgIds = image_ids
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
all_eval_result = coco_eval.statsreturn all_eval_resultdef train(train_loader, model, criterion, optimizer, scheduler, epoch, logger,
args):
cls_losses, reg_losses, losses = [], [], []# switch to train mode
model.train()iters = len(train_loader.dataset) // args.batch_size
prefetcher = COCODataPrefetcher(train_loader)
images, annotations = prefetcher.next()
iter_index = 1while images is not None:
images, annotations = images.cuda().float(), annotations.cuda()
cls_heads, reg_heads, batch_anchors = model(images)
cls_loss, reg_loss = criterion(cls_heads, reg_heads, batch_anchors,
annotations)
loss = cls_loss + reg_loss
if cls_loss == 0.0 or reg_loss == 0.0:
optimizer.zero_grad()
continueif args.apex:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
optimizer.step()
optimizer.zero_grad()cls_losses.append(cls_loss.item())
reg_losses.append(reg_loss.item())
losses.append(loss.item())images, annotations = prefetcher.next()if iter_index % args.print_interval == 0:
logger.info(
f"train: epoch {epoch:0>3d}, iter [{iter_index:0>5d}, {iters:0>5d}], cls_loss: {cls_loss.item():.2f}, reg_loss: {reg_loss.item():.2f}, loss_total: {loss.item():.2f}"
)iter_index += 1scheduler.step(np.mean(losses))return np.mean(cls_losses), np.mean(reg_losses), np.mean(losses)def main(logger, args):
if not torch.cuda.is_available():
raise Exception("need gpu to train network!")torch.cuda.empty_cache()if args.seed is not None:
random.seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
cudnn.deterministic = Truegpus = torch.cuda.device_count()
logger.info(f'use {gpus} gpus')
logger.info(f"args: {args}")cudnn.benchmark = True
cudnn.enabled = True
start_time = time.time()# dataset and dataloader
logger.info('start loading data')
train_loader = DataLoader(Config.train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
collate_fn=collater)
logger.info('finish loading data')model = resnet50_retinanet(**{
"pretrained": args.pretrained,
"num_classes": args.num_classes,
})for name, param in model.named_parameters():
logger.info(f"{name},{param.requires_grad}")flops_input = torch.randn(1, 3, args.input_image_size,
args.input_image_size)
flops, params = profile(model, inputs=(flops_input, ))
flops, params = clever_format([flops, params], "%.3f")
logger.info(f"model: '{args.network}', flops: {flops}, params: {params}")criterion = RetinaLoss(image_w=args.input_image_size,
image_h=args.input_image_size).cuda()
decoder = RetinaDecoder(image_w=args.input_image_size,
image_h=args.input_image_size).cuda()model = model.cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
patience=3,
verbose=True)if args.apex:
amp.register_float_function(torch, 'sigmoid')
amp.register_float_function(torch, 'softmax')
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')model = nn.DataParallel(model)if args.evaluate:
if not os.path.isfile(args.evaluate):
raise Exception(
f"{args.resume} is not a file, please check it again")
logger.info('start only evaluating')
logger.info(f"start resuming model from {args.evaluate}")
checkpoint = torch.load(args.evaluate,
map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
all_eval_result = validate(Config.val_dataset, model, decoder)
if all_eval_result is not None:
logger.info(
f"val: epoch: {checkpoint['epoch']:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}"
)returnbest_map = 0.0
start_epoch = 1
# resume training
if os.path.exists(args.resume):
logger.info(f"start resuming model from {args.resume}")
checkpoint = torch.load(args.resume, map_location=torch.device('cpu'))
start_epoch += checkpoint['epoch']
best_map = checkpoint['best_map']
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
logger.info(
f"finish resuming model from {args.resume}, epoch {checkpoint['epoch']}, best_map: {checkpoint['best_map']}, "
f"loss: {checkpoint['loss']:3f}, cls_loss: {checkpoint['cls_loss']:2f}, reg_loss: {checkpoint['reg_loss']:2f}"
)if not os.path.exists(args.checkpoints):
os.makedirs(args.checkpoints)logger.info('start training')
for epoch in range(start_epoch, args.epochs + 1):
cls_losses, reg_losses, losses = train(train_loader, model, criterion,
optimizer, scheduler, epoch,
logger, args)
logger.info(
f"train: epoch {epoch:0>3d}, cls_loss: {cls_losses:.2f}, reg_loss: {reg_losses:.2f}, loss: {losses:.2f}"
)if epoch % 5 == 0 or epoch == args.epochs:
all_eval_result = validate(Config.val_dataset, model, decoder)
logger.info(f"eval done.")
if all_eval_result is not None:
logger.info(
f"val: epoch: {epoch:0>5d}, IoU=0.5:0.95,area=all,maxDets=100,mAP:{all_eval_result[0]:.3f}, IoU=0.5,area=all,maxDets=100,mAP:{all_eval_result[1]:.3f}, IoU=0.75,area=all,maxDets=100,mAP:{all_eval_result[2]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAP:{all_eval_result[3]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAP:{all_eval_result[4]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAP:{all_eval_result[5]:.3f}, IoU=0.5:0.95,area=all,maxDets=1,mAR:{all_eval_result[6]:.3f}, IoU=0.5:0.95,area=all,maxDets=10,mAR:{all_eval_result[7]:.3f}, IoU=0.5:0.95,area=all,maxDets=100,mAR:{all_eval_result[8]:.3f}, IoU=0.5:0.95,area=small,maxDets=100,mAR:{all_eval_result[9]:.3f}, IoU=0.5:0.95,area=medium,maxDets=100,mAR:{all_eval_result[10]:.3f}, IoU=0.5:0.95,area=large,maxDets=100,mAR:{all_eval_result[11]:.3f}"
)
if all_eval_result[0] > best_map:
torch.save(model.module.state_dict(),
os.path.join(args.checkpoints, "best.pth"))
best_map = all_eval_result[0]
torch.save(
{
'epoch': epoch,
'best_map': best_map,
'cls_loss': cls_losses,
'reg_loss': reg_losses,
'loss': losses,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
}, os.path.join(args.checkpoints, 'latest.pth'))logger.info(f"finish training, best_map: {best_map:.3f}")
training_time = (time.time() - start_time) / 3600
logger.info(
f"finish training, total training time: {training_time:.2f} hours")if __name__ == '__main__':
args = parse_args()
logger = get_logger(__name__, args.log)
main(logger, args)
上面实现的是在nn.DataParallel模式下的训练,config.py文件和train.py文件中的各项超参数对应于下面模型评估中ResNet50-RetinaNet-apex-aug项的超参数设置。分布式训练方法我会在下一篇文章中实现。要想进行训练,只需python train.py即可。
模型复现情况评估 根据六篇文章在各个方面对RetinaNet的复现方法,目前要与论文中RetinaNet模型点数对上还有三个问题:
- Detectron和Detectron2中使用的ResNet50在ImageNet上的预训练模型参数是他们自己训练的,该预训练模型参数可能要比我的ResNet50预训练模型参数要好(我的ResNet50预训练模型表现为top1-error=23.488%)。根据以往经验,预训练模型的表现越好,finetune后的结果就会更好(两者不是线性关系,但是是正相关关系)。
- 上面的训练使用的是nn.parallel模式,这个模式下不能使用跨卡同步BN,而Detectron和Detectron2中均使用分布式训练+跨卡同步BN。在没有同步BN的情况下,BN只能根据单张卡上batchsize的数据更新均值和标准差,如果单卡上的batchsize小于论文中的16的话,会造成BN训练的没有使用跨卡同步BN时准,模型表现会因此有所下降。
- 由于我没有看完Detectron和Detectron2中的全部代码,这两个框架中可能还有我未发现的有助于训练的改进手段。
模型在COCO数据集上的性能表现如下(输入分辨率为600,约等于RetinaNet论文中的分辨率450):
Network | batch | gpu-num | apex | epoch5-mAP-loss | epoch10-mAP-loss | epoch12-mAP-loss | one-epoch-training-times |
---|---|---|---|---|---|---|---|
ResNet50-RetinaNet | 16 | 2 | no | 0.251,0.60 | 0.266,0.49 | 0.272,0.46 | 2h38min |
ResNet50-RetinaNet | 15 | 1 | yes | 0.251,0.59 | 0.272,0.48 | 0.276,0.45 | 2h31min |
ResNet50-RetinaNet-aug | 15 | 1 | yes | 0.254,0.62 | 0.274,0.53 | 0.283,0.51 | 2h31min |
【【庖丁解牛】从零实现RetinaNet(六)(RetinaNet的训练与测试)】根据结果,在同样数据增强情况下我的代码训练出来的RetinaNet(0.276)要比论文中低3.5个点(论文中分辨率450时点数推算应该在0.311左右),这应该是由于使用了Adam优化器代替SGD优化器,以及上面提出的问题1、2、3带来的点数差距。在下一篇文章中,我会分布式训练方法来训练RetinaNet,可以解决问题2。
推荐阅读
- 宽容谁
- 我要做大厨
- 增长黑客的海盗法则
- 画画吗()
- 2019-02-13——今天谈梦想()
- 远去的风筝
- 三十年后的广场舞大爷
- 叙述作文
- 20190302|20190302 复盘翻盘
- 学无止境,人生还很长