语义分割|Attention UNet结构及pytorch实现

注意力机制可以说是深度学习研究领域上的一个热门领域,它在很多模型上都有着不错的表现,比如说BERT模型中的自注意力机制。本博客仅作为本人在看了一些Attention UNet相关文章后所作的笔记,希望能给各位带来一点思考,注意力机制是怎么被应用在医学图像分割的。
参考文章:

  1. 【语义分割系列:七】Attention Unet 论文阅读翻译笔记 医学图像 python实现
  2. 医学图像分割-Attention Unet
  3. 如果不知道什么是注意力机制,可以看看这篇博客:浅谈Attention-based Model【原理篇】
Attention UNet网络结构 【语义分割|Attention UNet结构及pytorch实现】UNet是一个用于分割领域的架构,自2015年被提出以来,在医学图像领域取得了不错的表现,成为了不少医疗影像语义分割任务的baseline。感兴趣的可以去看一下这一篇博客:Unet神经网络为什么会在医学图像分割表现好?
UNet的网络结构并不复杂,最主要的特点便是U型结构和skip-connection。而Attention UNet则是使用了标准的UNet的网络架构,并在这基础上整合进去了Attention机制。更准确来说,是将Attention机制整合进了跳远连接(skip-connection)。
整个网络架构如下, 注意力block已用红色框出:语义分割|Attention UNet结构及pytorch实现
文章图片

与标准的UNet相比,整体结构是很相似的,唯一不同的是在红框内增加了注意力门。为了公式化这个过程,我们将跳远连接的输入称为x,来自前一个block的输入称为g,那么整个模块就可以用以下公式来表示了:
语义分割|Attention UNet结构及pytorch实现
文章图片

在这个公式里面,Attention就是注意力门,upsample是一个简单上采样模块,采用最近邻插值,而ConvBlock只是由两个(convolution + batch norm + ReLU)块组成的序列。唯一需要解释的是注意力。
接下来让我们看一下整个注意力门是怎么实现的,整个结构图如下:
语义分割|Attention UNet结构及pytorch实现
文章图片

整个过程不难理解 ,需要注意一下几点:
  1. x和g都被送入到1x1卷积中,将它们变为相同数量的通道数,而不改变大小
  2. 在上采样操作后(有相同的大小),他们被累加并通过ReLU
  3. 通过另一个1x1的卷积和一个sigmoid,得到一个0到1的重要性分数,分配给特征图的每个部分
  4. 然后用这个注意力图乘以skip输入,产生这个注意力块的最终输出
pytorch实现 下面的代码定义了注意力块(简化版)和用于UNet扩展路径的“up-block”。“down-block”与原UNet一样。
class AttentionBlock(nn.Module): def __init__(self, in_channels_x, in_channels_g, int_channels): super(AttentionBlock, self).__init__() self.Wx = nn.Sequential(nn.Conv2d(in_channels_x, int_channels, kernel_size = 1), nn.BatchNorm2d(int_channels)) self.Wg = nn.Sequential(nn.Conv2d(in_channels_g, int_channels, kernel_size = 1), nn.BatchNorm2d(int_channels)) self.psi = nn.Sequential(nn.Conv2d(int_channels, 1, kernel_size = 1), nn.BatchNorm2d(1), nn.Sigmoid())def forward(self, x, g): # apply the Wx to the skip connection x1 = self.Wx(x) # after applying Wg to the input, upsample to the size of the skip connection g1 = nn.functional.interpolate(self.Wg(g), x1.shape[2:], mode = 'bilinear', align_corners = False) out = self.psi(nn.ReLU()(x1 + g1)) return out*xclass AttentionUpBlock(nn.Module): def __init__(self, in_channels, out_channels): super(AttentionUpBlock, self).__init__() self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size = 2, stride = 2) self.attention = AttentionBlock(out_channels, in_channels, int(out_channels / 2)) self.conv_bn1 = ConvBatchNorm(in_channels+out_channels, out_channels) self.conv_bn2 = ConvBatchNorm(out_channels, out_channels)def forward(self, x, x_skip): # note : x_skip is the skip connection and x is the input from the previous block # apply the attention block to the skip connection, using x as context x_attention = self.attention(x_skip, x) # upsample x to have th same size as the attention map x = nn.functional.interpolate(x, x_skip.shape[2:], mode = 'bilinear', align_corners = False) # stack their channels to feed to both convolution blocks x = torch.cat((x_attention, x), dim = 1) x = self.conv_bn1(x) return self.conv_bn2(x)

整个网络架构完整版实现可以参考 【语义分割系列:七】Attention Unet 论文阅读翻译笔记 医学图像 python实现。

    推荐阅读