注意力机制可以说是深度学习研究领域上的一个热门领域,它在很多模型上都有着不错的表现,比如说BERT模型中的自注意力机制。本博客仅作为本人在看了一些Attention UNet相关文章后所作的笔记,希望能给各位带来一点思考,注意力机制是怎么被应用在医学图像分割的。
参考文章:
- 【语义分割系列:七】Attention Unet 论文阅读翻译笔记 医学图像 python实现
- 医学图像分割-Attention Unet
- 如果不知道什么是注意力机制,可以看看这篇博客:浅谈Attention-based Model【原理篇】
UNet的网络结构并不复杂,最主要的特点便是U型结构和skip-connection。而Attention UNet则是使用了标准的UNet的网络架构,并在这基础上整合进去了Attention机制。更准确来说,是将Attention机制整合进了跳远连接(skip-connection)。
整个网络架构如下, 注意力block已用红色框出:
文章图片
与标准的UNet相比,整体结构是很相似的,唯一不同的是在红框内增加了注意力门。为了公式化这个过程,我们将跳远连接的输入称为x,来自前一个block的输入称为g,那么整个模块就可以用以下公式来表示了:
文章图片
在这个公式里面,Attention就是注意力门,upsample是一个简单上采样模块,采用最近邻插值,而ConvBlock只是由两个(convolution + batch norm + ReLU)块组成的序列。唯一需要解释的是注意力。
接下来让我们看一下整个注意力门是怎么实现的,整个结构图如下:
文章图片
整个过程不难理解 ,需要注意一下几点:
- x和g都被送入到1x1卷积中,将它们变为相同数量的通道数,而不改变大小
- 在上采样操作后(有相同的大小),他们被累加并通过ReLU
- 通过另一个1x1的卷积和一个sigmoid,得到一个0到1的重要性分数,分配给特征图的每个部分
- 然后用这个注意力图乘以skip输入,产生这个注意力块的最终输出
class AttentionBlock(nn.Module):
def __init__(self, in_channels_x, in_channels_g, int_channels):
super(AttentionBlock, self).__init__()
self.Wx = nn.Sequential(nn.Conv2d(in_channels_x, int_channels, kernel_size = 1),
nn.BatchNorm2d(int_channels))
self.Wg = nn.Sequential(nn.Conv2d(in_channels_g, int_channels, kernel_size = 1),
nn.BatchNorm2d(int_channels))
self.psi = nn.Sequential(nn.Conv2d(int_channels, 1, kernel_size = 1),
nn.BatchNorm2d(1),
nn.Sigmoid())def forward(self, x, g):
# apply the Wx to the skip connection
x1 = self.Wx(x)
# after applying Wg to the input, upsample to the size of the skip connection
g1 = nn.functional.interpolate(self.Wg(g), x1.shape[2:], mode = 'bilinear', align_corners = False)
out = self.psi(nn.ReLU()(x1 + g1))
return out*xclass AttentionUpBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(AttentionUpBlock, self).__init__()
self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size = 2, stride = 2)
self.attention = AttentionBlock(out_channels, in_channels, int(out_channels / 2))
self.conv_bn1 = ConvBatchNorm(in_channels+out_channels, out_channels)
self.conv_bn2 = ConvBatchNorm(out_channels, out_channels)def forward(self, x, x_skip):
# note : x_skip is the skip connection and x is the input from the previous block
# apply the attention block to the skip connection, using x as context
x_attention = self.attention(x_skip, x)
# upsample x to have th same size as the attention map
x = nn.functional.interpolate(x, x_skip.shape[2:], mode = 'bilinear', align_corners = False)
# stack their channels to feed to both convolution blocks
x = torch.cat((x_attention, x), dim = 1)
x = self.conv_bn1(x)
return self.conv_bn2(x)
整个网络架构完整版实现可以参考 【语义分割系列:七】Attention Unet 论文阅读翻译笔记 医学图像 python实现。
推荐阅读
- 游戏|直播新玩法背后的音视频技术演进
- 【3】人体姿态估计研究|【HigherHRNet】 HigherHRNet 详解之 HigherHRNet的热图回归代码
- 姿态估计|HigherHRnet详解之论文详解
- 计算机网络|计算机网络——数据链路层
- 计算机网络|计算机网络——物理层
- 计算机网络|计算机网络——概论
- vue.js|vue入门开始,做完得物App的用户登录
- python|Python+AI智能编辑人脸
- 计算机视觉|干货收藏!基于深度学习目标姿态估计的论文一览(2017-2020)