MatchNet论文复现过程记录 【深度学习|MatchNet论文复现过程记录】原文为《Matchnet: Unifying feature and metric learning for patch-based matching》1:本文复现基于PyTorch深度学习框架,版本(1.7.1+cu110)。
I.Network architecture
文章图片
根据论文中描述,MatchNet包括:
A. Feature network
该特征提取网络类似AlexNet2,具体结构如下:
文章图片
其中,PS: patch size for convolution and pooling layers;
S: stride. Layer types: C: convolution, MP: max-pooling, FC: fully-connected.
B. Metric network
包括三个全连接层,FC3后接Softmax作为输出。
C. MatchNet in training
基于patch的匹配任务通常假设patch在计算相似度之前,先经过相同的特征编码。因此,论文中采用Two-tower structure with tied parameters结构,即,仅采用一个特征提取网络,在训练过程中,可以理解为同时使用了两个参数共享的特征提取网络去连接度量网络,更新任何一个特征提取网络,将会使得两个网络的参数都发生变化。(这里直接讲比较难理解,具体可以看代码实现。)
具体代码实现如下:
import torch
import torch.nn as nnclass FeatureNet(nn.Module):
"""特征提取网络
"""
def __init__(self):
super(FeatureNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=24, kernel_size=7, padding=3, stride=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
nn.Conv2d(in_channels=24, out_channels=64, kernel_size=5, padding=2, stride=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, padding=1, stride=1),
nn.ReLU(),
nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, padding=1, stride=1),
nn.ReLU(),
nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, padding=1, stride=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=0)
)def forward(self, x):
return self.features(x)class MetricNet(nn.Module):
"""度量网络
"""
def __init__(self):
super(MetricNet, self).__init__()
self.features = nn.Sequential(
nn.Linear(in_features=6272, out_features=1024),
nn.ReLU(),
nn.Linear(in_features=1024, out_features=1024),
nn.ReLU(),
nn.Linear(in_features=1024, out_features=2),
# nn.Softmax(dim=1)
''' 这里原本应该接Softmax,但损失函数采用的是交叉熵损失,
而Pytorch中的torch.nn.CrossEntropyLoss()方法包括Softmax,
具体可参考文档https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?highlight=nn%20crossentropyloss#torch.nn.CrossEntropyLoss
'''
)def forward(self, x):
return self.features(x)class MatchNet(nn.Module):
def __init__(self):
super(MatchNet, self).__init__()# 只添加一个特征提取网络
self.input_ = FeatureNet()
self.input_.apply(weights_init)self.matric_network = MetricNet()
self.matric_network.apply(weights_init)def forward(self, x):
"""x.shape = (2, C, H, W),即两个patch
"""
# 两个patch进入同一个FeatureNet,相当于two-tower sharing same parameters
feature1 = self.input_(x[0]).reshape((x[0].shape[0], -1)) #[256, 3136]
feature2 = self.input_(x[1]).reshape((x[1].shape[0], -1))features = torch.cat((feature1, feature2), 1) #[256, 6272]return self.matric_network(features)def weights_init(m):
'''
自定义权重初始化
'''
if isinstance(m, nn.Conv2d):
nn.init.orthogonal_(m.weight.data, gain=0.6)
try:
nn.init.constant_(m.bias.data, 0.01)
except Exception:
pass
return
参考文献
- Han等, 《Matchnet: Unifying feature and metric learning for patch-based matching》. ??
- A. Krizhevsky, I. Sutskever, and G. E. Hinton. ImageNet classification with deep convolutional neural networks. In NIPS, 2012. ??
推荐阅读
- 暑期深度学习入门|第二周学习(卷积神经网络)
- 学习日志|Python MINIST手写集的识别,卷积神经网络,CNN(最简单PyTorch的使用)
- ROS从入门到精通|ROS从入门到精通(十) TF坐标变换原理,为什么需要TF变换()
- 机器学习图像处理|从零开始实现一个简单的CycleGAN项目
- 数据库|网易游戏基于 Flink 的流式 ETL 建设
- 努力学习人工智能|基于图像识别的跌倒检测
- 人工智能|CV目标检测模型小抄(1)
- 算法|魔改YOLOv5!一种实时多尺度交通标志检测网络
- 神经网络|YOLO7 姿势识别实例