文章目录
- 什么是CenterSample
- 新的heads层
- loss类修改
- 模型训练和测试结果
所有代码已上传到本人github repository:https://github.com/zgcr/pytorch-ImageNet-CIFAR-COCO-VOC-training
如果觉得有用,请点个star哟!
下列代码均在pytorch1.4版本中测试过,确认正确无误。
什么是CenterSample 在原始的FCOS实现中,对于处在标注框内部的锚点全部都会算成正样本。但是,标注框内部靠近边缘的部分往往仍然是背景部分,因此这部分正样本实际对应的点确是属于负样本的背景部分,这会对模型的学习造成困扰。在前面对原始版本的FCOS的训练过程中可以发现,FCOS模型在resize=667分辨率下需要24个epoch才能训练到接近RetinaNet训练12个epoch时的模型性能。而在更大的resize=1000分辨率下,模型性能在前10个epoch有所上升,但是之后开始下降,这说明在大分辨率图片输入情况下,前面所说的在标注框内边缘的正样本会对模型学习产生更大的负面作用(因为这种锚点在大分辨率情况下更多)。
CenterSample的作用是在标注框区域内,以标注框中心点为圆心,取一个比框更小的圆形部分,只有在这个圆形部分内的锚点才算成正样本。这样就会把大多数在标注框内边缘、实际落在背景部分的锚点标注为负样本,FCOS模型的收敛速度就会变快,最终性能也更好。具体来说,我们设置一个超参数center_sample_radius,即这个圆的基础半径。然后,根据这个标注框分配到FPN的哪一层,将center_sample_radius乘以该层的stride,就得到了这个标注框内圆形部分的真正半径。
同时,在代码中,我还在几个head层加入group nomlization层,这样最后得到的FCOS就和论文中所有改进都加上后的FCOS模型配置一样了。
新的heads层 分类heads、回归heads、centerness heads全部写在一个类中,centerness head与回归heads共用。
heads代码实现如下:
class FCOSClsRegCntHead(nn.Module):
def __init__(self,
inplanes,
num_classes,
num_layers=4,
prior=0.01,
use_gn=True,
cnt_on_reg=True):
super(FCOSClsRegCntHead, self).__init__()
self.cnt_on_reg = cnt_on_regcls_layers = []
for _ in range(num_layers):
cls_layers.append(
nn.Conv2d(inplanes,
inplanes,
kernel_size=3,
stride=1,
padding=1))
if use_gn:
cls_layers.append(nn.GroupNorm(32, inplanes))
cls_layers.append(nn.ReLU(inplace=True))
self.cls_head = nn.Sequential(*cls_layers)reg_layers = []
for _ in range(num_layers):
reg_layers.append(
nn.Conv2d(inplanes,
inplanes,
kernel_size=3,
stride=1,
padding=1))
if use_gn:
reg_layers.append(nn.GroupNorm(32, inplanes))
reg_layers.append(nn.ReLU(inplace=True))
self.reg_head = nn.Sequential(*reg_layers)self.cls_out = nn.Conv2d(inplanes,
num_classes,
kernel_size=3,
stride=1,
padding=1)
self.reg_out = nn.Conv2d(inplanes,
4,
kernel_size=3,
stride=1,
padding=1)
self.center_out = nn.Conv2d(inplanes,
1,
kernel_size=3,
stride=1,
padding=1)for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, std=0.01)
if m.bias is not None:
nn.init.constant_(m.bias, val=0)prior = prior
b = -math.log((1 - prior) / prior)
self.cls_out.bias.data.fill_(b)def forward(self, x):
cls_x = self.cls_head(x)
reg_x = self.reg_head(x)del xcls_output = self.cls_out(cls_x)
reg_output = self.reg_out(reg_x)if self.cnt_on_reg:
center_output = self.center_out(reg_x)
else:
center_output = self.center_out(cls_x)return cls_output, reg_output, center_output
loss类修改 loss类只需要修改ground truth分配这部分添加centersample机制,即修改get_batch_position_annotations函数。修改后的get_batch_position_annotations实现如下:
def get_batch_position_annotations(self, cls_heads, reg_heads,
center_heads, batch_positions,
annotations):
"""
Assign a ground truth target for each position on feature map
"""
device = annotations.device
batch_mi, batch_stride = [], []
for reg_head, mi, stride in zip(reg_heads, self.mi, self.strides):
mi = torch.tensor(mi).to(device)
B, H, W, _ = reg_head.shape
per_level_mi = torch.zeros(B, H, W, 2).to(device)
per_level_mi = per_level_mi + mi
batch_mi.append(per_level_mi)
per_level_stride = torch.zeros(B, H, W, 1).to(device)
per_level_stride = per_level_stride + stride
batch_stride.append(per_level_stride)cls_preds,reg_preds,center_preds,all_points_position,all_points_mi,all_points_stride=[],[],[],[],[],[]
for cls_pred, reg_pred, center_pred, per_level_position, per_level_mi, per_level_stride in zip(
cls_heads, reg_heads, center_heads, batch_positions, batch_mi,
batch_stride):
cls_pred = cls_pred.view(cls_pred.shape[0], -1, cls_pred.shape[-1])
reg_pred = reg_pred.view(reg_pred.shape[0], -1, reg_pred.shape[-1])
center_pred = center_pred.view(center_pred.shape[0], -1,
center_pred.shape[-1])
per_level_position = per_level_position.view(
per_level_position.shape[0], -1, per_level_position.shape[-1])
per_level_mi = per_level_mi.view(per_level_mi.shape[0], -1,
per_level_mi.shape[-1])
per_level_stride = per_level_stride.view(
per_level_stride.shape[0], -1, per_level_stride.shape[-1])cls_preds.append(cls_pred)
reg_preds.append(reg_pred)
center_preds.append(center_pred)
all_points_position.append(per_level_position)
all_points_mi.append(per_level_mi)
all_points_stride.append(per_level_stride)cls_preds = torch.cat(cls_preds, axis=1)
reg_preds = torch.cat(reg_preds, axis=1)
center_preds = torch.cat(center_preds, axis=1)
all_points_position = torch.cat(all_points_position, axis=1)
all_points_mi = torch.cat(all_points_mi, axis=1)
all_points_stride = torch.cat(all_points_stride, axis=1)batch_targets = []
for per_image_position, per_image_mi, per_image_stride, per_image_annotations in zip(
all_points_position, all_points_mi, all_points_stride,
annotations):
per_image_annotations = per_image_annotations[
per_image_annotations[:, 4] >= 0]
points_num = per_image_position.shape[0]if per_image_annotations.shape[0] == 0:
# 6:l,t,r,b,class_index,center-ness_gt
per_image_targets = torch.zeros([points_num, 6], device=device)
else:
annotaion_num = per_image_annotations.shape[0]
per_image_gt_bboxes = per_image_annotations[:, 0:4]
candidates = torch.zeros([points_num, annotaion_num, 4],
device=device)
candidates = candidates + per_image_gt_bboxes.unsqueeze(0)per_image_position = per_image_position.unsqueeze(1).repeat(
1, annotaion_num, 1)if self.use_center_sample:
candidates_center = (candidates[:, :, 2:4] +
candidates[:, :, 0:2]) / 2
judge_distance = per_image_stride * self.center_sample_radius
judge_distance = judge_distance.repeat(1, annotaion_num)candidates[:, :,
0:2] = per_image_position[:, :,
0:2] - candidates[:, :,
0:2]
candidates[:, :,
2:4] = candidates[:, :,
2:4] - per_image_position[:, :,
0:2]candidates_min_value, _ = candidates.min(axis=-1, keepdim=True)
sample_flag = (candidates_min_value[:, :, 0] >
0).int().unsqueeze(-1)
# get all negative reg targets which points ctr out of gt box
candidates = candidates * sample_flag# if use center sample get all negative reg targets which points not in center circle
if self.use_center_sample:
compute_distance = torch.sqrt(
(per_image_position[:, :, 0] -
candidates_center[:, :, 0])**2 +
(per_image_position[:, :, 1] -
candidates_center[:, :, 1])**2)
center_sample_flag = (compute_distance <
judge_distance).int().unsqueeze(-1)
candidates = candidates * center_sample_flag# get all negative reg targets which assign ground turth not in range of mi
candidates_max_value, _ = candidates.max(axis=-1, keepdim=True)
per_image_mi = per_image_mi.unsqueeze(1).repeat(
1, annotaion_num, 1)
m1_negative_flag = (candidates_max_value[:, :, 0] >
per_image_mi[:, :, 0]).int().unsqueeze(-1)
candidates = candidates * m1_negative_flag
m2_negative_flag = (candidates_max_value[:, :, 0] <
per_image_mi[:, :, 1]).int().unsqueeze(-1)
candidates = candidates * m2_negative_flagfinal_sample_flag = candidates.sum(axis=-1).sum(axis=-1)
final_sample_flag = final_sample_flag > 0
positive_index = (final_sample_flag == True).nonzero().squeeze(
dim=-1)# if no assign positive sample
if len(positive_index) == 0:
del candidates
# 6:l,t,r,b,class_index,center-ness_gt
per_image_targets = torch.zeros([points_num, 6],
device=device)
else:
positive_candidates = candidates[positive_index]del candidatessample_box_gts = per_image_annotations[:, 0:4].unsqueeze(0)
sample_box_gts = sample_box_gts.repeat(
positive_candidates.shape[0], 1, 1)
sample_class_gts = per_image_annotations[:, 4].unsqueeze(
-1).unsqueeze(0)
sample_class_gts = sample_class_gts.repeat(
positive_candidates.shape[0], 1, 1)# 6:l,t,r,b,class_index,center-ness_gt
per_image_targets = torch.zeros([points_num, 6],
device=device)if positive_candidates.shape[1] == 1:
# if only one candidate for each positive sample
# assign l,t,r,b,class_index,center_ness_gt ground truth
# class_index value from 1 to 80 represent 80 positive classes
# class_index value 0 represenet negative class
positive_candidates = positive_candidates.squeeze(1)
sample_class_gts = sample_class_gts.squeeze(1)
per_image_targets[positive_index,
0:4] = positive_candidates
per_image_targets[positive_index,
4:5] = sample_class_gts + 1l, t, r, b = per_image_targets[
positive_index, 0:1], per_image_targets[
positive_index, 1:2], per_image_targets[
positive_index,
2:3], per_image_targets[positive_index,
3:4]
per_image_targets[positive_index, 5:6] = torch.sqrt(
(torch.min(l, r) / torch.max(l, r)) *
(torch.min(t, b) / torch.max(t, b)))
else:
# if a positive point sample have serveral object candidates,then choose the smallest area object candidate as the ground turth for this positive point sample
gts_w_h = sample_box_gts[:, :,
2:4] - sample_box_gts[:, :,
0:2]
gts_area = gts_w_h[:, :, 0] * gts_w_h[:, :, 1]
positive_candidates_value = https://www.it610.com/article/positive_candidates.sum(
axis=2)# make sure all negative candidates areas==100000000,thus .min() operation wouldn't choose negative candidates
INF = 100000000
inf_tensor = torch.ones_like(gts_area) * INF
gts_area = torch.where(
torch.eq(positive_candidates_value, 0.),
inf_tensor, gts_area)# get the smallest object candidate index
_, min_index = gts_area.min(axis=1)
candidate_indexes = (
torch.linspace(1, positive_candidates.shape[0],
positive_candidates.shape[0]) -
1).long()
final_candidate_reg_gts = positive_candidates[
candidate_indexes, min_index, :]
final_candidate_cls_gts = sample_class_gts[
candidate_indexes, min_index]# assign l,t,r,b,class_index,center_ness_gt ground truth
per_image_targets[positive_index,
0:4] = final_candidate_reg_gts
per_image_targets[positive_index,
4:5] = final_candidate_cls_gts + 1l, t, r, b = per_image_targets[
positive_index, 0:1], per_image_targets[
positive_index, 1:2], per_image_targets[
positive_index,
2:3], per_image_targets[positive_index,
3:4]
per_image_targets[positive_index, 5:6] = torch.sqrt(
(torch.min(l, r) / torch.max(l, r)) *
(torch.min(t, b) / torch.max(t, b)))per_image_targets = per_image_targets.unsqueeze(0)
batch_targets.append(per_image_targets)batch_targets = torch.cat(batch_targets, axis=0)
batch_targets = torch.cat([batch_targets, all_points_position], axis=2)# batch_targets shape:[batch_size, points_num, 8],8:l,t,r,b,class_index,center-ness_gt,point_ctr_x,point_ctr_y
return cls_preds, reg_preds, center_preds, batch_targets
模型训练和测试结果 重新训练FCOS模型,首先看看resize=667时的结果。
Network | batch | gpu-num | apex | syncbn | epoch5-mAP-mAR-loss | epoch10-mAP-mAR-loss | epoch12-mAP-mAR-loss |
---|---|---|---|---|---|---|---|
ResNet50-FCOS-myresize667-fastdecode | 24 | 2 | yes | no | 0.272,0.399,1.15 | 0.293,0.422,1.07 | 0.312,0.445,1.06 |
ResNet101-FCOS-myresize667-fastdecode | 16 | 2 | yes | no | 0.261,0.390,1.14 | 0.307,0.438,1.06 | 0.325,0.455,1.05 |
Network | batch | gpu-num | apex | syncbn | epoch12-mAP-mAR-loss | epoch24-mAP-mAR-loss |
---|---|---|---|---|---|---|
ResNet50-FCOS-myresize1000-fastdecode | 16 | 2 | yes | no | 0.352,0.490,1.03 | 0.352,0.491,1.01 |
推荐阅读
- 炼丹|使用FCOS训练自己的数据
- AI|论文研读(三)(FCOS: Fully Convolutional One-Stage Object Detection之补充)
- transformers|【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2
- 无人车|CCF智能无人车比赛(国内绿洲科学实验云平台)心路历程+AWS Deepracer智能无人车比赛经验(附优秀代码re:lnvent 2018赛道)
- 深度学习|论文阅读《Meta-FDMixup(Cross-Domain Few-Shot Learning Guided by Labeled Target Data》)
- 图像对比度修正|论文学习笔记: Learning Multi-Scale Photo Exposure Correction(含pytorch代码复现)
- 深度学习|yolov5之可视化特征图和检测结果
- 神经网络|【干货】新手炼丹经验总结
- 图像分类|保姆级使用PyTorch训练与评估自己的Wide ResNet网络教程