在ViT等基础之上继续演变的Swin刚刚拿到了ICCV2021的 best paper,经过实际使用体验来看,确实效果较好,从语义分割角度来看,Swin不仅在ADE20K取得了sota的效果,在各个其他场景数据集下都有极为优秀的表现,精度相比PSPnet和deeplabv3+等基于CNN的分割算法都有较大提升(优点:精度高,缺点:实时性较差,极度依赖预训练模型,由于tf较新,在嵌入式端部署可能会存在问题,目前嵌入式端推理框架还都是基于常规卷积做加速。)。
论文地址:https://arxiv.org/pdf/2103.14030.pdf
代码:https://github.com/microsoft/Swin-Transformer
在介绍Swin之前简单先介绍一下Vit:
受NLP中Transformer扩展成功的启发,ViT作者尝试将标准Transformer直接应用于图像,并进行最少的修改。为此,将图像拆分为小块,并提供这些小块的线性嵌入序列作为transformer的输入。图像图块与NLP应用程序中的token(words)的处理方式相同,以监督方式对模型进行图像分类训练。
文章图片
从上图可以看出Vit首先将图像切成固定数量的patch,标准的Transformer模块,要求输入的是token,以ViT-B/16为例,将输入图片(224x224)按照16x16大小的Patch进行划分,划分后会得到( 224 / 16 ) 2 =196个Patches,每个Patch数据shape为[16, 16, 3]通过映射得到一个长度为768的token([196, 768]),每个token向量长度为768 。接着通过线性映射成将每个Patch映射到一维向量
文章图片
在刚刚得到的一堆tokens中插入一个专门用于分类的[class]token,这个[class]token是一个可训练的参数,数据格式和其他token一样都是一个向量,以ViT-B/16为例,就是一个长度为768的向量,与之前从图片中生成的tokens拼接在一起,Cat([1, 768], [196, 768]) -> [197, 768]。
在Self-Attention模块中,输入a1,a2,a3输出b1,b2,b3,对于a1而言,a2和a3理他一样近而且没有先后顺序,假设输入序列变成 a1,a3,a2,对结果没有影响。为了引入位置信息需要做positional encoding,positional encoding有两种方法,一种固定位置,一种 可训练的位置编码,作者尝试后发现两个效果差不多
Position Embedding采用的是一个可训练的参数(1D Pos. Emb.),是直接叠加在tokens上的(add),所以shape也一样。以ViT-B/16为例,刚刚拼接[class]token后shape是[197, 768],那么这里的Position Embedding的shape也是[197, 768]
文章图片
这里我们只是需要分类的信息,所以我们只需要提取出[class]token生成的对应结果就行,即[197, 768]中抽取出[class]token对应的[1, 768]。接着我们通过MLP Head得到我们最终的分类结果
大家这里可以看出整个Transformer核心就是这个Muti-Head-Attention。
多头Attention介绍:
多头Attention本质是学这三个矩阵,Wq,Wk,Wv。
这里详细参考视觉注意力机制 | Non-local模块与Self-attention的之间的关系与区别? - 知乎
Self-attention结构自上而下分为三个分支,分别是query、key和value。计算时通常分为三步:
- 第一步是将query和每个key进行相似度计算得到权重,常用的相似度函数有点积,拼接,感知机等;
- 第二步一般是使用一个softmax函数对这些权重进行归一化;
- 第三步将权重和相应的键值value进行加权求和得到最后的attention。
三个输入
文章图片
文章图片
特点是输入向量与输出向量shape一致,对于QKV矩阵原理
这里参考知乎超详细图解Self-Attention - 知乎
简单来说就是向量的内积表征两个向量的夹角,表征一个向量在另一个向量上的投影。投影的值大,说明两个向量相关度高。这其实就是做Attention的本质,以SeNet为例,基于通道Attention是考虑到通道数量很多,但是并不是每个通道都那么重要,所以需要算每个通道的权重进行加权。
得到特征图第j个元素对第i个元素的影响,从而实现全局上下文任意两个元素的依赖关系。
有了这些前置条件,下面就来介绍Swin:
swin三作观点:在Attention is all you need那篇文章出来之后,就一直在思考一个问题:从建模的基本单元来看,self-attention module到底在vision领域能做什么?从现在回头看,主要尝试的就是两个方向:
1、作为convolution的补充。绝大多数工作基本上都是从这个角度出发的,比如relation networks、non-local networks、DETR。其中一部分是从long-range dependency引入,某种程度上是在弥补convolution is too local;另一部分例如建模物体之间或物体与像素之间的关系,也是在做一些conv做不了的事。
2、 替代convolution。在这个方向上尝试不多,早期有LocalRelationNet、Stand-alone Self-attention Net。如果仅看结果,这些工作基本上已经可以做到替换掉3x3 conv不掉点,但有一个通病就是速度慢
到这个时候(2020年左右),其实有一种到了瓶颈期的感觉,作为conv的补充好像做的差不多了,后续的工作也都大同小异,替代conv因为速度的问题难以解决而遥遥无期。
没想到的是,Vision Transformer(ViT)在2020年10月横空出世。
ViT的出现改变了很多固有认知,我的理解主要有两点:
1. locality(局部性);
2. translation invariance(平移不变性)(需要大量数据训练)
基于这些理解,作者提出了一个通用的视觉骨干网络,Swin Transformer:
文章图片
前面输入还是图像切成patch,喂入transform
文章图片
这里看的非常复杂。在翻看源码之后,其实实现非常的简单,仅仅只用了一个2Dconv就实现了:
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
输入kernel_size和stride相同,为parch size,默认为4
以Swin base为例,输入512x512x3图像,上面2D卷积进行4倍下采样,卷积核个数为patch的个数
swin base默认为128,所以输出为1x128(channel)x128(h)x128(w)大小
然后通过flatten对hw维度压平,形成类似vit的128x128大小的token
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
ViT的patch数量不会变化,而Swin Transformer随着网络深度的加深数量会逐渐减少并且每个patch的感知范围会扩大(patch个数不变)(类似resnet,但是resnet主要依靠卷积核,swin依靠transform),这个设计是为了方便Swin Transformer的层级构建,并且能够适应视觉任务的多尺度。
文章图片
这块就是整个Swin的核心部分,Swin为了减少MSA的计算量,采用的基于Window的MSA策略(W-MAS),以Swin-base为例,Window为7,即整张图切成7x7大小的Window,在4个stage的计算中,Window的大小是始终不变的。一直都是7x7=49个,那么以第一个stage为例,输入
batchx128x(128x128),128x128是宽高,也就是feature map的面积,切成7x7后,向上取整就是19x19=361,所以第一次进行MSA计算输入的shape就是batchx361x49x128,同理,这里的128类似前面token的数量。也不变
整个MAS输入shape由于每个stage后,w和h都减少2倍,所以window的size会减少2x2=4倍
整体来看:
输入 batchx3x512x512
PatchEmbed后
batch x(128*128) x 128
128/7向上取整=133(注意这里取整) 133/window size(7)=19
开始:
stage1: batchx361x49x128133/7=19 19X19=361即本stage的window面积大小
stage2:batchx100x49x128133/2=66.5 66.5/7向上取整=70/7=10 本stage 10x10=100
stage3:batchx25x49x128同理 70/2=35 35/7=5不用取整 本stage window面积 5x5=25
stage4:batchx9x49x128 35/2=17.5 17.5/7 取整 21/7=3本stage window面积 3x3=9
这里注意每个stage的下采样比较特殊:
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
用的一个Linear做的
创新点:Window Attention&Shifted Window Attention
传统的Transformer都是基于全局来计算注意力的,因此计算复杂度十分高。而Swin Transformer则将注意力的计算限制在每个窗口内,进而减少了计算量(类似resnet思想)。
原生 Transformer 对 N 个 token 做 Self-Attention ,复杂度为N^2,Swin Transformer 将 N 个 token 拆为 N/n 组, 每组 n 个token 进行计算,复杂度降为 N*n^2
文章图片
分组计算的方式虽然大大降低了 Self-Attention 的复杂度,但与此同时,还是会有两个问题:
1.是分组后 Transformer 的视野局限于 n 个token,看不到全局信息,
2. 组与组之间的信息缺乏交互。
问题1解决:Hierarchical
每个 stage 后对 2x2 组的特征向量进行融合和压缩,这样视野就和 CNN-based 的结构一样,随着 stage 逐渐变大
问题2解决:Shifted Windows
Shifted Windows 图像划分为不重合的windows会使得W-MSA缺乏不 同windows之间的connections,限制了模型的感受野 在相邻的两个ST Block中,后者对windows进行top-left方向 的cyclic shift,其中步长为2,但这一做法又引入了新的问 题,就是会产生更多的windows 这种方法引起的问题总结起来就是: 如何使得每个windows都保持原来的尺寸(4*4)
文章图片
论文中使用了pad和mask的方法解决了这一问题,如图中cyclic shift部分,对边缘部分尺寸较小的windows进行了填充(图中蓝色、绿色和黄色部分),使得每个windows都能够保持原来的大小,并且论文还采用了mask的方法来使得模型只在除了pad的部分做self-attention计算,这样一来就能够解决上面所提到的问题
文章图片
如何使得windows数量与原来保持一致,减少计算? 解决办法:作者通过对特征图移位,并给Attention设置mask来间接实现的。 能在保持原有的window个数下,最后的计算结果等价
首先我们对Shift Window后的每个窗口都给上index, 计算Attention的时候,让具有相同index QK进行计算 而忽略不同index QK计算结果
当然这么说可能还是不明确,这里可以参考
Swin Transformer对CNN的降维打击 - 知乎 这篇,里面对下面步骤有更详细的解释:
文章图片
文章图片
文章图片
文章图片
文章图片
文章图片
文章图片
文章图片
分类结果:
文章图片
文章图片
最后分割头接的是upernet,基于pspnet改进,在PPM融合后的特征,再分别和conv2-conv5分别做4次融合,融合方式类似于fpn, 最后融合这么多次融合出一个fused feature map
具体地,在4个stage中。每次stage后都会将output的feature map append进一个list
对于每个stage:
分别是:
1x128x128x128
1x256x64x64
1x512x32x32
1x1024x16x16
四组feature map喂入对应upernet的分支中做下图的特征融合,可以看到和检测的FPN非常的类似
文章图片
参数SETR太大了,相比swin-L要多三分一参数,mIOU要低3个点,训练:也是一样的transform系列极端依赖预训练模型,从头开始训基本训不动
【算法|语义分割算法分享之Swin-Transformer】
文章图片
整体来看,整个Swin是发现问题---->解决问题的形式:
问题:SETR太大
解决:采用基于Windows的局部注意力(W-MSA)
问题:不同Windows之间缺乏信息交互(同时期的Segformer直接加了overlap)
解决:对W-MSA进行改进,通过SW-MSA解决
最后best paper到手
最后,swin v2 11月底论文出来了。代码还没放出来,放出来同样第一时间进行分享
推荐阅读
- YOLO|YOLOX网络结构
- python|数据增强操作(旋转、翻转、裁剪、色彩变化、高斯噪声等)
- 深度学习|r3det 配环境避雷指南(pytorch版)
- 论文阅读|R3Det: Refined Single-Stage Detector with Feature Refinementfor Rotating Object论文学习
- 【扫盲】R3Det旋转目标检测训练(win10)
- 【扫盲】R3Det旋转目标检测训练
- 深度学习|pytorch计算分类验证精度acc1,acc5代码
- 深度学习|paddle.nn.functional.cross_entropy中的soft_label时间消耗问题
- 深度学习|使用tensorboard时踩的坑