PixPro自监督论文解读

PixPro自监督论文解读
文章图片

PixPro是第一个通过像素级对比学习来进行特征表示学习
上图是整个算法流程图,接下来进行详细解析
前向传播 input是输入图像,维度尺寸是(b, c, h, w)
augmentation:通过对同一张input进行随机大小、位置裁剪并缩放到统一大小224*224,并基于一定概率下进行随机水平翻转、color distortion、高斯模糊和solarization操作,最后生成两张不同视图view #1和view #2,大小都是(b, c, 224, 224)
backbone+projection:view #1和view #2分别送入两个网络分支,上下两分支中都含有相同结构的backbone+projection模块,其中backbone模块使用了Resnet,输出最后一层特征图,大小为(b, c1, 7, 7)。
projection模块是一个conv1*1+BN+Relu+conv1*1结构,先进行升维,再降维到256大小,这样就得到了两个输出大小为(b, 256, 7, 7)的特征$x$和$x^{,}$,projection模块代码如下:

class MLP2d(nn.Module): def __init__(self, in_dim, inner_dim=4096, out_dim=256): super(MLP2d, self).__init__()self.linear1 = conv1x1(in_dim, inner_dim) self.bn1 = nn.BatchNorm2d(inner_dim) self.relu1 = nn.ReLU(inplace=True)self.linear2 = conv1x1(inner_dim, out_dim)def forward(self, x): x = self.linear1(x) x = self.bn1(x) x = self.relu1(x)x = self.linear2(x)return xdef conv1x1(in_planes, out_planes): """1x1 convolution""" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=True)

PPM:是一个自注意力模块,针对(b, 256, 7, 7)的输入特征图$x$
PixPro自监督论文解读
文章图片

首先根据cosine相似度计算出attention图,大小为(b, 49, 49),表示每个特征点与其他特征点的相似度。再对输入特征图进行特征融合,得到输出大小为(b, 256, 7, 7)的特征图$y$,PPM代码如下:
def featprop(self, feat): N, C, H, W = feat.shape# Value transformation feat_value = https://www.it610.com/article/self.value_transform(feat)# 1*1卷积操作 feat_value = F.normalize(feat_value, dim=1) feat_value = feat_value.view(N, C, -1)# Similarity calculation feat = F.normalize(feat, dim=1)# [N, C, H * W] feat = feat.view(N, C, -1)# [N, H * W, H * W] attention = torch.bmm(feat.transpose(1, 2), feat) attention = torch.clamp(attention, min=self.pixpro_clamp_value) if self.pixpro_p < 1.: attention = attention + 1e-6 attention = attention ** self.pixpro_p# pixpro_p控制注意力的范围,默认为1# [N, C, H * W] feat = torch.bmm(feat_value, attention.transpose(1, 2))return feat.view(N, C, H, W)

Loss:计算$x^,$和$y$之间的loss。$x^,$和$y$的空间位置示意图如下所示:
PixPro自监督论文解读
文章图片

在数据增强augmentation过程中,可以获取裁剪图像的左上角和右下键坐标,由于输出特征图$x^,$和$y$大小为(b, 256, 7, 7),所以每个特征图中有7*7个特征点,根据插值法就可以获取输出特征图$x^,$和$y$的每个特征点的空间坐标,大小为(b, 2, 7, 7)。
首先计算出不同视图中每个特征点彼此之间的距离,可以得到大小为(b, 49, 49)的距离矩阵D,步骤如下:
  1. 特征图$x^{,}$的x坐标$X_{x^{,}}$:(b, 7, 7)->(b, 49, 1), y坐标$Y_{x^{,}}$:(b, 7, 7)->(b, 49, 1)
  2. 特征图$y$中x坐标$X_{y}$:(b, 7, 7)->(b, 1, 49), y坐标$Y_{y}$:(b, 7, 7)->(b, 1,49)
  3. 距离矩阵D=$\sqrt{(X_{x^{,}}-X_{y})^2+(Y_{x^{,}}-Y_{y})^2}/max\_bin$(max_bin是相邻特征点之间的最大距离,目的是为了"归一化")
