注意力机制代码_pytorch中加入注意力机制(CBAM),以ResNet为例


注意力机制代码_pytorch中加入注意力机制(CBAM),以ResNet为例
文章图片
对于注意力机制的个人理解:

  1. 网络越深、越宽、结构越复杂,注意力机制对网络的影响就越小。
  2. 在网络中加上CBAM不一定带来性能上的提升,对性能影响因素有数据集、网络自身、注意力所在的位置等等。
  3. 建议直接在网络中加上SE系列,大部分情况下性能都会有提升的。
CBAM的 解析:
heu御林军:CBAM:卷积注意力机制模块?zhuanlan.zhihu.com 注意力机制代码_pytorch中加入注意力机制(CBAM),以ResNet为例
文章图片
贴出一些和SE相关的:
初识CV:SE-Inception v3架构的模型搭建(keras代码实现)?zhuanlan.zhihu.com 注意力机制代码_pytorch中加入注意力机制(CBAM),以ResNet为例
文章图片
PyTorch Hub发布!一行代码调用所有模型:torch.hub?blog.csdn.net 注意力机制代码_pytorch中加入注意力机制(CBAM),以ResNet为例
文章图片
源码位置: 初识CV:ResNet_CBAM源码?zhuanlan.zhihu.com 注意力机制代码_pytorch中加入注意力机制(CBAM),以ResNet为例
文章图片
第一步:找到ResNet源代码
在里面添加通道注意力机制和空间注意力机制
所需库
import torch.nn as nn import math try: from torch.hub import load_state_dict_from_url except ImportError: from torch.utils.model_zoo import load_url as load_state_dict_from_url import torch

通道注意力机制
class ChannelAttention(nn.Module): def __init__(self, in_planes, ratio=16): super(ChannelAttention, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc1= nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) self.relu1 = nn.ReLU() self.fc2= nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x): avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) out = avg_out + max_out return self.sigmoid(out)

空间注意力机制
class SpatialAttention(nn.Module): def __init__(self, kernel_size=7): super(SpatialAttention, self).__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7' padding = 3 if kernel_size == 7 else 1self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) self.sigmoid = nn.Sigmoid()def forward(self, x): avg_out = torch.mean(x, dim=1, keepdim=True) max_out, _ = torch.max(x, dim=1, keepdim=True) x = torch.cat([avg_out, max_out], dim=1) x = self.conv1(x) return self.sigmoid(x)

在ResNet网络中添加注意力机制
注意点:因为不能改变ResNet的网络结构,所以CBAM不能加在block里面,因为加进去网络结构发生了变化,所以不能用预训练参数。加在最后一层卷积和第一层卷积不改变网络,可以用预训练参数
class ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, groups=1, width_per_group=64, replace_stride_with_dilation=None, norm_layer=None): super(ResNet, self).__init__() if norm_layer is None: norm_layer = nn.BatchNorm2d self._norm_layer = norm_layerself.inplanes = 64 self.dilation = 1 if replace_stride_with_dilation is None: # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [False, False, False] if len(replace_stride_with_dilation) != 3: raise ValueError("replace_stride_with_dilation should be None " "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) self.groups = groups self.base_width = width_per_group self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = norm_layer(self.inplanes) self.relu = nn.ReLU(inplace=True)# 网络的第一层加入注意力机制 self.ca = ChannelAttention(self.inplanes) self.sa = SpatialAttention()self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) # 网络的卷积层的最后一层加入注意力机制 self.ca1 = ChannelAttention(self.inplanes) self.sa1 = SpatialAttention()self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0)# Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): if isinstance(m, Bottleneck): nn.init.constant_(m.bn3.weight, 0) elif isinstance(m, BasicBlock): nn.init.constant_(m.bn2.weight, 0)def _make_layer(self, block, planes, blocks, stride=1, dilate=False): norm_layer = self._norm_layer downsample = None previous_dilation = self.dilation if dilate: self.dilation *= stride stride = 1 if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride), norm_layer(planes * block.expansion), )layers = [] layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer)) self.inplanes = planes * block.expansion for _ in range(1, blocks): layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer))return nn.Sequential(*layers)def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x)x = self.ca(x) * x x = self.sa(x) * xx = self.maxpool(x)x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x)x = self.ca1(x) * x x = self.sa1(x) * xx = self.avgpool(x) x = x.reshape(x.size(0), -1) x = self.fc(x)return x

请详细阅读代码加的位置:
# 网络的第一层加入注意力机制 self.ca = ChannelAttention(self.inplanes) self.sa = SpatialAttention()

【注意力机制代码_pytorch中加入注意力机制(CBAM),以ResNet为例】
# 网络的卷积层的最后一层加入注意力机制 self.ca1 = ChannelAttention(self.inplanes) self.sa1 = SpatialAttention()

forWord部分代码
x = self.ca(x) * x x = self.sa(x) * xx = self.maxpool(x)x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x)x = self.ca1(x) * x x = self.sa1(x) * x

请大家详细阅读,一定能看懂的。

    推荐阅读