不同视图中距离较近的特征点特征应该具有一致性consistency,所以对距离特征D根据阈值ratio进行二分来获取距离较近的特征点掩码M=(D 再计算出$x^,$和$y$的特征相似度图logit,大小为(b, 49, 49),这步与PPM中计算注意力相似度类似
最后根据特征相似图和掩码矩阵计算loss:
$loss = logit * M$
整个loss计算完整过程的代码如下:
def regression_loss(q, k, coord_q, coord_k, pos_ratio=0.5): """ q, k: N * C * H * W coord_q, coord_k: N * 4 (x_upper_left, y_upper_left, x_lower_right, y_lower_right) """ N, C, H, W = q.shape # [bs, feat_dim, 49] q = q.view(N, C, -1) k = k.view(N, C, -1)# generate center_coord, width, height # [1, 7, 7] x_array = torch.arange(0., float(W), dtype=coord_q.dtype, device=coord_q.device).view(1, 1, -1).repeat(1, H, 1) y_array = torch.arange(0., float(H), dtype=coord_q.dtype, device=coord_q.device).view(1, -1, 1).repeat(1, 1, W) # [bs, 1, 1] q_bin_width = ((coord_q[:, 2] - coord_q[:, 0]) / W).view(-1, 1, 1) q_bin_height = ((coord_q[:, 3] - coord_q[:, 1]) / H).view(-1, 1, 1) k_bin_width = ((coord_k[:, 2] - coord_k[:, 0]) / W).view(-1, 1, 1) k_bin_height = ((coord_k[:, 3] - coord_k[:, 1]) / H).view(-1, 1, 1) # [bs, 1, 1] q_start_x = coord_q[:, 0].view(-1, 1, 1) q_start_y = coord_q[:, 1].view(-1, 1, 1) k_start_x = coord_k[:, 0].view(-1, 1, 1) k_start_y = coord_k[:, 1].view(-1, 1, 1)# [bs, 1, 1] q_bin_diag = torch.sqrt(q_bin_width ** 2 + q_bin_height ** 2) k_bin_diag = torch.sqrt(k_bin_width ** 2 + k_bin_height ** 2) max_bin_diag = torch.max(q_bin_diag, k_bin_diag)# [bs, 7, 7] center_q_x = (x_array + 0.5) * q_bin_width + q_start_x center_q_y = (y_array + 0.5) * q_bin_height + q_start_y center_k_x = (x_array + 0.5) * k_bin_width + k_start_x center_k_y = (y_array + 0.5) * k_bin_height + k_start_y# [bs, 49, 49] dist_center = torch.sqrt((center_q_x.view(-1, H * W, 1) - center_k_x.view(-1, 1, H * W)) ** 2 + (center_q_y.view(-1, H * W, 1) - center_k_y.view(-1, 1, H * W)) ** 2) / max_bin_diag pos_mask = (dist_center < pos_ratio).float().detach()# [bs, 49, 49] logit = torch.bmm(q.transpose(1, 2), k)loss = (logit * pos_mask).sum(-1).sum(-1) / (pos_mask.sum(-1).sum(-1) + 1e-6)return -2 * loss.mean()

反向传播 下分支网络不参与直接训练,其中所有的权重参数不具有梯度值。其参数$param\_k$更新方式是基于上分支网络参数$param\_q$动量更新。训练开始前,上下分支网络初始权重保持一致。
$$ param\_k.data = https://www.it610.com/article/param/_k.data * momentum + param/_q.data * (1-momentum) $$
其中,momentum是动量值,整个训练过程从0.99逐渐增大到1.0
实验 优化器:LARS,weight_decay=1e-5
lr_scheduler:cosine, warmup
total_batchsize:1024
world size:8 V100 GPUs
与其他基于实例级自监督算法在下游检测分割任务上的比较结果
PixPro自监督论文解读
文章图片

不同超参数下的实验结果
PixPro自监督论文解读
文章图片

PixPro和ProContrast结果比较
PixPro自监督论文解读
文章图片

结合实例级模块的结果
PixPro自监督论文解读
文章图片

有无FPN、head、实例级模块的实验比较结果
【PixPro自监督论文解读】PixPro自监督论文解读
文章图片

    推荐阅读